Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions libs/core/include/cuda-qx/core/kwargs_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,22 +75,54 @@ inline heterogeneous_map hetMapFromKwargs(const py::kwargs &kwargs) {
} else if (py::isinstance<py::array>(value)) {
py::array np_array = value.cast<py::array>();
py::buffer_info info = np_array.request();
auto insert_vector = [&](auto type_tag) {
using T = decltype(type_tag);
std::vector<T> vec(static_cast<T *>(info.ptr),
static_cast<T *>(info.ptr) + info.size);
result.insert(key, std::move(vec));
};
if (info.format == py::format_descriptor<double>::format()) {
insert_vector(double{});
} else if (info.format == py::format_descriptor<float>::format()) {
insert_vector(float{});
} else if (info.format == py::format_descriptor<int>::format()) {
insert_vector(int{});
} else if (info.format == py::format_descriptor<uint8_t>::format()) {
insert_vector(uint8_t{});
if (info.ndim >= 2) {
if (info.strides[0] == static_cast<py::ssize_t>(info.itemsize)) {
throw std::runtime_error(
"Array in kwargs must be in row-major order, but "
"column-major order was detected.");
}
std::vector<std::size_t> shape(static_cast<std::size_t>(info.ndim),
std::size_t(0));
for (py::ssize_t d = 0; d < info.ndim; d++)
shape[d] = static_cast<std::size_t>(info.shape[d]);

auto insert_tensor = [&](auto type_tag) {
using T = decltype(type_tag);
cudaqx::tensor<T> ten(shape);
ten.borrow(static_cast<T *>(info.ptr), shape);
result.insert(key, std::move(ten));
};
if (info.format == py::format_descriptor<double>::format()) {
insert_tensor(double{});
} else if (info.format == py::format_descriptor<float>::format()) {
insert_tensor(float{});
} else if (info.format == py::format_descriptor<int>::format()) {
insert_tensor(int{});
} else if (info.format == py::format_descriptor<uint8_t>::format()) {
insert_tensor(uint8_t{});
} else {
throw std::runtime_error("Unsupported array data type in kwargs.");
}
} else {
throw std::runtime_error("Unsupported array data type in kwargs.");
// 1D array: keep as flattened vector for backward compatibility
// (e.g. error_rate_vec used by decoders).
auto insert_vector = [&](auto type_tag) {
using T = decltype(type_tag);
std::vector<T> vec(static_cast<T *>(info.ptr),
static_cast<T *>(info.ptr) + info.size);
result.insert(key, std::move(vec));
};
if (info.format == py::format_descriptor<double>::format()) {
insert_vector(double{});
} else if (info.format == py::format_descriptor<float>::format()) {
insert_vector(float{});
} else if (info.format == py::format_descriptor<int>::format()) {
insert_vector(int{});
} else if (info.format == py::format_descriptor<uint8_t>::format()) {
insert_vector(uint8_t{});
} else {
throw std::runtime_error("Unsupported array data type in kwargs.");
}
}
} else {
throw std::runtime_error(
Expand Down
134 changes: 119 additions & 15 deletions libs/qec/lib/decoders/plugins/pymatching/pymatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
#include <map>
#include <vector>

// Enable this to debug decode times.
#define PERFORM_TIMING 0

namespace cudaq::qec {

/// @brief This is a wrapper around the PyMatching library that implements the
/// MWPM decoder.
class pymatching : public decoder {
private:
pm::UserGraph user_graph;
pm::Mwpm *mwpm = nullptr;

// Input parameters
std::vector<double> error_rate_vec;
Expand All @@ -31,12 +35,19 @@ class pymatching : public decoder {
// efficient.
std::map<std::pair<int64_t, int64_t>, size_t> edge2col_idx;

bool decode_to_observables = false;

// Helper function to make a canonical edge from two nodes.
std::pair<int64_t, int64_t> make_canonical_edge(int64_t node1,
int64_t node2) {
return std::make_pair(std::min(node1, node2), std::max(node1, node2));
}

#if PERFORM_TIMING
static constexpr size_t NUM_TIMING_STEPS = 4;
std::array<double, NUM_TIMING_STEPS> decode_times;
#endif

public:
pymatching(const cudaqx::tensor<uint8_t> &H,
const cudaqx::heterogeneous_map &params)
Expand Down Expand Up @@ -77,10 +88,39 @@ class pymatching : public decoder {
}
}

std::vector<std::vector<size_t>> errs2observables(block_size);
if (params.contains("O")) {
auto O = params.get<cudaqx::tensor<uint8_t>>("O");
if (O.rank() != 2) {
throw std::runtime_error(
"O must be a 2-dimensional tensor (num_observables x block_size)");
}
const size_t num_observables = O.shape()[0];
if (O.shape()[1] != block_size) {
throw std::runtime_error(
"O must be of shape (num_observables, block_size); got second "
"dimension " +
std::to_string(O.shape()[1]) + ", block_size " +
std::to_string(block_size));
}
std::vector<std::vector<uint32_t>> O_sparse;
for (size_t i = 0; i < num_observables; i++) {
O_sparse.emplace_back();
auto *row = &O.at({i, 0});
for (size_t j = 0; j < block_size; j++) {
if (row[j] > 0) {
O_sparse.back().push_back(static_cast<uint32_t>(j));
errs2observables[j].push_back(static_cast<uint32_t>(i));
}
}
}
this->set_O_sparse(O_sparse);
decode_to_observables = true;
}

user_graph = pm::UserGraph(H.shape()[0]);

auto sparse = cudaq::qec::dense_to_sparse(H);
std::vector<size_t> observables;
std::size_t col_idx = 0;
for (auto &col : sparse) {
double weight = 1.0;
Expand All @@ -90,19 +130,27 @@ class pymatching : public decoder {
}
if (col.size() == 2) {
edge2col_idx[make_canonical_edge(col[0], col[1])] = col_idx;
user_graph.add_or_merge_edge(col[0], col[1], observables, weight, 0.0,
user_graph.add_or_merge_edge(col[0], col[1],
errs2observables.at(col_idx), weight, 0.0,
merge_strategy_enum);
} else if (col.size() == 1) {
edge2col_idx[make_canonical_edge(col[0], -1)] = col_idx;
user_graph.add_or_merge_boundary_edge(col[0], observables, weight, 0.0,
merge_strategy_enum);
user_graph.add_or_merge_boundary_edge(col[0],
errs2observables.at(col_idx),
weight, 0.0, merge_strategy_enum);
} else {
throw std::runtime_error(
"Invalid column in H: " + std::to_string(col_idx) + " has " +
std::to_string(col.size()) + " ones. Must have 1 or 2 ones.");
}
col_idx++;
}
this->mwpm = decode_to_observables
? &user_graph.get_mwpm()
: &user_graph.get_mwpm_with_search_graph();
#if PERFORM_TIMING
std::fill(decode_times.begin(), decode_times.end(), 0.0);
#endif
}

/// @brief Decode the syndrome using the MWPM decoder.
Expand All @@ -111,29 +159,85 @@ class pymatching : public decoder {
/// @throws std::runtime_error if no matching solution is found, or
/// std::out_of_range if an edge is not found in the edge2col_idx map.
virtual decoder_result decode(const std::vector<float_t> &syndrome) {
decoder_result result{false, std::vector<float_t>(block_size, 0.0)};
auto &mwpm = user_graph.get_mwpm_with_search_graph();
std::vector<int64_t> edges;
#if PERFORM_TIMING
auto t0 = std::chrono::high_resolution_clock::now();
#endif
decoder_result result{false, std::vector<float_t>()};
#if PERFORM_TIMING
auto t1 = std::chrono::high_resolution_clock::now();
#endif

std::vector<uint64_t> detection_events;
detection_events.reserve(syndrome.size());
for (size_t i = 0; i < syndrome.size(); i++)
if (syndrome[i] > 0.5)
detection_events.push_back(i);
pm::decode_detection_events_to_edges(mwpm, detection_events, edges);
// Loop over the edge pairs
assert(edges.size() % 2 == 0);
for (size_t i = 0; i < edges.size(); i += 2) {
auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1));
auto col_idx = edge2col_idx.at(edge);
result.result[col_idx] = 1.0;
#if PERFORM_TIMING
auto t2 = std::chrono::high_resolution_clock::now();
#endif
if (decode_to_observables) {
if (mwpm->flooder.graph.num_observables < 64) {
result.result.resize(mwpm->flooder.graph.num_observables);
auto res = pm::decode_detection_events_for_up_to_64_observables(
*mwpm, detection_events, /*edge_correlations=*/false);
for (size_t i = 0; i < mwpm->flooder.graph.num_observables; i++) {
result.result[i] =
static_cast<float_t>(res.obs_mask & (1 << i) ? 1.0 : 0.0);
}
} else {
result.result.resize(mwpm->flooder.graph.num_observables);
assert(O_sparse.size() == mwpm.flooder.graph.num_observables);
pm::total_weight_int weight = 0;
std::vector<uint8_t> obs(mwpm->flooder.graph.num_observables, 0);
obs.resize(mwpm->flooder.graph.num_observables);
pm::decode_detection_events(*mwpm, detection_events, obs.data(), weight,
/*edge_correlations=*/false);
result.result.resize(mwpm->flooder.graph.num_observables);
for (size_t i = 0; i < mwpm->flooder.graph.num_observables; i++) {
result.result[i] = static_cast<float_t>(obs[i]);
}
}
} else {
std::vector<int64_t> edges;
result.result.resize(block_size);
pm::decode_detection_events_to_edges(*mwpm, detection_events, edges);
// Loop over the edge pairs to reconstruct errors.
assert(edges.size() % 2 == 0);
for (size_t i = 0; i < edges.size(); i += 2) {
auto edge = make_canonical_edge(edges.at(i), edges.at(i + 1));
auto col_idx = edge2col_idx.at(edge);
result.result[col_idx] = 1.0;
}
}
// An exception is thrown if no matching solution is found, so we can just
// set converged to true.
result.converged = true;
#if PERFORM_TIMING
auto t3 = std::chrono::high_resolution_clock::now();
decode_times[0] +=
std::chrono::duration_cast<std::chrono::microseconds>(t1 - t0).count() /
1e6;
decode_times[1] +=
std::chrono::duration_cast<std::chrono::microseconds>(t2 - t1).count() /
1e6;
decode_times[2] +=
std::chrono::duration_cast<std::chrono::microseconds>(t3 - t2).count() /
1e6;
decode_times[3] +=
std::chrono::duration_cast<std::chrono::microseconds>(t3 - t0).count() /
1e6;
#endif
return result;
}

virtual ~pymatching() {}
virtual ~pymatching() {
#if PERFORM_TIMING
for (int i = 0; i < NUM_TIMING_STEPS; i++) {
std::cout << "Decode time[" << i << "]: " << decode_times[i] << " seconds"
<< std::endl;
}
#endif
}

CUDAQ_EXTENSION_CUSTOM_CREATOR_FUNCTION(
pymatching, static std::unique_ptr<decoder> create(
Expand Down
45 changes: 44 additions & 1 deletion libs/qec/python/tests/test_dem.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ============================================================================ #
# Copyright (c) 2024 - 2025 NVIDIA Corporation & Affiliates. #
# Copyright (c) 2024 - 2026 NVIDIA Corporation & Affiliates. #
# All rights reserved. #
# #
# This source code and the accompanying materials are made available under #
Expand Down Expand Up @@ -275,6 +275,49 @@ def test_decoding_from_surface_code_dem_from_memory_circuit(
assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding


def test_pymatching_decode_to_observable_surface_code_dem():
"""Test PyMatching with O (observables) matrix: decoder returns observable
flips directly.cpp)."""
cudaq.set_random_seed(13)
code = qec.get_code('surface_code', distance=5)
Lz = code.get_observables_z()
p = 0.003
noise = cudaq.NoiseModel()
noise.add_all_qubit_channel("x", cudaq.Depolarization2(p), 1)
statePrep = qec.operation.prep0
nRounds = 5
nShots = 2000

syndromes, data = qec.sample_memory_circuit(code, statePrep, nShots,
nRounds, noise)

logical_measurements = (Lz @ data.transpose()) % 2
logical_measurements = logical_measurements.flatten()

syndromes = syndromes.reshape((nShots, nRounds, -1))
syndromes = syndromes[:, :, :syndromes.shape[2] // 2]
syndromes = syndromes.reshape((nShots, -1))

dem = qec.z_dem_from_memory_circuit(code, statePrep, nRounds, noise)

decoder = qec.get_decoder(
'pymatching',
dem.detector_error_matrix,
O=dem.observables_flips_matrix,
error_rate_vec=np.array(dem.error_rates),
)

dr = decoder.decode_batch(syndromes)
# With decode_to_observables=True, each e.result is observable flips
# (length num_observables), not error predictions.
obs_per_shot = np.array([e.result for e in dr], dtype=np.float64)
data_predictions = np.round(obs_per_shot).astype(np.uint8).T

nLogicalErrorsWithoutDecoding = np.sum(logical_measurements)
nLogicalErrorsWithDecoding = np.sum(data_predictions ^ logical_measurements)
assert nLogicalErrorsWithDecoding < nLogicalErrorsWithoutDecoding


def test_pcm_extend_to_n_rounds():
# This test independently compares the functionality of dem_from_memory_circuit
# (of two different numbers of rounds) to pcm_extend_to_n_rounds.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,10 @@ int main(int argc, char *argv[]) {
if (!stim.priors.empty() && stim.priors.size() == stim.H.ncols)
pm_params.insert("error_rate_vec", stim.priors);

if (stim.O.loaded())
if (stim.O.loaded()) {
obs_row = stim.O.row_dense(0);
pm_params.insert("O", stim.O.to_dense());
}

std::cout << "[Setup] Creating " << config.num_decode_workers
<< " PyMatching decoders (full H)...\n";
Expand Down Expand Up @@ -473,12 +475,9 @@ int main(int argc, char *argv[]) {
decoder_ctx.num_residual_detectors);
auto result = my_decoder->decode(syndrome_tensor);
all_converged = result.converged;
if (!obs_row.empty() && obs_row.size() == result.result.size()) {
int obs_parity = 0;
for (size_t e = 0; e < result.result.size(); ++e)
if (result.result[e] > 0.5 && obs_row[e])
obs_parity ^= 1;
total_corrections += obs_parity;
if (!obs_row.empty() && !result.result.empty()) {
if (result.result[0] > 0.5)
total_corrections++;
} else {
for (auto v : result.result)
if (v > 0.5)
Expand Down
Loading