From b32c8a51cf9badac3d255d800ab68de39ffbe7a7 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Wed, 25 Mar 2026 14:37:04 -0700 Subject: [PATCH 1/3] adding decode_batch path in failure_analysis and vectorizing observable projection Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 49 +++++++++++--------- code/tests/test_failure_analysis.py | 69 +++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+), 20 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 4f9a4b1..c6614fe 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -179,27 +179,39 @@ def _build_cudaq_decoders(det_model): def _decode_cudaq_batch(decoder, L_dense, syndromes_np): """ - Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder (single-shot loop). + Decode a batch of syndromes with a cudaq-qec nv-qldpc-decoder. Returns (obs, stats) where: - obs: observable predictions as np.ndarray of shape (B,) - stats: dict with per-sample convergence flags, iteration counts - The decoder.decode() takes list[float] and returns DecoderResult with .result (list[float]). """ B = syndromes_np.shape[0] - obs = np.zeros(B, dtype=np.uint8) + n_bits = L_dense.shape[1] converged_flags = np.zeros(B, dtype=bool) iter_counts = np.zeros(B, dtype=np.int32) - for i in range(B): - syndrome_list = syndromes_np[i].astype(np.float64).tolist() - result = decoder.decode(syndrome_list) - correction = np.array(result.result, dtype=np.uint8) - obs[i] = int((L_dense @ correction).item() % - 2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2) + corrections = np.empty((B, n_bits), dtype=np.uint8) + syndromes_f64 = np.ascontiguousarray(syndromes_np, dtype=np.float64) + + def _unpack(i, result): + corrections[i] = np.array(result.result, dtype=np.uint8) converged_flags[i] = result.converged - # Collect iteration count if available via opt_results opt = getattr(result, 'opt_results', None) if opt and isinstance(opt, dict) and 'num_iter' in opt: iter_counts[i] = opt['num_iter'] + + def _loop_decode(): + for i in range(B): + _unpack(i, decoder.decode(syndromes_f64[i].tolist())) + + if hasattr(decoder, 'decode_batch'): + try: + for i, result in enumerate(decoder.decode_batch(syndromes_f64)): + _unpack(i, result) + except Exception: + _loop_decode() + else: + _loop_decode() + + obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8) return obs, {"converged_flags": converged_flags, "iter_counts": iter_counts} @@ -249,20 +261,17 @@ def _build_ldpc_decoders(det_model): def _decode_ldpc_batch(decoder, L_dense, syndromes_np): """ - Decode a batch of syndromes with an ldpc decoder (single-shot loop). + Decode a batch of syndromes with an ldpc decoder. Returns observable predictions as np.ndarray of shape (B,). """ B = syndromes_np.shape[0] - obs = np.zeros(B, dtype=np.uint8) + n_bits = L_dense.shape[1] + syndromes_c = np.ascontiguousarray(syndromes_np, dtype=np.uint8) + corrections = np.empty((B, n_bits), dtype=np.uint8) for i in range(B): - # Get the most-likely error configuration from the decoder for this syndrome. - correction = decoder.decode(syndromes_np[i]) - # Project the correction onto the logical observable via L_dense (mod 2). - # L_dense has shape (num_obs, num_errors); the first observable row is used. - obs[i] = ( - int((L_dense @ correction).item() % - 2) if L_dense.shape[0] == 1 else int((L_dense @ correction)[0] % 2) - ) + corrections[i] = decoder.decode(syndromes_c[i]) + + obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8) return obs diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 36ff37c..48707b8 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -635,6 +635,21 @@ def decode(self, syndrome): return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) +class _DummyCudaqDecoderBatch: + """Mock cudaq-qec decoder that exposes decode_batch() for the fast path""" + + def __init__(self, n_bits): + self._n_bits = n_bits + + def decode(self, syndrome): + return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) + + def decode_batch(self, syndromes): + """Accept (B, n_dets) float64 array, return list of DecoderResults""" + B = syndromes.shape[0] + return [_DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) for _ in range(B)] + + class TestDecodeCudaqBatch(unittest.TestCase): """_decode_cudaq_batch must return correct shape/dtype and collect stats""" @@ -701,6 +716,60 @@ def test_multi_observable_uses_first_row(self): self.assertEqual(obs.shape, (B,)) self.assertTrue(np.all((obs == 0) | (obs == 1))) + def test_decode_batch_fast_path_zero_syndrome(self): + B = 4 + decoder = _DummyCudaqDecoderBatch(self.n_bits) + L_dense = np.zeros((1, self.n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + obs, _ = self._fn(decoder, L_dense, syndromes) + np.testing.assert_array_equal(obs, np.zeros(B, dtype=np.uint8)) + + def test_decode_batch_fast_path_output_shape_and_dtype(self): + for B in (1, 5): + decoder = _DummyCudaqDecoderBatch(self.n_bits) + L_dense = np.zeros((1, self.n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + obs, stats = self._fn(decoder, L_dense, syndromes) + self.assertEqual(obs.shape, (B,)) + self.assertEqual(obs.dtype, np.uint8) + self.assertEqual(stats["converged_flags"].shape, (B,)) + self.assertEqual(stats["iter_counts"].shape, (B,)) + + def test_decode_batch_fast_path_convergence_flags(self): + B = 3 + decoder = _DummyCudaqDecoderBatch(self.n_bits) + L_dense = np.zeros((1, self.n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + _, stats = self._fn(decoder, L_dense, syndromes) + self.assertTrue(np.all(stats["converged_flags"])) + np.testing.assert_array_equal(stats["iter_counts"], np.full(B, 10, dtype=np.int32)) + + def test_decode_batch_and_loop_paths_agree(self): + B = 4 + n_bits = self.n_bits + L_dense = np.zeros((1, n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + + loop_decoder = _DummyCudaqDecoder(n_bits) + batch_decoder = _DummyCudaqDecoderBatch(n_bits) + + obs_loop, stats_loop = self._fn(loop_decoder, L_dense, syndromes) + obs_batch, stats_batch = self._fn(batch_decoder, L_dense, syndromes) + + np.testing.assert_array_equal(obs_loop, obs_batch) + np.testing.assert_array_equal(stats_loop["converged_flags"], stats_batch["converged_flags"]) + np.testing.assert_array_equal(stats_loop["iter_counts"], stats_batch["iter_counts"]) + + def test_decode_batch_called_not_decode(self): + from unittest.mock import patch + B = 3 + decoder = _DummyCudaqDecoderBatch(self.n_bits) + L_dense = np.zeros((1, self.n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + with patch.object(decoder, 'decode', wraps=decoder.decode) as mock_decode: + self._fn(decoder, L_dense, syndromes) + mock_decode.assert_not_called() + class TestBuildCudaqDecoders(unittest.TestCase): """_build_cudaq_decoders must return correctly keyed entries when cudaq_qec is available""" From f737989494fe5a2ee77ebcd296290d51f8a075d5 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Thu, 2 Apr 2026 21:20:19 -0700 Subject: [PATCH 2/3] pass syndromes as list-of-lists to cudaq decode_batch Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 2 +- code/tests/test_failure_analysis.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index c6614fe..8fc215a 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -204,7 +204,7 @@ def _loop_decode(): if hasattr(decoder, 'decode_batch'): try: - for i, result in enumerate(decoder.decode_batch(syndromes_f64)): + for i, result in enumerate(decoder.decode_batch(syndromes_f64.tolist())): _unpack(i, result) except Exception: _loop_decode() diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index 48707b8..bb72e18 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -645,8 +645,8 @@ def decode(self, syndrome): return _DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) def decode_batch(self, syndromes): - """Accept (B, n_dets) float64 array, return list of DecoderResults""" - B = syndromes.shape[0] + """Accept list-of-lists of float64, return list of DecoderResults""" + B = len(syndromes) return [_DummyCudaqResult(np.zeros(self._n_bits, dtype=np.float64)) for _ in range(B)] From 7459688a2a735893352f564353780a7e86cdc416 Mon Sep 17 00:00:00 2001 From: Sachin Pisal Date: Fri, 3 Apr 2026 13:47:43 -0700 Subject: [PATCH 3/3] implementing feedback Signed-off-by: Sachin Pisal --- code/evaluation/failure_analysis.py | 15 ++++++------ code/tests/test_failure_analysis.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 7 deletions(-) diff --git a/code/evaluation/failure_analysis.py b/code/evaluation/failure_analysis.py index 8fc215a..56e476a 100644 --- a/code/evaluation/failure_analysis.py +++ b/code/evaluation/failure_analysis.py @@ -18,6 +18,7 @@ """ import os import random +import warnings import numpy as np import torch @@ -202,14 +203,14 @@ def _loop_decode(): for i in range(B): _unpack(i, decoder.decode(syndromes_f64[i].tolist())) - if hasattr(decoder, 'decode_batch'): - try: - for i, result in enumerate(decoder.decode_batch(syndromes_f64.tolist())): - _unpack(i, result) - except Exception: - _loop_decode() - else: + try: + results = decoder.decode_batch(syndromes_f64.tolist()) + except Exception as exc: + warnings.warn(f"decode_batch failed ({exc}); falling back to per-sample loop") _loop_decode() + else: + for i, result in enumerate(results): + _unpack(i, result) obs = ((corrections.astype(np.int32) @ L_dense.T.astype(np.int32))[:, 0] % 2).astype(np.uint8) return obs, {"converged_flags": converged_flags, "iter_counts": iter_counts} diff --git a/code/tests/test_failure_analysis.py b/code/tests/test_failure_analysis.py index bb72e18..a29e6d5 100644 --- a/code/tests/test_failure_analysis.py +++ b/code/tests/test_failure_analysis.py @@ -770,6 +770,42 @@ def test_decode_batch_called_not_decode(self): self._fn(decoder, L_dense, syndromes) mock_decode.assert_not_called() + def test_decode_batch_exception_falls_back_to_loop(self): + """If decode_batch raises, per-sample decode is used and a warning is emitted.""" + import warnings + from unittest.mock import patch + B = 3 + decoder = _DummyCudaqDecoderBatch(self.n_bits) + L_dense = np.zeros((1, self.n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + with patch.object(decoder, 'decode_batch', side_effect=RuntimeError("gpu unavailable")): + with patch.object(decoder, 'decode', wraps=decoder.decode) as mock_decode: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + obs, stats = self._fn(decoder, L_dense, syndromes) + self.assertEqual(mock_decode.call_count, B) + self.assertEqual(obs.shape, (B,)) + self.assertEqual(len(caught), 1) + self.assertIn("gpu unavailable", str(caught[0].message)) + self.assertIn("falling back", str(caught[0].message)) + + def test_no_decode_batch_attribute_uses_loop(self): + """Decoder without decode_batch falls back to per-sample loop via AttributeError.""" + import warnings + from unittest.mock import patch + B = 3 + decoder = _DummyCudaqDecoder(self.n_bits) # no decode_batch + L_dense = np.zeros((1, self.n_bits), dtype=np.uint8) + syndromes = np.zeros((B, self.n_dets), dtype=np.uint8) + with patch.object(decoder, 'decode', wraps=decoder.decode) as mock_decode: + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + obs, stats = self._fn(decoder, L_dense, syndromes) + self.assertEqual(mock_decode.call_count, B) + self.assertEqual(obs.shape, (B,)) + self.assertEqual(len(caught), 1) + self.assertIn("falling back", str(caught[0].message)) + class TestBuildCudaqDecoders(unittest.TestCase): """_build_cudaq_decoders must return correctly keyed entries when cudaq_qec is available"""