Skip to content

Commit 3230244

Browse files
ChanceSiyuanclaude
andcommitted
Add detector error model generation (Issue #4)
- Add dem.py module for DEM extraction and manipulation - Extract DEM from circuits with decomposition support - Save/load DEMs in stim native format - Convert DEM to JSON for analysis - Build parity check matrix H for BP decoding - Integrate DEM generation into CLI with --generate-dem flag - Add comprehensive test suite for DEM operations - Add make generate-dem target Features: - Extract detector error models from circuits - Save in .dem format (stim native) - Export to JSON with structured error information - Build parity check matrix for BP decoder - CLI integration for automated workflow Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent e258fd3 commit 3230244

File tree

4 files changed

+415
-1
lines changed

4 files changed

+415
-1
lines changed

Makefile

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
.PHONY: help install setup test test-cov generate-dataset generate-syndromes clean
1+
.PHONY: help install setup test test-cov generate-dataset generate-dem generate-syndromes clean
22

33
help:
44
@echo "Available targets:"
55
@echo " install - Install uv package manager"
66
@echo " setup - Set up development environment with uv"
77
@echo " generate-dataset - Generate noisy circuit dataset"
8+
@echo " generate-dem - Generate detector error models"
89
@echo " generate-syndromes - Generate syndrome database (1000 shots)"
910
@echo " test - Run tests"
1011
@echo " test-cov - Run tests with coverage report"
@@ -22,6 +23,9 @@ setup: install
2223
generate-dataset:
2324
uv run generate-noisy-circuits --distance 3 --p 0.01 --rounds 3 5 7 --task z --output datasets/noisy_circuits
2425

26+
generate-dem:
27+
uv run generate-noisy-circuits --distance 3 --p 0.01 --rounds 3 5 7 --task z --output datasets/noisy_circuits --generate-dem
28+
2529
generate-syndromes:
2630
uv run generate-noisy-circuits --distance 3 --p 0.01 --rounds 3 5 7 --task z --output datasets/noisy_circuits --generate-syndromes 1000
2731

src/bpdecoderplus/cli.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
run_smoke_test,
1616
write_circuit,
1717
)
18+
from bpdecoderplus.dem import generate_dem_from_circuit
1819
from bpdecoderplus.syndrome import generate_syndrome_database_from_circuit
1920

2021

@@ -70,6 +71,11 @@ def create_parser() -> argparse.ArgumentParser:
7071
metavar="NUM_SHOTS",
7172
help="Generate syndrome database with specified number of shots",
7273
)
74+
parser.add_argument(
75+
"--generate-dem",
76+
action="store_true",
77+
help="Generate detector error model (.dem file)",
78+
)
7379
return parser
7480

7581

@@ -117,6 +123,11 @@ def main(argv: list[str] | None = None) -> int:
117123
write_circuit(circuit, output_path)
118124
print(f"Wrote {output_path}")
119125

126+
# Generate DEM if requested
127+
if args.generate_dem:
128+
dem_path = generate_dem_from_circuit(output_path)
129+
print(f"Wrote {dem_path}")
130+
120131
# Generate syndrome database if requested
121132
if args.generate_syndromes:
122133
syndrome_path = generate_syndrome_database_from_circuit(

src/bpdecoderplus/dem.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""
2+
Detector Error Model (DEM) extraction module for noisy circuits.
3+
4+
This module provides functions to extract and save Detector Error Models
5+
from Stim circuits for use in decoder implementations.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import json
11+
import pathlib
12+
from typing import Any
13+
14+
import numpy as np
15+
import stim
16+
17+
18+
def extract_dem(
19+
circuit: stim.Circuit,
20+
decompose_errors: bool = True,
21+
) -> stim.DetectorErrorModel:
22+
"""
23+
Extract Detector Error Model from a circuit.
24+
25+
Args:
26+
circuit: Stim circuit to extract DEM from.
27+
decompose_errors: Whether to decompose errors into components.
28+
29+
Returns:
30+
Detector Error Model describing error mechanisms.
31+
"""
32+
return circuit.detector_error_model(decompose_errors=decompose_errors)
33+
34+
35+
def save_dem(
36+
dem: stim.DetectorErrorModel,
37+
output_path: pathlib.Path,
38+
) -> None:
39+
"""
40+
Save Detector Error Model to file in stim format.
41+
42+
Args:
43+
dem: Detector Error Model to save.
44+
output_path: Path to save the DEM (.dem file).
45+
"""
46+
output_path.write_text(str(dem))
47+
48+
49+
def load_dem(input_path: pathlib.Path) -> stim.DetectorErrorModel:
50+
"""
51+
Load Detector Error Model from file.
52+
53+
Args:
54+
input_path: Path to the DEM file (.dem).
55+
56+
Returns:
57+
Loaded Detector Error Model.
58+
"""
59+
return stim.DetectorErrorModel.from_file(str(input_path))
60+
61+
62+
def dem_to_dict(dem: stim.DetectorErrorModel) -> dict[str, Any]:
63+
"""
64+
Convert DEM to dictionary with structured information.
65+
66+
Args:
67+
dem: Detector Error Model to convert.
68+
69+
Returns:
70+
Dictionary with DEM statistics and error information.
71+
"""
72+
errors = []
73+
for inst in dem.flattened():
74+
if inst.type == "error":
75+
prob = inst.args_copy()[0]
76+
targets = inst.targets_copy()
77+
detectors = [t.val for t in targets if t.is_relative_detector_id()]
78+
observables = [t.val for t in targets if t.is_logical_observable_id()]
79+
80+
errors.append({
81+
"probability": float(prob),
82+
"detectors": detectors,
83+
"observables": observables,
84+
})
85+
86+
return {
87+
"num_detectors": dem.num_detectors,
88+
"num_observables": dem.num_observables,
89+
"num_errors": len(errors),
90+
"errors": errors,
91+
}
92+
93+
94+
def save_dem_json(
95+
dem: stim.DetectorErrorModel,
96+
output_path: pathlib.Path,
97+
) -> None:
98+
"""
99+
Save DEM as JSON for easier analysis.
100+
101+
Args:
102+
dem: Detector Error Model to save.
103+
output_path: Path to save the JSON file.
104+
"""
105+
dem_dict = dem_to_dict(dem)
106+
with open(output_path, "w") as f:
107+
json.dump(dem_dict, f, indent=2)
108+
109+
110+
def build_parity_check_matrix(
111+
dem: stim.DetectorErrorModel,
112+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
113+
"""
114+
Build parity check matrix H from DEM for BP decoding.
115+
116+
Args:
117+
dem: Detector Error Model.
118+
119+
Returns:
120+
Tuple of (H, priors, obs_flip) where:
121+
- H: Parity check matrix, shape (num_detectors, num_errors)
122+
- priors: Prior error probabilities, shape (num_errors,)
123+
- obs_flip: Observable flip indicators, shape (num_errors,)
124+
"""
125+
errors = []
126+
for inst in dem.flattened():
127+
if inst.type == "error":
128+
prob = inst.args_copy()[0]
129+
targets = inst.targets_copy()
130+
detectors = [t.val for t in targets if t.is_relative_detector_id()]
131+
observables = [t.val for t in targets if t.is_logical_observable_id()]
132+
errors.append({
133+
"prob": prob,
134+
"detectors": detectors,
135+
"observables": observables,
136+
})
137+
138+
n_detectors = dem.num_detectors
139+
n_errors = len(errors)
140+
141+
H = np.zeros((n_detectors, n_errors), dtype=np.uint8)
142+
priors = np.zeros(n_errors, dtype=np.float64)
143+
obs_flip = np.zeros(n_errors, dtype=np.uint8)
144+
145+
for j, e in enumerate(errors):
146+
priors[j] = e["prob"]
147+
for d in e["detectors"]:
148+
H[d, j] = 1
149+
if e["observables"]:
150+
obs_flip[j] = 1
151+
152+
return H, priors, obs_flip
153+
154+
155+
def generate_dem_from_circuit(
156+
circuit_path: pathlib.Path,
157+
output_path: pathlib.Path | None = None,
158+
decompose_errors: bool = True,
159+
) -> pathlib.Path:
160+
"""
161+
Generate and save DEM from a circuit file.
162+
163+
Args:
164+
circuit_path: Path to the circuit file (.stim).
165+
output_path: Optional output path. If None, uses circuit_path with .dem extension.
166+
decompose_errors: Whether to decompose errors into components.
167+
168+
Returns:
169+
Path to the saved DEM file.
170+
"""
171+
circuit = stim.Circuit.from_file(str(circuit_path))
172+
173+
if output_path is None:
174+
output_path = circuit_path.with_suffix(".dem")
175+
176+
dem = extract_dem(circuit, decompose_errors=decompose_errors)
177+
save_dem(dem, output_path)
178+
179+
return output_path

0 commit comments

Comments
 (0)