From a32f4105f8a5e5bb5921d22729ca1566616243e0 Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Mon, 23 Feb 2026 14:51:30 -0800 Subject: [PATCH 1/4] PyMatching: decode to observables Signed-off-by: Ben Howe --- .../plugins/pymatching/pymatching.cpp | 63 +++++++++++++++---- libs/qec/python/tests/test_dem.py | 45 ++++++++++++- 2 files changed, 96 insertions(+), 12 deletions(-) diff --git a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp index 514adfb6..28fbde3b 100644 --- a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp +++ b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp @@ -31,6 +31,8 @@ class pymatching : public decoder { // efficient. std::map, size_t> edge2col_idx; + bool decode_to_observables = false; + // Helper function to make a canonical edge from two nodes. std::pair make_canonical_edge(int64_t node1, int64_t node2) { @@ -77,10 +79,34 @@ class pymatching : public decoder { } } + std::vector> errs2observables(block_size); + if (params.contains("O")) { + auto O = params.get>("O"); + int num_observables = O.size() / block_size; + uint8_t *O_ptr = O.data(); + if (O.size() % block_size != 0) { + throw std::runtime_error( + "O must be of size num_observables * block_size"); + } + // Convert O to a sparse matrix and save it. + std::vector> O_sparse; + for (size_t i = 0; i < num_observables; i++) { + O_sparse.emplace_back(); + const auto *row = &O_ptr[i * block_size]; + for (size_t j = 0; j < block_size; j++) { + if (row[j] > 0) { + O_sparse.back().push_back(j); + errs2observables[j].push_back(i); + } + } + } + this->set_O_sparse(O_sparse); + decode_to_observables = true; + } + user_graph = pm::UserGraph(H.shape()[0]); auto sparse = cudaq::qec::dense_to_sparse(H); - std::vector observables; std::size_t col_idx = 0; for (auto &col : sparse) { double weight = 1.0; @@ -90,12 +116,14 @@ class pymatching : public decoder { } if (col.size() == 2) { edge2col_idx[make_canonical_edge(col[0], col[1])] = col_idx; - user_graph.add_or_merge_edge(col[0], col[1], observables, weight, 0.0, + user_graph.add_or_merge_edge(col[0], col[1], + errs2observables.at(col_idx), weight, 0.0, merge_strategy_enum); } else if (col.size() == 1) { edge2col_idx[make_canonical_edge(col[0], -1)] = col_idx; - user_graph.add_or_merge_boundary_edge(col[0], observables, weight, 0.0, - merge_strategy_enum); + user_graph.add_or_merge_boundary_edge(col[0], + errs2observables.at(col_idx), + weight, 0.0, merge_strategy_enum); } else { throw std::runtime_error( "Invalid column in H: " + std::to_string(col_idx) + " has " + @@ -119,13 +147,26 @@ class pymatching : public decoder { for (size_t i = 0; i < syndrome.size(); i++) if (syndrome[i] > 0.5) detection_events.push_back(i); - pm::decode_detection_events_to_edges(mwpm, detection_events, edges); - // Loop over the edge pairs - assert(edges.size() % 2 == 0); - for (size_t i = 0; i < edges.size(); i += 2) { - auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1)); - auto col_idx = edge2col_idx.at(edge); - result.result[col_idx] = 1.0; + if (decode_to_observables) { + assert(O_sparse.size() == mwpm.flooder.graph.num_observables); + pm::total_weight_int weight = 0; + std::vector obs(mwpm.flooder.graph.num_observables, 0); + obs.resize(mwpm.flooder.graph.num_observables); + pm::decode_detection_events(mwpm, detection_events, obs.data(), weight, + /*edge_correlations=*/false); + result.result.resize(mwpm.flooder.graph.num_observables); + for (size_t i = 0; i < mwpm.flooder.graph.num_observables; i++) { + result.result[i] = static_cast(obs[i]); + } + } else { + pm::decode_detection_events_to_edges(mwpm, detection_events, edges); + // Loop over the edge pairs to reconstruct errors. + assert(edges.size() % 2 == 0); + for (size_t i = 0; i < edges.size(); i += 2) { + auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1)); + auto col_idx = edge2col_idx.at(edge); + result.result[col_idx] = 1.0; + } } // An exception is thrown if no matching solution is found, so we can just // set converged to true. diff --git a/libs/qec/python/tests/test_dem.py b/libs/qec/python/tests/test_dem.py index 48c3ee6d..69f57a77 100644 --- a/libs/qec/python/tests/test_dem.py +++ b/libs/qec/python/tests/test_dem.py @@ -1,5 +1,5 @@ # ============================================================================ # -# Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. # +# Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. # # All rights reserved. # # # # This source code and the accompanying materials are made available under # @@ -275,6 +275,49 @@ def test_decoding_from_surface_code_dem_from_memory_circuit( assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding +def test_pymatching_decode_to_observable_surface_code_dem(): + """Test PyMatching with O (observables) matrix: decoder returns observable + flips directly.cpp).""" + cudaq.set_random_seed(13) + code = qec.get_code('surface_code', distance=5) + Lz = code.get_observables_z() + p = 0.003 + noise = cudaq.NoiseModel() + noise.add_all_qubit_channel("x", cudaq.Depolarization2(p), 1) + statePrep = qec.operation.prep0 + nRounds = 5 + nShots = 2000 + + syndromes, data = qec.sample_memory_circuit(code, statePrep, nShots, + nRounds, noise) + + logical_measurements = (Lz @ data.transpose()) % 2 + logical_measurements = logical_measurements.flatten() + + syndromes = syndromes.reshape((nShots, nRounds, -1)) + syndromes = syndromes[:, :, :syndromes.shape[2] // 2] + syndromes = syndromes.reshape((nShots, -1)) + + dem = qec.z_dem_from_memory_circuit(code, statePrep, nRounds, noise) + + decoder = qec.get_decoder( + 'pymatching', + dem.detector_error_matrix, + O=dem.observables_flips_matrix, + error_rate_vec=np.array(dem.error_rates), + ) + + dr = decoder.decode_batch(syndromes) + # With decode_to_observables=True, each e.result is observable flips + # (length num_observables), not error predictions. + obs_per_shot = np.array([e.result for e in dr], dtype=np.float64) + data_predictions = np.round(obs_per_shot).astype(np.uint8).T + + nLogicalErrorsWithoutDecoding = np.sum(logical_measurements) + nLogicalErrorsWithDecoding = np.sum(data_predictions ^ logical_measurements) + assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding + + def test_pcm_extend_to_n_rounds(): # This test independently compares the functionality of dem_from_memory_circuit # (of two different numbers of rounds) to pcm_extend_to_n_rounds. From 8be40db5fb3a0366caab28c4dc1ed567a7d81d6d Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Mon, 23 Feb 2026 15:07:29 -0800 Subject: [PATCH 2/4] Use cudaqx::tensor Signed-off-by: Ben Howe --- libs/core/include/cuda-qx/core/kwargs_utils.h | 62 ++++++++++++++----- .../plugins/pymatching/pymatching.cpp | 23 ++++--- 2 files changed, 61 insertions(+), 24 deletions(-) diff --git a/libs/core/include/cuda-qx/core/kwargs_utils.h b/libs/core/include/cuda-qx/core/kwargs_utils.h index ce50f3cc..f0b0048b 100644 --- a/libs/core/include/cuda-qx/core/kwargs_utils.h +++ b/libs/core/include/cuda-qx/core/kwargs_utils.h @@ -75,22 +75,54 @@ inline heterogeneous_map hetMapFromKwargs(const py::kwargs &kwargs) { } else if (py::isinstance(value)) { py::array np_array = value.cast(); py::buffer_info info = np_array.request(); - auto insert_vector = [&](auto type_tag) { - using T = decltype(type_tag); - std::vector vec(static_cast(info.ptr), - static_cast(info.ptr) + info.size); - result.insert(key, std::move(vec)); - }; - if (info.format == py::format_descriptor::format()) { - insert_vector(double{}); - } else if (info.format == py::format_descriptor::format()) { - insert_vector(float{}); - } else if (info.format == py::format_descriptor::format()) { - insert_vector(int{}); - } else if (info.format == py::format_descriptor::format()) { - insert_vector(uint8_t{}); + if (info.ndim >= 2) { + if (info.strides[0] == static_cast(info.itemsize)) { + throw std::runtime_error( + "Array in kwargs must be in row-major order, but " + "column-major order was detected."); + } + std::vector shape(static_cast(info.ndim), + std::size_t(0)); + for (py::ssize_t d = 0; d < info.ndim; d++) + shape[d] = static_cast(info.shape[d]); + + auto insert_tensor = [&](auto type_tag) { + using T = decltype(type_tag); + cudaqx::tensor ten(shape); + ten.borrow(static_cast(info.ptr), shape); + result.insert(key, std::move(ten)); + }; + if (info.format == py::format_descriptor::format()) { + insert_tensor(double{}); + } else if (info.format == py::format_descriptor::format()) { + insert_tensor(float{}); + } else if (info.format == py::format_descriptor::format()) { + insert_tensor(int{}); + } else if (info.format == py::format_descriptor::format()) { + insert_tensor(uint8_t{}); + } else { + throw std::runtime_error("Unsupported array data type in kwargs."); + } } else { - throw std::runtime_error("Unsupported array data type in kwargs."); + // 1D array: keep as flattened vector for backward compatibility + // (e.g. error_rate_vec used by decoders). + auto insert_vector = [&](auto type_tag) { + using T = decltype(type_tag); + std::vector vec(static_cast(info.ptr), + static_cast(info.ptr) + info.size); + result.insert(key, std::move(vec)); + }; + if (info.format == py::format_descriptor::format()) { + insert_vector(double{}); + } else if (info.format == py::format_descriptor::format()) { + insert_vector(float{}); + } else if (info.format == py::format_descriptor::format()) { + insert_vector(int{}); + } else if (info.format == py::format_descriptor::format()) { + insert_vector(uint8_t{}); + } else { + throw std::runtime_error("Unsupported array data type in kwargs."); + } } } else { throw std::runtime_error( diff --git a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp index 28fbde3b..e976a9ba 100644 --- a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp +++ b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp @@ -81,22 +81,27 @@ class pymatching : public decoder { std::vector> errs2observables(block_size); if (params.contains("O")) { - auto O = params.get>("O"); - int num_observables = O.size() / block_size; - uint8_t *O_ptr = O.data(); - if (O.size() % block_size != 0) { + auto O = params.get>("O"); + if (O.rank() != 2) { throw std::runtime_error( - "O must be of size num_observables * block_size"); + "O must be a 2-dimensional tensor (num_observables x block_size)"); + } + const size_t num_observables = O.shape()[0]; + if (O.shape()[1] != block_size) { + throw std::runtime_error( + "O must be of shape (num_observables, block_size); got second " + "dimension " + + std::to_string(O.shape()[1]) + ", block_size " + + std::to_string(block_size)); } - // Convert O to a sparse matrix and save it. std::vector> O_sparse; for (size_t i = 0; i < num_observables; i++) { O_sparse.emplace_back(); - const auto *row = &O_ptr[i * block_size]; + auto *row = &O.at({i, 0}); for (size_t j = 0; j < block_size; j++) { if (row[j] > 0) { - O_sparse.back().push_back(j); - errs2observables[j].push_back(i); + O_sparse.back().push_back(static_cast(j)); + errs2observables[j].push_back(static_cast(i)); } } } From 4ea6a7097fe77bfb3e5b93d6ea2c1e63cffc972f Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Thu, 9 Apr 2026 22:34:57 -0700 Subject: [PATCH 3/4] Optimize decoding to observables and add timing - If < 64 observables, then use decode_detection_events_for_up_to_64_observables - Move where `result` is initialized to avoid unnecessary work when decoding to observables. - Introduced a debug option to measure decode times using the PERFORM_TIMING macro. - Added timing measurements at various stages of the decoding process. - Updated the destructor to print decode times for analysis. This enhancement allows for better performance analysis of the decoding process in the PyMatching library. Signed-off-by: Ben Howe --- .../plugins/pymatching/pymatching.cpp | 86 ++++++++++++++++--- 1 file changed, 72 insertions(+), 14 deletions(-) diff --git a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp index e976a9ba..f11e45c5 100644 --- a/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp +++ b/libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp @@ -15,6 +15,9 @@ #include #include +// Enable this to debug decode times. +#define PERFORM_TIMING 0 + namespace cudaq::qec { /// @brief This is a wrapper around the PyMatching library that implements the @@ -22,6 +25,7 @@ namespace cudaq::qec { class pymatching : public decoder { private: pm::UserGraph user_graph; + pm::Mwpm *mwpm = nullptr; // Input parameters std::vector error_rate_vec; @@ -39,6 +43,11 @@ class pymatching : public decoder { return std::make_pair(std::min(node1, node2), std::max(node1, node2)); } +#if PERFORM_TIMING + static constexpr size_t NUM_TIMING_STEPS = 4; + std::array decode_times; +#endif + public: pymatching(const cudaqx::tensor &H, const cudaqx::heterogeneous_map ¶ms) @@ -136,6 +145,12 @@ class pymatching : public decoder { } col_idx++; } + this->mwpm = decode_to_observables + ? &user_graph.get_mwpm() + : &user_graph.get_mwpm_with_search_graph(); +#if PERFORM_TIMING + std::fill(decode_times.begin(), decode_times.end(), 0.0); +#endif } /// @brief Decode the syndrome using the MWPM decoder. @@ -144,27 +159,48 @@ class pymatching : public decoder { /// @throws std::runtime_error if no matching solution is found, or /// std::out_of_range if an edge is not found in the edge2col_idx map. virtual decoder_result decode(const std::vector &syndrome) { - decoder_result result{false, std::vector(block_size, 0.0)}; - auto &mwpm = user_graph.get_mwpm_with_search_graph(); - std::vector edges; +#if PERFORM_TIMING + auto t0 = std::chrono::high_resolution_clock::now(); +#endif + decoder_result result{false, std::vector()}; +#if PERFORM_TIMING + auto t1 = std::chrono::high_resolution_clock::now(); +#endif + std::vector detection_events; detection_events.reserve(syndrome.size()); for (size_t i = 0; i < syndrome.size(); i++) if (syndrome[i] > 0.5) detection_events.push_back(i); +#if PERFORM_TIMING + auto t2 = std::chrono::high_resolution_clock::now(); +#endif if (decode_to_observables) { - assert(O_sparse.size() == mwpm.flooder.graph.num_observables); - pm::total_weight_int weight = 0; - std::vector obs(mwpm.flooder.graph.num_observables, 0); - obs.resize(mwpm.flooder.graph.num_observables); - pm::decode_detection_events(mwpm, detection_events, obs.data(), weight, - /*edge_correlations=*/false); - result.result.resize(mwpm.flooder.graph.num_observables); - for (size_t i = 0; i < mwpm.flooder.graph.num_observables; i++) { - result.result[i] = static_cast(obs[i]); + if (mwpm->flooder.graph.num_observables < 64) { + result.result.resize(mwpm->flooder.graph.num_observables); + auto res = pm::decode_detection_events_for_up_to_64_observables( + *mwpm, detection_events, /*edge_correlations=*/false); + for (size_t i = 0; i < mwpm->flooder.graph.num_observables; i++) { + result.result[i] = + static_cast(res.obs_mask & (1 << i) ? 1.0 : 0.0); + } + } else { + result.result.resize(mwpm->flooder.graph.num_observables); + assert(O_sparse.size() == mwpm.flooder.graph.num_observables); + pm::total_weight_int weight = 0; + std::vector obs(mwpm->flooder.graph.num_observables, 0); + obs.resize(mwpm->flooder.graph.num_observables); + pm::decode_detection_events(*mwpm, detection_events, obs.data(), weight, + /*edge_correlations=*/false); + result.result.resize(mwpm->flooder.graph.num_observables); + for (size_t i = 0; i < mwpm->flooder.graph.num_observables; i++) { + result.result[i] = static_cast(obs[i]); + } } } else { - pm::decode_detection_events_to_edges(mwpm, detection_events, edges); + std::vector edges; + result.result.resize(block_size); + pm::decode_detection_events_to_edges(*mwpm, detection_events, edges); // Loop over the edge pairs to reconstruct errors. assert(edges.size() % 2 == 0); for (size_t i = 0; i < edges.size(); i += 2) { @@ -176,10 +212,32 @@ class pymatching : public decoder { // An exception is thrown if no matching solution is found, so we can just // set converged to true. result.converged = true; +#if PERFORM_TIMING + auto t3 = std::chrono::high_resolution_clock::now(); + decode_times[0] += + std::chrono::duration_cast(t1 - t0).count() / + 1e6; + decode_times[1] += + std::chrono::duration_cast(t2 - t1).count() / + 1e6; + decode_times[2] += + std::chrono::duration_cast(t3 - t2).count() / + 1e6; + decode_times[3] += + std::chrono::duration_cast(t3 - t0).count() / + 1e6; +#endif return result; } - virtual ~pymatching() {} + virtual ~pymatching() { +#if PERFORM_TIMING + for (int i = 0; i < NUM_TIMING_STEPS; i++) { + std::cout << "Decode time[" << i << "]: " << decode_times[i] << " seconds" + << std::endl; + } +#endif + } CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION( pymatching, static std::unique_ptr create( From 2c821c178db122ab1eef5485c482bb4559f5621e Mon Sep 17 00:00:00 2001 From: Ben Howe Date: Thu, 9 Apr 2026 22:51:15 -0700 Subject: [PATCH 4/4] Update test_realtime_predecoder_w_pymatching.cpp to decode to observables Signed-off-by: Ben Howe --- .../test_realtime_predecoder_w_pymatching.cpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/libs/qec/unittests/realtime/test_realtime_predecoder_w_pymatching.cpp b/libs/qec/unittests/realtime/test_realtime_predecoder_w_pymatching.cpp index 3c41e33a..94500588 100644 --- a/libs/qec/unittests/realtime/test_realtime_predecoder_w_pymatching.cpp +++ b/libs/qec/unittests/realtime/test_realtime_predecoder_w_pymatching.cpp @@ -258,8 +258,10 @@ int main(int argc, char *argv[]) { if (!stim.priors.empty() && stim.priors.size() == stim.H.ncols) pm_params.insert("error_rate_vec", stim.priors); - if (stim.O.loaded()) + if (stim.O.loaded()) { obs_row = stim.O.row_dense(0); + pm_params.insert("O", stim.O.to_dense()); + } std::cout << "[Setup] Creating " << config.num_decode_workers << " PyMatching decoders (full H)...\n"; @@ -473,12 +475,9 @@ int main(int argc, char *argv[]) { decoder_ctx.num_residual_detectors); auto result = my_decoder->decode(syndrome_tensor); all_converged = result.converged; - if (!obs_row.empty() && obs_row.size() == result.result.size()) { - int obs_parity = 0; - for (size_t e = 0; e < result.result.size(); ++e) - if (result.result[e] > 0.5 && obs_row[e]) - obs_parity ^= 1; - total_corrections += obs_parity; + if (!obs_row.empty() && !result.result.empty()) { + if (result.result[0] > 0.5) + total_corrections++; } else { for (auto v : result.result) if (v > 0.5)