diff --git a/src/bloqade/analysis/address/__init__.py b/src/bloqade/analysis/address/__init__.py index 18938f823..20f026d91 100644 --- a/src/bloqade/analysis/address/__init__.py +++ b/src/bloqade/analysis/address/__init__.py @@ -12,4 +12,4 @@ UnknownQubit as UnknownQubit, PartialLambda as PartialLambda, ) -from .analysis import AddressAnalysis as AddressAnalysis +from .analysis import AddressFrame as AddressFrame, AddressAnalysis as AddressAnalysis diff --git a/src/bloqade/analysis/address/analysis.py b/src/bloqade/analysis/address/analysis.py index 6fefa08e4..18b8b5e77 100644 --- a/src/bloqade/analysis/address/analysis.py +++ b/src/bloqade/analysis/address/analysis.py @@ -1,16 +1,83 @@ from typing import Any, Type, TypeVar -from dataclasses import field +from contextlib import contextmanager +from dataclasses import field, dataclass from kirin import ir, types, interp -from kirin.analysis import Forward, const +from kirin.analysis import ForwardExtra, const +from kirin.dialects import func from kirin.dialects.ilist import IList from kirin.analysis.forward import ForwardFrame from kirin.analysis.const.lattice import PartialLambda from .lattice import Address, AddressReg, ConstResult, PartialIList, PartialTuple +InvokeKey = tuple[ir.Statement, tuple[int, ...]] + + +@dataclass +class AddressFrame(ForwardFrame[Address]): + _current_invoke_key: InvokeKey | None = None + _invoke_addresses: dict[InvokeKey, dict[ir.SSAValue, Address]] = field( + init=False, default_factory=dict + ) + + def collect_invoke_addresses( + self, call_frame: "AddressFrame", node: func.Invoke | None = None + ): + if node is not None: + inputs = self.get_values(node.inputs) + input_ids = tuple(map(id, inputs)) + key = (node, input_ids) + + # collect the addresses found in the function body + data = self._invoke_addresses.get(key, dict()) + data.update(call_frame.entries) + self._invoke_addresses[key] = data + + # collect nested invokes + self._invoke_addresses.update(call_frame._invoke_addresses) + + @contextmanager + def invoke_addresses(self, node: func.Invoke): + inputs = [self.get_or_fallback_to_invoke(input_) for input_ in node.inputs] + input_ids = tuple(map(id, inputs)) + context_key = (node, input_ids) + + reset_invoke_key = self._current_invoke_key + self._current_invoke_key = context_key + try: + yield self + finally: + self._current_invoke_key = reset_invoke_key + + def get_or_fallback_to_invoke(self, key: ir.SSAValue): + """Modified frame.get method that also checks addresses collected from + function invokes. + """ + value = self.entries.get(key, interp.Undefined) + + if not interp.is_undefined(value): + return value + + if self._current_invoke_key is not None: + additional_entries = self._invoke_addresses.get( + self._current_invoke_key, dict() + ) + value = additional_entries.get(key, interp.Undefined) + + if not interp.is_undefined(value): + return value -class AddressAnalysis(Forward[Address]): + if self.has_parent_access and self.parent: + if isinstance(self.parent, AddressFrame): + self.parent.get_or_fallback_to_invoke(key) + else: + return self.parent.get(key) + + raise interp.InterpreterError(f"SSAValue {key} not found") + + +class AddressAnalysis(ForwardExtra[AddressFrame, Address]): """ This analysis pass can be used to track the global addresses of qubits and wires. """ @@ -36,6 +103,11 @@ def initialize(self): self._const_prop.initialize() return self + def initialize_frame( + self, node: ir.Statement, *, has_parent_access: bool = False + ) -> AddressFrame: + return AddressFrame(node, has_parent_access=has_parent_access) + @property def qubit_count(self) -> int: """Total number of qubits found by the analysis.""" diff --git a/src/bloqade/analysis/address/impls.py b/src/bloqade/analysis/address/impls.py index a88f87872..242fd89db 100644 --- a/src/bloqade/analysis/address/impls.py +++ b/src/bloqade/analysis/address/impls.py @@ -15,7 +15,7 @@ PartialTuple, PartialLambda, ) -from .analysis import AddressAnalysis +from .analysis import AddressFrame, AddressAnalysis @py.constant.dialect.register(key="qubit.address") @@ -177,15 +177,17 @@ def return_( def invoke( self, interp_: AddressAnalysis, - frame: ForwardFrame[Address], + frame: AddressFrame, stmt: func.Invoke, ): - _, ret = interp_.call( + call_frame, ret = interp_.call( stmt.callee.code, interp_.method_self(stmt.callee), *frame.get_values(stmt.inputs), ) + frame.collect_invoke_addresses(call_frame, stmt) + return (ret,) @interp.impl(func.Lambda) @@ -307,7 +309,7 @@ def yield_( def ifelse( self, interp_: AddressAnalysis, - frame: ForwardFrame[Address], + frame: AddressFrame, stmt: scf.IfElse, ): address_cond = frame.get(stmt.cond) @@ -318,6 +320,7 @@ def ifelse( body = stmt.then_body if const_cond.data else stmt.else_body with interp_.new_frame(stmt, has_parent_access=True) as body_frame: ret = interp_.frame_call_region(body_frame, stmt, body, address_cond) + frame.collect_invoke_addresses(body_frame) # interp_.set_values(frame, body_frame.entries.keys(), body_frame.entries.values()) return ret else: @@ -330,6 +333,7 @@ def ifelse( address_cond, ) frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) + frame.collect_invoke_addresses(then_frame) with interp_.new_frame(stmt, has_parent_access=True) as else_frame: else_results = interp_.frame_call_region( @@ -339,6 +343,7 @@ def ifelse( address_cond, ) frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) + frame.collect_invoke_addresses(else_frame) # TODO: pick the non-return value if isinstance(then_results, interp.ReturnValue) and isinstance( else_results, interp.ReturnValue diff --git a/src/bloqade/analysis/validation/nocloning/__init__.py b/src/bloqade/analysis/validation/nocloning/__init__.py new file mode 100644 index 000000000..a61f8ba07 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/__init__.py @@ -0,0 +1,2 @@ +from . import impls as impls +from .analysis import NoCloningValidation as NoCloningValidation diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py new file mode 100644 index 000000000..bfc99944c --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -0,0 +1,180 @@ +from typing import Any +from dataclasses import dataclass + +from kirin import ir +from kirin.analysis import Forward +from kirin.ir.exception import ( + ValidationError, + DefiniteValidationError, + PotentialValidationError, +) +from kirin.analysis.forward import ForwardFrame +from kirin.validation.validationpass import ValidationPass + +from bloqade.analysis.address import AddressFrame, AddressAnalysis + +from .lattice import May, Must, Bottom, QubitValidation + + +class QubitValidationError(DefiniteValidationError): + """ValidationError for definite (Must) violations with concrete qubit addresses.""" + + qubit_id: int + gate_name: str + + def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): + super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.") + self.qubit_id = qubit_id + self.gate_name = gate_name + + +class PotentialQubitValidationError(PotentialValidationError): + """ValidationError for potential (May) violations with unknown addresses.""" + + gate_name: str + condition: str + + def __init__(self, node: ir.IRNode, gate_name: str, condition: str): + super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.") + self.gate_name = gate_name + self.condition = condition + + +@dataclass +class _NoCloningAnalysis(Forward[QubitValidation]): + """Internal forward analysis for tracking qubit cloning violations.""" + + keys = ("validate.nocloning",) + lattice = QubitValidation + _address_frame: AddressFrame | None = None + + def method_self(self, method: ir.Method) -> QubitValidation: + return self.lattice.bottom() + + def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation): + if self._address_frame is None: + addr_analysis = AddressAnalysis(self.dialects) + addr_analysis.initialize() + self._address_frame, _ = addr_analysis.run(method) + return super().run(method, *args, **kwargs) + + def eval_fallback( + self, frame: ForwardFrame[QubitValidation], node: ir.Statement + ) -> tuple[QubitValidation, ...]: + """Check for qubit usage violations and return lattice values.""" + return tuple(Bottom() for _ in node.results) + + def _get_source_name(self, value: ir.SSAValue) -> str: + """Trace back to get the source variable name.""" + from kirin.dialects.py.indexing import GetItem + + if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem): + index_arg = value.stmt.args[1] + return self._get_source_name(index_arg) + + if isinstance(value, ir.BlockArgument): + return value.name or f"arg{value.index}" + + if hasattr(value, "name") and value.name: + return value.name + + return str(value) + + def extract_errors_from_frame( + self, frame: ForwardFrame[QubitValidation] + ) -> list[ValidationError]: + """Extract validation errors from final lattice values. + + Only extracts errors from top-level statements (not nested in regions). + """ + errors = [] + seen_statements = set() + + for node, value in frame.entries.items(): + if isinstance(node, ir.ResultValue): + stmt = node.stmt + elif isinstance(node, ir.Statement): + stmt = node + else: + continue + if stmt in seen_statements: + continue + seen_statements.add(stmt) + if isinstance(value, Must): + for qubit_id, gate_name in value.violations: + errors.append(QubitValidationError(stmt, qubit_id, gate_name)) + elif isinstance(value, May): + for gate_name, condition in value.violations: + errors.append( + PotentialQubitValidationError(stmt, gate_name, condition) + ) + return errors + + def count_violations(self, frame: Any) -> int: + """Count individual violations from the frame, same as test helper.""" + from .lattice import May, Must + + total = 0 + for node, value in frame.entries.items(): + if isinstance(value, Must): + total += len(value.violations) + elif isinstance(value, May): + total += len(value.violations) + return total + + +class NoCloningValidation(ValidationPass): + """Validates the no-cloning theorem by tracking qubit addresses.""" + + def __init__(self): + self._analysis: _NoCloningAnalysis | None = None + self._cached_address_frame = None + + def name(self) -> str: + return "No-Cloning Validation" + + def get_required_analyses(self) -> list[type]: + """Declare dependency on AddressAnalysis.""" + return [AddressAnalysis] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Use cached AddressAnalysis result.""" + self._cached_address_frame = cache.get(AddressAnalysis) + + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run the no-cloning validation analysis.""" + if self._analysis is None: + self._analysis = _NoCloningAnalysis(method.dialects) + + self._analysis.initialize() + if self._cached_address_frame is not None: + self._analysis._address_frame = self._cached_address_frame + + frame, _ = self._analysis.run(method) + errors = self._analysis.extract_errors_from_frame(frame) + + return frame, errors + + def print_validation_errors(self): + """Print all collected errors with formatted snippets.""" + if self._analysis is None: + return + + if self._analysis.state._current_frame: + frame = self._analysis.state._current_frame + errors = self._analysis.extract_errors_from_frame(frame) + + for err in errors: + if isinstance(err, QubitValidationError): + print( + f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" + ) + elif isinstance(err, PotentialQubitValidationError): + print( + f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" + ) + else: + print( + f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" + ) + print(err.hint()) diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py new file mode 100644 index 000000000..2b68a0700 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -0,0 +1,158 @@ +from kirin import interp +from kirin.analysis import ForwardFrame +from kirin.dialects import scf, func + +from bloqade.analysis.address.lattice import ( + Unknown, + AddressReg, + UnknownReg, + AddressQubit, + PartialIList, + PartialTuple, + UnknownQubit, +) + +from .lattice import May, Top, Must, Bottom, QubitValidation +from .analysis import _NoCloningAnalysis + + +@scf.dialect.register(key="validate.nocloning") +class Scf(interp.MethodTable): + @interp.impl(scf.IfElse) + def if_else( + self, + interp_: _NoCloningAnalysis, + frame: ForwardFrame[QubitValidation], + stmt: scf.IfElse, + ): + try: + cond_validation = frame.get(stmt.cond) + except Exception: + cond_validation = Top() + + with interp_.new_frame(stmt, has_parent_access=True) as then_frame: + interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) + + then_state = Bottom() + for node, val in then_frame.entries.items(): + if isinstance(val, (Must, May)): + then_state = then_state.join(val) + + else_state = Bottom() + if stmt.else_body: + with interp_.new_frame(stmt, has_parent_access=True) as else_frame: + interp_.frame_call_region( + else_frame, stmt, stmt.else_body, cond_validation + ) + + for node, val in else_frame.entries.items(): + if isinstance(val, (Must, May)): + else_state = else_state.join(val) + + merged = then_state.join(else_state) + + if isinstance(merged, May): + then_has = not isinstance(then_state, Bottom) + else_has = not isinstance(else_state, Bottom) + + if then_has and not else_has: + new_violations = frozenset( + (gate, ", when condition is true") for gate, _ in merged.violations + ) + merged = May(violations=new_violations) + elif else_has and not then_has: + new_violations = frozenset( + (gate, ", when condition is false") for gate, _ in merged.violations + ) + merged = May(violations=new_violations) + + return (merged,) + + +@func.dialect.register(key="validate.nocloning") +class Func(interp.MethodTable): + @interp.impl(func.Invoke) + def invoke_( + self, + interp_: _NoCloningAnalysis, + frame: ForwardFrame[QubitValidation], + stmt: func.Invoke, + ): + if not isinstance(stmt, func.Invoke): + return tuple(Bottom() for _ in stmt.results) + + address_frame = interp_._address_frame + if address_frame is None: + return tuple(Top() for _ in stmt.results) + + concrete_addrs: list[int] = [] + has_unknown = False + has_qubit_args = False + unknown_arg_names: list[str] = [] + + for arg in stmt.args: + addr = address_frame.get_or_fallback_to_invoke(arg) + match addr: + case AddressQubit(data=qubit_addr): + has_qubit_args = True + concrete_addrs.append(qubit_addr) + case AddressReg(data=addrs): + has_qubit_args = True + concrete_addrs.extend(addrs) + case ( + UnknownQubit() + | UnknownReg() + | PartialIList() + | PartialTuple() + | Unknown() + ): + has_qubit_args = True + has_unknown = True + arg_name = interp_._get_source_name(arg) + unknown_arg_names.append(arg_name) + case _: + pass + + if not has_qubit_args: + return tuple(Bottom() for _ in stmt.results) + + seen: set[int] = set() + violations: set[tuple[int, str]] = set() + s_name = getattr(stmt.callee, "sym_name", "") + gate_name = s_name.upper() + + for qubit_addr in concrete_addrs: + if qubit_addr in seen: + violations.add((qubit_addr, gate_name)) + seen.add(qubit_addr) + + if violations: + current_errors = interp_.get_validation_errors() + # NOTE: verify violation by stepping into the function + # need local addresses to be available then + assert interp_._address_frame is not None + with interp_._address_frame.invoke_addresses(stmt): + _ = interp_.call( + stmt.callee.code, + interp_.method_self(stmt.callee), + *frame.get_values(stmt.inputs), + ) + + if len(interp_.get_validation_errors()) > len(current_errors): + # NOTE: there was a new error added + usage = Must(violations=frozenset(violations)) + else: + # NOTE: we're actually fine + usage = Bottom() + elif has_unknown: + args_str = " == ".join(unknown_arg_names) + if len(unknown_arg_names) > 1: + condition = f", when {args_str}" + else: + condition = f", with unknown argument {args_str}" + + usage = May(violations=frozenset([(gate_name, condition)])) + else: + usage = Bottom() + + return tuple(usage for _ in stmt.results) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py new file mode 100644 index 000000000..a8ca50492 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -0,0 +1,174 @@ +from abc import abstractmethod +from typing import FrozenSet, final +from dataclasses import field, dataclass + +from kirin.lattice import SingletonMeta, BoundedLattice + + +@dataclass +class QubitValidation(BoundedLattice["QubitValidation"]): + r"""Base class for qubit-cloning validation lattice. + + Semantics for control flow: + - Bottom: proven safe / never occurs + - Must: definitely occurs on ALL paths + - May: possibly occurs on SOME paths + - Top: unknown / no information + + Lattice ordering (more precise --> less precise): + Bottom ⊑ Must ⊑ May ⊑ Top + Bottom ⊑ May ⊑ Top + + Key insight: Must ⊔ Bottom = May (happens on some paths, not all) + """ + + @classmethod + def bottom(cls) -> "Bottom": + return Bottom() + + @classmethod + def top(cls) -> "Top": + return Top() + + @abstractmethod + def is_subseteq(self, other: "QubitValidation") -> bool: ... + + @abstractmethod + def join(self, other: "QubitValidation") -> "QubitValidation": ... + + @abstractmethod + def meet(self, other: "QubitValidation") -> "QubitValidation": ... + + +@final +class Bottom(QubitValidation, metaclass=SingletonMeta): + def is_subseteq(self, other: QubitValidation) -> bool: + return True + + def join(self, other: QubitValidation) -> QubitValidation: + return other + + def meet(self, other: QubitValidation) -> QubitValidation: + return self + + def __repr__(self) -> str: + return "⊥ (No Errors)" + + +@final +class Top(QubitValidation, metaclass=SingletonMeta): + def is_subseteq(self, other: QubitValidation) -> bool: + return isinstance(other, Top) + + def join(self, other: QubitValidation) -> QubitValidation: + return self + + def meet(self, other: QubitValidation) -> QubitValidation: + return other + + def __repr__(self) -> str: + return "⊤ (Unknown)" + + +@final +@dataclass +class Must(QubitValidation): + """Definite violations with concrete qubit IDs and gate names.""" + + violations: FrozenSet[tuple[int, str]] = field(default_factory=frozenset) + """Set of (qubit_id, gate_name) tuples""" + + def is_subseteq(self, other: QubitValidation) -> bool: + match other: + case Bottom(): + return False + case Must(violations=ov): + return self.violations.issubset(ov) + case May() | Top(): + return True + return False + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + may_violations = frozenset((gate, "") for _, gate in self.violations) + return May(violations=may_violations) + case Must(violations=ov): + merged = self.violations | ov + return Must(violations=merged) + case May(violations=ov): + may_viols = frozenset((gate, "") for _, gate in self.violations) + return May(violations=may_viols | ov) + case Top(): + return other + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return other + case Must(violations=ov): + inter = self.violations & ov + return Must(violations=inter) if inter else Bottom() + case May(): + return self + case Top(): + return self + return Bottom() + + def __repr__(self) -> str: + if not self.violations: + return "Must(∅)" + viols = ", ".join(f"Qubit[{qid}] at {gate}" for qid, gate in self.violations) + return f"Must({{{viols}}})" + + +@final +@dataclass +class May(QubitValidation): + """Potential violations with gate names and conditions.""" + + violations: FrozenSet[tuple[str, str]] = field(default_factory=frozenset) + """Set of (gate_name, condition) tuples""" + + def is_subseteq(self, other: QubitValidation) -> bool: + match other: + case Bottom() | Must(): + return False + case May(violations=ov): + return self.violations.issubset(ov) + case Top(): + return True + return False + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=ov): + may_viols = frozenset((gate, "") for _, gate in ov) + return May(violations=self.violations | may_viols) + case May(violations=ov): + return May(violations=self.violations | ov) + case Top(): + return other + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return other + case Must(): + return other + case May(violations=ov): + inter = self.violations & ov + return May(violations=inter) if inter else Bottom() + case Top(): + return self + return Bottom() + + def __repr__(self) -> str: + if not self.violations: + return "May(∅)" + viols = ", ".join(f"{gate}{cond}" for gate, cond in self.violations) + return f"May({{{viols}}})" diff --git a/src/bloqade/squin/__init__.py b/src/bloqade/squin/__init__.py index 6be327a8e..a5c3c8f8b 100644 --- a/src/bloqade/squin/__init__.py +++ b/src/bloqade/squin/__init__.py @@ -53,6 +53,7 @@ single_qubit_pauli_channel as single_qubit_pauli_channel, ) from .analysis.fidelity import impls as impls +from .analysis.validation.nocloning import impls as nocloning_impls # noqa: F401 # NOTE: it's important to keep these imports here since they import squin.kernel # we skip isort here diff --git a/src/bloqade/squin/analysis/validation/nocloning/__init__.py b/src/bloqade/squin/analysis/validation/nocloning/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/bloqade/squin/analysis/validation/nocloning/impls.py b/src/bloqade/squin/analysis/validation/nocloning/impls.py new file mode 100644 index 000000000..d0f952852 --- /dev/null +++ b/src/bloqade/squin/analysis/validation/nocloning/impls.py @@ -0,0 +1,92 @@ +from typing import Sequence + +from kirin import ir, interp +from kirin.analysis import ForwardFrame + +from bloqade.squin import gate +from bloqade.analysis.address import AddressReg +from bloqade.analysis.validation.nocloning.lattice import Must, Bottom, QubitValidation +from bloqade.analysis.validation.nocloning.analysis import _NoCloningAnalysis + + +@gate.dialect.register(key="validate.nocloning") +class GateMethods(interp.MethodTable): + + @interp.impl(gate.stmts.X) + @interp.impl(gate.stmts.Y) + @interp.impl(gate.stmts.Z) + @interp.impl(gate.stmts.H) + @interp.impl(gate.stmts.S) + @interp.impl(gate.stmts.T) + @interp.impl(gate.stmts.SqrtX) + @interp.impl(gate.stmts.SqrtY) + @interp.impl(gate.stmts.Rx) + @interp.impl(gate.stmts.Ry) + @interp.impl(gate.stmts.Rz) + @interp.impl(gate.stmts.U3) + def single_qubit_gate( + self, + interp_: _NoCloningAnalysis, + frame: ForwardFrame[QubitValidation], + stmt: gate.stmts.SingleQubitGate, + ): + if interp_._address_frame is None: + return + + addr = interp_._address_frame.get_or_fallback_to_invoke(stmt.qubits) + + if not isinstance(addr, AddressReg): + raise NotImplementedError("TODO: handle unknowns") + + return self._check_addresses_for_overlap(addr.data, interp_, stmt) + + @interp.impl(gate.stmts.CX) + @interp.impl(gate.stmts.CY) + @interp.impl(gate.stmts.CZ) + def controlled_gate( + self, + interp_: _NoCloningAnalysis, + frame: ForwardFrame[QubitValidation], + stmt: gate.stmts.ControlledGate, + ): + if interp_._address_frame is None: + return + + ctrl_addr = interp_._address_frame.get_or_fallback_to_invoke(stmt.controls) + target_addr = interp_._address_frame.get_or_fallback_to_invoke(stmt.targets) + + if not isinstance(ctrl_addr, AddressReg) or not isinstance( + target_addr, AddressReg + ): + raise NotImplementedError("TODO: handle unknowns") + + data = list(ctrl_addr.data) + list(target_addr.data) + + return self._check_addresses_for_overlap(data, interp_, stmt) + + def _check_addresses_for_overlap( + self, data: Sequence[int], interp_: _NoCloningAnalysis, stmt: ir.Statement + ): + gate_name = stmt.name.upper() + seen = set() + violations = set() + errors = [] + for qubit_addr in data: + if qubit_addr in seen: + errors.append( + ir.ValidationError( + stmt, + f"Gate {gate_name} applies to qubit {qubit_addr} more than once.", + ) + ) + violations.add((qubit_addr, gate_name)) + seen.add(qubit_addr) + + if violations: + usage = Must(violations=frozenset(violations)) + for error in errors: + interp_.add_validation_error(stmt, error) + else: + usage = Bottom() + + return tuple(usage for _ in stmt.results) diff --git a/test/analysis/validation/nocloning/test_no_cloning.py b/test/analysis/validation/nocloning/test_no_cloning.py new file mode 100644 index 000000000..ddd791c3a --- /dev/null +++ b/test/analysis/validation/nocloning/test_no_cloning.py @@ -0,0 +1,213 @@ +from typing import Any, TypeVar + +import pytest +from kirin import ir +from kirin.dialects.ilist.runtime import IList + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.analysis.validation.nocloning.lattice import May, Must +from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation + +T = TypeVar("T", bound=Must | May) + + +def collect_errors_from_validation( + validation: NoCloningValidation, + frame, +) -> tuple[int, int]: + """Count Must (definite) and May (potential) errors from the validation pass. + + Returns: + (must_count, may_count) - number of definite and potential violations + """ + must_count = 0 + may_count = 0 + + if validation._analysis is None: + return (must_count, may_count) + + for node, value in frame.entries.items(): + if isinstance(value, Must): + must_count += len(value.violations) + elif isinstance(value, May): + may_count += len(value.violations) + + return must_count, may_count + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(): + q = squin.qalloc(1) + control_gate(q[0], q[0]) + + validation = NoCloningValidation() + + frame, errors = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 + assert may_count == 0 + + +@pytest.mark.parametrize("control_gate", [squin.cx]) +def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(cond: bool): + q = squin.qalloc(10) + if cond: + control_gate(q[0], q[0]) + else: + control_gate(q[0], q[1]) + squin.cx(q[1], q[1]) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 # squin.cx(q[1], q[1]) outside conditional + assert may_count == 1 # control_gate(q[0], q[0]) inside conditional + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def test(): + q = squin.qalloc(3) + control_gate(q[0], q[1]) + squin.rx(1.57, q[0]) + squin.measure(q[0]) + control_gate(q[0], q[2]) + + validation = NoCloningValidation() + frame, errors = validation.run(test) + print() + test.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 0 + assert may_count == 0 + + +def test_fail_2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + a = 1 + squin.cx(q[0], q[1]) + squin.cy(q[1], q[a]) + + validation = NoCloningValidation() + frame, errors = validation.run(good_kernel) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 + assert may_count == 0 + + +def test_parallel_fail(): + @squin.kernel + def bad_kernel(): + q = squin.qalloc(5) + squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 2 + assert may_count == 0 + + +def test_potential_fail(): + @squin.kernel + def bad_kernel(a: int, b: int): + q = squin.qalloc(5) + squin.cx(q[a], q[2]) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 0 + assert may_count == 1 + + +def test_potential_parallel_fail(): + @squin.kernel + def bad_kernel(a: IList): + q = squin.qalloc(5) + squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 0 + assert may_count == 1 + + +def test_single_qubit_gate(): + @squin.kernel + def bad_kernel(): + q = squin.qalloc(1) + squin.broadcast.x([q[0], q[0]]) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 + assert may_count == 0 + + +def test_custom_subroutines(): + + @squin.kernel + def custom_subroutine(q1: Qubit, q2: Qubit): + squin.x(q1) + squin.y(q2) + + @squin.kernel + def good_kernel(): + q = squin.qalloc(1) + custom_subroutine(q[0], q[0]) + + validation = NoCloningValidation() + frame, errors = validation.run(good_kernel) + + good_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == may_count == 0 + + +def test_same_gate_different_errors(): + @squin.kernel + def bad_kernel(a: int): + q = squin.qalloc(1) + squin.cx(q[0], q[0]) + squin.cx(q[0], q[a]) + + validation = NoCloningValidation() + frame, errors = validation.run(bad_kernel) + + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation, frame) + assert must_count == 1 + assert may_count == 1 diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py new file mode 100644 index 000000000..7294e349f --- /dev/null +++ b/test/analysis/validation/test_compose_validation.py @@ -0,0 +1,51 @@ +import pytest +from kirin.ir.exception import ValidationErrorGroup +from kirin.validation.validationpass import ValidationSuite + +from bloqade import squin +from bloqade.analysis.validation.nocloning import NoCloningValidation + + +def test_validation_suite(): + @squin.kernel + def bad_kernel(a: int): + q = squin.qalloc(2) + squin.cx(q[0], q[0]) # definite cloning error + squin.cx(q[a], q[1]) # potential cloning error + + # Running no-cloning validation multiple times + suite = ValidationSuite( + [ + NoCloningValidation, + NoCloningValidation, + NoCloningValidation, + ] + ) + result = suite.validate(bad_kernel) + + assert not result.is_valid + assert ( + result.error_count() == 2 + ) # Report 2 errors, even when validated multiple times + with pytest.raises(ValidationErrorGroup) as exc_info: + result.raise_if_invalid() + print(f"{exc_info.value}") + + +def test_validation_suite2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + squin.cx(q[0], q[1]) + + suite = ValidationSuite( + [ + NoCloningValidation, + ], + fail_fast=True, + ) + result = suite.validate(good_kernel) + + assert result.is_valid + assert result.error_count() == 0 + result.raise_if_invalid()