Skip to content

Commit e258fd3

Browse files
ChanceSiyuanclaude
andcommitted
Add syndrome database generation (Issue #5)
- Add syndrome.py module for sampling and saving syndromes - Integrate syndrome generation into CLI with --generate-syndromes flag - Add comprehensive test suite for syndrome operations - Add make generate-syndromes target for easy database creation - Support npz format with metadata for efficient storage Features: - Sample detection events from circuits - Save/load syndrome databases with metadata - Generate databases directly from circuit files - CLI integration for automated workflow Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 86dad3b commit e258fd3

File tree

4 files changed

+337
-8
lines changed

4 files changed

+337
-8
lines changed

Makefile

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1-
.PHONY: help install setup test test-cov generate-dataset clean
1+
.PHONY: help install setup test test-cov generate-dataset generate-syndromes clean
22

33
help:
44
@echo "Available targets:"
5-
@echo " install - Install uv package manager"
6-
@echo " setup - Set up development environment with uv"
7-
@echo " generate-dataset - Generate noisy circuit dataset"
8-
@echo " test - Run tests"
9-
@echo " test-cov - Run tests with coverage report"
10-
@echo " clean - Remove generated files and caches"
5+
@echo " install - Install uv package manager"
6+
@echo " setup - Set up development environment with uv"
7+
@echo " generate-dataset - Generate noisy circuit dataset"
8+
@echo " generate-syndromes - Generate syndrome database (1000 shots)"
9+
@echo " test - Run tests"
10+
@echo " test-cov - Run tests with coverage report"
11+
@echo " clean - Remove generated files and caches"
1112

1213
install:
1314
@command -v uv >/dev/null 2>&1 || { \
@@ -21,6 +22,9 @@ setup: install
2122
generate-dataset:
2223
uv run generate-noisy-circuits --distance 3 --p 0.01 --rounds 3 5 7 --task z --output datasets/noisy_circuits
2324

25+
generate-syndromes:
26+
uv run generate-noisy-circuits --distance 3 --p 0.01 --rounds 3 5 7 --task z --output datasets/noisy_circuits --generate-syndromes 1000
27+
2428
test:
2529
uv run pytest
2630

src/bpdecoderplus/cli.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""
2-
Command-line interface for generating noisy surface-code circuits.
2+
Command-line interface for generating noisy surface-code circuits and syndrome databases.
33
"""
44

55
from __future__ import annotations
@@ -15,6 +15,7 @@
1515
run_smoke_test,
1616
write_circuit,
1717
)
18+
from bpdecoderplus.syndrome import generate_syndrome_database_from_circuit
1819

1920

2021
def create_parser() -> argparse.ArgumentParser:
@@ -63,6 +64,12 @@ def create_parser() -> argparse.ArgumentParser:
6364
action="store_true",
6465
help="Skip compiling and sampling for quick validation",
6566
)
67+
parser.add_argument(
68+
"--generate-syndromes",
69+
type=int,
70+
metavar="NUM_SHOTS",
71+
help="Generate syndrome database with specified number of shots",
72+
)
6673
return parser
6774

6875

@@ -110,6 +117,13 @@ def main(argv: list[str] | None = None) -> int:
110117
write_circuit(circuit, output_path)
111118
print(f"Wrote {output_path}")
112119

120+
# Generate syndrome database if requested
121+
if args.generate_syndromes:
122+
syndrome_path = generate_syndrome_database_from_circuit(
123+
output_path, args.generate_syndromes
124+
)
125+
print(f"Wrote {syndrome_path}")
126+
113127
print("Done.")
114128
return 0
115129

src/bpdecoderplus/syndrome.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
Syndrome database generation module for noisy surface-code circuits.
3+
4+
This module provides functions to sample detection events (syndromes) from
5+
circuits and save them in a structured format for decoder training/testing.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import json
11+
import pathlib
12+
from typing import Any
13+
14+
import numpy as np
15+
import stim
16+
17+
18+
def sample_syndromes(
19+
circuit: stim.Circuit,
20+
num_shots: int,
21+
include_observables: bool = True,
22+
) -> tuple[np.ndarray, np.ndarray | None]:
23+
"""
24+
Sample detection events (syndromes) from a circuit.
25+
26+
Args:
27+
circuit: Stim circuit to sample from.
28+
num_shots: Number of syndrome samples to generate.
29+
include_observables: Whether to include observable flip outcomes.
30+
31+
Returns:
32+
Tuple of (syndromes, observables) where:
33+
- syndromes: Array of shape (num_shots, num_detectors)
34+
- observables: Array of shape (num_shots,) if include_observables else None
35+
"""
36+
sampler = circuit.compile_detector_sampler()
37+
samples = sampler.sample(num_shots, append_observables=include_observables)
38+
39+
if include_observables:
40+
syndromes = samples[:, :-1]
41+
observables = samples[:, -1]
42+
return syndromes, observables
43+
else:
44+
return samples, None
45+
46+
47+
def save_syndrome_database(
48+
syndromes: np.ndarray,
49+
observables: np.ndarray | None,
50+
output_path: pathlib.Path,
51+
metadata: dict[str, Any] | None = None,
52+
) -> None:
53+
"""
54+
Save syndrome database to disk in npz format.
55+
56+
Args:
57+
syndromes: Array of detection events, shape (num_shots, num_detectors).
58+
observables: Array of observable flips, shape (num_shots,), or None.
59+
output_path: Path to save the database (.npz file).
60+
metadata: Optional metadata dictionary to save alongside the data.
61+
"""
62+
save_dict = {"syndromes": syndromes}
63+
64+
if observables is not None:
65+
save_dict["observables"] = observables
66+
67+
if metadata is not None:
68+
# Save metadata as JSON string in the npz file
69+
save_dict["metadata"] = np.array([json.dumps(metadata)])
70+
71+
np.savez_compressed(output_path, **save_dict)
72+
73+
74+
def load_syndrome_database(
75+
input_path: pathlib.Path,
76+
) -> tuple[np.ndarray, np.ndarray | None, dict[str, Any] | None]:
77+
"""
78+
Load syndrome database from disk.
79+
80+
Args:
81+
input_path: Path to the database file (.npz).
82+
83+
Returns:
84+
Tuple of (syndromes, observables, metadata) where:
85+
- syndromes: Array of shape (num_shots, num_detectors)
86+
- observables: Array of shape (num_shots,) or None
87+
- metadata: Dictionary of metadata or None
88+
"""
89+
data = np.load(input_path, allow_pickle=True)
90+
91+
syndromes = data["syndromes"]
92+
observables = data.get("observables", None)
93+
94+
metadata = None
95+
if "metadata" in data:
96+
metadata = json.loads(str(data["metadata"][0]))
97+
98+
return syndromes, observables, metadata
99+
100+
101+
def generate_syndrome_database_from_circuit(
102+
circuit_path: pathlib.Path,
103+
num_shots: int,
104+
output_path: pathlib.Path | None = None,
105+
) -> pathlib.Path:
106+
"""
107+
Generate and save syndrome database from a circuit file.
108+
109+
Args:
110+
circuit_path: Path to the circuit file (.stim).
111+
num_shots: Number of syndrome samples to generate.
112+
output_path: Optional output path. If None, uses circuit_path with .npz extension.
113+
114+
Returns:
115+
Path to the saved database file.
116+
"""
117+
# Load circuit
118+
circuit = stim.Circuit.from_file(str(circuit_path))
119+
120+
# Generate output path if not provided
121+
if output_path is None:
122+
output_path = circuit_path.with_suffix(".npz")
123+
124+
# Sample syndromes
125+
syndromes, observables = sample_syndromes(circuit, num_shots, include_observables=True)
126+
127+
# Create metadata
128+
dem = circuit.detector_error_model()
129+
metadata = {
130+
"circuit_file": str(circuit_path.name),
131+
"num_shots": num_shots,
132+
"num_detectors": dem.num_detectors,
133+
"num_observables": dem.num_observables,
134+
}
135+
136+
# Save database
137+
save_syndrome_database(syndromes, observables, output_path, metadata)
138+
139+
return output_path

tests/test_syndrome.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""Tests for syndrome database generation module."""
2+
3+
from __future__ import annotations
4+
5+
import pathlib
6+
import tempfile
7+
8+
import numpy as np
9+
import pytest
10+
import stim
11+
12+
from bpdecoderplus.circuit import generate_circuit
13+
from bpdecoderplus.syndrome import (
14+
generate_syndrome_database_from_circuit,
15+
load_syndrome_database,
16+
sample_syndromes,
17+
save_syndrome_database,
18+
)
19+
20+
21+
class TestSampleSyndromes:
22+
"""Tests for sample_syndromes function."""
23+
24+
def test_basic_sampling(self):
25+
"""Test basic syndrome sampling."""
26+
circuit = generate_circuit(distance=3, rounds=3, p=0.01, task="z")
27+
syndromes, observables = sample_syndromes(circuit, num_shots=10)
28+
29+
assert syndromes.shape[0] == 10
30+
assert observables.shape == (10,)
31+
assert syndromes.dtype == np.uint8
32+
assert observables.dtype == np.uint8
33+
34+
def test_without_observables(self):
35+
"""Test sampling without observables."""
36+
circuit = generate_circuit(distance=3, rounds=3, p=0.01, task="z")
37+
syndromes, observables = sample_syndromes(
38+
circuit, num_shots=10, include_observables=False
39+
)
40+
41+
assert syndromes.shape[0] == 10
42+
assert observables is None
43+
44+
def test_num_detectors(self):
45+
"""Test that number of detectors matches circuit."""
46+
circuit = generate_circuit(distance=3, rounds=3, p=0.01, task="z")
47+
dem = circuit.detector_error_model()
48+
syndromes, _ = sample_syndromes(circuit, num_shots=5)
49+
50+
assert syndromes.shape[1] == dem.num_detectors
51+
52+
53+
class TestSaveSyndromeDatabase:
54+
"""Tests for save_syndrome_database function."""
55+
56+
def test_save_with_observables(self):
57+
"""Test saving database with observables."""
58+
syndromes = np.random.randint(0, 2, size=(10, 24), dtype=np.uint8)
59+
observables = np.random.randint(0, 2, size=10, dtype=np.uint8)
60+
61+
with tempfile.TemporaryDirectory() as tmpdir:
62+
output_path = pathlib.Path(tmpdir) / "test.npz"
63+
save_syndrome_database(syndromes, observables, output_path)
64+
65+
assert output_path.exists()
66+
67+
def test_save_without_observables(self):
68+
"""Test saving database without observables."""
69+
syndromes = np.random.randint(0, 2, size=(10, 24), dtype=np.uint8)
70+
71+
with tempfile.TemporaryDirectory() as tmpdir:
72+
output_path = pathlib.Path(tmpdir) / "test.npz"
73+
save_syndrome_database(syndromes, None, output_path)
74+
75+
assert output_path.exists()
76+
77+
def test_save_with_metadata(self):
78+
"""Test saving database with metadata."""
79+
syndromes = np.random.randint(0, 2, size=(10, 24), dtype=np.uint8)
80+
observables = np.random.randint(0, 2, size=10, dtype=np.uint8)
81+
metadata = {"distance": 3, "rounds": 3, "p": 0.01}
82+
83+
with tempfile.TemporaryDirectory() as tmpdir:
84+
output_path = pathlib.Path(tmpdir) / "test.npz"
85+
save_syndrome_database(syndromes, observables, output_path, metadata)
86+
87+
assert output_path.exists()
88+
89+
90+
class TestLoadSyndromeDatabase:
91+
"""Tests for load_syndrome_database function."""
92+
93+
def test_load_with_observables(self):
94+
"""Test loading database with observables."""
95+
syndromes = np.random.randint(0, 2, size=(10, 24), dtype=np.uint8)
96+
observables = np.random.randint(0, 2, size=10, dtype=np.uint8)
97+
98+
with tempfile.TemporaryDirectory() as tmpdir:
99+
output_path = pathlib.Path(tmpdir) / "test.npz"
100+
save_syndrome_database(syndromes, observables, output_path)
101+
102+
loaded_syndromes, loaded_observables, _ = load_syndrome_database(output_path)
103+
104+
np.testing.assert_array_equal(loaded_syndromes, syndromes)
105+
np.testing.assert_array_equal(loaded_observables, observables)
106+
107+
def test_load_without_observables(self):
108+
"""Test loading database without observables."""
109+
syndromes = np.random.randint(0, 2, size=(10, 24), dtype=np.uint8)
110+
111+
with tempfile.TemporaryDirectory() as tmpdir:
112+
output_path = pathlib.Path(tmpdir) / "test.npz"
113+
save_syndrome_database(syndromes, None, output_path)
114+
115+
loaded_syndromes, loaded_observables, _ = load_syndrome_database(output_path)
116+
117+
np.testing.assert_array_equal(loaded_syndromes, syndromes)
118+
assert loaded_observables is None
119+
120+
def test_load_with_metadata(self):
121+
"""Test loading database with metadata."""
122+
syndromes = np.random.randint(0, 2, size=(10, 24), dtype=np.uint8)
123+
observables = np.random.randint(0, 2, size=10, dtype=np.uint8)
124+
metadata = {"distance": 3, "rounds": 3, "p": 0.01}
125+
126+
with tempfile.TemporaryDirectory() as tmpdir:
127+
output_path = pathlib.Path(tmpdir) / "test.npz"
128+
save_syndrome_database(syndromes, observables, output_path, metadata)
129+
130+
_, _, loaded_metadata = load_syndrome_database(output_path)
131+
132+
assert loaded_metadata == metadata
133+
134+
135+
class TestGenerateSyndromeDatabaseFromCircuit:
136+
"""Tests for generate_syndrome_database_from_circuit function."""
137+
138+
def test_generate_from_circuit_file(self):
139+
"""Test generating database from circuit file."""
140+
circuit = generate_circuit(distance=3, rounds=3, p=0.01, task="z")
141+
142+
with tempfile.TemporaryDirectory() as tmpdir:
143+
circuit_path = pathlib.Path(tmpdir) / "test.stim"
144+
circuit_path.write_text(str(circuit))
145+
146+
db_path = generate_syndrome_database_from_circuit(circuit_path, num_shots=20)
147+
148+
assert db_path.exists()
149+
assert db_path.suffix == ".npz"
150+
151+
# Load and verify
152+
syndromes, observables, metadata = load_syndrome_database(db_path)
153+
assert syndromes.shape[0] == 20
154+
assert observables.shape == (20,)
155+
assert metadata["num_shots"] == 20
156+
assert metadata["circuit_file"] == "test.stim"
157+
158+
def test_custom_output_path(self):
159+
"""Test generating database with custom output path."""
160+
circuit = generate_circuit(distance=3, rounds=3, p=0.01, task="z")
161+
162+
with tempfile.TemporaryDirectory() as tmpdir:
163+
circuit_path = pathlib.Path(tmpdir) / "test.stim"
164+
circuit_path.write_text(str(circuit))
165+
166+
custom_output = pathlib.Path(tmpdir) / "custom_db.npz"
167+
db_path = generate_syndrome_database_from_circuit(
168+
circuit_path, num_shots=15, output_path=custom_output
169+
)
170+
171+
assert db_path == custom_output
172+
assert db_path.exists()

0 commit comments

Comments
 (0)