Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/bloqade/stim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from ._wrappers import (
h as h,
s as s,
t as t,
x as x,
y as y,
z as z,
Expand All @@ -15,6 +16,7 @@
rx as rx,
ry as ry,
rz as rz,
u3 as u3,
mpp as mpp,
mxx as mxx,
myy as myy,
Expand All @@ -32,6 +34,9 @@
detector as detector,
identity as identity,
qubit_loss as qubit_loss,
rotation_x as rotation_x,
rotation_y as rotation_y,
rotation_z as rotation_z,
depolarize1 as depolarize1,
depolarize2 as depolarize2,
pauli_string as pauli_string,
Expand Down
160 changes: 124 additions & 36 deletions src/bloqade/stim/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,68 +8,147 @@
# dialect:: gate
## 1q
@wraps(gate.X)
def x(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def x(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.Y)
def y(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def y(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.Z)
def z(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def z(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.Identity)
def identity(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def identity(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.H)
def h(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def h(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.S)
def s(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def s(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.SqrtX)
def sqrt_x(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def sqrt_x(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.SqrtY)
def sqrt_y(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def sqrt_y(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.SqrtZ)
def sqrt_z(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def sqrt_z(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


## clif 2q
@wraps(gate.Swap)
def swap(targets: tuple[int, ...], dagger: bool = False) -> None: ...
def swap(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


## ctrl 2q
@wraps(gate.CX)
def cx(
controls: tuple[int, ...], targets: tuple[int, ...], dagger: bool = False
controls: tuple[int, ...],
targets: tuple[int, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


@wraps(gate.CY)
def cy(
controls: tuple[int, ...], targets: tuple[int, ...], dagger: bool = False
controls: tuple[int, ...],
targets: tuple[int, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


@wraps(gate.CZ)
def cz(
controls: tuple[int, ...], targets: tuple[int, ...], dagger: bool = False
controls: tuple[int, ...],
targets: tuple[int, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


## pp
@wraps(gate.SPP)
def spp(targets: tuple[auxiliary.PauliString, ...], dagger=False) -> None: ...
def spp(
targets: tuple[auxiliary.PauliString, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


## Non-Clifford


@wraps(gate.T)
def t(
targets: tuple[int, ...], dagger: bool = False, tag: str | None = None
) -> None: ...


@wraps(gate.Rx)
def rotation_x(
angle: float,
targets: tuple[int, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


@wraps(gate.Ry)
def rotation_y(
angle: float,
targets: tuple[int, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


@wraps(gate.Rz)
def rotation_z(
angle: float,
targets: tuple[int, ...],
dagger: bool = False,
tag: str | None = None,
) -> None: ...


@wraps(gate.U3)
def u3(
theta: float,
phi: float,
lam: float,
targets: tuple[int, ...],
tag: str | None = None,
) -> None: ...


# dialect:: aux
Expand All @@ -79,18 +158,20 @@ def rec(id: int) -> auxiliary.RecordResult: ...

@wraps(auxiliary.Detector)
def detector(
coord: tuple[Union[int, float], ...], targets: tuple[auxiliary.RecordResult, ...]
coord: tuple[Union[int, float], ...],
targets: tuple[auxiliary.RecordResult, ...],
tag: str | None = None,
) -> None: ...


@wraps(auxiliary.ObservableInclude)
def observable_include(
idx: int, targets: tuple[auxiliary.RecordResult, ...]
idx: int, targets: tuple[auxiliary.RecordResult, ...], tag: str | None = None
) -> None: ...


@wraps(auxiliary.Tick)
def tick() -> None: ...
def tick(tag: str | None = None) -> None: ...


@wraps(auxiliary.NewPauliString)
Expand All @@ -100,62 +181,66 @@ def pauli_string(


@wraps(auxiliary.QubitCoordinates)
def qubit_coords(coord: tuple[Union[int, float], ...], target: int) -> None: ...
def qubit_coords(
coord: tuple[Union[int, float], ...], target: int, tag: str | None = None
) -> None: ...


# dialect:: collapse
@wraps(collapse.MZ)
def mz(p: float, targets: tuple[int, ...]) -> None: ...
def mz(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.MY)
def my(p: float, targets: tuple[int, ...]) -> None: ...
def my(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.MX)
def mx(p: float, targets: tuple[int, ...]) -> None: ...
def mx(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.MZZ)
def mzz(p: float, targets: tuple[int, ...]) -> None: ...
def mzz(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.MYY)
def myy(p: float, targets: tuple[int, ...]) -> None: ...
def myy(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.MXX)
def mxx(p: float, targets: tuple[int, ...]) -> None: ...
def mxx(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.PPMeasurement)
def mpp(p: float, targets: tuple[auxiliary.PauliString, ...]) -> None: ...
def mpp(
p: float, targets: tuple[auxiliary.PauliString, ...], tag: str | None = None
) -> None: ...


@wraps(collapse.RZ)
def rz(targets: tuple[int, ...]) -> None: ...
def rz(targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.RY)
def ry(targets: tuple[int, ...]) -> None: ...
def ry(targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(collapse.RX)
def rx(targets: tuple[int, ...]) -> None: ...
def rx(targets: tuple[int, ...], tag: str | None = None) -> None: ...


# dialect:: noise
@wraps(noise.Depolarize1)
def depolarize1(p: float, targets: tuple[int, ...]) -> None: ...
def depolarize1(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(noise.Depolarize2)
def depolarize2(p: float, targets: tuple[int, ...]) -> None: ...
def depolarize2(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(noise.PauliChannel1)
def pauli_channel1(
px: float, py: float, pz: float, targets: tuple[int, ...]
px: float, py: float, pz: float, targets: tuple[int, ...], tag: str | None = None
) -> None: ...


Expand All @@ -177,26 +262,29 @@ def pauli_channel2(
pzy: float,
pzz: float,
targets: tuple[int, ...],
tag: str | None = None,
) -> None: ...


@wraps(noise.XError)
def x_error(p: float, targets: tuple[int, ...]) -> None: ...
def x_error(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(noise.YError)
def y_error(p: float, targets: tuple[int, ...]) -> None: ...
def y_error(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(noise.ZError)
def z_error(p: float, targets: tuple[int, ...]) -> None: ...
def z_error(p: float, targets: tuple[int, ...], tag: str | None = None) -> None: ...


@wraps(noise.QubitLoss)
def qubit_loss(probs: tuple[float, ...], targets: tuple[int, ...]) -> None: ...
def qubit_loss(
probs: tuple[float, ...], targets: tuple[int, ...], tag: str | None = None
) -> None: ...


@wraps(noise.CorrelatedQubitLoss)
def correlated_qubit_loss(
probs: tuple[float, ...], targets: tuple[int, ...]
probs: tuple[float, ...], targets: tuple[int, ...], tag: str | None = None
) -> None: ...
21 changes: 15 additions & 6 deletions src/bloqade/stim/dialects/auxiliary/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
@dialect.register(key="emit.stim")
class EmitStimAuxMethods(MethodTable):

def _format_with_tag(self, name: str, tag: str | None) -> str:
"""Format instruction name with optional tag annotation."""
if tag:
return f"{name}[{tag}]"
return name

@impl(stmts.ConstInt)
def const_int(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.ConstInt):

Expand Down Expand Up @@ -57,8 +63,8 @@ def get_rec(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.GetRecor

@impl(stmts.Tick)
def tick(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Tick):

frame.write_line("TICK")
name = self._format_with_tag("TICK", stmt.tag)
frame.write_line(name)

return ()

Expand All @@ -67,13 +73,14 @@ def detector(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: stmts.Detecto

coords: tuple[str, ...] = frame.get_values(stmt.coord)
targets: tuple[str, ...] = frame.get_values(stmt.targets)
name = self._format_with_tag("DETECTOR", stmt.tag)

coord_str: str = ", ".join(coords)
target_str: str = " ".join(targets)
if len(coords):
frame.write_line(f"DETECTOR({coord_str}) {target_str}")
frame.write_line(f"{name}({coord_str}) {target_str}")
else:
frame.write_line(f"DETECTOR {target_str}")
frame.write_line(f"{name} {target_str}")
return ()

@impl(stmts.ObservableInclude)
Expand All @@ -83,9 +90,10 @@ def obs_include(

idx: str = frame.get(stmt.idx)
targets: tuple[str, ...] = frame.get_values(stmt.targets)
name = self._format_with_tag("OBSERVABLE_INCLUDE", stmt.tag)

target_str: str = " ".join(targets)
frame.write_line(f"OBSERVABLE_INCLUDE({idx}) {target_str}")
frame.write_line(f"{name}({idx}) {target_str}")

return ()

Expand All @@ -111,8 +119,9 @@ def qubit_coordinates(

coords: tuple[str, ...] = frame.get_values(stmt.coord)
target: str = frame.get(stmt.target)
name = self._format_with_tag("QUBIT_COORDS", stmt.tag)

coord_str: str = ", ".join(coords)
frame.write_line(f"QUBIT_COORDS({coord_str}) {target}")
frame.write_line(f"{name}({coord_str}) {target}")

return ()
Loading
Loading