From f87a7a0bdf065cb18796f9cddcfbf230dea85638 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 11:20:48 -0700 Subject: [PATCH 01/20] adding dependencies for decoder_ablation workflow Signed-off-by: Sachin Pisal --- code/requirements_public_train.txt | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/code/requirements_public_train.txt b/code/requirements_public_train.txt index 3f70df4..2a3cc54 100644 --- a/code/requirements_public_train.txt +++ b/code/requirements_public_train.txt @@ -14,3 +14,7 @@ -r requirements_public_inference.txt tensorboard torchinfo +# decoder_ablation workflow +scipy +ldpc +beliefmatching From ebe77de0bce72a2241b7baf8dec7aa11b82d683a Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 11:23:25 -0700 Subject: [PATCH 02/20] adding failure_analysis containing the decoder helpers, decoder ablation, and plotting helpers Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 597 ++++++++++++++++++++++++++++ 1 file changed, 597 insertions(+) create mode 100644 code/evaluation/failure_analysis.py diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py new file mode 100644 index 0000000..8edd31a --- /dev/null +++ b/code/evaluation/failure_analysis.py @@ -0,0 +1,597 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +""" +Decoder ablation study: apply multiple global decoders of varying complexity +to the same pre-decoder residual syndromes and compare logical error rates. +""" +import inspect +import os +import random + +import numpy as np +import torch + +from evaluation.logical_error_rate import ( + _build_stab_maps, + _decode_batch, + map_grid_to_stabilizer_tensor, + sample_predictions, +) + + +def _build_ldpc_decoders(det_model): + """ + Convert a Stim DetectorErrorModel to an H matrix and build ldpc decoders. + Returns dict of {name: (decoder, L_dense)} where L_dense is (num_obs, num_mechanisms). + """ + import scipy.sparse as sp + from beliefmatching.belief_matching import detector_error_model_to_check_matrices + from ldpc.bp_decoder import BpDecoder + from ldpc.bplsd_decoder import BpLsdDecoder + from ldpc.union_find_decoder import UnionFindDecoder + + matrices = detector_error_model_to_check_matrices(det_model) + H = sp.csc_matrix(matrices.check_matrix) + L = matrices.observables_matrix + priors = np.array(matrices.priors, dtype=np.float64) + L_dense = np.asarray(L.toarray(), dtype=np.uint8) + + # Clamp priors away from 0/1 for BP stability + priors = np.clip(priors, 1e-9, 1.0 - 1e-9) + + decoders = {} + decoders["Union-Find"] = (UnionFindDecoder(H, uf_method="peeling"), L_dense) + decoders["BP-only"] = ( + BpDecoder(H, error_channel=priors, bp_method="product_sum", max_iter=10, schedule="parallel"), + L_dense, + ) + decoders["BP+LSD-0"] = ( + BpLsdDecoder( + H, + error_channel=priors, + bp_method="product_sum", + max_iter=10, + schedule="parallel", + lsd_method="lsd_cs", + lsd_order=0, + ), + L_dense, + ) + return decoders + + +def _decode_ldpc_batch(decoder, L_dense, syndromes_np): + """ + Decode a batch of syndromes with an ldpc decoder (single-shot loop). + Returns observable predictions as np.ndarray of shape (B,). + """ + B = syndromes_np.shape[0] + obs = np.zeros(B, dtype=np.uint8) + for i in range(B): + correction = decoder.decode(syndromes_np[i]) + obs[i] = ( + int((L_dense @ correction).item() % 2) + if L_dense.shape[0] == 1 + else int((L_dense @ correction)[0] % 2) + ) + return obs + + +@torch.inference_mode() +def decoder_ablation_study(model, device, dist, cfg): + """ + Run the pre-decoder on the test set, then apply multiple global decoders + of varying complexity to the same residual syndromes. + Measures LER per decoder, residual weight distribution, and decoder agreement. + + Uses Stim datapipe (with boundary detectors) for baseline, ground truth, and + DEM/matcher construction — matching the reference implementation in + logical_error_rate.py for apples-to-apples comparison. + """ + import time as _time + from copy import deepcopy + + import pymatching + + from data.factory import DatapipeFactory + + th_data = float(getattr(cfg.test, "th_data", 0.0)) + th_syn = float(getattr(cfg.test, "th_syn", 0.0)) + sampling_mode = str(getattr(cfg.test, "sampling_mode", "threshold")).lower() + temperature = float(getattr(cfg.test, "temperature", 1.0)) + temperature_data = getattr(cfg.test, "temperature_data", None) + temperature_syn = getattr(cfg.test, "temperature_syn", None) + temperature_data = float(temperature_data) if temperature_data is not None else temperature + temperature_syn = float(temperature_syn) if temperature_syn is not None else temperature + + model.eval() + enable_correlated = getattr(cfg.data, "enable_correlated_pymatching", False) + basis = str(getattr(cfg.test, "meas_basis_test", "X")).upper() + if basis not in ("X", "Z"): + basis = "X" + + # --- Create Stim datapipe (with boundary detectors) --- + total_samples = int(cfg.test.num_samples) + samples_per_gpu = total_samples // max(1, dist.world_size) + + torch_state = torch.get_rng_state() + cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None + np_state = np.random.get_state() + py_state = random.getstate() + try: + rank_seed = 12345 + dist.rank * 1000 + torch.manual_seed(rank_seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(rank_seed) + np.random.seed(rank_seed) + random.seed(rank_seed) + cfg_copy = deepcopy(cfg) + cfg_copy.test.num_samples = samples_per_gpu + test_dataset = DatapipeFactory.create_datapipe_inference(cfg_copy) + finally: + torch.set_rng_state(torch_state) + if cuda_state is not None: + torch.cuda.set_rng_state_all(cuda_state) + np.random.set_state(np_state) + random.setstate(py_state) + + circuit = test_dataset.circ.stim_circuit + num_obs = circuit.num_observables + assert num_obs == 1 + + # DEM and matchers from Stim circuit (includes boundary detectors) + det_model = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True) + _supports_corr = "enable_correlations" in inspect.signature( + pymatching.Matching.from_detector_error_model + ).parameters + + if _supports_corr: + matcher_corr = pymatching.Matching.from_detector_error_model(det_model, enable_correlations=True) + matcher_uncorr = pymatching.Matching.from_detector_error_model(det_model, enable_correlations=False) + else: + matcher_corr = pymatching.Matching.from_detector_error_model(det_model) + matcher_uncorr = matcher_corr + + # Build ldpc decoders from the same DEM (with boundary detectors) + ldpc_decoders = _build_ldpc_decoders(det_model) + + # Stim baseline detectors and ground truth observables + stim_dets = np.asarray(test_dataset.dets_and_obs[:, :-num_obs], dtype=np.uint8) + assert stim_dets.shape[1] == det_model.num_detectors, \ + f"Stim dets width {stim_dets.shape[1]} != DEM {det_model.num_detectors}" + stim_obs = np.asarray(test_dataset.dets_and_obs[:, -num_obs:], dtype=np.uint8) + + # Number of boundary detectors + surface_code = test_dataset.circ.code + num_boundary_dets = surface_code.hx.shape[0] if basis == 'X' else surface_code.hz.shape[0] + + D = cfg.distance + code_rotation = getattr(cfg.data, "code_rotation", "XV") + maps = _build_stab_maps(D, code_rotation) + Hx_idx = maps["Hx_idx"].to(device=device, dtype=torch.long) + Hz_idx = maps["Hz_idx"].to(device=device, dtype=torch.long) + Hx_mask = maps["Hx_mask"].to(device=device, dtype=torch.bool) + Hz_mask = maps["Hz_mask"].to(device=device, dtype=torch.bool) + stab_indices_x = maps["stab_x"].to(device=device, dtype=torch.long) + stab_indices_z = maps["stab_z"].to(device=device, dtype=torch.long) + Kx, Kz = maps["Kx"], maps["Kz"] + D2 = D * D + if code_rotation.upper() in ("XV", "ZH"): + Lx = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lx[0, :D] = 1 + Lz = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lz[0, ::D] = 1 + else: + Lx = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lx[0, ::D] = 1 + Lz = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lz[0, :D] = 1 + + if dist.rank == 0: + print( + f"\n[Decoder Ablation] basis={basis}, d={D}, r={cfg.n_rounds}," + f" p={getattr(cfg.test, 'p_error', 0.003)}" + ) + print( + f"[Decoder Ablation] Using Stim datapipe (with boundary detectors)" + f" for apples-to-apples comparison" + ) + print( + f"[Decoder Ablation] DEM detectors: {det_model.num_detectors}" + f" (incl. {num_boundary_dets} boundary)" + ) + print( + f"[Decoder Ablation] Decoders: No-op, Union-Find, BP+LSD-0," + f" Uncorr PM, Corr PM, + Baseline PM" + ) + + batch_size = int(getattr(cfg.test.dataloader, "batch_size", 2048)) + N = len(test_dataset) + num_batches = (N + batch_size - 1) // batch_size + + decoder_names = ["No-op", "Union-Find", "BP-only", "BP+LSD-0", "Uncorr-PM", "Corr-PM"] + total_scanned = 0 + baseline_errors = 0 + decoder_errors = {name: 0 for name in decoder_names} + all_residual_weights = [] + weight_bucket_stats = {} + n_all_agree = 0 + + _timing = { + "collate": 0.0, + "baseline_pm": 0.0, + "model_fwd": 0.0, + "residual_build": 0.0, + "uf_decode": 0.0, + "bp_only_decode": 0.0, + "bplsd_decode": 0.0, + "uncorr_pm": 0.0, + "corr_pm": 0.0, + "bookkeeping": 0.0, + } + + for batch_idx in range(num_batches): + start = batch_idx * batch_size + end = min(start + batch_size, N) + B = end - start + + # Collate batch from dataset items + _t0 = _time.perf_counter() + items = [test_dataset[i] for i in range(start, end)] + x_syn_diff = torch.stack([it["x_syn_diff"] for it in items]).to(device=device, dtype=torch.int32) + z_syn_diff = torch.stack([it["z_syn_diff"] for it in items]).to(device=device, dtype=torch.int32) + trainX = torch.stack([it["trainX"] for it in items]).to(device=device) + _timing["collate"] += _time.perf_counter() - _t0 + + _, n_x, T = x_syn_diff.shape + if T < 2: + continue + + # --- Baseline: Stim detectors (with boundary dets), Stim ground truth --- + baseline_detectors_batch = stim_dets[start:end] + gt_obs_batch = stim_obs[start:end] + + _t0 = _time.perf_counter() + baseline_pred_obs = _decode_batch(matcher_corr, baseline_detectors_batch, True) + baseline_pred_obs = np.asarray(baseline_pred_obs, dtype=np.uint8).reshape(-1, num_obs) + baseline_errors += int((baseline_pred_obs != gt_obs_batch).sum()) + _timing["baseline_pm"] += _time.perf_counter() - _t0 + + gt_obs_np = gt_obs_batch.reshape(-1).astype(np.int64) + + # Model forward + _t0 = _time.perf_counter() + with torch.amp.autocast( + device_type=device.type if hasattr(device, "type") else "cuda", + enabled=getattr(cfg, "enable_fp16", False), + ): + logits = model(trainX) + z_data_corr = sample_predictions(logits[:, 0], th_data, sampling_mode, temperature_data) + x_data_corr = sample_predictions(logits[:, 1], th_data, sampling_mode, temperature_data) + syn_x_grid = sample_predictions(logits[:, 2], th_syn, sampling_mode, temperature_syn) + syn_z_grid = sample_predictions(logits[:, 3], th_syn, sampling_mode, temperature_syn) + + z_flat = z_data_corr.permute(0, 2, 3, 1).contiguous().view(B, D2, T).to(torch.int32) + x_flat = x_data_corr.permute(0, 2, 3, 1).contiguous().view(B, D2, T).to(torch.int32) + z_exp = z_flat.unsqueeze(2).expand(B, D2, Kx, T) + hx_idx_e = Hx_idx.clamp_min(0).view(1, -1, Kx, 1).expand(B, -1, -1, T) + g_x = z_exp.gather(1, hx_idx_e) + m_x = Hx_mask.view(1, -1, Kx, 1).expand_as(g_x) + S_X = (g_x.masked_fill(~m_x, 0).sum(dim=2) & 1) + x_exp = x_flat.unsqueeze(2).expand(B, D2, Kz, T) + hz_idx_e = Hz_idx.clamp_min(0).view(1, -1, Kz, 1).expand(B, -1, -1, T) + g_z = x_exp.gather(1, hz_idx_e) + m_z = Hz_mask.view(1, -1, Kz, 1).expand_as(g_z) + S_Z = (g_z.masked_fill(~m_z, 0).sum(dim=2) & 1) + + syn_x_flat = map_grid_to_stabilizer_tensor(syn_x_grid, stab_indices_x).to(torch.int32) + syn_z_flat = map_grid_to_stabilizer_tensor(syn_z_grid, stab_indices_z).to(torch.int32) + R_X = torch.empty_like(x_syn_diff, dtype=torch.uint8) + R_X[:, :, 0] = (x_syn_diff[:, :, 0] + syn_x_flat[:, :, 0] + S_X[:, :, 0]) & 1 + if T > 1: + R_X[:, :, 1:] = ( + x_syn_diff[:, :, 1:] + syn_x_flat[:, :, 1:] + syn_x_flat[:, :, :-1] + S_X[:, :, 1:] + ) & 1 + R_Z = torch.empty_like(z_syn_diff, dtype=torch.uint8) + R_Z[:, :, 0] = (z_syn_diff[:, :, 0] + syn_z_flat[:, :, 0] + S_Z[:, :, 0]) & 1 + if T > 1: + R_Z[:, :, 1:] = ( + z_syn_diff[:, :, 1:] + syn_z_flat[:, :, 1:] + syn_z_flat[:, :, :-1] + S_Z[:, :, 1:] + ) & 1 + + # Logical frame from data corrections + if basis == "X": + pre_L_t = torch.einsum( + "ld,bdt->blt", Lx.to(torch.float32), z_flat.to(torch.float32) + ).remainder_(2).to(torch.int32) + else: + pre_L_t = torch.einsum( + "ld,bdt->blt", Lz.to(torch.float32), x_flat.to(torch.float32) + ).remainder_(2).to(torch.int32) + pre_L = pre_L_t.sum(dim=2).remainder_(2).view(-1) + + # Build residual detectors (matching logical_error_rate.py exactly) + if basis == "X": + initial_detectors = R_X[:, :, 0].view(B, -1) + else: + initial_detectors = R_Z[:, :, 0].view(B, -1) + R_X_rest = R_X[:, :, 1:] + R_Z_rest = R_Z[:, :, 1:] + R_cat_rest = torch.cat([R_X_rest, R_Z_rest], dim=1) + rest_flat = R_cat_rest.permute(0, 2, 1).contiguous().view(B, -1) + residual = torch.cat([initial_detectors, rest_flat], dim=1).to(torch.uint8) + + # Append boundary detectors from Stim (unchanged by pre-decoder) + boundary_dets_batch = baseline_detectors_batch[:, -num_boundary_dets:] + residual = torch.cat( + [residual, torch.from_numpy(boundary_dets_batch).to(residual.device)], dim=1 + ) + + if residual.shape[1] != det_model.num_detectors: + raise ValueError( + f"Residual shape {residual.shape} != DEM detectors {det_model.num_detectors}. " + f"Check interleave order for basis '{basis}' and time slicing." + ) + + if device.type == "cuda": + torch.cuda.synchronize() + _timing["residual_build"] += _time.perf_counter() - _t0 + + residual_np = residual.cpu().numpy() + pre_L_np = pre_L.cpu().numpy() + + weights = residual_np.sum(axis=1) + all_residual_weights.extend(weights.tolist()) + + # --- Run all decoders --- + # 1. No-op: pred_obs = 0 + noop_final = pre_L_np % 2 + + # 2. Union-Find (ldpc) + _t0 = _time.perf_counter() + uf_dec, uf_L = ldpc_decoders["Union-Find"] + uf_obs = _decode_ldpc_batch(uf_dec, uf_L, residual_np) + uf_final = (pre_L_np + uf_obs) % 2 + _timing["uf_decode"] += _time.perf_counter() - _t0 + + # 3. BP-only (no LSD fallback) + _t0 = _time.perf_counter() + bp_dec, bp_L = ldpc_decoders["BP-only"] + bp_obs = _decode_ldpc_batch(bp_dec, bp_L, residual_np) + bp_final = (pre_L_np + bp_obs) % 2 + _timing["bp_only_decode"] += _time.perf_counter() - _t0 + + # 4. BP+LSD-0 (ldpc) + _t0 = _time.perf_counter() + bplsd_dec, bplsd_L = ldpc_decoders["BP+LSD-0"] + bplsd_obs = _decode_ldpc_batch(bplsd_dec, bplsd_L, residual_np) + bplsd_final = (pre_L_np + bplsd_obs) % 2 + _timing["bplsd_decode"] += _time.perf_counter() - _t0 + + # 5. Uncorrelated PyMatching + _t0 = _time.perf_counter() + uncorr_pred = _decode_batch(matcher_uncorr, residual_np, False) + uncorr_pred = np.asarray(uncorr_pred, dtype=np.int64).reshape(-1) + uncorr_final = (pre_L_np + uncorr_pred) % 2 + _timing["uncorr_pm"] += _time.perf_counter() - _t0 + + # 6. Correlated PyMatching + _t0 = _time.perf_counter() + corr_pred = _decode_batch(matcher_corr, residual_np, True) + corr_pred = np.asarray(corr_pred, dtype=np.int64).reshape(-1) + corr_final = (pre_L_np + corr_pred) % 2 + _timing["corr_pm"] += _time.perf_counter() - _t0 + + _t0 = _time.perf_counter() + all_finals = { + "No-op": noop_final, + "Union-Find": uf_final, + "BP-only": bp_final, + "BP+LSD-0": bplsd_final, + "Uncorr-PM": uncorr_final, + "Corr-PM": corr_final, + } + + for name in decoder_names: + fails = all_finals[name] != gt_obs_np + decoder_errors[name] += int(fails.sum()) + + stacked = np.stack([all_finals[n] for n in decoder_names], axis=0) # (n_decoders, B) + agree = np.all(stacked == stacked[0:1], axis=0) # (B,) + n_all_agree += int(agree.sum()) + + for i in range(B): + w = int(weights[i]) + bucket = w if w <= 6 else 7 # 0-6, 7+ + if bucket not in weight_bucket_stats: + weight_bucket_stats[bucket] = {n: [0, 0] for n in decoder_names} + weight_bucket_stats[bucket]["_total"] = weight_bucket_stats[bucket].get("_total", 0) + 1 + for name in decoder_names: + if bucket not in weight_bucket_stats or name not in weight_bucket_stats[bucket]: + weight_bucket_stats[bucket][name] = [0, 0] + weight_bucket_stats[bucket][name][1] += 1 + if all_finals[name][i] != gt_obs_np[i]: + weight_bucket_stats[bucket][name][0] += 1 + + _timing["bookkeeping"] += _time.perf_counter() - _t0 + + total_scanned += B + if dist.rank == 0 and (batch_idx + 1) % 5 == 0: + print(f" [Ablation] Processed {total_scanned} samples...") + + # --- Print timing breakdown --- + if dist.rank == 0: + _total_time = sum(_timing.values()) + print(f"\n{'='*60}") + print(f"TIMING BREAKDOWN (total loop = {_total_time:.2f}s)") + print(f"{'='*60}") + for k, v in sorted(_timing.items(), key=lambda x: -x[1]): + pct = v / max(_total_time, 1e-9) * 100 + print(f" {k:<20s} {v:8.2f}s ({pct:5.1f}%)") + print(f"{'='*60}") + + # --- Print summary --- + if dist.rank == 0: + print(f"\n{'='*70}") + print( + f"DECODER ABLATION STUDY | basis={basis} d={D} r={cfg.n_rounds}" + f" p={getattr(cfg.test, 'p_error', 0.003)}" + ) + print(f"{'='*70}") + print(f"Total samples: {total_scanned}") + + baseline_ler = baseline_errors / max(1, total_scanned) + print(f"\n--- Logical Error Rates ---") + print( + f" {'Baseline (no pre-dec)':<25s} LER = {baseline_ler:.6f}" + f" ({baseline_errors} errors)" + ) + for name in decoder_names: + ler = decoder_errors[name] / max(1, total_scanned) + print(f" {name:<25s} LER = {ler:.6f} ({decoder_errors[name]} errors)") + + agreement_rate = n_all_agree / max(1, total_scanned) + print(f"\n--- Decoder Agreement ---") + print( + f" All {len(decoder_names)} decoders agree:" + f" {agreement_rate*100:.2f}% ({n_all_agree}/{total_scanned})" + ) + + weights_arr = np.array(all_residual_weights) + print(f"\n--- Residual Weight Distribution ---") + for w in sorted(weight_bucket_stats.keys()): + label = f"{w}+" if w == 7 else str(w) + count = weight_bucket_stats[w].get("_total", 0) + pct = count / max(1, total_scanned) * 100 + print(f" Weight {label:>3s}: {count:>7d} samples ({pct:6.2f}%)") + print(f" Mean weight: {weights_arr.mean():.3f}, Max: {int(weights_arr.max())}") + + print(f"\n--- Conditional LER by Residual Weight ---") + header = f" {'Weight':>7s}" + for name in decoder_names: + header += f" {name:>12s}" + print(header) + for w in sorted(weight_bucket_stats.keys()): + label = f"{w}+" if w == 7 else str(w) + row = f" {label:>7s}" + for name in decoder_names: + n_err, n_tot = weight_bucket_stats[w].get(name, [0, 0]) + if n_tot > 0: + row += f" {n_err/n_tot:>12.6f}" + else: + row += f" {'N/A':>12s}" + print(row) + print(f"{'='*70}") + + # --- Plots --- + if dist.rank == 0: + _plot_residual_weight_histogram(all_residual_weights, basis, cfg) + _plot_conditional_ler(weight_bucket_stats, decoder_names, basis, cfg) + + return ( + { + "total_samples": total_scanned, + "baseline_errors": baseline_errors, + "decoder_errors": decoder_errors, + "residual_weights": all_residual_weights, + "weight_bucket_stats": weight_bucket_stats, + "agreement_count": n_all_agree, + } + if dist.rank == 0 + else {} + ) + + +def _plot_residual_weight_histogram(weights, basis, cfg): + """Plot and save residual weight histogram.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + weights_arr = np.array(weights) + max_w = min(int(weights_arr.max()) + 1, 20) + + fig, ax = plt.subplots(figsize=(8, 5)) + bins = np.arange(-0.5, max_w + 1.5, 1) + ax.hist(weights_arr, bins=bins, edgecolor="black", alpha=0.7, color="#4C72B0") + ax.set_xlabel("Residual Weight (# non-zero detectors)", fontsize=12) + ax.set_ylabel("Count", fontsize=12) + ax.set_title( + f"Residual Syndrome Weight Distribution\n" + f"basis={basis} d={cfg.distance} r={cfg.n_rounds}" + f" p={getattr(cfg.test, 'p_error', 0.003)} N={len(weights)}", + fontsize=11, + ) + ax.set_yscale("log") + n_zero = int((weights_arr == 0).sum()) + pct_zero = n_zero / max(1, len(weights_arr)) * 100 + ax.axvline(x=0, color="red", linestyle="--", alpha=0.5) + ax.text(0.5, 0.95, f"Weight-0: {pct_zero:.1f}%", transform=ax.transAxes, + fontsize=11, verticalalignment="top", color="red") + plt.tight_layout() + output_dir = os.path.join(cfg.output, "plots") + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, f"residual_weight_hist_{basis}.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved: {path}") + + +def _plot_conditional_ler(weight_bucket_stats, decoder_names, basis, cfg): + """Plot conditional LER by residual weight for each decoder.""" + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + buckets = sorted(weight_bucket_stats.keys()) + labels = [f"{w}+" if w == 7 else str(w) for w in buckets] + + fig, ax = plt.subplots(figsize=(9, 5)) + colors = ["#999999", "#E24A33", "#348ABD", "#FBC15E", "#8EBA42"] + markers = ["x", "s", "D", "^", "o"] + for idx, name in enumerate(decoder_names): + lers = [] + x_pos = [] + for i, w in enumerate(buckets): + n_err, n_tot = weight_bucket_stats[w].get(name, [0, 0]) + if n_tot >= 10: + lers.append(n_err / n_tot) + x_pos.append(i) + if lers: + ax.plot( + x_pos, + lers, + marker=markers[idx % len(markers)], + color=colors[idx % len(colors)], + label=name, + linewidth=1.5, + markersize=6, + ) + + ax.set_xticks(range(len(labels))) + ax.set_xticklabels(labels) + ax.set_xlabel("Residual Weight (# non-zero detectors)", fontsize=12) + ax.set_ylabel("Logical Error Rate", fontsize=12) + ax.set_title( + f"Conditional LER by Residual Weight\n" + f"basis={basis} d={cfg.distance} r={cfg.n_rounds}" + f" p={getattr(cfg.test, 'p_error', 0.003)}", + fontsize=11, + ) + ax.legend(fontsize=9) + ax.set_ylim(bottom=-0.02) + ax.grid(True, alpha=0.3) + plt.tight_layout() + output_dir = os.path.join(cfg.output, "plots") + os.makedirs(output_dir, exist_ok=True) + path = os.path.join(output_dir, f"conditional_ler_{basis}.png") + plt.savefig(path, dpi=150, bbox_inches="tight") + plt.close() + print(f"Saved: {path}") From 01b9fc03a42001b0d367862e750c1b3859f7d38b Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 13:42:56 -0700 Subject: [PATCH 03/20] adding modified code/evaluation/failure_analysis.py Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 42 +++++++++++++++------------ code/evaluation/logical_error_rate.py | 14 +++++++++ 2 files changed, 37 insertions(+), 19 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 8edd31a..ac3092f 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -11,7 +11,6 @@ Decoder ablation study: apply multiple global decoders of varying complexity to the same pre-decoder residual syndromes and compare logical error rates. """ -import inspect import os import random @@ -21,10 +20,17 @@ from evaluation.logical_error_rate import ( _build_stab_maps, _decode_batch, + _PYMATCHING_SUPPORTS_CORRELATIONS, map_grid_to_stabilizer_tensor, sample_predictions, ) +# LDPC-based decoders built by _build_ldpc_decoders. +LDPC_DECODER_NAMES = ("Union-Find", "BP-only", "BP+LSD-0") + +# Ordered names of all decoders run by decoder_ablation_study. +DECODER_NAMES = ("No-op",) + LDPC_DECODER_NAMES + ("Uncorr-PM", "Corr-PM") + def _build_ldpc_decoders(det_model): """ @@ -46,13 +52,14 @@ def _build_ldpc_decoders(det_model): # Clamp priors away from 0/1 for BP stability priors = np.clip(priors, 1e-9, 1.0 - 1e-9) + _uf, _bp, _bplsd = LDPC_DECODER_NAMES decoders = {} - decoders["Union-Find"] = (UnionFindDecoder(H, uf_method="peeling"), L_dense) - decoders["BP-only"] = ( + decoders[_uf] = (UnionFindDecoder(H, uf_method="peeling"), L_dense) + decoders[_bp] = ( BpDecoder(H, error_channel=priors, bp_method="product_sum", max_iter=10, schedule="parallel"), L_dense, ) - decoders["BP+LSD-0"] = ( + decoders[_bplsd] = ( BpLsdDecoder( H, error_channel=priors, @@ -148,11 +155,7 @@ def decoder_ablation_study(model, device, dist, cfg): # DEM and matchers from Stim circuit (includes boundary detectors) det_model = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True) - _supports_corr = "enable_correlations" in inspect.signature( - pymatching.Matching.from_detector_error_model - ).parameters - - if _supports_corr: + if _PYMATCHING_SUPPORTS_CORRELATIONS: matcher_corr = pymatching.Matching.from_detector_error_model(det_model, enable_correlations=True) matcher_uncorr = pymatching.Matching.from_detector_error_model(det_model, enable_correlations=False) else: @@ -216,7 +219,7 @@ def decoder_ablation_study(model, device, dist, cfg): N = len(test_dataset) num_batches = (N + batch_size - 1) // batch_size - decoder_names = ["No-op", "Union-Find", "BP-only", "BP+LSD-0", "Uncorr-PM", "Corr-PM"] + decoder_names = list(DECODER_NAMES) total_scanned = 0 baseline_errors = 0 decoder_errors = {name: 0 for name in decoder_names} @@ -355,22 +358,23 @@ def decoder_ablation_study(model, device, dist, cfg): noop_final = pre_L_np % 2 # 2. Union-Find (ldpc) + _uf, _bp, _bplsd = LDPC_DECODER_NAMES _t0 = _time.perf_counter() - uf_dec, uf_L = ldpc_decoders["Union-Find"] + uf_dec, uf_L = ldpc_decoders[_uf] uf_obs = _decode_ldpc_batch(uf_dec, uf_L, residual_np) uf_final = (pre_L_np + uf_obs) % 2 _timing["uf_decode"] += _time.perf_counter() - _t0 # 3. BP-only (no LSD fallback) _t0 = _time.perf_counter() - bp_dec, bp_L = ldpc_decoders["BP-only"] + bp_dec, bp_L = ldpc_decoders[_bp] bp_obs = _decode_ldpc_batch(bp_dec, bp_L, residual_np) bp_final = (pre_L_np + bp_obs) % 2 _timing["bp_only_decode"] += _time.perf_counter() - _t0 # 4. BP+LSD-0 (ldpc) _t0 = _time.perf_counter() - bplsd_dec, bplsd_L = ldpc_decoders["BP+LSD-0"] + bplsd_dec, bplsd_L = ldpc_decoders[_bplsd] bplsd_obs = _decode_ldpc_batch(bplsd_dec, bplsd_L, residual_np) bplsd_final = (pre_L_np + bplsd_obs) % 2 _timing["bplsd_decode"] += _time.perf_counter() - _t0 @@ -391,12 +395,12 @@ def decoder_ablation_study(model, device, dist, cfg): _t0 = _time.perf_counter() all_finals = { - "No-op": noop_final, - "Union-Find": uf_final, - "BP-only": bp_final, - "BP+LSD-0": bplsd_final, - "Uncorr-PM": uncorr_final, - "Corr-PM": corr_final, + DECODER_NAMES[0]: noop_final, + _uf: uf_final, + _bp: bp_final, + _bplsd: bplsd_final, + DECODER_NAMES[4]: uncorr_final, + DECODER_NAMES[5]: corr_final, } for name in decoder_names: diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 015d6e4..77c1d46 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -8,6 +8,8 @@ # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. +import inspect + import pymatching import numpy as np import torch @@ -17,6 +19,18 @@ from enum import IntEnum from typing import Optional +_PYMATCHING_SUPPORTS_CORRELATIONS = "enable_correlations" in inspect.signature( + pymatching.Matching.from_detector_error_model +).parameters + + +def _decode_batch(matcher, detectors, enable_correlated): + """Wrapper for decode_batch that handles older pymatching versions.""" + if _PYMATCHING_SUPPORTS_CORRELATIONS: + return matcher.decode_batch(detectors, enable_correlations=enable_correlated) + else: + return matcher.decode_batch(detectors) + class OnnxWorkflow(IntEnum): """ONNX_WORKFLOW env: 0=torch only, 1=export ONNX only, 2=export ONNX and use TensorRT, 3=use engine file only.""" From 0a85a624221b923c8dad6d24f82d7595aeb36bb1 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 13:43:35 -0700 Subject: [PATCH 04/20] adding tests for failure analysis Signed-off-by: Sachin Pisal --- code/tests/test_failure_analysis.py | 243 ++++++++++++++++++++++++++++ 1 file changed, 243 insertions(+) create mode 100644 code/tests/test_failure_analysis.py diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py new file mode 100644 index 0000000..dc27ce2 --- /dev/null +++ b/code/tests/test_failure_analysis.py @@ -0,0 +1,243 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual +# property and proprietary rights in and to this material, related +# documentation and any modifications thereto. Any use, reproduction, +# disclosure or distribution of this material and related documentation +# without an express license agreement from NVIDIA CORPORATION or +# its affiliates is strictly prohibited. +import sys +import tempfile +import types +import unittest +from pathlib import Path +from unittest.mock import patch + +import numpy as np +import torch + +_repo_code = Path(__file__).resolve().parent.parent +if str(_repo_code) not in sys.path: + sys.path.insert(0, str(_repo_code)) + +try: + import ldpc + import beliefmatching + import scipy + _HAS_LDPC_DEPS = True +except ImportError: + _HAS_LDPC_DEPS = False + +_skip_ldpc = unittest.skipUnless(_HAS_LDPC_DEPS, "ldpc/beliefmatching/scipy not installed") + + +def _make_tiny_dem(distance=3, n_rounds=3, basis="X", code_rotation="XV"): + """Build a minimal surface-code DEM (with boundary detectors) for testing.""" + from qec.surface_code.memory_circuit import MemoryCircuit + mc = MemoryCircuit( + distance=distance, + idle_error=0.01, + sqgate_error=0.01, + tqgate_error=0.01, + spam_error=0.007, + n_rounds=n_rounds, + basis=basis, + code_rotation=code_rotation, + add_boundary_detectors=True, + ) + mc.set_error_rates() + return mc.stim_circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) + + +def _make_cfg(output_dir, distance=3, n_rounds=3, basis="X", n_samples=8): + """Build a minimal cfg SimpleNamespace for decoder_ablation_study.""" + test_ns = types.SimpleNamespace( + th_data=0.0, + th_syn=0.0, + sampling_mode="threshold", + temperature=1.0, + temperature_data=None, + temperature_syn=None, + meas_basis_test=basis, + num_samples=n_samples, + p_error=0.01, + dataloader=types.SimpleNamespace(batch_size=n_samples), + use_model_checkpoint=-1, + ) + data_ns = types.SimpleNamespace( + enable_correlated_pymatching=False, + code_rotation="XV", + ) + return types.SimpleNamespace( + test=test_ns, + data=data_ns, + distance=distance, + n_rounds=n_rounds, + enable_fp16=False, + output=output_dir, + ) + + +class _ZeroModel(torch.nn.Module): + """Model that always returns zero logits (same shape as input).""" + + def forward(self, x): + return torch.zeros_like(x) + + +class _FakeDist: + rank = 0 + world_size = 1 + local_rank = 0 + device = torch.device("cpu") + + +@_skip_ldpc +class TestBuildLdpcDecoders(unittest.TestCase): + """_build_ldpc_decoders must return correctly keyed decoder objects with consistent shapes.""" + + def setUp(self): + from evaluation.failure_analysis import _build_ldpc_decoders + self.det_model = _make_tiny_dem() + self.decoders = _build_ldpc_decoders(self.det_model) + + def test_expected_decoder_names_present(self): + from evaluation.failure_analysis import LDPC_DECODER_NAMES + for name in LDPC_DECODER_NAMES: + self.assertIn(name, self.decoders) + + def test_each_entry_is_decoder_and_l_dense_pair(self): + for name, (dec, L_dense) in self.decoders.items(): + with self.subTest(decoder=name): + self.assertIsInstance(L_dense, np.ndarray) + self.assertEqual(L_dense.dtype, np.uint8) + # rows = num_observables (1 for surface code), cols = num error mechanisms + self.assertEqual(L_dense.shape[0], self.det_model.num_observables) + self.assertGreater(L_dense.shape[1], 0) + self.assertTrue(hasattr(dec, "decode"), f"{name} decoder has no .decode()") + + def test_l_dense_columns_consistent_across_decoders(self): + widths = [v[1].shape[1] for v in self.decoders.values()] + self.assertEqual(len(set(widths)), 1, "All L_dense must have the same column count") + + +@_skip_ldpc +class TestDecodeLdpcBatch(unittest.TestCase): + """_decode_ldpc_batch must return correct shape/dtype; zero syndrome decodes to 0.""" + + def setUp(self): + from evaluation.failure_analysis import _build_ldpc_decoders, _decode_ldpc_batch + self._fn = _decode_ldpc_batch + det_model = _make_tiny_dem() + self.decoders = _build_ldpc_decoders(det_model) + self.num_detectors = det_model.num_detectors + + def test_zero_syndrome_gives_zero_observable(self): + B = 4 + syndromes = np.zeros((B, self.num_detectors), dtype=np.uint8) + for name, (dec, L_dense) in self.decoders.items(): + with self.subTest(decoder=name): + obs = self._fn(dec, L_dense, syndromes) + np.testing.assert_array_equal( + obs, + np.zeros(B, dtype=np.uint8), + err_msg=f"{name}: zero syndrome should give zero observable", + ) + + def test_output_shape_is_batch_size(self): + for B in (1, 6): + syndromes = np.zeros((B, self.num_detectors), dtype=np.uint8) + for name, (dec, L_dense) in self.decoders.items(): + with self.subTest(decoder=name, B=B): + obs = self._fn(dec, L_dense, syndromes) + self.assertEqual(obs.shape, (B,)) + self.assertEqual(obs.dtype, np.uint8) + + def test_output_values_are_binary(self): + """Observable must be 0 or 1; use sparse single-bit syndromes (fast for all decoders).""" + B = min(4, self.num_detectors) + syndromes = np.zeros((B, self.num_detectors), dtype=np.uint8) + for i in range(B): + syndromes[i, i] = 1 # one detector fired per sample + for name, (dec, L_dense) in self.decoders.items(): + with self.subTest(decoder=name): + obs = self._fn(dec, L_dense, syndromes) + self.assertTrue( + np.all((obs == 0) | (obs == 1)), + f"{name}: output contains values other than 0/1", + ) + + +@_skip_ldpc +class TestDecoderAblationStudy(unittest.TestCase): + """ + Smoke test: decoder_ablation_study must complete, return expected keys, + and report the correct sample count. + """ + + _D = 3 + _T = 3 + _N = 8 + + def _build_datapipe(self, basis): + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + return QCDataPipePreDecoder_Memory_inference( + distance=self._D, + n_rounds=self._T, + num_samples=self._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis=basis, + code_rotation="XV", + ) + + def _run(self, basis): + from evaluation.failure_analysis import decoder_ablation_study + real_ds = self._build_datapipe(basis) + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg(tmpdir, distance=self._D, n_rounds=self._T, basis=basis, + n_samples=self._N) + with patch("data.factory.DatapipeFactory") as mock_factory: + mock_factory.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study( + _ZeroModel(), _FakeDist.device, _FakeDist(), cfg + ) + return result + + def test_return_keys_present(self): + result = self._run("X") + for key in ( + "total_samples", "baseline_errors", "decoder_errors", + "residual_weights", "weight_bucket_stats", "agreement_count", + ): + self.assertIn(key, result, f"Missing key in result: {key}") + + def test_total_samples_matches_dataset_size(self): + result = self._run("X") + self.assertEqual(result["total_samples"], self._N) + + def test_decoder_errors_has_all_six_decoders(self): + from evaluation.failure_analysis import DECODER_NAMES + result = self._run("X") + self.assertEqual(set(result["decoder_errors"].keys()), set(DECODER_NAMES)) + + def test_residual_weights_length_matches_total_samples(self): + result = self._run("X") + self.assertEqual(len(result["residual_weights"]), result["total_samples"]) + + def test_agreement_count_within_bounds(self): + result = self._run("X") + self.assertGreaterEqual(result["agreement_count"], 0) + self.assertLessEqual(result["agreement_count"], result["total_samples"]) + + def test_z_basis_runs_and_returns_correct_structure(self): + result = self._run("Z") + self.assertEqual(result["total_samples"], self._N) + self.assertIn("decoder_errors", result) + + +if __name__ == "__main__": + unittest.main() From fe87b094f6dbb0330df455910a09a71ef87c31a2 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 13:47:08 -0700 Subject: [PATCH 05/20] adding decoder_ablation as a workflow task Signed-off-by: Sachin Pisal --- code/workflows/run.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/code/workflows/run.py b/code/workflows/run.py index a1a87e2..0e88dcc 100644 --- a/code/workflows/run.py +++ b/code/workflows/run.py @@ -81,10 +81,16 @@ def run_surface(cfg: DictConfig): train_loader, _ = DatapipeFactory.create_dataloader(cfg, dist.world_size, dist.rank) for j, dl in enumerate(train_loader): print(f"Batch {j}: syndrome_shape: {dl['syndrome'].shape}") + elif cfg.workflow.task == "decoder_ablation": + from evaluation.failure_analysis import decoder_ablation_study + DistributedManager.initialize() + dist = DistributedManager() + model = _load_model(cfg, dist) + decoder_ablation_study(model, dist.device, dist, cfg) elif cfg.workflow.task in ("sampling", "visualize"): raise ValueError( f"workflow.task={cfg.workflow.task!r} is not supported in the early-access public release. " - "Supported workflows: train, inference." + "Supported workflows: train, inference, decoder_ablation." ) From a7f988136f5a9d63d73e7f43952ccb9b329c967b Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 13:59:55 -0700 Subject: [PATCH 06/20] overriding the config copy with the resolved single basis Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index ac3092f..36dbce5 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -141,6 +141,7 @@ def decoder_ablation_study(model, device, dist, cfg): random.seed(rank_seed) cfg_copy = deepcopy(cfg) cfg_copy.test.num_samples = samples_per_gpu + cfg_copy.test.meas_basis_test = basis test_dataset = DatapipeFactory.create_datapipe_inference(cfg_copy) finally: torch.set_rng_state(torch_state) From 9519f1bd20d07b2cfbe926e787a1dedcc8868ff5 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 14:03:29 -0700 Subject: [PATCH 07/20] formatting Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 52 ++++++++++++++++++----------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 36dbce5..19ffd31 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -56,7 +56,9 @@ def _build_ldpc_decoders(det_model): decoders = {} decoders[_uf] = (UnionFindDecoder(H, uf_method="peeling"), L_dense) decoders[_bp] = ( - BpDecoder(H, error_channel=priors, bp_method="product_sum", max_iter=10, schedule="parallel"), + BpDecoder( + H, error_channel=priors, bp_method="product_sum", max_iter=10, schedule="parallel" + ), L_dense, ) decoders[_bplsd] = ( @@ -84,9 +86,8 @@ def _decode_ldpc_batch(decoder, L_dense, syndromes_np): for i in range(B): correction = decoder.decode(syndromes_np[i]) obs[i] = ( - int((L_dense @ correction).item() % 2) - if L_dense.shape[0] == 1 - else int((L_dense @ correction)[0] % 2) + int((L_dense @ correction).item() % + 2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2) ) return obs @@ -155,10 +156,16 @@ def decoder_ablation_study(model, device, dist, cfg): assert num_obs == 1 # DEM and matchers from Stim circuit (includes boundary detectors) - det_model = circuit.detector_error_model(decompose_errors=True, approximate_disjoint_errors=True) + det_model = circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) if _PYMATCHING_SUPPORTS_CORRELATIONS: - matcher_corr = pymatching.Matching.from_detector_error_model(det_model, enable_correlations=True) - matcher_uncorr = pymatching.Matching.from_detector_error_model(det_model, enable_correlations=False) + matcher_corr = pymatching.Matching.from_detector_error_model( + det_model, enable_correlations=True + ) + matcher_uncorr = pymatching.Matching.from_detector_error_model( + det_model, enable_correlations=False + ) else: matcher_corr = pymatching.Matching.from_detector_error_model(det_model) matcher_uncorr = matcher_corr @@ -249,8 +256,10 @@ def decoder_ablation_study(model, device, dist, cfg): # Collate batch from dataset items _t0 = _time.perf_counter() items = [test_dataset[i] for i in range(start, end)] - x_syn_diff = torch.stack([it["x_syn_diff"] for it in items]).to(device=device, dtype=torch.int32) - z_syn_diff = torch.stack([it["z_syn_diff"] for it in items]).to(device=device, dtype=torch.int32) + x_syn_diff = torch.stack([it["x_syn_diff"] for it in items] + ).to(device=device, dtype=torch.int32) + z_syn_diff = torch.stack([it["z_syn_diff"] for it in items] + ).to(device=device, dtype=torch.int32) trainX = torch.stack([it["trainX"] for it in items]).to(device=device) _timing["collate"] += _time.perf_counter() - _t0 @@ -312,13 +321,11 @@ def decoder_ablation_study(model, device, dist, cfg): # Logical frame from data corrections if basis == "X": - pre_L_t = torch.einsum( - "ld,bdt->blt", Lx.to(torch.float32), z_flat.to(torch.float32) - ).remainder_(2).to(torch.int32) + pre_L_t = torch.einsum("ld,bdt->blt", Lx.to(torch.float32), + z_flat.to(torch.float32)).remainder_(2).to(torch.int32) else: - pre_L_t = torch.einsum( - "ld,bdt->blt", Lz.to(torch.float32), x_flat.to(torch.float32) - ).remainder_(2).to(torch.int32) + pre_L_t = torch.einsum("ld,bdt->blt", Lz.to(torch.float32), + x_flat.to(torch.float32)).remainder_(2).to(torch.int32) pre_L = pre_L_t.sum(dim=2).remainder_(2).view(-1) # Build residual detectors (matching logical_error_rate.py exactly) @@ -508,9 +515,7 @@ def decoder_ablation_study(model, device, dist, cfg): "residual_weights": all_residual_weights, "weight_bucket_stats": weight_bucket_stats, "agreement_count": n_all_agree, - } - if dist.rank == 0 - else {} + } if dist.rank == 0 else {} ) @@ -538,8 +543,15 @@ def _plot_residual_weight_histogram(weights, basis, cfg): n_zero = int((weights_arr == 0).sum()) pct_zero = n_zero / max(1, len(weights_arr)) * 100 ax.axvline(x=0, color="red", linestyle="--", alpha=0.5) - ax.text(0.5, 0.95, f"Weight-0: {pct_zero:.1f}%", transform=ax.transAxes, - fontsize=11, verticalalignment="top", color="red") + ax.text( + 0.5, + 0.95, + f"Weight-0: {pct_zero:.1f}%", + transform=ax.transAxes, + fontsize=11, + verticalalignment="top", + color="red" + ) plt.tight_layout() output_dir = os.path.join(cfg.output, "plots") os.makedirs(output_dir, exist_ok=True) From 834da192ee39cc555e46ae3b16ccb6de30213335 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 14:31:59 -0700 Subject: [PATCH 08/20] formatting Signed-off-by: Sachin Pisal --- code/tests/test_failure_analysis.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index dc27ce2..7bef3fd 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -198,20 +198,23 @@ def _run(self, basis): from evaluation.failure_analysis import decoder_ablation_study real_ds = self._build_datapipe(basis) with tempfile.TemporaryDirectory() as tmpdir: - cfg = _make_cfg(tmpdir, distance=self._D, n_rounds=self._T, basis=basis, - n_samples=self._N) + cfg = _make_cfg( + tmpdir, distance=self._D, n_rounds=self._T, basis=basis, n_samples=self._N + ) with patch("data.factory.DatapipeFactory") as mock_factory: mock_factory.create_datapipe_inference.return_value = real_ds - result = decoder_ablation_study( - _ZeroModel(), _FakeDist.device, _FakeDist(), cfg - ) + result = decoder_ablation_study(_ZeroModel(), _FakeDist.device, _FakeDist(), cfg) return result def test_return_keys_present(self): result = self._run("X") for key in ( - "total_samples", "baseline_errors", "decoder_errors", - "residual_weights", "weight_bucket_stats", "agreement_count", + "total_samples", + "baseline_errors", + "decoder_errors", + "residual_weights", + "weight_bucket_stats", + "agreement_count", ): self.assertIn(key, result, f"Missing key in result: {key}") From a32b0c313f885959a6408dba559daee45c6f0834 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 18 Mar 2026 16:39:30 -0700 Subject: [PATCH 09/20] adding CUDA-Q nv-qldpc-decoder from internal repo and tests Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 260 +++++++++++++++++++++++++++- code/tests/test_failure_analysis.py | 248 +++++++++++++++++++++++++- 2 files changed, 499 insertions(+), 9 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 19ffd31..2aedf64 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -32,6 +32,161 @@ DECODER_NAMES = ("No-op",) + LDPC_DECODER_NAMES + ("Uncorr-PM", "Corr-PM") +def _build_cudaq_decoders(det_model): + """ + Build GPU-accelerated cudaq-qec nv-qldpc-decoder instances from a Stim DEM. + Returns dict of {name: (decoder, L_dense)} mirroring _build_ldpc_decoders. + + Decoder variants: + - "cudaq-BP": sum-product BP (bp_method=0), no OSD + - "cudaq-MinSum": min-sum BP (bp_method=1), no OSD + - "cudaq-BP+OSD-0": sum-product BP + OSD order 0 + - "cudaq-BP+OSD-7": sum-product BP + OSD order 7 + - "cudaq-MemBP": min-sum+mem BP (bp_method=2, uniform gamma) + - "cudaq-MemBP+OSD": min-sum+mem BP + OSD order 7 + - "cudaq-RelayBP": sequential relay (composition=1, bp_method=3) + """ + import cudaq_qec + import scipy.sparse as sp + from beliefmatching.belief_matching import detector_error_model_to_check_matrices + + matrices = detector_error_model_to_check_matrices(det_model) + H_sparse = sp.csc_matrix(matrices.check_matrix) + L = matrices.observables_matrix + priors = np.array(matrices.priors, dtype=np.float64) + L_dense = np.asarray(L.toarray(), dtype=np.uint8) + + # cudaq-qec expects a dense row-major (C-contiguous) H matrix (uint8) + H_dense = np.ascontiguousarray(H_sparse.toarray(), dtype=np.uint8) + + # Per-edge priors clamped for numerical stability + priors_list = np.clip(priors, 1e-9, 1.0 - 1e-9).tolist() + + # Enable num_iter reporting in opt_results for all decoders + opt_res = {"num_iter": True} + + # max_iterations=50 for standard BP/MinSum/OSD + bp_kwargs = dict(max_iterations=50, error_rate_vec=priors_list, opt_results=opt_res) + # max_iterations=100 for MemBP and RelayBP (need more iterations to converge) + mem_kwargs = dict(max_iterations=100, error_rate_vec=priors_list, opt_results=opt_res) + + decoders = {} + + # --- Standard BP variants (max_iterations=10) --- + # Sum-product BP (no OSD) + decoders["cudaq-BP"] = ( + cudaq_qec.get_decoder("nv-qldpc-decoder", H_dense, bp_method=0, use_osd=0, **bp_kwargs), + L_dense, + ) + # Min-sum BP (no OSD) + decoders["cudaq-MinSum"] = ( + cudaq_qec.get_decoder("nv-qldpc-decoder", H_dense, bp_method=1, use_osd=0, **bp_kwargs), + L_dense, + ) + # Sum-product BP + OSD-0 + decoders["cudaq-BP+OSD-0"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", H_dense, bp_method=0, use_osd=1, osd_order=0, **bp_kwargs + ), + L_dense, + ) + # Sum-product BP + OSD-7 + decoders["cudaq-BP+OSD-7"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", H_dense, bp_method=0, use_osd=1, osd_order=7, **bp_kwargs + ), + L_dense, + ) + + # --- Memory BP variants (max_iterations=100) --- + try: + decoders["cudaq-MemBP"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", + H_dense, + bp_method=2, + use_sparsity=True, + gamma0=0.5, + use_osd=0, + **mem_kwargs + ), + L_dense, + ) + decoders["cudaq-MemBP+OSD"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", + H_dense, + bp_method=2, + use_sparsity=True, + gamma0=0.5, + use_osd=1, + osd_order=7, + **mem_kwargs + ), + L_dense, + ) + except Exception as e: + import warnings + warnings.warn(f"cudaq-qec MemBP unavailable: {e}") + + # --- RelayBP (max_iterations=100) --- + # composition=1 (sequential relay), bp_method=3 (min-sum+dmem) + # gamma_dist=[-0.254, 0.985] optimized for surface codes + try: + srelay_cfg = { + "pre_iter": 10, + "num_sets": 5, + "stopping_criterion": "FirstConv", + } + # Note: opt_results num_iter not supported for composition=1 per docs + relay_kwargs = dict(max_iterations=100, error_rate_vec=priors_list) + decoders["cudaq-RelayBP"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", + H_dense, + composition=1, + bp_method=3, + use_sparsity=True, + gamma0=0.5, + gamma_dist=[-0.254, 0.985], + srelay_config=srelay_cfg, + **relay_kwargs + ), + L_dense, + ) + except Exception as e: + import warnings + warnings.warn(f"cudaq-qec RelayBP unavailable: {e}") + + return decoders + + +def _decode_cudaq_batch(decoder, L_dense, syndromes_np): + """ + Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder (single-shot loop). + Returns (obs, stats) where: + - obs: observable predictions as np.ndarray of shape (B,) + - stats: dict with per-sample convergence flags, iteration counts + The decoder.decode() takes list[float] and returns DecoderResult with .result (list[float]). + """ + B = syndromes_np.shape[0] + obs = np.zeros(B, dtype=np.uint8) + converged_flags = np.zeros(B, dtype=bool) + iter_counts = np.zeros(B, dtype=np.int32) + for i in range(B): + syndrome_list = syndromes_np[i].astype(np.float64).tolist() + result = decoder.decode(syndrome_list) + correction = np.array(result.result, dtype=np.uint8) + obs[i] = int((L_dense @ correction).item() % + 2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2) + converged_flags[i] = result.converged + # Collect iteration count if available via opt_results + opt = getattr(result, 'opt_results', None) + if opt and isinstance(opt, dict) and 'num_iter' in opt: + iter_counts[i] = opt['num_iter'] + return obs, {"converged_flags": converged_flags, "iter_counts": iter_counts} + + def _build_ldpc_decoders(det_model): """ Convert a Stim DetectorErrorModel to an H matrix and build ldpc decoders. @@ -173,6 +328,16 @@ def decoder_ablation_study(model, device, dist, cfg): # Build ldpc decoders from the same DEM (with boundary detectors) ldpc_decoders = _build_ldpc_decoders(det_model) + # Build cudaq-qec GPU-accelerated decoders + cudaq_decoders = {} + try: + cudaq_decoders = _build_cudaq_decoders(det_model) + if dist.rank == 0: + print(f"[Decoder Ablation] cudaq-qec decoders loaded: {list(cudaq_decoders.keys())}") + except Exception as e: + if dist.rank == 0: + print(f"[Decoder Ablation] cudaq-qec decoders unavailable: {e}") + # Stim baseline detectors and ground truth observables stim_dets = np.asarray(test_dataset.dets_and_obs[:, :-num_obs], dtype=np.uint8) assert stim_dets.shape[1] == det_model.num_detectors, \ @@ -218,9 +383,9 @@ def decoder_ablation_study(model, device, dist, cfg): f"[Decoder Ablation] DEM detectors: {det_model.num_detectors}" f" (incl. {num_boundary_dets} boundary)" ) + cudaq_names_str = ", ".join(cudaq_decoders.keys()) if cudaq_decoders else "(none)" print( - f"[Decoder Ablation] Decoders: No-op, Union-Find, BP+LSD-0," - f" Uncorr PM, Corr PM, + Baseline PM" + f"[Decoder Ablation] Decoders: No-op, Union-Find, BP+LSD-0, Uncorr PM, Corr PM, {cudaq_names_str}, + Baseline PM" ) batch_size = int(getattr(cfg.test.dataloader, "batch_size", 2048)) @@ -228,6 +393,10 @@ def decoder_ablation_study(model, device, dist, cfg): num_batches = (N + batch_size - 1) // batch_size decoder_names = list(DECODER_NAMES) + # Append cudaq decoder names dynamically + cudaq_decoder_names = sorted(cudaq_decoders.keys()) + decoder_names.extend(cudaq_decoder_names) + total_scanned = 0 baseline_errors = 0 decoder_errors = {name: 0 for name in decoder_names} @@ -247,6 +416,10 @@ def decoder_ablation_study(model, device, dist, cfg): "corr_pm": 0.0, "bookkeeping": 0.0, } + _cudaq_stats = {} + for cn in cudaq_decoder_names: + _timing[f"{cn}_decode"] = 0.0 + _cudaq_stats[cn] = {"converged_flags": [], "iter_counts": [], "error_flags": []} for batch_idx in range(num_batches): start = batch_idx * batch_size @@ -401,6 +574,23 @@ def decoder_ablation_study(model, device, dist, cfg): corr_final = (pre_L_np + corr_pred) % 2 _timing["corr_pm"] += _time.perf_counter() - _t0 + # 7. cudaq-qec GPU-accelerated decoders + cudaq_finals = {} + for cn in cudaq_decoder_names: + _t0 = _time.perf_counter() + cdec, cL = cudaq_decoders[cn] + c_obs, c_stats = _decode_cudaq_batch(cdec, cL, residual_np) + c_final = (pre_L_np + c_obs) % 2 + cudaq_finals[cn] = c_final + _timing[f"{cn}_decode"] += _time.perf_counter() - _t0 + # Accumulate per-sample convergence, iteration, and error stats + conv_flags = c_stats["converged_flags"] + iters = c_stats["iter_counts"] + fails = (c_final != gt_obs_np) + _cudaq_stats[cn]["converged_flags"].append(conv_flags) + _cudaq_stats[cn]["iter_counts"].append(iters) + _cudaq_stats[cn]["error_flags"].append(fails) + _t0 = _time.perf_counter() all_finals = { DECODER_NAMES[0]: noop_final, @@ -410,6 +600,7 @@ def decoder_ablation_study(model, device, dist, cfg): DECODER_NAMES[4]: uncorr_final, DECODER_NAMES[5]: corr_final, } + all_finals.update(cudaq_finals) for name in decoder_names: fails = all_finals[name] != gt_obs_np @@ -469,6 +660,61 @@ def decoder_ablation_study(model, device, dist, cfg): ler = decoder_errors[name] / max(1, total_scanned) print(f" {name:<25s} LER = {ler:.6f} ({decoder_errors[name]} errors)") + # cudaq decoder convergence and iteration stats + if _cudaq_stats: + print(f"\n--- cudaq-qec BP Convergence & Iteration Breakdown ---") + print( + f" {'Decoder':<20s} {'Conv%':>7s} {'AvgIt':>6s} " + f"{'Conv.It':>8s} {'Conv.LER':>9s} {'Conv.Err':>9s} " + f"{'!Conv.It':>8s} {'!Conv.LER':>10s} {'!Conv.Err':>10s}" + ) + for cn in cudaq_decoder_names: + st = _cudaq_stats[cn] + conv_all = np.concatenate(st["converged_flags"]) + iters_all = np.concatenate(st["iter_counts"]) + errs_all = np.concatenate(st["error_flags"]) + N = len(conv_all) + n_conv = int(conv_all.sum()) + n_noconv = N - n_conv + conv_pct = n_conv / max(1, N) * 100 + has_iters = iters_all.sum() > 0 + + # Converged subset + if n_conv > 0 and has_iters: + conv_avg_it = iters_all[conv_all].mean() + conv_ler = errs_all[conv_all].mean() + conv_errs = int(errs_all[conv_all].sum()) + else: + conv_avg_it = conv_ler = 0.0 + conv_errs = 0 + + # Non-converged subset + if n_noconv > 0 and has_iters: + noconv_avg_it = iters_all[~conv_all].mean() + noconv_ler = errs_all[~conv_all].mean() + noconv_errs = int(errs_all[~conv_all].sum()) + else: + noconv_avg_it = noconv_ler = 0.0 + noconv_errs = 0 + + if has_iters: + avg_it_str = f"{iters_all.mean():5.1f}" + conv_it_str = f"{conv_avg_it:7.1f}" + noconv_it_str = f"{noconv_avg_it:7.1f}" if n_noconv > 0 else " N/A" + else: + avg_it_str = " N/A" + conv_it_str = " N/A" + noconv_it_str = " N/A" + + noconv_ler_str = f"{noconv_ler:9.6f}" if n_noconv > 0 else " N/A" + noconv_err_str = f"{noconv_errs:>9d}" if n_noconv > 0 else " N/A" + + print( + f" {cn:<20s} {conv_pct:>6.1f}% {avg_it_str} " + f"{conv_it_str} {conv_ler:>9.6f} {conv_errs:>9d} " + f"{noconv_it_str} {noconv_ler_str} {noconv_err_str}" + ) + agreement_rate = n_all_agree / max(1, total_scanned) print(f"\n--- Decoder Agreement ---") print( @@ -570,9 +816,13 @@ def _plot_conditional_ler(weight_bucket_stats, decoder_names, basis, cfg): buckets = sorted(weight_bucket_stats.keys()) labels = [f"{w}+" if w == 7 else str(w) for w in buckets] - fig, ax = plt.subplots(figsize=(9, 5)) - colors = ["#999999", "#E24A33", "#348ABD", "#FBC15E", "#8EBA42"] - markers = ["x", "s", "D", "^", "o"] + fig, ax = plt.subplots(figsize=(10, 5)) + colors = [ + "#999999", "#E24A33", "#348ABD", "#FBC15E", "#8EBA42", "#988ED5", "#777B7E", "#76B900", + "#FF6F61", "#2CA02C", "#D62728", "#9467BD", "#17BECF" + ] + markers = ["x", "s", "D", "^", "o", "v", "P", "*", "h", "d", "<", ">", "X"] + for idx, name in enumerate(decoder_names): lers = [] x_pos = [] diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 7bef3fd..669b2f6 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -88,7 +88,7 @@ def forward(self, x): return torch.zeros_like(x) -class _FakeDist: +class _DummyDist: rank = 0 world_size = 1 local_rank = 0 @@ -203,7 +203,7 @@ def _run(self, basis): ) with patch("data.factory.DatapipeFactory") as mock_factory: mock_factory.create_datapipe_inference.return_value = real_ds - result = decoder_ablation_study(_ZeroModel(), _FakeDist.device, _FakeDist(), cfg) + result = decoder_ablation_study(_ZeroModel(), _DummyDist.device, _DummyDist(), cfg) return result def test_return_keys_present(self): @@ -222,10 +222,15 @@ def test_total_samples_matches_dataset_size(self): result = self._run("X") self.assertEqual(result["total_samples"], self._N) - def test_decoder_errors_has_all_six_decoders(self): + def test_decoder_errors_contains_all_base_decoders(self): + # DECODER_NAMES is the fixed set; cudaq decoders may add more keys when available. from evaluation.failure_analysis import DECODER_NAMES result = self._run("X") - self.assertEqual(set(result["decoder_errors"].keys()), set(DECODER_NAMES)) + self.assertTrue( + set(DECODER_NAMES).issubset(set(result["decoder_errors"].keys())), + f"Missing base decoder keys in result: " + f"{set(DECODER_NAMES) - set(result['decoder_errors'].keys())}", + ) def test_residual_weights_length_matches_total_samples(self): result = self._run("X") @@ -242,5 +247,240 @@ def test_z_basis_runs_and_returns_correct_structure(self): self.assertIn("decoder_errors", result) +class _DummyCudaqResult: + """Minimal DecoderResult lookalike returned by a mock cudaq-qec decoder""" + + def __init__(self, correction, converged=True, num_iter=10): + self.result = list(correction.astype(float)) + self.converged = converged + self.opt_results = {"num_iter": num_iter} + + +class _DummyCudaqDecoder: + """Mock cudaq-qec decoder that always returns the zero correction vector""" + + def __init__(self, n_bits): + self._n_bits = n_bits + + def decode(self, syndrome): + return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) + + +@_skip_ldpc +class TestDecodeCudaqBatch(unittest.TestCase): + """_decode_cudaq_batch must return correct shape/dtype and collect stats""" + + def setUp(self): + from evaluation.failure_analysis import _decode_cudaq_batch + self._fn = _decode_cudaq_batch + self.det_model = _make_tiny_dem() + self.n_bits = 20 # arbitrary correction vector length + self.n_dets = self.det_model.num_detectors + + def _make_decoder_and_L(self, n_bits=None): + if n_bits is None: + n_bits = self.n_bits + L_dense = np.zeros((1, n_bits), dtype=np.uint8) + decoder = _DummyCudaqDecoder(n_bits) + return decoder, L_dense + + def test_zero_syndrome_gives_zero_observable(self): + B = 4 + decoder, L_dense = self._make_decoder_and_L() + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + obs, _ = self._fn(decoder, L_dense, syndromes) + np.testing.assert_array_equal(obs, np.zeros(B, dtype=np.uint8)) + + def test_output_shape_is_batch_size(self): + for B in (1, 5): + decoder, L_dense = self._make_decoder_and_L() + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + obs, stats = self._fn(decoder, L_dense, syndromes) + self.assertEqual(obs.shape, (B,)) + self.assertEqual(obs.dtype, np.uint8) + self.assertEqual(stats["converged_flags"].shape, (B,)) + self.assertEqual(stats["iter_counts"].shape, (B,)) + + def test_output_values_are_binary(self): + B = 4 + decoder, L_dense = self._make_decoder_and_L() + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + obs, _ = self._fn(decoder, L_dense, syndromes) + self.assertTrue(np.all((obs == 0) | (obs == 1))) + + def test_convergence_flags_collected(self): + B = 3 + decoder, L_dense = self._make_decoder_and_L() + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + _, stats = self._fn(decoder, L_dense, syndromes) + self.assertTrue(np.all(stats["converged_flags"])) + + def test_iter_counts_collected(self): + B = 3 + decoder, L_dense = self._make_decoder_and_L() + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + _, stats = self._fn(decoder, L_dense, syndromes) + np.testing.assert_array_equal(stats["iter_counts"], np.full(B, 10, dtype=np.int32)) + + def test_multi_observable_uses_first_row(self): + """L_dense with 2 observable rows: result must still be 0/1""" + B = 3 + n_bits = 10 + L_dense = np.zeros((2, n_bits), dtype=np.uint8) + decoder = _DummyCudaqDecoder(n_bits) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + obs, _ = self._fn(decoder, L_dense, syndromes) + self.assertEqual(obs.shape, (B,)) + self.assertTrue(np.all((obs == 0) | (obs == 1))) + + +@_skip_ldpc +class TestBuildCudaqDecoders(unittest.TestCase): + """_build_cudaq_decoders must return correctly keyed entries when cudaq_qec is available""" + + def _make_mock_cudaq_qec(self, n_bits): + """Return a mock cudaq_qec module whose get_decoder always succeeds""" + mock_module = types.ModuleType("cudaq_qec") + mock_module.get_decoder = lambda name, H, **kw: _DummyCudaqDecoder(H.shape[1]) + return mock_module + + def test_standard_bp_decoders_present(self): + from evaluation.failure_analysis import _build_cudaq_decoders + det_model = _make_tiny_dem() + mock_cudaq = self._make_mock_cudaq_qec(n_bits=10) + with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): + decoders = _build_cudaq_decoders(det_model) + for name in ("cudaq-BP", "cudaq-MinSum", "cudaq-BP+OSD-0", "cudaq-BP+OSD-7"): + self.assertIn(name, decoders, f"Missing decoder key: {name}") + + def test_each_entry_is_decoder_and_l_dense_pair(self): + from evaluation.failure_analysis import _build_cudaq_decoders + det_model = _make_tiny_dem() + mock_cudaq = self._make_mock_cudaq_qec(n_bits=10) + with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): + decoders = _build_cudaq_decoders(det_model) + for name, (dec, L_dense) in decoders.items(): + with self.subTest(decoder=name): + self.assertTrue(hasattr(dec, "decode"), f"{name} has no .decode()") + self.assertIsInstance(L_dense, np.ndarray) + self.assertEqual(L_dense.dtype, np.uint8) + self.assertEqual(L_dense.shape[0], det_model.num_observables) + + def test_l_dense_columns_consistent_across_decoders(self): + from evaluation.failure_analysis import _build_cudaq_decoders + det_model = _make_tiny_dem() + mock_cudaq = self._make_mock_cudaq_qec(n_bits=10) + with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): + decoders = _build_cudaq_decoders(det_model) + widths = [v[1].shape[1] for v in decoders.values()] + self.assertEqual(len(set(widths)), 1, "All L_dense must have the same column count") + + def test_gracefully_skips_failing_variants(self): + """MemBP/RelayBP builders that raise must not abort the whole build""" + from evaluation.failure_analysis import _build_cudaq_decoders + det_model = _make_tiny_dem() + call_count = {"n": 0} + + def flaky_get_decoder(name, H, **kw): + call_count["n"] += 1 + bp_method = kw.get("bp_method", 0) + if bp_method in (2, 3): # MemBP / RelayBP + raise RuntimeError("Not supported in this build") + return _DummyCudaqDecoder(H.shape[1]) + + mock_cudaq = types.ModuleType("cudaq_qec") + mock_cudaq.get_decoder = flaky_get_decoder + with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): + import warnings + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + decoders = _build_cudaq_decoders(det_model) + # At minimum the 4 standard decoders should be present + self.assertGreaterEqual(len(decoders), 4) + for name in ("cudaq-BP", "cudaq-MinSum", "cudaq-BP+OSD-0", "cudaq-BP+OSD-7"): + self.assertIn(name, decoders) + + +@_skip_ldpc +class TestDecoderAblationStudyWithCudaq(unittest.TestCase): + """ + Smoke test: decoder_ablation_study must include cudaq decoder keys in results + when mocked cudaq decoders are injected + """ + + _D = 3 + _T = 3 + _N = 8 + + def _build_datapipe(self, basis): + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + return QCDataPipePreDecoder_Memory_inference( + distance=self._D, + n_rounds=self._T, + num_samples=self._N, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis=basis, + code_rotation="XV", + ) + + def test_cudaq_decoder_keys_appear_in_results_when_available(self): + from evaluation.failure_analysis import decoder_ablation_study, DECODER_NAMES + real_ds = self._build_datapipe("X") + + # Build a dummy cudaq decoder dict that matches what _build_cudaq_decoders returns + from beliefmatching.belief_matching import detector_error_model_to_check_matrices + det_model = _make_tiny_dem(distance=self._D, n_rounds=self._T) + matrices = detector_error_model_to_check_matrices(det_model) + import scipy.sparse as sp + L_dense = np.asarray(sp.csc_matrix(matrices.observables_matrix).toarray(), dtype=np.uint8) + n_bits = L_dense.shape[1] + dummy_cudaq_decoders = { + "cudaq-BP": (_DummyCudaqDecoder(n_bits), L_dense), + "cudaq-MinSum": (_DummyCudaqDecoder(n_bits), L_dense), + } + + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg( + tmpdir, distance=self._D, n_rounds=self._T, basis="X", n_samples=self._N + ) + with patch("data.factory.DatapipeFactory") as mock_factory, \ + patch("evaluation.failure_analysis._build_cudaq_decoders", + return_value=dummy_cudaq_decoders): + mock_factory.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study(_ZeroModel(), _DummyDist.device, _DummyDist(), cfg) + + # All base decoder names must still be present + self.assertTrue(set(DECODER_NAMES).issubset(set(result["decoder_errors"].keys()))) + # Injected cudaq keys must also appear + for name in dummy_cudaq_decoders: + self.assertIn(name, result["decoder_errors"], f"Missing cudaq key: {name}") + + def test_cudaq_error_counts_are_non_negative(self): + from evaluation.failure_analysis import decoder_ablation_study + real_ds = self._build_datapipe("X") + + from beliefmatching.belief_matching import detector_error_model_to_check_matrices + import scipy.sparse as sp + det_model = _make_tiny_dem(distance=self._D, n_rounds=self._T) + matrices = detector_error_model_to_check_matrices(det_model) + L_dense = np.asarray(sp.csc_matrix(matrices.observables_matrix).toarray(), dtype=np.uint8) + n_bits = L_dense.shape[1] + dummy_cudaq_decoders = {"cudaq-BP": (_DummyCudaqDecoder(n_bits), L_dense)} + + with tempfile.TemporaryDirectory() as tmpdir: + cfg = _make_cfg( + tmpdir, distance=self._D, n_rounds=self._T, basis="X", n_samples=self._N + ) + with patch("data.factory.DatapipeFactory") as mock_factory, \ + patch("evaluation.failure_analysis._build_cudaq_decoders", + return_value=dummy_cudaq_decoders): + mock_factory.create_datapipe_inference.return_value = real_ds + result = decoder_ablation_study(_ZeroModel(), _DummyDist.device, _DummyDist(), cfg) + + self.assertGreaterEqual(result["decoder_errors"]["cudaq-BP"], 0) + self.assertLessEqual(result["decoder_errors"]["cudaq-BP"], result["total_samples"]) + + if __name__ == "__main__": unittest.main() From 604e539bb7d40e6786b8d7c04bd4a07205e96c49 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 14:30:33 -0700 Subject: [PATCH 10/20] removing _PYMATCHING_SUPPORTS_CORRELATIONS and inspect Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 17 ++++++----------- code/evaluation/logical_error_rate.py | 13 +------------ 2 files changed, 7 insertions(+), 23 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 2aedf64..413a9db 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -20,7 +20,6 @@ from evaluation.logical_error_rate import ( _build_stab_maps, _decode_batch, - _PYMATCHING_SUPPORTS_CORRELATIONS, map_grid_to_stabilizer_tensor, sample_predictions, ) @@ -314,16 +313,12 @@ def decoder_ablation_study(model, device, dist, cfg): det_model = circuit.detector_error_model( decompose_errors=True, approximate_disjoint_errors=True ) - if _PYMATCHING_SUPPORTS_CORRELATIONS: - matcher_corr = pymatching.Matching.from_detector_error_model( - det_model, enable_correlations=True - ) - matcher_uncorr = pymatching.Matching.from_detector_error_model( - det_model, enable_correlations=False - ) - else: - matcher_corr = pymatching.Matching.from_detector_error_model(det_model) - matcher_uncorr = matcher_corr + matcher_corr = pymatching.Matching.from_detector_error_model( + det_model, enable_correlations=True + ) + matcher_uncorr = pymatching.Matching.from_detector_error_model( + det_model, enable_correlations=False + ) # Build ldpc decoders from the same DEM (with boundary detectors) ldpc_decoders = _build_ldpc_decoders(det_model) diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 77c1d46..8b81522 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -8,8 +8,6 @@ # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. -import inspect - import pymatching import numpy as np import torch @@ -19,17 +17,8 @@ from enum import IntEnum from typing import Optional -_PYMATCHING_SUPPORTS_CORRELATIONS = "enable_correlations" in inspect.signature( - pymatching.Matching.from_detector_error_model -).parameters - - def _decode_batch(matcher, detectors, enable_correlated): - """Wrapper for decode_batch that handles older pymatching versions.""" - if _PYMATCHING_SUPPORTS_CORRELATIONS: - return matcher.decode_batch(detectors, enable_correlations=enable_correlated) - else: - return matcher.decode_batch(detectors) + return matcher.decode_batch(detectors, enable_correlations=enable_correlated) class OnnxWorkflow(IntEnum): From c86eb80703cb8da78c308d34f953a213e8f66eba Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 14:32:23 -0700 Subject: [PATCH 11/20] formatting Signed-off-by: Sachin Pisal --- code/evaluation/logical_error_rate.py | 1 + 1 file changed, 1 insertion(+) diff --git a/code/evaluation/logical_error_rate.py b/code/evaluation/logical_error_rate.py index 8b81522..4bbef46 100644 --- a/code/evaluation/logical_error_rate.py +++ b/code/evaluation/logical_error_rate.py @@ -17,6 +17,7 @@ from enum import IntEnum from typing import Optional + def _decode_batch(matcher, detectors, enable_correlated): return matcher.decode_batch(detectors, enable_correlations=enable_correlated) From cbb581b401c06310c20f71d666cdb88c1e891747 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 16:09:07 -0700 Subject: [PATCH 12/20] removing unconditional imports Signed-off-by: Sachin Pisal --- code/tests/test_failure_analysis.py | 18 +++--------------- 1 file changed, 3 insertions(+), 15 deletions(-) diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 669b2f6..1dfb347 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -21,15 +21,9 @@ if str(_repo_code) not in sys.path: sys.path.insert(0, str(_repo_code)) -try: - import ldpc - import beliefmatching - import scipy - _HAS_LDPC_DEPS = True -except ImportError: - _HAS_LDPC_DEPS = False - -_skip_ldpc = unittest.skipUnless(_HAS_LDPC_DEPS, "ldpc/beliefmatching/scipy not installed") +import ldpc +import beliefmatching +import scipy def _make_tiny_dem(distance=3, n_rounds=3, basis="X", code_rotation="XV"): @@ -95,7 +89,6 @@ class _DummyDist: device = torch.device("cpu") -@_skip_ldpc class TestBuildLdpcDecoders(unittest.TestCase): """_build_ldpc_decoders must return correctly keyed decoder objects with consistent shapes.""" @@ -124,7 +117,6 @@ def test_l_dense_columns_consistent_across_decoders(self): self.assertEqual(len(set(widths)), 1, "All L_dense must have the same column count") -@_skip_ldpc class TestDecodeLdpcBatch(unittest.TestCase): """_decode_ldpc_batch must return correct shape/dtype; zero syndrome decodes to 0.""" @@ -171,7 +163,6 @@ def test_output_values_are_binary(self): ) -@_skip_ldpc class TestDecoderAblationStudy(unittest.TestCase): """ Smoke test: decoder_ablation_study must complete, return expected keys, @@ -266,7 +257,6 @@ def decode(self, syndrome): return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) -@_skip_ldpc class TestDecodeCudaqBatch(unittest.TestCase): """_decode_cudaq_batch must return correct shape/dtype and collect stats""" @@ -334,7 +324,6 @@ def test_multi_observable_uses_first_row(self): self.assertTrue(np.all((obs == 0) | (obs == 1))) -@_skip_ldpc class TestBuildCudaqDecoders(unittest.TestCase): """_build_cudaq_decoders must return correctly keyed entries when cudaq_qec is available""" @@ -401,7 +390,6 @@ def flaky_get_decoder(name, H, **kw): self.assertIn(name, decoders) -@_skip_ldpc class TestDecoderAblationStudyWithCudaq(unittest.TestCase): """ Smoke test: decoder_ablation_study must include cudaq decoder keys in results From 8264cbbc0c9d09b55bec193a7e6e695e4b693086 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 16:52:52 -0700 Subject: [PATCH 13/20] adding a test to check predecoder actually modifies residual syndromes Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 3 +++ code/tests/test_failure_analysis.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 413a9db..7aef2ee 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -396,6 +396,7 @@ def decoder_ablation_study(model, device, dist, cfg): baseline_errors = 0 decoder_errors = {name: 0 for name in decoder_names} all_residual_weights = [] + all_baseline_weights = [] weight_bucket_stats = {} n_all_agree = 0 @@ -438,6 +439,7 @@ def decoder_ablation_study(model, device, dist, cfg): # --- Baseline: Stim detectors (with boundary dets), Stim ground truth --- baseline_detectors_batch = stim_dets[start:end] gt_obs_batch = stim_obs[start:end] + all_baseline_weights.extend(baseline_detectors_batch.sum(axis=1).tolist()) _t0 = _time.perf_counter() baseline_pred_obs = _decode_batch(matcher_corr, baseline_detectors_batch, True) @@ -754,6 +756,7 @@ def decoder_ablation_study(model, device, dist, cfg): "baseline_errors": baseline_errors, "decoder_errors": decoder_errors, "residual_weights": all_residual_weights, + "baseline_weights": all_baseline_weights, "weight_bucket_stats": weight_bucket_stats, "agreement_count": n_all_agree, } if dist.rank == 0 else {} diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 1dfb347..72fb09c 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -232,6 +232,25 @@ def test_agreement_count_within_bounds(self): self.assertGreaterEqual(result["agreement_count"], 0) self.assertLessEqual(result["agreement_count"], result["total_samples"]) + def test_predecoder_changes_residual_syndromes(self): + """ + Residual syndromes must differ from the baseline Stim syndromes when the + pre-decoder applies non-trivial corrections. + """ + result = self._run("X") + self.assertIn("baseline_weights", result) + self.assertIn("residual_weights", result) + + self.assertEqual(len(result["baseline_weights"]), result["total_samples"]) + self.assertEqual(len(result["residual_weights"]), result["total_samples"]) + + self.assertNotEqual( + result["residual_weights"], + result["baseline_weights"], + "Pre-decoder with all-ones corrections produced identical residual " + "and baseline syndrome weights - transformation is likely a no-op.", + ) + def test_z_basis_runs_and_returns_correct_structure(self): result = self._run("Z") self.assertEqual(result["total_samples"], self._N) From 2883c4878d58dfcfa1116c7ed0da4a22b6c86aca Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 17:05:11 -0700 Subject: [PATCH 14/20] adding comments for _decode_ldpc_batch Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 7aef2ee..316a641 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -238,7 +238,10 @@ def _decode_ldpc_batch(decoder, L_dense, syndromes_np): B = syndromes_np.shape[0] obs = np.zeros(B, dtype=np.uint8) for i in range(B): + # Get the most-likely error configuration from the decoder for this syndrome. correction = decoder.decode(syndromes_np[i]) + # Project the correction onto the logical observable via L_dense (mod 2). + # L_dense has shape (num_obs, num_errors); the first observable row is used. obs[i] = ( int((L_dense @ correction).item() % 2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2) From a261d4db2bffb195f9fd496e1d0aa6ca8e8d3176 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 19:53:26 -0700 Subject: [PATCH 15/20] refactoring decoder_ablation_study Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 855 ++++++++++++++++------------ 1 file changed, 505 insertions(+), 350 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 316a641..f40816d 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -249,6 +249,412 @@ def _decode_ldpc_batch(decoder, L_dense, syndromes_np): return obs +def _build_all_decoders(det_model, dist): + """Build all decoders (PyMatching, LDPC, cudaq-qec) from the DEM""" + import pymatching + matcher_corr = pymatching.Matching.from_detector_error_model( + det_model, enable_correlations=True + ) + matcher_uncorr = pymatching.Matching.from_detector_error_model( + det_model, enable_correlations=False + ) + ldpc_decoders = _build_ldpc_decoders(det_model) + cudaq_decoders = {} + try: + cudaq_decoders = _build_cudaq_decoders(det_model) + if dist.rank == 0: + print(f"[Decoder Ablation] cudaq-qec decoders loaded: {list(cudaq_decoders.keys())}") + except Exception as e: + if dist.rank == 0: + print(f"[Decoder Ablation] cudaq-qec decoders unavailable: {e}") + return matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders + + +def _build_logical_operators(D, code_rotation, device): + """Build parity-check index tensors and logical operator masks for the surface code""" + maps = _build_stab_maps(D, code_rotation) + Hx_idx = maps["Hx_idx"].to(device=device, dtype=torch.long) + Hz_idx = maps["Hz_idx"].to(device=device, dtype=torch.long) + Hx_mask = maps["Hx_mask"].to(device=device, dtype=torch.bool) + Hz_mask = maps["Hz_mask"].to(device=device, dtype=torch.bool) + stab_indices_x = maps["stab_x"].to(device=device, dtype=torch.long) + stab_indices_z = maps["stab_z"].to(device=device, dtype=torch.long) + Kx, Kz = maps["Kx"], maps["Kz"] + D2 = D * D + if code_rotation.upper() in ("XV", "ZH"): + Lx = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lx[0, :D] = 1 + Lz = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lz[0, ::D] = 1 + else: + Lx = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lx[0, ::D] = 1 + Lz = torch.zeros((1, D2), dtype=torch.int32, device=device) + Lz[0, :D] = 1 + return Hx_idx, Hz_idx, Hx_mask, Hz_mask, stab_indices_x, stab_indices_z, Kx, Kz, Lx, Lz + + +def _model_forward_and_residual( + model, + trainX, + x_syn_diff, + z_syn_diff, + basis, + B, + D2, + T, + Hx_idx, + Hz_idx, + Hx_mask, + Hz_mask, + Kx, + Kz, + stab_indices_x, + stab_indices_z, + Lx, + Lz, + th_data, + th_syn, + sampling_mode, + temperature_data, + temperature_syn, + cfg, + device, + num_boundary_dets, + baseline_detectors_batch, + det_model, +): + """ + Run the pre-decoder model on one batch and build the residual syndrome. + + Returns: + residual_np: (B, num_detectors) uint8 array - residual syndromes for global decoders. + pre_L_np: (B,) int64 array - logical frame contribution from data corrections. + """ + with torch.amp.autocast( + device_type=device.type if hasattr(device, "type") else "cuda", + enabled=getattr(cfg, "enable_fp16", False), + ): + logits = model(trainX) + z_data_corr = sample_predictions(logits[:, 0], th_data, sampling_mode, temperature_data) + x_data_corr = sample_predictions(logits[:, 1], th_data, sampling_mode, temperature_data) + syn_x_grid = sample_predictions(logits[:, 2], th_syn, sampling_mode, temperature_syn) + syn_z_grid = sample_predictions(logits[:, 3], th_syn, sampling_mode, temperature_syn) + + z_flat = z_data_corr.permute(0, 2, 3, 1).contiguous().view(B, D2, T).to(torch.int32) + x_flat = x_data_corr.permute(0, 2, 3, 1).contiguous().view(B, D2, T).to(torch.int32) + z_exp = z_flat.unsqueeze(2).expand(B, D2, Kx, T) + hx_idx_e = Hx_idx.clamp_min(0).view(1, -1, Kx, 1).expand(B, -1, -1, T) + g_x = z_exp.gather(1, hx_idx_e) + m_x = Hx_mask.view(1, -1, Kx, 1).expand_as(g_x) + S_X = (g_x.masked_fill(~m_x, 0).sum(dim=2) & 1) + x_exp = x_flat.unsqueeze(2).expand(B, D2, Kz, T) + hz_idx_e = Hz_idx.clamp_min(0).view(1, -1, Kz, 1).expand(B, -1, -1, T) + g_z = x_exp.gather(1, hz_idx_e) + m_z = Hz_mask.view(1, -1, Kz, 1).expand_as(g_z) + S_Z = (g_z.masked_fill(~m_z, 0).sum(dim=2) & 1) + + syn_x_flat = map_grid_to_stabilizer_tensor(syn_x_grid, stab_indices_x).to(torch.int32) + syn_z_flat = map_grid_to_stabilizer_tensor(syn_z_grid, stab_indices_z).to(torch.int32) + R_X = torch.empty_like(x_syn_diff, dtype=torch.uint8) + R_X[:, :, 0] = (x_syn_diff[:, :, 0] + syn_x_flat[:, :, 0] + S_X[:, :, 0]) & 1 + if T > 1: + R_X[:, :, 1:] = ( + x_syn_diff[:, :, 1:] + syn_x_flat[:, :, 1:] + syn_x_flat[:, :, :-1] + S_X[:, :, 1:] + ) & 1 + R_Z = torch.empty_like(z_syn_diff, dtype=torch.uint8) + R_Z[:, :, 0] = (z_syn_diff[:, :, 0] + syn_z_flat[:, :, 0] + S_Z[:, :, 0]) & 1 + if T > 1: + R_Z[:, :, 1:] = ( + z_syn_diff[:, :, 1:] + syn_z_flat[:, :, 1:] + syn_z_flat[:, :, :-1] + S_Z[:, :, 1:] + ) & 1 + + # Logical frame from data corrections + if basis == "X": + pre_L_t = torch.einsum("ld,bdt->blt", Lx.to(torch.float32), + z_flat.to(torch.float32)).remainder_(2).to(torch.int32) + else: + pre_L_t = torch.einsum("ld,bdt->blt", Lz.to(torch.float32), + x_flat.to(torch.float32)).remainder_(2).to(torch.int32) + pre_L = pre_L_t.sum(dim=2).remainder_(2).view(-1) + + # Build residual detectors (matching logical_error_rate.py exactly) + if basis == "X": + initial_detectors = R_X[:, :, 0].view(B, -1) + else: + initial_detectors = R_Z[:, :, 0].view(B, -1) + R_X_rest = R_X[:, :, 1:] + R_Z_rest = R_Z[:, :, 1:] + R_cat_rest = torch.cat([R_X_rest, R_Z_rest], dim=1) + rest_flat = R_cat_rest.permute(0, 2, 1).contiguous().view(B, -1) + residual = torch.cat([initial_detectors, rest_flat], dim=1).to(torch.uint8) + + # Append boundary detectors from Stim (unchanged by pre-decoder) + boundary_dets_batch = baseline_detectors_batch[:, -num_boundary_dets:] + residual = torch.cat( + [residual, torch.from_numpy(boundary_dets_batch).to(residual.device)], dim=1 + ) + + if residual.shape[1] != det_model.num_detectors: + raise ValueError( + f"Residual shape {residual.shape} != DEM detectors {det_model.num_detectors}. " + f"Check interleave order for basis '{basis}' and time slicing." + ) + + return residual.cpu().numpy(), pre_L.cpu().numpy() + + +def _run_decoders_on_batch( + residual_np, + pre_L_np, + weights, + ldpc_decoders, + cudaq_decoders, + matcher_uncorr, + matcher_corr, + cudaq_decoder_names, + decoder_names, + gt_obs_np, + _timing, + _cudaq_stats, + weight_bucket_stats, +): + """ + Run all configured decoders on one batch of residual syndromes. + + Mutates _timing, _cudaq_stats, and weight_bucket_stats in-place. + Returns: + all_finals: dict mapping decoder name -> (B,) int array of final observable predictions. + n_agree: number of samples where all decoders agreed. + """ + import time as _t + + B = residual_np.shape[0] + + # 1. No-op: pred_obs = 0 + noop_final = pre_L_np % 2 + + # 2. Union-Find (ldpc) + _uf, _bp, _bplsd = LDPC_DECODER_NAMES + _t0 = _t.perf_counter() + uf_dec, uf_L = ldpc_decoders[_uf] + uf_obs = _decode_ldpc_batch(uf_dec, uf_L, residual_np) + uf_final = (pre_L_np + uf_obs) % 2 + _timing["uf_decode"] += _t.perf_counter() - _t0 + + # 3. BP-only (no LSD fallback) + _t0 = _t.perf_counter() + bp_dec, bp_L = ldpc_decoders[_bp] + bp_obs = _decode_ldpc_batch(bp_dec, bp_L, residual_np) + bp_final = (pre_L_np + bp_obs) % 2 + _timing["bp_only_decode"] += _t.perf_counter() - _t0 + + # 4. BP+LSD-0 (ldpc) + _t0 = _t.perf_counter() + bplsd_dec, bplsd_L = ldpc_decoders[_bplsd] + bplsd_obs = _decode_ldpc_batch(bplsd_dec, bplsd_L, residual_np) + bplsd_final = (pre_L_np + bplsd_obs) % 2 + _timing["bplsd_decode"] += _t.perf_counter() - _t0 + + # 5. Uncorrelated PyMatching + _t0 = _t.perf_counter() + uncorr_pred = _decode_batch(matcher_uncorr, residual_np, False) + uncorr_pred = np.asarray(uncorr_pred, dtype=np.int64).reshape(-1) + uncorr_final = (pre_L_np + uncorr_pred) % 2 + _timing["uncorr_pm"] += _t.perf_counter() - _t0 + + # 6. Correlated PyMatching + _t0 = _t.perf_counter() + corr_pred = _decode_batch(matcher_corr, residual_np, True) + corr_pred = np.asarray(corr_pred, dtype=np.int64).reshape(-1) + corr_final = (pre_L_np + corr_pred) % 2 + _timing["corr_pm"] += _t.perf_counter() - _t0 + + # 7. cudaq-qec GPU-accelerated decoders + cudaq_finals = {} + for cn in cudaq_decoder_names: + _t0 = _t.perf_counter() + cdec, cL = cudaq_decoders[cn] + c_obs, c_stats = _decode_cudaq_batch(cdec, cL, residual_np) + c_final = (pre_L_np + c_obs) % 2 + cudaq_finals[cn] = c_final + _timing[f"{cn}_decode"] += _t.perf_counter() - _t0 + # Accumulate per-sample convergence, iteration, and error stats + conv_flags = c_stats["converged_flags"] + iters = c_stats["iter_counts"] + fails = (c_final != gt_obs_np) + _cudaq_stats[cn]["converged_flags"].append(conv_flags) + _cudaq_stats[cn]["iter_counts"].append(iters) + _cudaq_stats[cn]["error_flags"].append(fails) + + _t0 = _t.perf_counter() + all_finals = { + DECODER_NAMES[0]: noop_final, + _uf: uf_final, + _bp: bp_final, + _bplsd: bplsd_final, + DECODER_NAMES[4]: uncorr_final, + DECODER_NAMES[5]: corr_final, + } + all_finals.update(cudaq_finals) + + stacked = np.stack([all_finals[n] for n in decoder_names], axis=0) # (n_decoders, B) + agree = np.all(stacked == stacked[0:1], axis=0) # (B,) + + for i in range(B): + w = int(weights[i]) + bucket = w if w <= 6 else 7 # 0-6, 7+ + if bucket not in weight_bucket_stats: + weight_bucket_stats[bucket] = {n: [0, 0] for n in decoder_names} + weight_bucket_stats[bucket]["_total"] = weight_bucket_stats[bucket].get("_total", 0) + 1 + for name in decoder_names: + if bucket not in weight_bucket_stats or name not in weight_bucket_stats[bucket]: + weight_bucket_stats[bucket][name] = [0, 0] + weight_bucket_stats[bucket][name][1] += 1 + if all_finals[name][i] != gt_obs_np[i]: + weight_bucket_stats[bucket][name][0] += 1 + + _timing["bookkeeping"] += _t.perf_counter() - _t0 + + return all_finals, int(agree.sum()) + + +def _print_ablation_results( + basis, + D, + cfg, + total_scanned, + baseline_errors, + decoder_errors, + decoder_names, + cudaq_decoder_names, + _cudaq_stats, + n_all_agree, + all_residual_weights, + weight_bucket_stats, + _timing, +): + """Print timing breakdown, LER summary, convergence stats, and generate plots.""" + _total_time = sum(_timing.values()) + print(f"\n{'='*60}") + print(f"TIMING BREAKDOWN (total loop = {_total_time:.2f}s)") + print(f"{'='*60}") + for k, v in sorted(_timing.items(), key=lambda x: -x[1]): + pct = v / max(_total_time, 1e-9) * 100 + print(f" {k:<20s} {v:8.2f}s ({pct:5.1f}%)") + print(f"{'='*60}") + + print(f"\n{'='*70}") + print( + f"DECODER ABLATION STUDY | basis={basis} d={D} r={cfg.n_rounds}" + f" p={getattr(cfg.test, 'p_error', 0.003)}" + ) + print(f"{'='*70}") + print(f"Total samples: {total_scanned}") + + baseline_ler = baseline_errors / max(1, total_scanned) + print(f"\n--- Logical Error Rates ---") + print( + f" {'Baseline (no pre-dec)':<25s} LER = {baseline_ler:.6f}" + f" ({baseline_errors} errors)" + ) + for name in decoder_names: + ler = decoder_errors[name] / max(1, total_scanned) + print(f" {name:<25s} LER = {ler:.6f} ({decoder_errors[name]} errors)") + + # cudaq decoder convergence and iteration stats + if _cudaq_stats: + print(f"\n--- cudaq-qec BP Convergence & Iteration Breakdown ---") + print( + f" {'Decoder':<20s} {'Conv%':>7s} {'AvgIt':>6s} " + f"{'Conv.It':>8s} {'Conv.LER':>9s} {'Conv.Err':>9s} " + f"{'!Conv.It':>8s} {'!Conv.LER':>10s} {'!Conv.Err':>10s}" + ) + for cn in cudaq_decoder_names: + st = _cudaq_stats[cn] + conv_all = np.concatenate(st["converged_flags"]) + iters_all = np.concatenate(st["iter_counts"]) + errs_all = np.concatenate(st["error_flags"]) + N = len(conv_all) + n_conv = int(conv_all.sum()) + n_noconv = N - n_conv + conv_pct = n_conv / max(1, N) * 100 + has_iters = iters_all.sum() > 0 + + # Converged subset + if n_conv > 0 and has_iters: + conv_avg_it = iters_all[conv_all].mean() + conv_ler = errs_all[conv_all].mean() + conv_errs = int(errs_all[conv_all].sum()) + else: + conv_avg_it = conv_ler = 0.0 + conv_errs = 0 + + # Non-converged subset + if n_noconv > 0 and has_iters: + noconv_avg_it = iters_all[~conv_all].mean() + noconv_ler = errs_all[~conv_all].mean() + noconv_errs = int(errs_all[~conv_all].sum()) + else: + noconv_avg_it = noconv_ler = 0.0 + noconv_errs = 0 + + if has_iters: + avg_it_str = f"{iters_all.mean():5.1f}" + conv_it_str = f"{conv_avg_it:7.1f}" + noconv_it_str = f"{noconv_avg_it:7.1f}" if n_noconv > 0 else " N/A" + else: + avg_it_str = " N/A" + conv_it_str = " N/A" + noconv_it_str = " N/A" + + noconv_ler_str = f"{noconv_ler:9.6f}" if n_noconv > 0 else " N/A" + noconv_err_str = f"{noconv_errs:>9d}" if n_noconv > 0 else " N/A" + + print( + f" {cn:<20s} {conv_pct:>6.1f}% {avg_it_str} " + f"{conv_it_str} {conv_ler:>9.6f} {conv_errs:>9d} " + f"{noconv_it_str} {noconv_ler_str} {noconv_err_str}" + ) + + agreement_rate = n_all_agree / max(1, total_scanned) + print(f"\n--- Decoder Agreement ---") + print( + f" All {len(decoder_names)} decoders agree:" + f" {agreement_rate*100:.2f}% ({n_all_agree}/{total_scanned})" + ) + + weights_arr = np.array(all_residual_weights) + print(f"\n--- Residual Weight Distribution ---") + for w in sorted(weight_bucket_stats.keys()): + label = f"{w}+" if w == 7 else str(w) + count = weight_bucket_stats[w].get("_total", 0) + pct = count / max(1, total_scanned) * 100 + print(f" Weight {label:>3s}: {count:>7d} samples ({pct:6.2f}%)") + print(f" Mean weight: {weights_arr.mean():.3f}, Max: {int(weights_arr.max())}") + + print(f"\n--- Conditional LER by Residual Weight ---") + header = f" {'Weight':>7s}" + for name in decoder_names: + header += f" {name:>12s}" + print(header) + for w in sorted(weight_bucket_stats.keys()): + label = f"{w}+" if w == 7 else str(w) + row = f" {label:>7s}" + for name in decoder_names: + n_err, n_tot = weight_bucket_stats[w].get(name, [0, 0]) + if n_tot > 0: + row += f" {n_err/n_tot:>12.6f}" + else: + row += f" {'N/A':>12s}" + print(row) + print(f"{'='*70}") + + # --- Plots --- + _plot_residual_weight_histogram(all_residual_weights, basis, cfg) + _plot_conditional_ler(weight_bucket_stats, decoder_names, basis, cfg) + + @torch.inference_mode() def decoder_ablation_study(model, device, dist, cfg): """ @@ -263,10 +669,7 @@ def decoder_ablation_study(model, device, dist, cfg): import time as _time from copy import deepcopy - import pymatching - - from data.factory import DatapipeFactory - + # --- Config --- th_data = float(getattr(cfg.test, "th_data", 0.0)) th_syn = float(getattr(cfg.test, "th_syn", 0.0)) sampling_mode = str(getattr(cfg.test, "sampling_mode", "threshold")).lower() @@ -277,14 +680,14 @@ def decoder_ablation_study(model, device, dist, cfg): temperature_syn = float(temperature_syn) if temperature_syn is not None else temperature model.eval() - enable_correlated = getattr(cfg.data, "enable_correlated_pymatching", False) basis = str(getattr(cfg.test, "meas_basis_test", "X")).upper() if basis not in ("X", "Z"): basis = "X" - # --- Create Stim datapipe (with boundary detectors) --- + # --- Dataset --- total_samples = int(cfg.test.num_samples) samples_per_gpu = total_samples // max(1, dist.world_size) + from data.factory import DatapipeFactory torch_state = torch.get_rng_state() cuda_state = torch.cuda.get_rng_state_all() if torch.cuda.is_available() else None @@ -311,62 +714,32 @@ def decoder_ablation_study(model, device, dist, cfg): circuit = test_dataset.circ.stim_circuit num_obs = circuit.num_observables assert num_obs == 1 - - # DEM and matchers from Stim circuit (includes boundary detectors) det_model = circuit.detector_error_model( decompose_errors=True, approximate_disjoint_errors=True ) - matcher_corr = pymatching.Matching.from_detector_error_model( - det_model, enable_correlations=True - ) - matcher_uncorr = pymatching.Matching.from_detector_error_model( - det_model, enable_correlations=False - ) - - # Build ldpc decoders from the same DEM (with boundary detectors) - ldpc_decoders = _build_ldpc_decoders(det_model) - # Build cudaq-qec GPU-accelerated decoders - cudaq_decoders = {} - try: - cudaq_decoders = _build_cudaq_decoders(det_model) - if dist.rank == 0: - print(f"[Decoder Ablation] cudaq-qec decoders loaded: {list(cudaq_decoders.keys())}") - except Exception as e: - if dist.rank == 0: - print(f"[Decoder Ablation] cudaq-qec decoders unavailable: {e}") + # --- Decoders --- + matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders = _build_all_decoders( + det_model, dist + ) + cudaq_decoder_names = sorted(cudaq_decoders.keys()) + decoder_names = list(DECODER_NAMES) + cudaq_decoder_names - # Stim baseline detectors and ground truth observables + # --- Baseline data --- stim_dets = np.asarray(test_dataset.dets_and_obs[:, :-num_obs], dtype=np.uint8) assert stim_dets.shape[1] == det_model.num_detectors, \ f"Stim dets width {stim_dets.shape[1]} != DEM {det_model.num_detectors}" stim_obs = np.asarray(test_dataset.dets_and_obs[:, -num_obs:], dtype=np.uint8) - # Number of boundary detectors surface_code = test_dataset.circ.code num_boundary_dets = surface_code.hx.shape[0] if basis == 'X' else surface_code.hz.shape[0] + # --- Logical operators --- D = cfg.distance code_rotation = getattr(cfg.data, "code_rotation", "XV") - maps = _build_stab_maps(D, code_rotation) - Hx_idx = maps["Hx_idx"].to(device=device, dtype=torch.long) - Hz_idx = maps["Hz_idx"].to(device=device, dtype=torch.long) - Hx_mask = maps["Hx_mask"].to(device=device, dtype=torch.bool) - Hz_mask = maps["Hz_mask"].to(device=device, dtype=torch.bool) - stab_indices_x = maps["stab_x"].to(device=device, dtype=torch.long) - stab_indices_z = maps["stab_z"].to(device=device, dtype=torch.long) - Kx, Kz = maps["Kx"], maps["Kz"] + Hx_idx, Hz_idx, Hx_mask, Hz_mask, stab_indices_x, stab_indices_z, Kx, Kz, Lx, Lz = \ + _build_logical_operators(D, code_rotation, device) D2 = D * D - if code_rotation.upper() in ("XV", "ZH"): - Lx = torch.zeros((1, D2), dtype=torch.int32, device=device) - Lx[0, :D] = 1 - Lz = torch.zeros((1, D2), dtype=torch.int32, device=device) - Lz[0, ::D] = 1 - else: - Lx = torch.zeros((1, D2), dtype=torch.int32, device=device) - Lx[0, ::D] = 1 - Lz = torch.zeros((1, D2), dtype=torch.int32, device=device) - Lz[0, :D] = 1 if dist.rank == 0: print( @@ -383,18 +756,15 @@ def decoder_ablation_study(model, device, dist, cfg): ) cudaq_names_str = ", ".join(cudaq_decoders.keys()) if cudaq_decoders else "(none)" print( - f"[Decoder Ablation] Decoders: No-op, Union-Find, BP+LSD-0, Uncorr PM, Corr PM, {cudaq_names_str}, + Baseline PM" + f"[Decoder Ablation] Decoders: No-op, Union-Find, BP+LSD-0," + f" Uncorr PM, Corr PM, {cudaq_names_str}, + Baseline PM" ) + # --- Batch loop --- batch_size = int(getattr(cfg.test.dataloader, "batch_size", 2048)) N = len(test_dataset) num_batches = (N + batch_size - 1) // batch_size - decoder_names = list(DECODER_NAMES) - # Append cudaq decoder names dynamically - cudaq_decoder_names = sorted(cudaq_decoders.keys()) - decoder_names.extend(cudaq_decoder_names) - total_scanned = 0 baseline_errors = 0 decoder_errors = {name: 0 for name in decoder_names} @@ -404,28 +774,34 @@ def decoder_ablation_study(model, device, dist, cfg): n_all_agree = 0 _timing = { - "collate": 0.0, - "baseline_pm": 0.0, - "model_fwd": 0.0, - "residual_build": 0.0, - "uf_decode": 0.0, - "bp_only_decode": 0.0, - "bplsd_decode": 0.0, - "uncorr_pm": 0.0, - "corr_pm": 0.0, - "bookkeeping": 0.0, + k: 0.0 for k in ( + "collate", + "baseline_pm", + "model_fwd", + "residual_build", + "uf_decode", + "bp_only_decode", + "bplsd_decode", + "uncorr_pm", + "corr_pm", + "bookkeeping", + ) } - _cudaq_stats = {} for cn in cudaq_decoder_names: _timing[f"{cn}_decode"] = 0.0 - _cudaq_stats[cn] = {"converged_flags": [], "iter_counts": [], "error_flags": []} + _cudaq_stats = { + cn: { + "converged_flags": [], + "iter_counts": [], + "error_flags": [] + } for cn in cudaq_decoder_names + } for batch_idx in range(num_batches): start = batch_idx * batch_size end = min(start + batch_size, N) B = end - start - # Collate batch from dataset items _t0 = _time.perf_counter() items = [test_dataset[i] for i in range(start, end)] x_syn_diff = torch.stack([it["x_syn_diff"] for it in items] @@ -435,11 +811,11 @@ def decoder_ablation_study(model, device, dist, cfg): trainX = torch.stack([it["trainX"] for it in items]).to(device=device) _timing["collate"] += _time.perf_counter() - _t0 - _, n_x, T = x_syn_diff.shape + _, _, T = x_syn_diff.shape if T < 2: continue - # --- Baseline: Stim detectors (with boundary dets), Stim ground truth --- + # Baseline: raw Stim syndromes + ground truth baseline_detectors_batch = stim_dets[start:end] gt_obs_batch = stim_obs[start:end] all_baseline_weights.extend(baseline_detectors_batch.sum(axis=1).tolist()) @@ -452,307 +828,86 @@ def decoder_ablation_study(model, device, dist, cfg): gt_obs_np = gt_obs_batch.reshape(-1).astype(np.int64) - # Model forward + # Pre-decoder forward pass + residual syndrome construction _t0 = _time.perf_counter() - with torch.amp.autocast( - device_type=device.type if hasattr(device, "type") else "cuda", - enabled=getattr(cfg, "enable_fp16", False), - ): - logits = model(trainX) - z_data_corr = sample_predictions(logits[:, 0], th_data, sampling_mode, temperature_data) - x_data_corr = sample_predictions(logits[:, 1], th_data, sampling_mode, temperature_data) - syn_x_grid = sample_predictions(logits[:, 2], th_syn, sampling_mode, temperature_syn) - syn_z_grid = sample_predictions(logits[:, 3], th_syn, sampling_mode, temperature_syn) - - z_flat = z_data_corr.permute(0, 2, 3, 1).contiguous().view(B, D2, T).to(torch.int32) - x_flat = x_data_corr.permute(0, 2, 3, 1).contiguous().view(B, D2, T).to(torch.int32) - z_exp = z_flat.unsqueeze(2).expand(B, D2, Kx, T) - hx_idx_e = Hx_idx.clamp_min(0).view(1, -1, Kx, 1).expand(B, -1, -1, T) - g_x = z_exp.gather(1, hx_idx_e) - m_x = Hx_mask.view(1, -1, Kx, 1).expand_as(g_x) - S_X = (g_x.masked_fill(~m_x, 0).sum(dim=2) & 1) - x_exp = x_flat.unsqueeze(2).expand(B, D2, Kz, T) - hz_idx_e = Hz_idx.clamp_min(0).view(1, -1, Kz, 1).expand(B, -1, -1, T) - g_z = x_exp.gather(1, hz_idx_e) - m_z = Hz_mask.view(1, -1, Kz, 1).expand_as(g_z) - S_Z = (g_z.masked_fill(~m_z, 0).sum(dim=2) & 1) - - syn_x_flat = map_grid_to_stabilizer_tensor(syn_x_grid, stab_indices_x).to(torch.int32) - syn_z_flat = map_grid_to_stabilizer_tensor(syn_z_grid, stab_indices_z).to(torch.int32) - R_X = torch.empty_like(x_syn_diff, dtype=torch.uint8) - R_X[:, :, 0] = (x_syn_diff[:, :, 0] + syn_x_flat[:, :, 0] + S_X[:, :, 0]) & 1 - if T > 1: - R_X[:, :, 1:] = ( - x_syn_diff[:, :, 1:] + syn_x_flat[:, :, 1:] + syn_x_flat[:, :, :-1] + S_X[:, :, 1:] - ) & 1 - R_Z = torch.empty_like(z_syn_diff, dtype=torch.uint8) - R_Z[:, :, 0] = (z_syn_diff[:, :, 0] + syn_z_flat[:, :, 0] + S_Z[:, :, 0]) & 1 - if T > 1: - R_Z[:, :, 1:] = ( - z_syn_diff[:, :, 1:] + syn_z_flat[:, :, 1:] + syn_z_flat[:, :, :-1] + S_Z[:, :, 1:] - ) & 1 - - # Logical frame from data corrections - if basis == "X": - pre_L_t = torch.einsum("ld,bdt->blt", Lx.to(torch.float32), - z_flat.to(torch.float32)).remainder_(2).to(torch.int32) - else: - pre_L_t = torch.einsum("ld,bdt->blt", Lz.to(torch.float32), - x_flat.to(torch.float32)).remainder_(2).to(torch.int32) - pre_L = pre_L_t.sum(dim=2).remainder_(2).view(-1) - - # Build residual detectors (matching logical_error_rate.py exactly) - if basis == "X": - initial_detectors = R_X[:, :, 0].view(B, -1) - else: - initial_detectors = R_Z[:, :, 0].view(B, -1) - R_X_rest = R_X[:, :, 1:] - R_Z_rest = R_Z[:, :, 1:] - R_cat_rest = torch.cat([R_X_rest, R_Z_rest], dim=1) - rest_flat = R_cat_rest.permute(0, 2, 1).contiguous().view(B, -1) - residual = torch.cat([initial_detectors, rest_flat], dim=1).to(torch.uint8) - - # Append boundary detectors from Stim (unchanged by pre-decoder) - boundary_dets_batch = baseline_detectors_batch[:, -num_boundary_dets:] - residual = torch.cat( - [residual, torch.from_numpy(boundary_dets_batch).to(residual.device)], dim=1 + residual_np, pre_L_np = _model_forward_and_residual( + model, + trainX, + x_syn_diff, + z_syn_diff, + basis, + B, + D2, + T, + Hx_idx, + Hz_idx, + Hx_mask, + Hz_mask, + Kx, + Kz, + stab_indices_x, + stab_indices_z, + Lx, + Lz, + th_data, + th_syn, + sampling_mode, + temperature_data, + temperature_syn, + cfg, + device, + num_boundary_dets, + baseline_detectors_batch, + det_model, ) - - if residual.shape[1] != det_model.num_detectors: - raise ValueError( - f"Residual shape {residual.shape} != DEM detectors {det_model.num_detectors}. " - f"Check interleave order for basis '{basis}' and time slicing." - ) - if device.type == "cuda": torch.cuda.synchronize() _timing["residual_build"] += _time.perf_counter() - _t0 - residual_np = residual.cpu().numpy() - pre_L_np = pre_L.cpu().numpy() - weights = residual_np.sum(axis=1) all_residual_weights.extend(weights.tolist()) - # --- Run all decoders --- - # 1. No-op: pred_obs = 0 - noop_final = pre_L_np % 2 - - # 2. Union-Find (ldpc) - _uf, _bp, _bplsd = LDPC_DECODER_NAMES - _t0 = _time.perf_counter() - uf_dec, uf_L = ldpc_decoders[_uf] - uf_obs = _decode_ldpc_batch(uf_dec, uf_L, residual_np) - uf_final = (pre_L_np + uf_obs) % 2 - _timing["uf_decode"] += _time.perf_counter() - _t0 - - # 3. BP-only (no LSD fallback) - _t0 = _time.perf_counter() - bp_dec, bp_L = ldpc_decoders[_bp] - bp_obs = _decode_ldpc_batch(bp_dec, bp_L, residual_np) - bp_final = (pre_L_np + bp_obs) % 2 - _timing["bp_only_decode"] += _time.perf_counter() - _t0 - - # 4. BP+LSD-0 (ldpc) - _t0 = _time.perf_counter() - bplsd_dec, bplsd_L = ldpc_decoders[_bplsd] - bplsd_obs = _decode_ldpc_batch(bplsd_dec, bplsd_L, residual_np) - bplsd_final = (pre_L_np + bplsd_obs) % 2 - _timing["bplsd_decode"] += _time.perf_counter() - _t0 - - # 5. Uncorrelated PyMatching - _t0 = _time.perf_counter() - uncorr_pred = _decode_batch(matcher_uncorr, residual_np, False) - uncorr_pred = np.asarray(uncorr_pred, dtype=np.int64).reshape(-1) - uncorr_final = (pre_L_np + uncorr_pred) % 2 - _timing["uncorr_pm"] += _time.perf_counter() - _t0 - - # 6. Correlated PyMatching - _t0 = _time.perf_counter() - corr_pred = _decode_batch(matcher_corr, residual_np, True) - corr_pred = np.asarray(corr_pred, dtype=np.int64).reshape(-1) - corr_final = (pre_L_np + corr_pred) % 2 - _timing["corr_pm"] += _time.perf_counter() - _t0 - - # 7. cudaq-qec GPU-accelerated decoders - cudaq_finals = {} - for cn in cudaq_decoder_names: - _t0 = _time.perf_counter() - cdec, cL = cudaq_decoders[cn] - c_obs, c_stats = _decode_cudaq_batch(cdec, cL, residual_np) - c_final = (pre_L_np + c_obs) % 2 - cudaq_finals[cn] = c_final - _timing[f"{cn}_decode"] += _time.perf_counter() - _t0 - # Accumulate per-sample convergence, iteration, and error stats - conv_flags = c_stats["converged_flags"] - iters = c_stats["iter_counts"] - fails = (c_final != gt_obs_np) - _cudaq_stats[cn]["converged_flags"].append(conv_flags) - _cudaq_stats[cn]["iter_counts"].append(iters) - _cudaq_stats[cn]["error_flags"].append(fails) - - _t0 = _time.perf_counter() - all_finals = { - DECODER_NAMES[0]: noop_final, - _uf: uf_final, - _bp: bp_final, - _bplsd: bplsd_final, - DECODER_NAMES[4]: uncorr_final, - DECODER_NAMES[5]: corr_final, - } - all_finals.update(cudaq_finals) - + # All decoder runs + all_finals, n_agree = _run_decoders_on_batch( + residual_np, + pre_L_np, + weights, + ldpc_decoders, + cudaq_decoders, + matcher_uncorr, + matcher_corr, + cudaq_decoder_names, + decoder_names, + gt_obs_np, + _timing, + _cudaq_stats, + weight_bucket_stats, + ) for name in decoder_names: - fails = all_finals[name] != gt_obs_np - decoder_errors[name] += int(fails.sum()) - - stacked = np.stack([all_finals[n] for n in decoder_names], axis=0) # (n_decoders, B) - agree = np.all(stacked == stacked[0:1], axis=0) # (B,) - n_all_agree += int(agree.sum()) - - for i in range(B): - w = int(weights[i]) - bucket = w if w <= 6 else 7 # 0-6, 7+ - if bucket not in weight_bucket_stats: - weight_bucket_stats[bucket] = {n: [0, 0] for n in decoder_names} - weight_bucket_stats[bucket]["_total"] = weight_bucket_stats[bucket].get("_total", 0) + 1 - for name in decoder_names: - if bucket not in weight_bucket_stats or name not in weight_bucket_stats[bucket]: - weight_bucket_stats[bucket][name] = [0, 0] - weight_bucket_stats[bucket][name][1] += 1 - if all_finals[name][i] != gt_obs_np[i]: - weight_bucket_stats[bucket][name][0] += 1 - - _timing["bookkeeping"] += _time.perf_counter() - _t0 + decoder_errors[name] += int((all_finals[name] != gt_obs_np).sum()) + n_all_agree += n_agree total_scanned += B if dist.rank == 0 and (batch_idx + 1) % 5 == 0: print(f" [Ablation] Processed {total_scanned} samples...") - # --- Print timing breakdown --- - if dist.rank == 0: - _total_time = sum(_timing.values()) - print(f"\n{'='*60}") - print(f"TIMING BREAKDOWN (total loop = {_total_time:.2f}s)") - print(f"{'='*60}") - for k, v in sorted(_timing.items(), key=lambda x: -x[1]): - pct = v / max(_total_time, 1e-9) * 100 - print(f" {k:<20s} {v:8.2f}s ({pct:5.1f}%)") - print(f"{'='*60}") - - # --- Print summary --- if dist.rank == 0: - print(f"\n{'='*70}") - print( - f"DECODER ABLATION STUDY | basis={basis} d={D} r={cfg.n_rounds}" - f" p={getattr(cfg.test, 'p_error', 0.003)}" - ) - print(f"{'='*70}") - print(f"Total samples: {total_scanned}") - - baseline_ler = baseline_errors / max(1, total_scanned) - print(f"\n--- Logical Error Rates ---") - print( - f" {'Baseline (no pre-dec)':<25s} LER = {baseline_ler:.6f}" - f" ({baseline_errors} errors)" - ) - for name in decoder_names: - ler = decoder_errors[name] / max(1, total_scanned) - print(f" {name:<25s} LER = {ler:.6f} ({decoder_errors[name]} errors)") - - # cudaq decoder convergence and iteration stats - if _cudaq_stats: - print(f"\n--- cudaq-qec BP Convergence & Iteration Breakdown ---") - print( - f" {'Decoder':<20s} {'Conv%':>7s} {'AvgIt':>6s} " - f"{'Conv.It':>8s} {'Conv.LER':>9s} {'Conv.Err':>9s} " - f"{'!Conv.It':>8s} {'!Conv.LER':>10s} {'!Conv.Err':>10s}" - ) - for cn in cudaq_decoder_names: - st = _cudaq_stats[cn] - conv_all = np.concatenate(st["converged_flags"]) - iters_all = np.concatenate(st["iter_counts"]) - errs_all = np.concatenate(st["error_flags"]) - N = len(conv_all) - n_conv = int(conv_all.sum()) - n_noconv = N - n_conv - conv_pct = n_conv / max(1, N) * 100 - has_iters = iters_all.sum() > 0 - - # Converged subset - if n_conv > 0 and has_iters: - conv_avg_it = iters_all[conv_all].mean() - conv_ler = errs_all[conv_all].mean() - conv_errs = int(errs_all[conv_all].sum()) - else: - conv_avg_it = conv_ler = 0.0 - conv_errs = 0 - - # Non-converged subset - if n_noconv > 0 and has_iters: - noconv_avg_it = iters_all[~conv_all].mean() - noconv_ler = errs_all[~conv_all].mean() - noconv_errs = int(errs_all[~conv_all].sum()) - else: - noconv_avg_it = noconv_ler = 0.0 - noconv_errs = 0 - - if has_iters: - avg_it_str = f"{iters_all.mean():5.1f}" - conv_it_str = f"{conv_avg_it:7.1f}" - noconv_it_str = f"{noconv_avg_it:7.1f}" if n_noconv > 0 else " N/A" - else: - avg_it_str = " N/A" - conv_it_str = " N/A" - noconv_it_str = " N/A" - - noconv_ler_str = f"{noconv_ler:9.6f}" if n_noconv > 0 else " N/A" - noconv_err_str = f"{noconv_errs:>9d}" if n_noconv > 0 else " N/A" - - print( - f" {cn:<20s} {conv_pct:>6.1f}% {avg_it_str} " - f"{conv_it_str} {conv_ler:>9.6f} {conv_errs:>9d} " - f"{noconv_it_str} {noconv_ler_str} {noconv_err_str}" - ) - - agreement_rate = n_all_agree / max(1, total_scanned) - print(f"\n--- Decoder Agreement ---") - print( - f" All {len(decoder_names)} decoders agree:" - f" {agreement_rate*100:.2f}% ({n_all_agree}/{total_scanned})" + _print_ablation_results( + basis, + D, + cfg, + total_scanned, + baseline_errors, + decoder_errors, + decoder_names, + cudaq_decoder_names, + _cudaq_stats, + n_all_agree, + all_residual_weights, + weight_bucket_stats, + _timing, ) - weights_arr = np.array(all_residual_weights) - print(f"\n--- Residual Weight Distribution ---") - for w in sorted(weight_bucket_stats.keys()): - label = f"{w}+" if w == 7 else str(w) - count = weight_bucket_stats[w].get("_total", 0) - pct = count / max(1, total_scanned) * 100 - print(f" Weight {label:>3s}: {count:>7d} samples ({pct:6.2f}%)") - print(f" Mean weight: {weights_arr.mean():.3f}, Max: {int(weights_arr.max())}") - - print(f"\n--- Conditional LER by Residual Weight ---") - header = f" {'Weight':>7s}" - for name in decoder_names: - header += f" {name:>12s}" - print(header) - for w in sorted(weight_bucket_stats.keys()): - label = f"{w}+" if w == 7 else str(w) - row = f" {label:>7s}" - for name in decoder_names: - n_err, n_tot = weight_bucket_stats[w].get(name, [0, 0]) - if n_tot > 0: - row += f" {n_err/n_tot:>12.6f}" - else: - row += f" {'N/A':>12s}" - print(row) - print(f"{'='*70}") - - # --- Plots --- - if dist.rank == 0: - _plot_residual_weight_histogram(all_residual_weights, basis, cfg) - _plot_conditional_ler(weight_bucket_stats, decoder_names, basis, cfg) - return ( { "total_samples": total_scanned, From be74a671dffec3f88956c6bd7847583e7822eb01 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 20:08:41 -0700 Subject: [PATCH 16/20] adding unittests for refactored functions Signed-off-by: Sachin Pisal --- code/tests/test_failure_analysis.py | 349 ++++++++++++++++++++++++++++ 1 file changed, 349 insertions(+) diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 72fb09c..e715a93 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -163,6 +163,355 @@ def test_output_values_are_binary(self): ) +class TestBuildAllDecoders(unittest.TestCase): + """_build_all_decoders must return correctly typed decoder objects.""" + + def setUp(self): + from evaluation.failure_analysis import _build_all_decoders, LDPC_DECODER_NAMES + self.det_model = _make_tiny_dem() + self.result = _build_all_decoders(self.det_model, _DummyDist()) + self.LDPC_DECODER_NAMES = LDPC_DECODER_NAMES + + def test_returns_four_values(self): + self.assertEqual(len(self.result), 4) + + def test_matchers_have_decode_method(self): + matcher_corr, matcher_uncorr, _, _ = self.result + self.assertTrue(hasattr(matcher_corr, "decode")) + self.assertTrue(hasattr(matcher_uncorr, "decode")) + + def test_ldpc_decoders_contains_all_names(self): + _, _, ldpc_decoders, _ = self.result + for name in self.LDPC_DECODER_NAMES: + self.assertIn(name, ldpc_decoders) + + def test_cudaq_decoders_is_dict(self): + _, _, _, cudaq_decoders = self.result + self.assertIsInstance(cudaq_decoders, dict) + + +class TestBuildLogicalOperators(unittest.TestCase): + """_build_logical_operators must return tensors of the correct shape and values.""" + + _D = 3 + + def setUp(self): + from evaluation.failure_analysis import _build_logical_operators + self.ops = _build_logical_operators(self._D, "XV", torch.device("cpu")) + self.Hx_idx, self.Hz_idx, self.Hx_mask, self.Hz_mask, \ + self.stab_x, self.stab_z, self.Kx, self.Kz, self.Lx, self.Lz = self.ops + + def test_returns_ten_values(self): + self.assertEqual(len(self.ops), 10) + + def test_logical_operator_shapes(self): + D2 = self._D * self._D + self.assertEqual(self.Lx.shape, (1, D2)) + self.assertEqual(self.Lz.shape, (1, D2)) + + def test_logical_operators_are_binary(self): + for L in (self.Lx, self.Lz): + vals = L.unique().tolist() + self.assertTrue(all(v in (0, 1) for v in vals)) + + def test_xv_rotation_lx_row_pattern(self): + # XV rotation: Lx[0, :D] = 1, rest 0 + self.assertEqual(int(self.Lx[0, :self._D].sum()), self._D) + self.assertEqual(int(self.Lx[0, self._D:].sum()), 0) + + def test_xv_rotation_lz_column_pattern(self): + # XV rotation: Lz[0, ::D] = 1 (first column of D×D grid) + self.assertEqual(int(self.Lz[0, ::self._D].sum()), self._D) + + def test_kx_kz_are_positive_ints(self): + self.assertIsInstance(self.Kx, int) + self.assertIsInstance(self.Kz, int) + self.assertGreater(self.Kx, 0) + self.assertGreater(self.Kz, 0) + + def test_index_tensors_are_long(self): + self.assertEqual(self.Hx_idx.dtype, torch.long) + self.assertEqual(self.Hz_idx.dtype, torch.long) + + def test_mask_tensors_are_bool(self): + self.assertEqual(self.Hx_mask.dtype, torch.bool) + self.assertEqual(self.Hz_mask.dtype, torch.bool) + + +class TestModelForwardAndResidual(unittest.TestCase): + """_model_forward_and_residual must return binary arrays of the expected shape.""" + + _D = 3 + _T = 3 + _B = 4 + + def _build_inputs(self, basis="X"): + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + from evaluation.failure_analysis import _build_logical_operators + ds = QCDataPipePreDecoder_Memory_inference( + distance=self._D, + n_rounds=self._T, + num_samples=self._B, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis=basis, + code_rotation="XV", + ) + items = [ds[i] for i in range(self._B)] + x_syn_diff = torch.stack([it["x_syn_diff"] for it in items]).to(torch.int32) + z_syn_diff = torch.stack([it["z_syn_diff"] for it in items]).to(torch.int32) + trainX = torch.stack([it["trainX"] for it in items]) + + det_model = ds.circ.stim_circuit.detector_error_model( + decompose_errors=True, approximate_disjoint_errors=True + ) + surface_code = ds.circ.code + num_boundary_dets = surface_code.hx.shape[0] if basis == "X" else surface_code.hz.shape[0] + stim_dets = np.asarray(ds.dets_and_obs[:, :-1], dtype=np.uint8) + baseline_detectors_batch = stim_dets[:self._B] + + ops = _build_logical_operators(self._D, "XV", torch.device("cpu")) + Hx_idx, Hz_idx, Hx_mask, Hz_mask, stab_x, stab_z, Kx, Kz, Lx, Lz = ops + return dict( + x_syn_diff=x_syn_diff, + z_syn_diff=z_syn_diff, + trainX=trainX, + det_model=det_model, + num_boundary_dets=num_boundary_dets, + baseline_detectors_batch=baseline_detectors_batch, + Hx_idx=Hx_idx, + Hz_idx=Hz_idx, + Hx_mask=Hx_mask, + Hz_mask=Hz_mask, + stab_x=stab_x, + stab_z=stab_z, + Kx=Kx, + Kz=Kz, + Lx=Lx, + Lz=Lz, + ) + + def _call(self, basis="X"): + import types + from evaluation.failure_analysis import _model_forward_and_residual + inp = self._build_inputs(basis) + _, _, T = inp["x_syn_diff"].shape + cfg = types.SimpleNamespace(enable_fp16=False) + device = torch.device("cpu") + return _model_forward_and_residual( + _ZeroModel(), + inp["trainX"], + inp["x_syn_diff"], + inp["z_syn_diff"], + basis, + self._B, + self._D * self._D, + T, + inp["Hx_idx"], + inp["Hz_idx"], + inp["Hx_mask"], + inp["Hz_mask"], + inp["Kx"], + inp["Kz"], + inp["stab_x"], + inp["stab_z"], + inp["Lx"], + inp["Lz"], + 0.0, + 0.0, + "threshold", + 1.0, + 1.0, + cfg, + device, + inp["num_boundary_dets"], + inp["baseline_detectors_batch"], + inp["det_model"], + ) + + def test_output_shapes(self): + inp = self._build_inputs() + residual_np, pre_L_np = self._call() + self.assertEqual(residual_np.shape, (self._B, inp["det_model"].num_detectors)) + self.assertEqual(pre_L_np.shape, (self._B,)) + + def test_residual_is_binary_uint8(self): + residual_np, _ = self._call() + self.assertEqual(residual_np.dtype, np.uint8) + self.assertTrue(np.all((residual_np == 0) | (residual_np == 1))) + + def test_pre_l_is_binary(self): + _, pre_L_np = self._call() + self.assertTrue(np.all((pre_L_np == 0) | (pre_L_np == 1))) + + def test_z_basis_output_shapes(self): + inp = self._build_inputs("Z") + residual_np, pre_L_np = self._call("Z") + self.assertEqual(residual_np.shape, (self._B, inp["det_model"].num_detectors)) + self.assertEqual(pre_L_np.shape, (self._B,)) + + +class TestRunDecodersOnBatch(unittest.TestCase): + """_run_decoders_on_batch must return binary finals for every decoder and a valid agreement count.""" + + _D = 3 + _T = 3 + _B = 4 + + def setUp(self): + from evaluation.failure_analysis import ( + _build_all_decoders, + _build_logical_operators, + _model_forward_and_residual, + _run_decoders_on_batch, + DECODER_NAMES, + ) + import types + + det_model = _make_tiny_dem() + matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders = _build_all_decoders( + det_model, _DummyDist() + ) + self.decoder_names = list(DECODER_NAMES) + self.cudaq_decoder_names = sorted(cudaq_decoders.keys()) + self.decoder_names += self.cudaq_decoder_names + + from data.datapipe_stim import QCDataPipePreDecoder_Memory_inference + ds = QCDataPipePreDecoder_Memory_inference( + distance=self._D, + n_rounds=self._T, + num_samples=self._B, + error_mode="circuit_level_surface_custom", + p_error=0.01, + measure_basis="X", + code_rotation="XV", + ) + items = [ds[i] for i in range(self._B)] + x_syn_diff = torch.stack([it["x_syn_diff"] for it in items]).to(torch.int32) + z_syn_diff = torch.stack([it["z_syn_diff"] for it in items]).to(torch.int32) + trainX = torch.stack([it["trainX"] for it in items]) + stim_dets = np.asarray(ds.dets_and_obs[:, :-1], dtype=np.uint8) + stim_obs = np.asarray(ds.dets_and_obs[:, -1:], dtype=np.uint8) + baseline_detectors_batch = stim_dets[:self._B] + num_boundary_dets = ds.circ.code.hx.shape[0] + _, _, T = x_syn_diff.shape + ops = _build_logical_operators(self._D, "XV", torch.device("cpu")) + Hx_idx, Hz_idx, Hx_mask, Hz_mask, stab_x, stab_z, Kx, Kz, Lx, Lz = ops + cfg = types.SimpleNamespace(enable_fp16=False) + device = torch.device("cpu") + residual_np, pre_L_np = _model_forward_and_residual( + _ZeroModel(), + trainX, + x_syn_diff, + z_syn_diff, + "X", + self._B, + self._D * self._D, + T, + Hx_idx, + Hz_idx, + Hx_mask, + Hz_mask, + Kx, + Kz, + stab_x, + stab_z, + Lx, + Lz, + 0.0, + 0.0, + "threshold", + 1.0, + 1.0, + cfg, + device, + num_boundary_dets, + baseline_detectors_batch, + det_model, + ) + self.residual_np = residual_np + self.pre_L_np = pre_L_np + self.weights = residual_np.sum(axis=1) + self.gt_obs_np = stim_obs[:self._B].reshape(-1).astype(np.int64) + self.ldpc_decoders = ldpc_decoders + self.cudaq_decoders = cudaq_decoders + self.matcher_uncorr = matcher_uncorr + self.matcher_corr = matcher_corr + self._fn = _run_decoders_on_batch + + def _run(self): + _timing = { + k: 0.0 for k in ( + "uf_decode", + "bp_only_decode", + "bplsd_decode", + "uncorr_pm", + "corr_pm", + "bookkeeping", + ) + } + for cn in self.cudaq_decoder_names: + _timing[f"{cn}_decode"] = 0.0 + _cudaq_stats = { + cn: { + "converged_flags": [], + "iter_counts": [], + "error_flags": [] + } for cn in self.cudaq_decoder_names + } + weight_bucket_stats = {} + all_finals, n_agree = self._fn( + self.residual_np, + self.pre_L_np, + self.weights, + self.ldpc_decoders, + self.cudaq_decoders, + self.matcher_uncorr, + self.matcher_corr, + self.cudaq_decoder_names, + self.decoder_names, + self.gt_obs_np, + _timing, + _cudaq_stats, + weight_bucket_stats, + ) + return all_finals, n_agree, _timing, weight_bucket_stats + + def test_all_decoder_keys_present(self): + all_finals, _, _, _ = self._run() + for name in self.decoder_names: + self.assertIn(name, all_finals) + + def test_finals_are_binary(self): + all_finals, _, _, _ = self._run() + for name, arr in all_finals.items(): + with self.subTest(decoder=name): + self.assertTrue(np.all((arr == 0) | (arr == 1))) + + def test_finals_have_correct_shape(self): + all_finals, _, _, _ = self._run() + for name, arr in all_finals.items(): + with self.subTest(decoder=name): + self.assertEqual(arr.shape, (self._B,)) + + def test_n_agree_within_bounds(self): + _, n_agree, _, _ = self._run() + self.assertGreaterEqual(n_agree, 0) + self.assertLessEqual(n_agree, self._B) + + def test_timing_keys_populated(self): + _, _, _timing, _ = self._run() + for key in ("uf_decode", "bp_only_decode", "bplsd_decode", "uncorr_pm", "corr_pm"): + self.assertGreaterEqual(_timing[key], 0.0) + + def test_weight_bucket_stats_populated(self): + _, _, _, weight_bucket_stats = self._run() + self.assertGreater(len(weight_bucket_stats), 0) + for bucket, stats in weight_bucket_stats.items(): + self.assertIn("_total", stats) + self.assertGreater(stats["_total"], 0) + + class TestDecoderAblationStudy(unittest.TestCase): """ Smoke test: decoder_ablation_study must complete, return expected keys, From a6a52da7f5b2fed7dec393f2e78bacf1ea936f5d Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 19 Mar 2026 20:29:02 -0700 Subject: [PATCH 17/20] tracking unavailable decoders Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 24 ++++++++++++++++++------ code/tests/test_failure_analysis.py | 29 +++++++++++++++++------------ 2 files changed, 35 insertions(+), 18 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index f40816d..4cf4a6d 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -70,6 +70,8 @@ def _build_cudaq_decoders(det_model): mem_kwargs = dict(max_iterations=100, error_rate_vec=priors_list, opt_results=opt_res) decoders = {} + # list of cudaq decoder names that failed to initialize + unavailable = [] # --- Standard BP variants (max_iterations=10) --- # Sum-product BP (no OSD) @@ -127,6 +129,7 @@ def _build_cudaq_decoders(det_model): except Exception as e: import warnings warnings.warn(f"cudaq-qec MemBP unavailable: {e}") + unavailable.extend(["cudaq-MemBP", "cudaq-MemBP+OSD"]) # --- RelayBP (max_iterations=100) --- # composition=1 (sequential relay), bp_method=3 (min-sum+dmem) @@ -156,8 +159,9 @@ def _build_cudaq_decoders(det_model): except Exception as e: import warnings warnings.warn(f"cudaq-qec RelayBP unavailable: {e}") + unavailable.append("cudaq-RelayBP") - return decoders + return decoders, unavailable def _decode_cudaq_batch(decoder, L_dense, syndromes_np): @@ -260,14 +264,17 @@ def _build_all_decoders(det_model, dist): ) ldpc_decoders = _build_ldpc_decoders(det_model) cudaq_decoders = {} + unavailable_decoders = [] try: - cudaq_decoders = _build_cudaq_decoders(det_model) + cudaq_decoders, unavailable_decoders = _build_cudaq_decoders(det_model) if dist.rank == 0: print(f"[Decoder Ablation] cudaq-qec decoders loaded: {list(cudaq_decoders.keys())}") + if unavailable_decoders: + print(f"[Decoder Ablation] cudaq-qec decoders unavailable: {unavailable_decoders}") except Exception as e: if dist.rank == 0: print(f"[Decoder Ablation] cudaq-qec decoders unavailable: {e}") - return matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders + return matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders, unavailable_decoders def _build_logical_operators(D, code_rotation, device): @@ -528,6 +535,7 @@ def _print_ablation_results( decoder_errors, decoder_names, cudaq_decoder_names, + unavailable_decoders, _cudaq_stats, n_all_agree, all_residual_weights, @@ -561,6 +569,9 @@ def _print_ablation_results( for name in decoder_names: ler = decoder_errors[name] / max(1, total_scanned) print(f" {name:<25s} LER = {ler:.6f} ({decoder_errors[name]} errors)") + if unavailable_decoders: + for name in unavailable_decoders: + print(f" {name:<25s} LER = {'N/A':>13s} (unavailable)") # cudaq decoder convergence and iteration stats if _cudaq_stats: @@ -719,9 +730,8 @@ def decoder_ablation_study(model, device, dist, cfg): ) # --- Decoders --- - matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders = _build_all_decoders( - det_model, dist - ) + matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders, unavailable_decoders = \ + _build_all_decoders(det_model, dist) cudaq_decoder_names = sorted(cudaq_decoders.keys()) decoder_names = list(DECODER_NAMES) + cudaq_decoder_names @@ -901,6 +911,7 @@ def decoder_ablation_study(model, device, dist, cfg): decoder_errors, decoder_names, cudaq_decoder_names, + unavailable_decoders, _cudaq_stats, n_all_agree, all_residual_weights, @@ -917,6 +928,7 @@ def decoder_ablation_study(model, device, dist, cfg): "baseline_weights": all_baseline_weights, "weight_bucket_stats": weight_bucket_stats, "agreement_count": n_all_agree, + "unavailable_decoders": unavailable_decoders, } if dist.rank == 0 else {} ) diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index e715a93..99fae37 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -172,23 +172,27 @@ def setUp(self): self.result = _build_all_decoders(self.det_model, _DummyDist()) self.LDPC_DECODER_NAMES = LDPC_DECODER_NAMES - def test_returns_four_values(self): - self.assertEqual(len(self.result), 4) + def test_returns_five_values(self): + self.assertEqual(len(self.result), 5) def test_matchers_have_decode_method(self): - matcher_corr, matcher_uncorr, _, _ = self.result + matcher_corr, matcher_uncorr, _, _, _ = self.result self.assertTrue(hasattr(matcher_corr, "decode")) self.assertTrue(hasattr(matcher_uncorr, "decode")) def test_ldpc_decoders_contains_all_names(self): - _, _, ldpc_decoders, _ = self.result + _, _, ldpc_decoders, _, _ = self.result for name in self.LDPC_DECODER_NAMES: self.assertIn(name, ldpc_decoders) def test_cudaq_decoders_is_dict(self): - _, _, _, cudaq_decoders = self.result + _, _, _, cudaq_decoders, _ = self.result self.assertIsInstance(cudaq_decoders, dict) + def test_unavailable_decoders_is_list(self): + _, _, _, _, unavailable = self.result + self.assertIsInstance(unavailable, list) + class TestBuildLogicalOperators(unittest.TestCase): """_build_logical_operators must return tensors of the correct shape and values.""" @@ -369,7 +373,7 @@ def setUp(self): import types det_model = _make_tiny_dem() - matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders = _build_all_decoders( + matcher_corr, matcher_uncorr, ldpc_decoders, cudaq_decoders, _ = _build_all_decoders( det_model, _DummyDist() ) self.decoder_names = list(DECODER_NAMES) @@ -555,6 +559,7 @@ def test_return_keys_present(self): "residual_weights", "weight_bucket_stats", "agreement_count", + "unavailable_decoders", ): self.assertIn(key, result, f"Missing key in result: {key}") @@ -706,7 +711,7 @@ def test_standard_bp_decoders_present(self): det_model = _make_tiny_dem() mock_cudaq = self._make_mock_cudaq_qec(n_bits=10) with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): - decoders = _build_cudaq_decoders(det_model) + decoders, _ = _build_cudaq_decoders(det_model) for name in ("cudaq-BP", "cudaq-MinSum", "cudaq-BP+OSD-0", "cudaq-BP+OSD-7"): self.assertIn(name, decoders, f"Missing decoder key: {name}") @@ -715,7 +720,7 @@ def test_each_entry_is_decoder_and_l_dense_pair(self): det_model = _make_tiny_dem() mock_cudaq = self._make_mock_cudaq_qec(n_bits=10) with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): - decoders = _build_cudaq_decoders(det_model) + decoders, _ = _build_cudaq_decoders(det_model) for name, (dec, L_dense) in decoders.items(): with self.subTest(decoder=name): self.assertTrue(hasattr(dec, "decode"), f"{name} has no .decode()") @@ -728,7 +733,7 @@ def test_l_dense_columns_consistent_across_decoders(self): det_model = _make_tiny_dem() mock_cudaq = self._make_mock_cudaq_qec(n_bits=10) with patch.dict("sys.modules", {"cudaq_qec": mock_cudaq}): - decoders = _build_cudaq_decoders(det_model) + decoders, _ = _build_cudaq_decoders(det_model) widths = [v[1].shape[1] for v in decoders.values()] self.assertEqual(len(set(widths)), 1, "All L_dense must have the same column count") @@ -751,7 +756,7 @@ def flaky_get_decoder(name, H, **kw): import warnings with warnings.catch_warnings(record=True): warnings.simplefilter("always") - decoders = _build_cudaq_decoders(det_model) + decoders, _ = _build_cudaq_decoders(det_model) # At minimum the 4 standard decoders should be present self.assertGreaterEqual(len(decoders), 4) for name in ("cudaq-BP", "cudaq-MinSum", "cudaq-BP+OSD-0", "cudaq-BP+OSD-7"): @@ -802,7 +807,7 @@ def test_cudaq_decoder_keys_appear_in_results_when_available(self): ) with patch("data.factory.DatapipeFactory") as mock_factory, \ patch("evaluation.failure_analysis._build_cudaq_decoders", - return_value=dummy_cudaq_decoders): + return_value=(dummy_cudaq_decoders, [])): mock_factory.create_datapipe_inference.return_value = real_ds result = decoder_ablation_study(_ZeroModel(), _DummyDist.device, _DummyDist(), cfg) @@ -830,7 +835,7 @@ def test_cudaq_error_counts_are_non_negative(self): ) with patch("data.factory.DatapipeFactory") as mock_factory, \ patch("evaluation.failure_analysis._build_cudaq_decoders", - return_value=dummy_cudaq_decoders): + return_value=(dummy_cudaq_decoders, [])): mock_factory.create_datapipe_inference.return_value = real_ds result = decoder_ablation_study(_ZeroModel(), _DummyDist.device, _DummyDist(), cfg) From 443dfa880e7d0ff3c176384eacc4edbf883c69f7 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Fri, 20 Mar 2026 14:18:07 -0700 Subject: [PATCH 18/20] adding BP variants in try/except block Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 53 ++++++++++++++++------------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 4cf4a6d..f973b90 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -74,30 +74,35 @@ def _build_cudaq_decoders(det_model): unavailable = [] # --- Standard BP variants (max_iterations=10) --- - # Sum-product BP (no OSD) - decoders["cudaq-BP"] = ( - cudaq_qec.get_decoder("nv-qldpc-decoder", H_dense, bp_method=0, use_osd=0, **bp_kwargs), - L_dense, - ) - # Min-sum BP (no OSD) - decoders["cudaq-MinSum"] = ( - cudaq_qec.get_decoder("nv-qldpc-decoder", H_dense, bp_method=1, use_osd=0, **bp_kwargs), - L_dense, - ) - # Sum-product BP + OSD-0 - decoders["cudaq-BP+OSD-0"] = ( - cudaq_qec.get_decoder( - "nv-qldpc-decoder", H_dense, bp_method=0, use_osd=1, osd_order=0, **bp_kwargs - ), - L_dense, - ) - # Sum-product BP + OSD-7 - decoders["cudaq-BP+OSD-7"] = ( - cudaq_qec.get_decoder( - "nv-qldpc-decoder", H_dense, bp_method=0, use_osd=1, osd_order=7, **bp_kwargs - ), - L_dense, - ) + try: + # Sum-product BP (no OSD) + decoders["cudaq-BP"] = ( + cudaq_qec.get_decoder("nv-qldpc-decoder", H_dense, bp_method=0, use_osd=0, **bp_kwargs), + L_dense, + ) + # Min-sum BP (no OSD) + decoders["cudaq-MinSum"] = ( + cudaq_qec.get_decoder("nv-qldpc-decoder", H_dense, bp_method=1, use_osd=0, **bp_kwargs), + L_dense, + ) + # Sum-product BP + OSD-0 + decoders["cudaq-BP+OSD-0"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", H_dense, bp_method=0, use_osd=1, osd_order=0, **bp_kwargs + ), + L_dense, + ) + # Sum-product BP + OSD-7 + decoders["cudaq-BP+OSD-7"] = ( + cudaq_qec.get_decoder( + "nv-qldpc-decoder", H_dense, bp_method=0, use_osd=1, osd_order=7, **bp_kwargs + ), + L_dense, + ) + except Exception as e: + import warnings + warnings.warn(f"cudaq-qec BP unavailable: {e}") + unavailable.extend(["cudaq-BP", "cudaq-MinSum", "cudaq-BP+OSD-0", "cudaq-BP+OSD-7"]) # --- Memory BP variants (max_iterations=100) --- try: From 98371430c8dda8f6c192b9431d19c5dab017466f Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Fri, 20 Mar 2026 14:27:18 -0700 Subject: [PATCH 19/20] removing redundant check Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index f973b90..770fb4a 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -520,7 +520,7 @@ def _run_decoders_on_batch( weight_bucket_stats[bucket] = {n: [0, 0] for n in decoder_names} weight_bucket_stats[bucket]["_total"] = weight_bucket_stats[bucket].get("_total", 0) + 1 for name in decoder_names: - if bucket not in weight_bucket_stats or name not in weight_bucket_stats[bucket]: + if name not in weight_bucket_stats[bucket]: weight_bucket_stats[bucket][name] = [0, 0] weight_bucket_stats[bucket][name][1] += 1 if all_finals[name][i] != gt_obs_np[i]: From eaed4cc816a2a44f9dd5e0a4dc2f77c5a8c75757 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Fri, 20 Mar 2026 18:43:34 -0700 Subject: [PATCH 20/20] adding modules to install to requirements Signed-off-by: Sachin Pisal --- code/requirements_public_inference.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/code/requirements_public_inference.txt b/code/requirements_public_inference.txt index 20c7301..a86d3c2 100644 --- a/code/requirements_public_inference.txt +++ b/code/requirements_public_inference.txt @@ -19,6 +19,9 @@ stim pymatching matplotlib safetensors>=0.4.0 +scipy +ldpc +beliefmatching # Optional GPU-only prerequisite (not pip-installed here due to size and CUDA dependency): # tensorrt -- required for ONNX_WORKFLOW=2 (EXPORT_AND_USE_TRT) and ONNX_WORKFLOW=3 # (USE_ENGINE_ONLY). Install via: pip install tensorrt