Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
4a1f708
feat: add Stim REPEAT support for scf.For loops
johnzl-777 Mar 24, 2026
f3cb43b
Let claude work out this hintlen problem
johnzl-777 Mar 26, 2026
d5a81b6
Merge branch 'main' into john/complete-repeat-support
johnzl-777 Mar 26, 2026
c6cecec
get claude to clean things up
johnzl-777 Mar 26, 2026
3232f35
fix validation vs. flatten ordering, I want validation to happen befo…
johnzl-777 Mar 26, 2026
eb5ac9a
remove the address as attribute system, keep logic simpler by just pa…
johnzl-777 Mar 26, 2026
ef091e8
use early terminating constprop
johnzl-777 Mar 27, 2026
02532a3
make the loop hint const infra more idiomatic
johnzl-777 Mar 28, 2026
ff1f81e
clarify early termination in constprop behavior
johnzl-777 Mar 28, 2026
d788336
cut down on unnecessary code
johnzl-777 Mar 28, 2026
61cce14
get rid of awkward in-function lib imports
johnzl-777 Mar 28, 2026
9446db8
make unit tests more robust, assert that all measurements accumulated…
johnzl-777 Mar 28, 2026
98b4d01
allow for standard empty list syntax as opposed to less idiomatic mea…
johnzl-777 Mar 28, 2026
facc82b
simplify empty ilist check even further
johnzl-777 Mar 28, 2026
eccd0db
Merge branch 'main' into john/complete-repeat-support
johnzl-777 Mar 30, 2026
45253cb
make squin noise rewrite more robust
johnzl-777 Mar 30, 2026
5562f27
get rid of non idiomatic statement deletion handling
johnzl-777 Mar 30, 2026
455cb42
remove monkeypatching
johnzl-777 Mar 30, 2026
fe6eba4
opt for analysis level solution to support appending to empty lists
johnzl-777 Mar 31, 2026
2fa1060
make count argument to repeat an attribute
johnzl-777 Apr 1, 2026
7277919
Have default factory that creates region w/ block
johnzl-777 Apr 1, 2026
283c636
Ensure conformity with is_subseteq behavior
johnzl-777 Apr 1, 2026
5713c2d
accidentally applied suggestion to wrong type, fixed here
johnzl-777 Apr 1, 2026
2d286a4
Merge branch 'main' into john/complete-repeat-support
johnzl-777 Apr 1, 2026
7473a8a
remove unnecessary inline predicate
johnzl-777 Apr 1, 2026
3275b15
make non stim cleanup safer
johnzl-777 Apr 1, 2026
3a45843
rework RemoveDeadNonStimStatements rule
johnzl-777 Apr 1, 2026
3a4fa24
rework RemoveDeadNonStimStatements rule (let DCE handle more of the p…
johnzl-777 Apr 1, 2026
1eb1b92
remove unnecessary dangling register cleaning (superceded by cleanup_…
johnzl-777 Apr 1, 2026
da78e0a
split and rename qubit to stim rewrite rule
johnzl-777 Apr 3, 2026
af6e883
remove early return support
johnzl-777 Apr 3, 2026
fce0f7e
revert removal of early return handling in AddressAnalysis loop
johnzl-777 Apr 3, 2026
45f603a
lean more towards inlining then validating
johnzl-777 Apr 3, 2026
86c337b
avoid monkeypatching altogether on constprop fix
johnzl-777 Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions src/bloqade/analysis/measure_id/impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
RawMeasureId,
MeasureIdBool,
MeasureIdTuple,
ConstantCarrier,
InvalidMeasureId,
)
from .analysis import MeasureIDFrame, MeasurementIDAnalysis
Expand Down Expand Up @@ -183,6 +184,18 @@ def getitem(
return (InvalidMeasureId(),)


@py.constant.dialect.register(key="measure_id")
class PyConstant(interp.MethodTable):
@interp.impl(py.Constant)
def constant(
self,
interp: MeasurementIDAnalysis,
frame: MeasureIDFrame,
stmt: py.Constant,
):
return (ConstantCarrier(data=stmt.value.unwrap()),)


@py.assign.dialect.register(key="measure_id")
class PyAssign(interp.MethodTable):
@interp.impl(py.Alias)
Expand All @@ -202,14 +215,20 @@ def add(self, interp: MeasurementIDAnalysis, frame: MeasureIDFrame, stmt: py.Add
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)

# Unwrap constant carriers holding empty ILists into empty MeasureIdTuples
if isinstance(lhs, ConstantCarrier) and isinstance(lhs.data, ilist.IList):
lhs = MeasureIdTuple(data=(), obj_type=ilist.IList)
if isinstance(rhs, ConstantCarrier) and isinstance(rhs.data, ilist.IList):
rhs = MeasureIdTuple(data=(), obj_type=ilist.IList)

if (
isinstance(lhs, MeasureIdTuple)
and isinstance(rhs, MeasureIdTuple)
and lhs.obj_type is rhs.obj_type
):
return (MeasureIdTuple(data=lhs.data + rhs.data, obj_type=lhs.obj_type),)
else:
return (InvalidMeasureId(),)

return (InvalidMeasureId(),)


@func.dialect.register(key="measure_id")
Expand Down Expand Up @@ -270,6 +289,36 @@ def if_else(
case _:
return interp_.join_results(then_results, else_results)

@interp.impl(scf.For)
def for_loop(
self,
interp_: MeasurementIDAnalysis,
frame: MeasureIDFrame,
stmt: scf.For,
):
hint = stmt.iterable.hints.get("const")
if not isinstance(hint, const.Value):
return interp_.eval_fallback(frame, stmt)

loop_vars = frame.get_values(stmt.initializers)
iterable = hint.data

body_values = {}
for value in iterable:
with interp_.new_frame(stmt, has_parent_access=True) as body_frame:
loop_vars = interp_.frame_call_region(
body_frame, stmt, stmt.body, NotMeasureId(), *loop_vars
)

for ssa, val in body_frame.entries.items():
body_values[ssa] = body_values.setdefault(ssa, val).join(val)

if loop_vars is None:
loop_vars = ()

frame.set_values(body_values.keys(), body_values.values())
return loop_vars


@record_idx_helper_dialect.register(key="measure_id")
class RecordIdxHelperAnalysis(interp.MethodTable):
Expand Down
20 changes: 19 additions & 1 deletion src/bloqade/analysis/measure_id/lattice.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Type, final
from typing import Any, Type, final
from dataclasses import dataclass

from kirin.lattice import (
Expand Down Expand Up @@ -70,6 +70,24 @@ def is_subseteq(self, other: MeasureId) -> bool:
return isinstance(other, NotMeasureId)


@final
@dataclass
class ConstantCarrier(MeasureId):
"""Carries a constant Python value through the MeasureID lattice.

When ConstantFold replaces an IR statement with py.Constant, the original
statement's MeasureID handler is lost. This element preserves the constant
value so downstream operations (e.g. Add) can interpret it.
"""

data: Any

def is_subseteq(self, other: MeasureId) -> bool:
if isinstance(other, AnyMeasureId):
return True
return isinstance(other, ConstantCarrier) and self.data == other.data
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an equality, not subseteq. It means that self.is_subseteq(AnyMeasureId) will return False, but every element must be subseteq(Top()).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This equality is still a little weird. Could someone draw this lattice for me? 😄 At which point does the ConstantCarrier fit into the hierarchy?

I also noticed now that ConcreteMeasureId has a similar issue in that it implements is_subseteq as equality, meaning it's not subseteq to AnyMeasureId. But that's a pre-existing bug.

@weinbe58 I think you originally implemented this lattice, correct? What's the reasoning behind this?



@dataclass
class ConcreteMeasureId(MeasureId):
"""Base class of lattice elements that must be structurally equal to be subseteq."""
Expand Down
5 changes: 0 additions & 5 deletions src/bloqade/squin/rewrite/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1 @@
from .wrap_analysis import (
AddressAttribute as AddressAttribute,
WrapAddressAnalysis as WrapAddressAnalysis,
)
from .U3_to_clifford import SquinU3ToClifford as SquinU3ToClifford
from .remove_dangling_qubits import RemoveDeadRegister as RemoveDeadRegister
19 changes: 0 additions & 19 deletions src/bloqade/squin/rewrite/remove_dangling_qubits.py

This file was deleted.

56 changes: 0 additions & 56 deletions src/bloqade/squin/rewrite/wrap_analysis.py

This file was deleted.

3 changes: 1 addition & 2 deletions src/bloqade/stim/analysis/from_squin_validation/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from bloqade.qubit import stmts as qubit_stmts
from bloqade.squin import gate
from bloqade.types import MeasurementResultType
from bloqade.qubit._dialect import dialect as qubit_dialect

PauliGateType = (gate.stmts.X, gate.stmts.Y, gate.stmts.Z)

Expand Down Expand Up @@ -107,7 +106,7 @@ def return_stmt(
)


@qubit_dialect.register(key="stim.validate.from_squin")
@qubit_stmts.dialect.register(key="stim.validate.from_squin")
class _QubitMethods(interp.MethodTable):

@interp.impl(qubit_stmts.IsZero)
Expand Down
8 changes: 7 additions & 1 deletion src/bloqade/stim/dialects/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from . import gate as gate, noise as noise, collapse as collapse, auxiliary as auxiliary
from . import (
gate as gate,
noise as noise,
stim_cf as stim_cf,
collapse as collapse,
auxiliary as auxiliary,
)
3 changes: 3 additions & 0 deletions src/bloqade/stim/dialects/stim_cf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .emit import EmitStimCfMethods as EmitStimCfMethods
from .stmts import Repeat as Repeat
from ._dialect import dialect as dialect
3 changes: 3 additions & 0 deletions src/bloqade/stim/dialects/stim_cf/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("stim.cf")
28 changes: 28 additions & 0 deletions src/bloqade/stim/dialects/stim_cf/emit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from kirin.interp import MethodTable, impl

from bloqade.stim.emit.stim_str import EmitStimMain, EmitStimFrame

from .stmts import Repeat
from ._dialect import dialect


@dialect.register(key="emit.stim")
class EmitStimCfMethods(MethodTable):

@impl(Repeat)
def emit_repeat(self, emit: EmitStimMain, frame: EmitStimFrame, stmt: Repeat):
frame.write_line(f"REPEAT {stmt.count} {{")
frame._indent += 1

for block in stmt.body.blocks:
frame.current_block = block
for body_stmt in block.stmts:
frame.current_stmt = body_stmt
res = emit.frame_eval(frame, body_stmt)
if isinstance(res, tuple):
frame.set_values(body_stmt.results, res)

frame._indent -= 1
frame.write_line("}")

return ()
14 changes: 14 additions & 0 deletions src/bloqade/stim/dialects/stim_cf/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from kirin import ir
from kirin.decl import info, statement

from ._dialect import dialect


@statement(dialect=dialect)
class Repeat(ir.Statement):
name = "REPEAT"
traits = frozenset({ir.HasCFG(), ir.SSACFG()})
count: int = info.attribute()
body: ir.Region = info.region(
multi=False, default_factory=lambda: ir.Region(ir.Block())
)
3 changes: 2 additions & 1 deletion src/bloqade/stim/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from kirin.passes import Fold, TypeInfer
from kirin.dialects import func, debug, ssacfg, lowering

from .dialects import gate, noise, collapse, auxiliary
from .dialects import gate, noise, stim_cf, collapse, auxiliary


@ir.dialect_group(
Expand All @@ -11,6 +11,7 @@
gate,
auxiliary,
collapse,
stim_cf,
func,
lowering.func,
lowering.call,
Expand Down
61 changes: 61 additions & 0 deletions src/bloqade/stim/passes/cleanup_non_stim.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Remove leftover impure non-stim statements after conversion.

After the full squin-to-stim conversion pipeline, some impure non-stim
statements survive because DCE only removes pure ops. This pass deletes
dead impure statements from an explicit list of expected leftovers and
warns if an unexpected impure non-stim statement survives.
"""

import warnings

from kirin import ir
from kirin.rewrite.abc import RewriteRule, RewriteResult

from bloqade.qubit import stmts as qubit_stmts

# Impure statements expected to be left over after squin-to-stim conversion.
# These don't have stim equivalents — their side effects are subsumed by
# the stim statements that replaced their consumers (e.g. qubit.New
# becomes irrelevant once all gates using that qubit are converted).
EXPECTED_IMPURE_DEAD: tuple[type[ir.Statement], ...] = (
qubit_stmts.New,
qubit_stmts.Measure,
)


class RemoveDeadNonStimStatements(RewriteRule):
"""Remove dead impure non-stim statements after conversion.

- Dead impure statements in EXPECTED_IMPURE_DEAD: deleted.
- Dead impure statements NOT in the list: warned about (likely a missed rewrite).
- Pure statements are left for DCE to handle.
"""

def __init__(self, keep: ir.DialectGroup):
self.keep = keep

def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
if not isinstance(node, EXPECTED_IMPURE_DEAD):
# Warn about unexpected dead impure non-stim statements.
if (
node.dialect is not None
and node.dialect not in self.keep
and not node.regions
and not node.has_trait(ir.IsTerminator)
and not node.has_trait(ir.Pure)
and not (
(trait := node.get_trait(ir.MaybePure)) and trait.is_pure(node)
)
and all(len(r.uses) == 0 for r in node.results)
):
warnings.warn(
f"Unexpected non-stim statement survived conversion: "
f"{type(node).__name__} from {node.dialect}",
)
return RewriteResult()

if any(len(r.uses) > 0 for r in node.results):
return RewriteResult()

node.delete()
return RewriteResult(has_done_something=True)
Loading
Loading