Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 30 additions & 20 deletions code/evaluation/failure_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""
import os
import random
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -179,27 +180,39 @@ def _build_cudaq_decoders(det_model):

def _decode_cudaq_batch(decoder, L_dense, syndromes_np):
"""
Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder (single-shot loop).
Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder.
Returns (obs, stats) where:
- obs: observable predictions as np.ndarray of shape (B,)
- stats: dict with per-sample convergence flags, iteration counts
The decoder.decode() takes list[float] and returns DecoderResult with .result (list[float]).
"""
B = syndromes_np.shape[0]
obs = np.zeros(B, dtype=np.uint8)
n_bits = L_dense.shape[1]
converged_flags = np.zeros(B, dtype=bool)
iter_counts = np.zeros(B, dtype=np.int32)
for i in range(B):
syndrome_list = syndromes_np[i].astype(np.float64).tolist()
result = decoder.decode(syndrome_list)
correction = np.array(result.result, dtype=np.uint8)
obs[i] = int((L_dense @ correction).item() %
2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2)
corrections = np.empty((B, n_bits), dtype=np.uint8)
syndromes_f64 = np.ascontiguousarray(syndromes_np, dtype=np.float64)

def _unpack(i, result):
corrections[i] = np.array(result.result, dtype=np.uint8)
converged_flags[i] = result.converged
# Collect iteration count if available via opt_results
opt = getattr(result, 'opt_results', None)
if opt and isinstance(opt, dict) and 'num_iter' in opt:
iter_counts[i] = opt['num_iter']

def _loop_decode():
for i in range(B):
_unpack(i, decoder.decode(syndromes_f64[i].tolist()))

try:
results = decoder.decode_batch(syndromes_f64.tolist())
except Exception as exc:
warnings.warn(f"decode_batch failed ({exc}); falling back to per-sample loop")
_loop_decode()
Comment thread
sacpis marked this conversation as resolved.
else:
for i, result in enumerate(results):
_unpack(i, result)

obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8)
return obs, {"converged_flags": converged_flags, "iter_counts": iter_counts}


Expand Down Expand Up @@ -249,20 +262,17 @@ def _build_ldpc_decoders(det_model):

def _decode_ldpc_batch(decoder, L_dense, syndromes_np):
"""
Decode a batch of syndromes with an ldpc decoder (single-shot loop).
Decode a batch of syndromes with an ldpc decoder.
Returns observable predictions as np.ndarray of shape (B,).
"""
B = syndromes_np.shape[0]
obs = np.zeros(B, dtype=np.uint8)
n_bits = L_dense.shape[1]
syndromes_c = np.ascontiguousarray(syndromes_np, dtype=np.uint8)
corrections = np.empty((B, n_bits), dtype=np.uint8)
for i in range(B):
# Get the most-likely error configuration from the decoder for this syndrome.
correction = decoder.decode(syndromes_np[i])
# Project the correction onto the logical observable via L_dense (mod 2).
# L_dense has shape (num_obs, num_errors); the first observable row is used.
obs[i] = (
int((L_dense @ correction).item() %
2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2)
)
corrections[i] = decoder.decode(syndromes_c[i])

obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8)
return obs


Expand Down
105 changes: 105 additions & 0 deletions code/tests/test_failure_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,21 @@ def decode(self, syndrome):
return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64))


class _DummyCudaqDecoderBatch:
"""Mock cudaq-qec decoder that exposes decode_batch() for the fast path"""

def __init__(self, n_bits):
self._n_bits = n_bits

def decode(self, syndrome):
return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64))

def decode_batch(self, syndromes):
"""Accept list-of-lists of float64, return list of DecoderResults"""
B = len(syndromes)
return [_DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) for _ in range(B)]


class TestDecodeCudaqBatch(unittest.TestCase):
"""_decode_cudaq_batch must return correct shape/dtype and collect stats"""

Expand Down Expand Up @@ -701,6 +716,96 @@ def test_multi_observable_uses_first_row(self):
self.assertEqual(obs.shape, (B,))
self.assertTrue(np.all((obs == 0) | (obs == 1)))

def test_decode_batch_fast_path_zero_syndrome(self):
B = 4
decoder = _DummyCudaqDecoderBatch(self.n_bits)
L_dense = np.zeros((1, self.n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)
obs, _ = self._fn(decoder, L_dense, syndromes)
np.testing.assert_array_equal(obs, np.zeros(B, dtype=np.uint8))

def test_decode_batch_fast_path_output_shape_and_dtype(self):
for B in (1, 5):
decoder = _DummyCudaqDecoderBatch(self.n_bits)
L_dense = np.zeros((1, self.n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)
obs, stats = self._fn(decoder, L_dense, syndromes)
self.assertEqual(obs.shape, (B,))
self.assertEqual(obs.dtype, np.uint8)
self.assertEqual(stats["converged_flags"].shape, (B,))
self.assertEqual(stats["iter_counts"].shape, (B,))

def test_decode_batch_fast_path_convergence_flags(self):
B = 3
decoder = _DummyCudaqDecoderBatch(self.n_bits)
L_dense = np.zeros((1, self.n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)
_, stats = self._fn(decoder, L_dense, syndromes)
self.assertTrue(np.all(stats["converged_flags"]))
np.testing.assert_array_equal(stats["iter_counts"], np.full(B, 10, dtype=np.int32))

def test_decode_batch_and_loop_paths_agree(self):
B = 4
n_bits = self.n_bits
L_dense = np.zeros((1, n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)

loop_decoder = _DummyCudaqDecoder(n_bits)
batch_decoder = _DummyCudaqDecoderBatch(n_bits)

obs_loop, stats_loop = self._fn(loop_decoder, L_dense, syndromes)
obs_batch, stats_batch = self._fn(batch_decoder, L_dense, syndromes)

np.testing.assert_array_equal(obs_loop, obs_batch)
np.testing.assert_array_equal(stats_loop["converged_flags"], stats_batch["converged_flags"])
np.testing.assert_array_equal(stats_loop["iter_counts"], stats_batch["iter_counts"])

def test_decode_batch_called_not_decode(self):
from unittest.mock import patch
B = 3
decoder = _DummyCudaqDecoderBatch(self.n_bits)
L_dense = np.zeros((1, self.n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)
with patch.object(decoder, 'decode', wraps=decoder.decode) as mock_decode:
self._fn(decoder, L_dense, syndromes)
mock_decode.assert_not_called()

def test_decode_batch_exception_falls_back_to_loop(self):
"""If decode_batch raises, per-sample decode is used and a warning is emitted."""
import warnings
from unittest.mock import patch
B = 3
decoder = _DummyCudaqDecoderBatch(self.n_bits)
L_dense = np.zeros((1, self.n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)
with patch.object(decoder, 'decode_batch', side_effect=RuntimeError("gpu unavailable")):
with patch.object(decoder, 'decode', wraps=decoder.decode) as mock_decode:
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
obs, stats = self._fn(decoder, L_dense, syndromes)
self.assertEqual(mock_decode.call_count, B)
self.assertEqual(obs.shape, (B,))
self.assertEqual(len(caught), 1)
self.assertIn("gpu unavailable", str(caught[0].message))
self.assertIn("falling back", str(caught[0].message))

def test_no_decode_batch_attribute_uses_loop(self):
"""Decoder without decode_batch falls back to per-sample loop via AttributeError."""
import warnings
from unittest.mock import patch
B = 3
decoder = _DummyCudaqDecoder(self.n_bits) # no decode_batch
L_dense = np.zeros((1, self.n_bits), dtype=np.uint8)
syndromes = np.zeros((B, self.n_dets), dtype=np.uint8)
with patch.object(decoder, 'decode', wraps=decoder.decode) as mock_decode:
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
obs, stats = self._fn(decoder, L_dense, syndromes)
self.assertEqual(mock_decode.call_count, B)
self.assertEqual(obs.shape, (B,))
self.assertEqual(len(caught), 1)
self.assertIn("falling back", str(caught[0].message))


class TestBuildCudaqDecoders(unittest.TestCase):
"""_build_cudaq_decoders must return correctly keyed entries when cudaq_qec is available"""
Expand Down
Loading