Skip to content

Commit 42c7ff9

Browse files
authored
Improved exception handling (#91)
Slight normalization issues no longer raise an error but a warning. These can occur for circuits with arbitrary rotation gates, where fp32 is used instead of an exact representation. Parsing errors for invalid Stim circuits now raise useful exceptions.
1 parent 88e152f commit 42c7ff9

7 files changed

Lines changed: 123 additions & 50 deletions

File tree

src/tsim/circuit.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from tsim.noise.dem import get_detector_error_model
1717
from tsim.utils.clifford import parametric_to_clifford_gates
1818
from tsim.utils.diagram import render_svg
19-
from tsim.utils.program_text import shorthand_to_stim, stim_to_shorthand
19+
from tsim.utils.program_text import (
20+
enriched_stim_error,
21+
shorthand_to_stim,
22+
stim_to_shorthand,
23+
)
2024

2125

2226
class Circuit:
@@ -42,7 +46,11 @@ def __init__(self, stim_program_text: str = ""):
4246
empty circuit.
4347
4448
"""
45-
self._stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text))
49+
converted = shorthand_to_stim(stim_program_text)
50+
try:
51+
self._stim_circ = stim.Circuit(converted)
52+
except ValueError as exc:
53+
raise enriched_stim_error(exc, converted) from None
4654

4755
@classmethod
4856
def from_stim_program(cls, stim_circuit: stim.Circuit) -> Circuit:
@@ -64,9 +72,11 @@ def append_from_stim_program_text(self, stim_program_text: str) -> None:
6472
6573
Supports the same shorthand syntax as the constructor.
6674
"""
67-
self._stim_circ.append_from_stim_program_text(
68-
shorthand_to_stim(stim_program_text)
69-
)
75+
converted = shorthand_to_stim(stim_program_text)
76+
try:
77+
self._stim_circ.append_from_stim_program_text(converted)
78+
except ValueError as exc:
79+
raise enriched_stim_error(exc, converted) from None
7080

7181
@overload
7282
def append(
@@ -141,18 +151,21 @@ def append(
141151
name = "S_DAG"
142152
tag = "T"
143153
elif name in ("R_X", "R_Y", "R_Z"):
144-
assert arg is not None, f"For {name} gates, an angle must be provided."
154+
if arg is None:
155+
raise ValueError(f"For {name} gates, an angle must be provided.")
145156
args = list(arg) if isinstance(arg, Iterable) else [arg]
146-
assert (
147-
len(args) == 1
148-
), f"For {name} gates, a single angle must be provided."
157+
if len(args) != 1:
158+
raise ValueError(
159+
f"For {name} gates, a single angle must be provided."
160+
)
149161
tag = f"{name}(theta={args[0]}*pi)"
150162
name = "I"
151163
arg = None
152164
elif name == "U3":
153-
assert arg is not None and (
154-
isinstance(arg, Iterable) and len(list(arg)) == 3
155-
), f"For U3 gates, three rotation angles must be provided."
165+
if arg is None or not isinstance(arg, Iterable) or len(list(arg)) != 3:
166+
raise ValueError(
167+
"For U3 gates, three rotation angles must be provided."
168+
)
156169
theta, phi, lam = list(arg)
157170
tag = f"U3(theta={theta}*pi, phi={phi}*pi, lambda={lam}*pi)"
158171
name = "I"
@@ -175,7 +188,11 @@ def from_file(cls, filename: str) -> Circuit:
175188
"""
176189
with open(filename, "r", encoding="utf-8") as f:
177190
stim_program_text = f.read()
178-
stim_circ = stim.Circuit(shorthand_to_stim(stim_program_text))
191+
converted = shorthand_to_stim(stim_program_text)
192+
try:
193+
stim_circ = stim.Circuit(converted)
194+
except ValueError as exc:
195+
raise enriched_stim_error(exc, converted) from None
179196
return cls.from_stim_program(stim_circ)
180197

181198
def __repr__(self) -> str:

src/tsim/core/parse.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,10 @@ def parse_parametric_tag(tag: str) -> tuple[str, dict[str, Fraction]] | None:
4747
param = param.strip()
4848
if not param:
4949
continue
50-
# Match param=value*pi (value can be negative/decimal)
51-
param_match = re.match(r"^(\w+)=([-+]?[\d.]+)\*pi$", param)
50+
# Match param=value*pi (value can be negative/decimal/scientific)
51+
param_match = re.match(
52+
r"^(\w+)=([-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?)\*pi$", param
53+
)
5254
if not param_match:
5355
return None
5456
param_name = param_match.group(1)

src/tsim/sampler.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from math import ceil
67
from typing import TYPE_CHECKING, Literal, overload
78

@@ -133,12 +134,17 @@ def sample_program(
133134

134135
for component in program.components:
135136
samples, key, max_norm_deviation = sample_component(component, f_params, key)
137+
if np.isclose(max_norm_deviation, 1):
138+
raise ValueError(
139+
"A vanishing marginal probability distributionwas encountered (normalization 0). "
140+
"This is likely the result of an underflow error. Please report this "
141+
"as a bug at https://github.com/QuEraComputing/tsim/issues/new."
142+
) # pragma: no cover
136143
if max_norm_deviation > 1e-5:
137-
raise AssertionError(
144+
warnings.warn(
138145
"A marginal probability was not normalized correctly "
139146
f"(normalization deviated from 1 by {max_norm_deviation:.1e}). "
140-
"This is likely the result of an underflow error. Please report this "
141-
"as a bug at https://github.com/QuEraComputing/tsim/issues/new."
147+
"This is likely a floating point precision issue."
142148
)
143149
results.append(samples)
144150

src/tsim/utils/encoder.py

Lines changed: 6 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@ def _transform_circuit(
2121
*,
2222
stride: int,
2323
offsets: list[int],
24-
gate_expansions: dict[str, list[str]] | None = None,
2524
used_qubits: set[int] | None = None,
2625
stabilizer_generators: list[list[int]] | None = None,
2726
observables: list[list[int]] | None = None,
2827
) -> stim.Circuit:
29-
"""Expand and duplicate instructions with broadcast targets for encoding."""
28+
"""Duplicate instructions with broadcast targets for encoding."""
3029
stim_circ = tsim.Circuit(program_text)._stim_circ.flattened()
3130
mod_circ = stim.Circuit()
3231

@@ -71,19 +70,12 @@ def _transform_circuit(
7170
instr.target_groups(), stride=stride, offsets=offsets
7271
)
7372

74-
gate_seq = (
75-
gate_expansions.get(instr.name, [instr.name])
76-
if gate_expansions
77-
else [instr.name]
73+
mod_circ.append(
74+
instr.name,
75+
new_ts,
76+
instr.gate_args_copy(),
77+
tag=instr.tag,
7878
)
79-
80-
for g in gate_seq:
81-
mod_circ.append(
82-
g,
83-
new_ts,
84-
instr.gate_args_copy(),
85-
tag=instr.tag,
86-
)
8779
return mod_circ
8880

8981

@@ -100,15 +92,13 @@ def __init__(
10092
encoding_program_text: str | None,
10193
stabilizer_generators: list[list[int]],
10294
observables: list[list[int]],
103-
logical_gate_expansions: dict[str, list[str]] | None = None,
10495
):
10596
"""Initialize the transversal encoder with code parameters."""
10697
self.n = n
10798
self.encoding_qubit = encoding_qubit
10899
self.circuit = tsim.Circuit()
109100
self.used_qubits: set[int] = set()
110101
self.encoding_program_text = encoding_program_text
111-
self.logical_gate_expansions = logical_gate_expansions or {}
112102
self.stabilizer_generators = stabilizer_generators
113103
self.observables = observables
114104

@@ -169,7 +159,6 @@ def encode_transversally(self, program_text: str) -> None:
169159
program_text,
170160
stride=self.n,
171161
offsets=list(range(self.n)),
172-
gate_expansions=self.logical_gate_expansions,
173162
stabilizer_generators=self.stabilizer_generators,
174163
observables=self.observables,
175164
)
@@ -215,12 +204,6 @@ def __init__(self):
215204
n=7,
216205
encoding_qubit=6,
217206
encoding_program_text=encoding_program,
218-
logical_gate_expansions={
219-
"SQRT_X": ["SQRT_X", "X"],
220-
"SQRT_X_DAG": ["SQRT_X_DAG", "X"],
221-
"S": ["S", "Z"],
222-
"S_DAG": ["S_DAG", "Z"],
223-
},
224207
stabilizer_generators=[[0, 1, 2, 3], [1, 2, 4, 5], [2, 3, 4, 6]],
225208
observables=[[0, 1, 5]],
226209
)

src/tsim/utils/program_text.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,30 @@
22

33
import re
44

5+
# Matches valid numeric literals including scientific notation (e.g. 0.5, 4e-4, 1.2e3)
6+
_FLOAT_RE = r"[-+]?(?:\d+\.?\d*|\.\d+)(?:[eE][-+]?\d+)?"
7+
8+
_TSIM_GATES = {"R_X", "R_Y", "R_Z", "U3"}
9+
_GATE_NOT_FOUND_RE = re.compile(r"Gate not found: '(\w+)'")
10+
_GATE_USAGE_RE = re.compile(r"(?<!\[)\b(R_[A-Z]\([^)]*\)|R_[XYZ]\b|U3\([^)]*\)|U3\b)")
11+
12+
13+
def enriched_stim_error(exc: ValueError, converted_text: str) -> ValueError:
14+
"""Improve stim parse errors for tsim-specific gates.
15+
16+
When stim raises a 'Gate not found' error for a gate that should have been
17+
converted by shorthand_to_stim, this searches the converted text for the
18+
unconverted usage and returns a more helpful error message.
19+
"""
20+
m = _GATE_NOT_FOUND_RE.search(str(exc))
21+
if not m or m.group(1) not in _TSIM_GATES:
22+
return exc
23+
# Successfully converted gates live inside brackets (e.g. I[R_Z(...)]) and won't match.
24+
usage = _GATE_USAGE_RE.search(converted_text)
25+
if not usage:
26+
return exc
27+
return ValueError(f"Could not parse '{usage.group()}' in program text.")
28+
529

630
def shorthand_to_stim(text: str) -> str:
731
"""Convert tsim shorthand syntax to valid stim instructions.
@@ -13,27 +37,25 @@ def shorthand_to_stim(text: str) -> str:
1337
R_X(0.25) 0 → I[R_X(theta=0.25*pi)] 0
1438
R_Y(-0.5) 0 → I[R_Y(theta=-0.5*pi)] 0
1539
U3(0.3, 0.24, 0.49) 0 → I[U3(theta=0.3*pi, phi=0.24*pi, lambda=0.49*pi)] 0
40+
1641
"""
1742
# T_DAG must come before T to avoid partial matches
1843
# (?<!\[) ensures we don't match T inside [T]
1944
text = re.sub(r"(?<!\[)\bT_DAG\b(?!\[)", "S_DAG[T]", text)
2045
text = re.sub(r"(?<!\[)\bT\b(?!\[)", "S[T]", text)
2146

22-
# R_Z(angle), R_X(angle), R_Y(angle)
2347
def replace_rotation(m: re.Match) -> str:
2448
axis = m.group(1)
25-
angle = m.group(2)
26-
return f"I[R_{axis}(theta={angle}*pi)]"
49+
return f"I[R_{axis}(theta={float(m.group(2))}*pi)]"
2750

28-
text = re.sub(r"\bR_([XYZ])\(([-+]?[\d.]+)\)", replace_rotation, text)
51+
text = re.sub(rf"\bR_([XYZ])\(({_FLOAT_RE})\)", replace_rotation, text)
2952

30-
# U3(theta, phi, lambda)
3153
def replace_u3(m: re.Match) -> str:
32-
theta, phi, lam = m.group(1), m.group(2), m.group(3)
54+
theta, phi, lam = float(m.group(1)), float(m.group(2)), float(m.group(3))
3355
return f"I[U3(theta={theta}*pi, phi={phi}*pi, lambda={lam}*pi)]"
3456

3557
text = re.sub(
36-
r"\bU3\(([-+]?[\d.]+)\s*,\s*([-+]?[\d.]+)\s*,\s*([-+]?[\d.]+)\)",
58+
rf"\bU3\(({_FLOAT_RE})\s*,\s*({_FLOAT_RE})\s*,\s*({_FLOAT_RE})\)",
3759
replace_u3,
3860
text,
3961
)

test/integration/test_sampler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from test.helpers.gen import gen_stim_circuit
2+
23
import jax
34
import jax.numpy as jnp
45
import numpy as np
@@ -137,7 +138,7 @@ def fake_sample_component(component, f_params, key):
137138

138139
monkeypatch.setattr(sampler_module, "sample_component", fake_sample_component)
139140

140-
with pytest.raises(AssertionError, match="underflow error"):
141+
with pytest.warns(UserWarning, match="not normalized"):
141142
sampler_module.sample_program(
142143
program,
143144
jnp.zeros((1, 0), dtype=jnp.bool_),

test/unit/utils/test_program_text.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,12 @@
1-
from tsim.utils.program_text import shorthand_to_stim, stim_to_shorthand
1+
import re
2+
3+
import pytest
4+
5+
from tsim import Circuit
6+
from tsim.utils.program_text import (
7+
shorthand_to_stim,
8+
stim_to_shorthand,
9+
)
210

311

412
def test_shorthand_to_stim_t_and_t_dag():
@@ -41,3 +49,37 @@ def test_stim_to_shorthand_rotations_and_u3():
4149
def test_shorthand_roundtrip():
4250
text = "T 0\nR_X(0.5) 1\nU3(0.1, 0.2, 0.3) 2"
4351
assert stim_to_shorthand(shorthand_to_stim(text)) == text
52+
53+
54+
def test_shorthand_scientific_notation():
55+
result = shorthand_to_stim("R_Z(4e-4) 0")
56+
assert "I[R_Z(theta=0.0004*pi)]" in result
57+
58+
59+
def test_shorthand_scientific_notation_u3():
60+
result = shorthand_to_stim("U3(1e-2, 2.5e1, 3e-3) 0")
61+
assert "I[U3(" in result
62+
63+
64+
def test_circuit_scientific_notation():
65+
c = Circuit("R_Z(4e-4) 0")
66+
assert len(c) == 1
67+
68+
69+
@pytest.mark.parametrize(
70+
"text, snippet",
71+
[
72+
("R_Z(a) 0", "R_Z(a)"),
73+
("R_Z(pi) 0", "R_Z(pi)"),
74+
("R_Z(1/3) 0", "R_Z(1/3)"),
75+
("R_Z() 0", "R_Z()"),
76+
("R_Z 0", "R_Z"),
77+
("R_Z(0.5, 0.3) 0", "R_Z(0.5, 0.3)"),
78+
("R_X(abc) 0", "R_X(abc)"),
79+
("U3(0.1, 0.2) 0", "U3(0.1, 0.2)"),
80+
("U3(0.1, 0.2, 0.3, 0.4) 0", "U3(0.1, 0.2, 0.3, 0.4)"),
81+
],
82+
)
83+
def test_circuit_parse_error_shows_snippet(text, snippet):
84+
with pytest.raises(ValueError, match=re.escape(snippet)):
85+
Circuit(text)

0 commit comments

Comments
 (0)