diff --git a/source/pip/benchmarks/bench_qre.py b/source/pip/benchmarks/bench_qre.py index 6dd54f2af9..561aa2c0b4 100644 --- a/source/pip/benchmarks/bench_qre.py +++ b/source/pip/benchmarks/bench_qre.py @@ -3,6 +3,7 @@ import timeit from dataclasses import dataclass, KW_ONLY, field +from qsharp.qre.models import AQREGateBased, SurfaceCode from qsharp.qre._enumeration import _enumerate_instances @@ -35,30 +36,13 @@ def bench_enumerate_isas(): # Add the tests directory to sys.path to import test_qre # TODO: Remove this once the models in test_qre are moved to a proper module sys.path.append(os.path.join(os.path.dirname(__file__), "../tests")) - import test_qre # type: ignore + from test_qre import ExampleLogicalFactory, ExampleFactory # type: ignore - from qsharp.qre._isa_enumeration import ( - Context, - ISAQuery, - ProductNode, - ) - - ctx = Context(architecture=test_qre.ExampleArchitecture()) + ctx = AQREGateBased().context() # Hierarchical factory using from_components - query = ProductNode( - sources=[ - ISAQuery(test_qre.SurfaceCode), - ISAQuery( - test_qre.ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISAQuery(test_qre.SurfaceCode), - ISAQuery(test_qre.ExampleFactory), - ] - ), - ), - ] + query = SurfaceCode.q() * ExampleLogicalFactory.q( + source=SurfaceCode.q() * ExampleFactory.q() ) number = 100 diff --git a/source/pip/qsharp/qre/__init__.py b/source/pip/qsharp/qre/__init__.py index 771a23ea14..d6dbb24e29 100644 --- a/source/pip/qsharp/qre/__init__.py +++ b/source/pip/qsharp/qre/__init__.py @@ -1,38 +1,61 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from ._application import Application, QSharpApplication +from ._architecture import Architecture +from ._estimation import estimate from ._instruction import ( LOGICAL, PHYSICAL, Encoding, + ISATransform, constraint, instruction, - ISATransform, ) +from ._isa_enumeration import ISAQuery, ISARefNode, ISA_ROOT from ._qre import ( ISA, + InstructionFrontier, Constraint, ConstraintBound, + EstimationResult, + FactoryResult, ISARequirements, + Block, + Trace, block_linear_function, constant_function, linear_function, ) -from ._architecture import Architecture +from ._trace import LatticeSurgery, PSSPC, TraceQuery __all__ = [ "block_linear_function", "constant_function", "constraint", + "estimate", "instruction", "linear_function", + "Application", "Architecture", + "Block", "Constraint", "ConstraintBound", "Encoding", + "EstimationResult", + "FactoryResult", + "InstructionFrontier", "ISA", + "ISA_ROOT", + "ISAQuery", + "ISARefNode", "ISARequirements", "ISATransform", + "LatticeSurgery", + "PSSPC", + "QSharpApplication", + "Trace", + "TraceQuery", "LOGICAL", "PHYSICAL", ] diff --git a/source/pip/qsharp/qre/_application.py b/source/pip/qsharp/qre/_application.py new file mode 100644 index 0000000000..43e81ea4eb --- /dev/null +++ b/source/pip/qsharp/qre/_application.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations + +import types +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import ( + Any, + Callable, + ClassVar, + Generic, + Protocol, + TypeVar, + Generator, + get_type_hints, + cast, +) + +from .._qsharp import logical_counts +from ..estimator import LogicalCounts +from ._enumeration import _enumerate_instances +from ._qre import Trace +from .instruction_ids import CCX, MEAS_Z, RZ, T + + +class DataclassProtocol(Protocol): + __dataclass_fields__: ClassVar[dict] + + +TraceParameters = TypeVar("TraceParameters", DataclassProtocol, types.NoneType) + + +class Application(ABC, Generic[TraceParameters]): + """ + An application defines a class of quantum computation problems along with a + method to generate traces for specific problem instances. + + We distinguish between application and trace parameters. The application + parameters define which particular instance of the application we want to + consider. The trace parameters define how to generate a trace. They + change the specific way in which we solve the problem, but not the problem + itself. + + For example, in quantum cryptography, the application parameters could + define the key size for an RSA prime product, while the trace parameters + define which algorithm to use to break the cryptography, as well as + parameters therein. + """ + + @abstractmethod + def get_trace(self, parameters: TraceParameters) -> Trace: + """Return the trace corresponding to this application.""" + + def context(self, **kwargs) -> _Context: + """Create a new enumeration context for this application.""" + return _Context(self, **kwargs) + + def enumerate_traces( + self, + **kwargs, + ) -> Generator[Trace, None, None]: + """Yields all traces of an application given its dataclass parameters.""" + + param_type = get_type_hints(self.__class__.get_trace).get("parameters") + if param_type is types.NoneType: + yield self.get_trace(None) # type: ignore + return + + if isinstance(param_type, TypeVar): + for c in param_type.__constraints__: + if c is not types.NoneType: + param_type = c + break + for parameters in _enumerate_instances(cast(type, param_type), **kwargs): + yield self.get_trace(parameters) + + +class _Context: + application: Application + kwargs: dict[str, Any] + + def __init__(self, application: Application, **kwargs): + self.application = application + self.kwargs = kwargs + + +@dataclass +class QSharpApplication(Application[None]): + def __init__(self, entry_expr: str | Callable | LogicalCounts): + self._entry_expr = entry_expr + + def get_trace(self, parameters: None = None) -> Trace: + if not isinstance(self._entry_expr, LogicalCounts): + self._counts = logical_counts(self._entry_expr) + else: + self._counts = self._entry_expr + return self._trace_from_logical_counts(self._counts) + + def _trace_from_logical_counts(self, counts: LogicalCounts) -> Trace: + ccx_count = counts.get("cczCount", 0) + counts.get("ccixCount", 0) + + trace = Trace(counts.get("numQubits", 0)) + + rotation_count = counts.get("rotationCount", 0) + rotation_depth = counts.get("rotationDepth", rotation_count) + + if rotation_count != 0: + if rotation_depth > 1: + rotations_per_layer = rotation_count // (rotation_depth - 1) + else: + rotations_per_layer = 0 + + last_layer = rotation_count - (rotations_per_layer * (rotation_depth - 1)) + + if rotations_per_layer != 0: + block = trace.add_block(repetitions=rotation_depth - 1) + for i in range(rotations_per_layer): + block.add_operation(RZ, [i]) + block = trace.add_block() + for i in range(last_layer): + block.add_operation(RZ, [i]) + + if t_count := counts.get("tCount", 0): + block = trace.add_block(repetitions=t_count) + block.add_operation(T, [0]) + + if ccx_count: + block = trace.add_block(repetitions=ccx_count) + block.add_operation(CCX, [0, 1, 2]) + + if meas_count := counts.get("measurementCount", 0): + block = trace.add_block(repetitions=meas_count) + block.add_operation(MEAS_Z, [0]) + + # TODO: handle memory qubits + + return trace diff --git a/source/pip/qsharp/qre/_architecture.py b/source/pip/qsharp/qre/_architecture.py index 0d95bb0a93..fe991aff42 100644 --- a/source/pip/qsharp/qre/_architecture.py +++ b/source/pip/qsharp/qre/_architecture.py @@ -1,7 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from __future__ import annotations + from abc import ABC, abstractmethod +from dataclasses import dataclass, field from ._qre import ISA @@ -10,3 +13,25 @@ class Architecture(ABC): @property @abstractmethod def provided_isa(self) -> ISA: ... + + def context(self) -> _Context: + """Create a new enumeration context for this architecture.""" + return _Context(self.provided_isa) + + +@dataclass(slots=True, frozen=True) +class _Context: + """ + Context passed through enumeration, holding shared state. + + Attributes: + root_isa: The root ISA for enumeration. + """ + + root_isa: ISA + _bindings: dict[str, ISA] = field(default_factory=dict, repr=False) + + def _with_binding(self, name: str, isa: ISA) -> _Context: + """Return a new context with an additional binding (internal use).""" + new_bindings = {**self._bindings, name: isa} + return _Context(self.root_isa, new_bindings) diff --git a/source/pip/qsharp/qre/_enumeration.py b/source/pip/qsharp/qre/_enumeration.py index 59eb1a9582..d41b279d0c 100644 --- a/source/pip/qsharp/qre/_enumeration.py +++ b/source/pip/qsharp/qre/_enumeration.py @@ -1,7 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Generator, Type, TypeVar, Literal, get_args, get_origin +from typing import ( + Generator, + Type, + TypeVar, + Literal, + get_args, + get_origin, + get_type_hints, +) from dataclasses import MISSING from itertools import product from enum import Enum @@ -57,8 +65,13 @@ class MyConfig: yield cls(**kwargs) return - for field in fields.values(): + # Resolve type hints to handle stringified types from __future__.annotations + type_hints = get_type_hints(cls) + + for field in fields.values(): # type: ignore name = field.name + # Get resolved type or fallback to field.type + current_type = type_hints.get(name, field.type) if name in kwargs: val = kwargs[name] @@ -83,16 +96,16 @@ class MyConfig: values.append(domain) continue - if field.type is bool: + if current_type is bool: values.append([True, False]) continue - if isinstance(field.type, type) and issubclass(field.type, Enum): - values.append(list(field.type)) + if isinstance(current_type, type) and issubclass(current_type, Enum): + values.append(list(current_type)) continue - if get_origin(field.type) is Literal: - values.append(list(get_args(field.type))) + if get_origin(current_type) is Literal: + values.append(list(get_args(current_type))) continue if field.default is not MISSING: diff --git a/source/pip/qsharp/qre/_estimation.py b/source/pip/qsharp/qre/_estimation.py new file mode 100644 index 0000000000..79b11b9eb7 --- /dev/null +++ b/source/pip/qsharp/qre/_estimation.py @@ -0,0 +1,52 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._application import Application +from ._architecture import Architecture +from ._qre import EstimationCollection, estimate_parallel +from ._trace import TraceQuery +from ._isa_enumeration import ISAQuery + + +def estimate( + application: Application, + architecture: Architecture, + trace_query: TraceQuery, + isa_query: ISAQuery, + *, + max_error: float = 1.0, +) -> EstimationCollection: + """ + Estimate the resource requirements for a given application instance and + architecture. + + The application instance might return multiple traces. Each of the traces + is transformed by the trace query, which applies several trace transforms in + sequence. Each transform may return multiple traces. Similarly, the + architecture's ISA is transformed by the ISA query, which applies several + ISA transforms in sequence, each of which may return multiple ISAs. The + estimation is performed for each combination of transformed trace and ISA. + The results are collected into an EstimationCollection and returned. + + The collection only contains the results that are optimal with respect to + the total number of qubits and the total runtime. + + Args: + application (Application): The quantum application to be estimated. + architecture (Architecture): The target quantum architecture. + trace_query (TraceQuery): The trace query to enumerate traces from the + application. + isa_query (ISAQuery): The ISA query to enumerate ISAs from the architecture. + + Returns: + EstimationCollection: A collection of estimation results. + """ + + app_ctx = application.context() + arch_ctx = architecture.context() + + return estimate_parallel( + list(trace_query.enumerate(app_ctx)), + list(isa_query.enumerate(arch_ctx)), + max_error, + ) diff --git a/source/pip/qsharp/qre/_instruction.py b/source/pip/qsharp/qre/_instruction.py index a74c97376b..9c4b24260e 100644 --- a/source/pip/qsharp/qre/_instruction.py +++ b/source/pip/qsharp/qre/_instruction.py @@ -6,7 +6,7 @@ from enum import IntEnum from ._enumeration import _enumerate_instances -from ._isa_enumeration import ISA_ROOT, BindingNode, ISAQuery, Node +from ._isa_enumeration import ISA_ROOT, _BindingNode, _ComponentQuery, ISAQuery from ._qre import ( ISA, Constraint, @@ -193,7 +193,7 @@ def enumerate_isas( yield from component.provided_isa(isa) @classmethod - def q(cls, *, source: Node | None = None, **kwargs) -> ISAQuery: + def q(cls, *, source: ISAQuery | None = None, **kwargs) -> ISAQuery: """ Creates an ISAQuery node for this transform. @@ -205,12 +205,12 @@ def q(cls, *, source: Node | None = None, **kwargs) -> ISAQuery: Returns: ISAQuery: An enumeration node representing this transform. """ - return ISAQuery( + return _ComponentQuery( cls, source=source if source is not None else ISA_ROOT, kwargs=kwargs ) @classmethod - def bind(cls, name: str, node: Node) -> BindingNode: + def bind(cls, name: str, node: ISAQuery) -> _BindingNode: """ Creates a BindingNode for this transform. diff --git a/source/pip/qsharp/qre/_isa_enumeration.py b/source/pip/qsharp/qre/_isa_enumeration.py index 54908aa9a6..0cfe5e5940 100644 --- a/source/pip/qsharp/qre/_isa_enumeration.py +++ b/source/pip/qsharp/qre/_isa_enumeration.py @@ -9,11 +9,11 @@ from dataclasses import dataclass, field from typing import Generator -from ._architecture import Architecture +from ._architecture import _Context from ._qre import ISA -class Node(ABC): +class ISAQuery(ABC): """ Abstract base class for all nodes in the ISA enumeration tree. @@ -24,7 +24,7 @@ class Node(ABC): """ @abstractmethod - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields all ISA instances represented by this enumeration node. @@ -37,7 +37,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: """ pass - def __add__(self, other: Node) -> SumNode: + def __add__(self, other: ISAQuery) -> _SumNode: """ Performs a union of two enumeration nodes. @@ -59,19 +59,19 @@ def __add__(self, other: Node) -> SumNode: for isa in SurfaceCode.q() + ColorCode.q(): ... """ - if isinstance(self, SumNode) and isinstance(other, SumNode): + if isinstance(self, _SumNode) and isinstance(other, _SumNode): sources = self.sources + other.sources - return SumNode(sources) - elif isinstance(self, SumNode): + return _SumNode(sources) + elif isinstance(self, _SumNode): sources = self.sources + [other] - return SumNode(sources) - elif isinstance(other, SumNode): + return _SumNode(sources) + elif isinstance(other, _SumNode): sources = [self] + other.sources - return SumNode(sources) + return _SumNode(sources) else: - return SumNode([self, other]) + return _SumNode([self, other]) - def __mul__(self, other: Node) -> ProductNode: + def __mul__(self, other: ISAQuery) -> _ProductNode: """ Performs the cross product of two enumeration nodes. @@ -97,19 +97,19 @@ def __mul__(self, other: Node) -> ProductNode: for isa in SurfaceCode.q() * Factory.q(): ... """ - if isinstance(self, ProductNode) and isinstance(other, ProductNode): + if isinstance(self, _ProductNode) and isinstance(other, _ProductNode): sources = self.sources + other.sources - return ProductNode(sources) - elif isinstance(self, ProductNode): + return _ProductNode(sources) + elif isinstance(self, _ProductNode): sources = self.sources + [other] - return ProductNode(sources) - elif isinstance(other, ProductNode): + return _ProductNode(sources) + elif isinstance(other, _ProductNode): sources = [self] + other.sources - return ProductNode(sources) + return _ProductNode(sources) else: - return ProductNode([self, other]) + return _ProductNode([self, other]) - def bind(self, name: str, node: Node) -> "BindingNode": + def bind(self, name: str, node: ISAQuery) -> "_BindingNode": """Create a BindingNode with this node as the component. Args: @@ -124,40 +124,17 @@ def bind(self, name: str, node: Node) -> "BindingNode": .. code-block:: python ExampleErrorCorrection.q().bind("c", ISARefNode("c") * ISARefNode("c")) """ - return BindingNode(name=name, component=self, node=node) + return _BindingNode(name=name, component=self, node=node) @dataclass -class Context: - """ - Context passed through enumeration, holding shared state. - - Attributes: - architecture: The base architecture for enumeration. - """ - - architecture: Architecture - _bindings: dict[str, ISA] = field(default_factory=dict, repr=False) - - @property - def root_isa(self) -> ISA: - """The architecture's provided ISA.""" - return self.architecture.provided_isa - - def _with_binding(self, name: str, isa: ISA) -> "Context": - """Return a new context with an additional binding (internal use).""" - new_bindings = {**self._bindings, name: isa} - return Context(self.architecture, new_bindings) - - -@dataclass -class RootNode(Node): +class RootNode(ISAQuery): """ Represents the architecture's base ISA. Reads from the context instead of holding a reference. """ - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields the architecture ISA from the context. @@ -175,7 +152,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class ISAQuery(Node): +class _ComponentQuery(ISAQuery): """ Query node that enumerates ISAs based on a component type and source. @@ -191,10 +168,10 @@ class ISAQuery(Node): """ component: type - source: Node = field(default_factory=lambda: ISA_ROOT) + source: ISAQuery = field(default_factory=lambda: ISA_ROOT) kwargs: dict = field(default_factory=dict) - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs generated by the component from source ISAs. @@ -209,7 +186,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class ProductNode(Node): +class _ProductNode(ISAQuery): """ Node representing the Cartesian product of multiple source nodes. @@ -217,9 +194,9 @@ class ProductNode(Node): sources: A list of source nodes to combine. """ - sources: list[Node] + sources: list[ISAQuery] - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs formed by combining ISAs from all source nodes. @@ -237,7 +214,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class SumNode(Node): +class _SumNode(ISAQuery): """ Node representing the union of multiple source nodes. @@ -245,9 +222,9 @@ class SumNode(Node): sources: A list of source nodes to enumerate sequentially. """ - sources: list[Node] + sources: list[ISAQuery] - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields ISAs from each source node in sequence. @@ -262,7 +239,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class ISARefNode(Node): +class ISARefNode(ISAQuery): """ A reference to a bound ISA in the enumeration context. @@ -274,7 +251,7 @@ class ISARefNode(Node): name: str - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Yields the bound ISA from the context. @@ -293,7 +270,7 @@ def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: @dataclass -class BindingNode(Node): +class _BindingNode(ISAQuery): """ Enumeration node that binds a component to a name. @@ -306,7 +283,7 @@ class BindingNode(Node): Args: name: The name to bind the component to. - component: An EnumerationNode (e.g., ISAQuery) that produces the bound ISAs. + component: An EnumerationNode (e.g., _ComponentQuery) that produces the bound ISAs. node: The child enumeration node that may contain ISARefNodes. Example: @@ -334,10 +311,10 @@ class BindingNode(Node): """ name: str - component: Node - node: Node + component: ISAQuery + node: ISAQuery - def enumerate(self, ctx: Context) -> Generator[ISA, None, None]: + def enumerate(self, ctx: _Context) -> Generator[ISA, None, None]: """ Enumerates child nodes with the bound component in context. diff --git a/source/pip/qsharp/qre/_qre.py b/source/pip/qsharp/qre/_qre.py index c01b87587b..3fdd913414 100644 --- a/source/pip/qsharp/qre/_qre.py +++ b/source/pip/qsharp/qre/_qre.py @@ -2,16 +2,27 @@ # Licensed under the MIT License. # flake8: noqa E402 +# pyright: reportAttributeAccessIssue=false from .._native import ( - ISA, + block_linear_function, + Block, + constant_function, Constraint, ConstraintBound, - Instruction, - ISARequirements, + estimate_parallel, + EstimationCollection, + EstimationResult, + FactoryResult, FloatFunction, + Instruction, + InstructionFrontier, IntFunction, - block_linear_function, - constant_function, + ISA, + ISARequirements, + Property, linear_function, + LatticeSurgery, + PSSPC, + Trace, ) diff --git a/source/pip/qsharp/qre/_qre.pyi b/source/pip/qsharp/qre/_qre.pyi index 01d999b49e..85be2b136e 100644 --- a/source/pip/qsharp/qre/_qre.pyi +++ b/source/pip/qsharp/qre/_qre.pyi @@ -2,7 +2,7 @@ # Licensed under the MIT License. from __future__ import annotations -from typing import Iterator, Optional, overload +from typing import Any, Iterator, Optional, overload class ISA: @overload @@ -44,6 +44,22 @@ class ISA: """ ... + def get( + self, id: int, default: Optional[Instruction] = None + ) -> Optional[Instruction]: + """ + Gets an instruction by its ID, or returns a default value if not found. + + Args: + id (int): The instruction ID. + default (Optional[Instruction]): The default value to return if the + instruction is not found. + + Returns: + Optional[Instruction]: The instruction, or the default value if not found. + """ + ... + def __len__(self) -> int: """ Returns the number of instructions in the ISA. @@ -422,3 +438,525 @@ def block_linear_function( IntFunction | FloatFunction: The block linear function. """ ... + +class Property: + def __new__(cls, value: Any) -> Property: + """ + Creates a property from a value. + + Args: + value (Any): The value. + """ + ... + + def as_bool(self) -> Optional[bool]: + """ + Returns the value as a boolean. + + Returns: + Optional[bool]: The value as a boolean, or None if it is not a boolean. + """ + ... + + def as_int(self) -> Optional[int]: + """ + Returns the value as an integer. + + Returns: + Optional[int]: The value as an integer, or None if it is not an integer. + """ + ... + + def as_float(self) -> Optional[float]: + """ + Returns the value as a float. + + Returns: + Optional[float]: The value as a float, or None if it is not a float. + """ + ... + + def as_str(self) -> Optional[str]: + """ + Returns the value as a string. + + Returns: + Optional[str]: The value as a string, or None if it is not a string. + """ + ... + + def is_bool(self) -> bool: + """ + Checks if the value is a boolean. + + Returns: + bool: True if the value is a boolean, False otherwise. + """ + ... + + def is_int(self) -> bool: + """ + Checks if the value is an integer. + + Returns: + bool: True if the value is an integer, False otherwise. + """ + ... + + def is_float(self) -> bool: + """ + Checks if the value is a float. + + Returns: + bool: True if the value is a float, False otherwise. + """ + ... + + def is_str(self) -> bool: + """ + Checks if the value is a string. + + Returns: + bool: True if the value is a string, False otherwise. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the property. + + Returns: + str: A string representation of the property. + """ + ... + +class EstimationResult: + """ + Represents the result of a resource estimation. + """ + + @property + def qubits(self) -> int: + """ + The number of logical qubits. + + Returns: + int: The number of logical qubits. + """ + ... + + @property + def runtime(self) -> int: + """ + The runtime in nanoseconds. + + Returns: + int: The runtime in nanoseconds. + """ + ... + + @property + def error(self) -> float: + """ + The error probability of the computation. + + Returns: + float: The error probability of the computation. + """ + ... + + @property + def factories(self) -> dict[int, FactoryResult]: + """ + The factory results. + + Returns: + dict[int, FactoryResult]: A dictionary mapping factory IDs to their results. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the estimation result. + + Returns: + str: A string representation of the estimation result. + """ + ... + +class EstimationCollection: + """ + Represents a collection of estimation results. Results are stored as a 2D + Pareto frontier with physical qubits and runtime as objectives. + """ + + def __new__(cls) -> EstimationCollection: + """ + Creates a new estimation collection. + + Returns: + EstimationCollection: The estimation collection. + """ + ... + + def insert(self, result: EstimationResult) -> None: + """ + Inserts an estimation result into the collection. + + Args: + result (EstimationResult): The estimation result to insert. + """ + ... + + def __len__(self) -> int: + """ + Returns the number of estimation results in the collection. + + Returns: + int: The number of estimation results. + """ + ... + + def __iter__(self) -> Iterator[EstimationResult]: + """ + Returns an iterator over the estimation results. + + Returns: + Iterator[EstimationResult]: The estimation result iterator. + """ + ... + +class FactoryResult: + """ + Represents the result of a factory used in resource estimation. + """ + + @property + def copies(self) -> int: + """ + The number of factory copies. + + Returns: + int: The number of factory copies. + """ + ... + + @property + def runs(self) -> int: + """ + The number of factory runs. + + Returns: + int: The number of factory runs. + """ + ... + + @property + def error_rate(self) -> float: + """ + The error rate of the factory. + + Returns: + float: The error rate of the factory. + """ + ... + + @property + def states(self) -> int: + """ + The number of states produced by the factory. + + Returns: + int: The number of states produced by the factory. + """ + ... + +class Trace: + """ + Represents a quantum program optimized for resource estimation. + + A trace originates from a quantum application and can be modified via trace + transformations. It consists of blocks of operations. + """ + + def __new__(cls, compute_qubits: int) -> Trace: + """ + Creates a new trace. + + Returns: + Trace: The trace. + """ + ... + + def clone_empty(self, compute_qubits: Optional[int] = None) -> Trace: + """ + Creates a new trace with the same metadata but empty block. + + Args: + compute_qubits (Optional[int]): The number of compute qubits. If None, + the number of compute qubits of the original trace is used. + + Returns: + Trace: The new trace. + """ + ... + + @property + def compute_qubits(self) -> int: + """ + The number of compute qubits. + + Returns: + int: The number of compute qubits. + """ + ... + + @property + def base_error(self) -> float: + """ + The base error of the trace. + + Returns: + float: The base error of the trace. + """ + ... + + def increment_base_error(self, amount: float) -> None: + """ + Increments the base error. + + Args: + amount (float): The amount to increment. + """ + ... + + def increment_resource_state(self, resource_id: int, amount: int) -> None: + """ + Increments a resource state count. + + Args: + resource_id (int): The resource state ID. + amount (int): The amount to increment. + """ + ... + + def set_property(self, key: str, value: Property) -> None: + """ + Sets a property. + + Args: + key (str): The property key. + value (Property): The property value. + """ + ... + + def get_property(self, key: str) -> Optional[Property]: + """ + Gets a property. + + Args: + key (str): The property key. + + Returns: + Optional[Property]: The property value, or None if not found. + """ + ... + + @property + def depth(self) -> int: + """ + The trace depth. + + Returns: + int: The trace depth. + """ + ... + + def estimate( + self, isa: ISA, max_error: Optional[float] = None + ) -> Optional[EstimationResult]: + """ + Estimates resources for the trace given a logical ISA. + + Args: + isa (ISA): The logical ISA. + max_error (Optional[float]): The maximum allowed error. If None, + Pareto points are computed. + + Returns: + Optional[EstimationResult]: The estimation result if max_error is + provided, otherwise valid Pareto points. + """ + ... # The implementation in Rust returns Option, so it fits + + @property + def resource_states(self) -> dict[int, int]: + """ + The resource states used in the trace. + + Returns: + dict[int, int]: A dictionary mapping instruction IDs to their counts. + """ + ... + + def add_operation( + self, id: int, qubits: list[int], params: list[float] = [] + ) -> None: + """ + Adds an operation to the trace. + + Args: + id (int): The operation ID. + qubits (list[int]): The qubits involved in the operation. + params (list[float]): The operation parameters. + """ + ... + + def add_block(self, repetitions: int = 1) -> Block: + """ + Adds a block to the trace. + + Args: + repetitions (int): The number of times the block is repeated. + + Returns: + Block: The block. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the trace. + + Returns: + str: A string representation of the trace. + """ + ... + +class Block: + """ + Represents a block of operations in a trace. + + An operation in a block can either refer to an instruction applied to some + qubits or can be another block to create a hierarchical structure. Blocks + can be repeated. + """ + + def add_operation( + self, id: int, qubits: list[int], params: list[float] = [] + ) -> None: + """ + Adds an operation to the block. + + Args: + id (int): The operation ID. + qubits (list[int]): The qubits involved in the operation. + params (list[float]): The operation parameters. + """ + ... + + def add_block(self, repetitions: int = 1) -> Block: + """ + Adds a nested block to the block. + + Args: + repetitions (int): The number of times the block is repeated. + + Returns: + Block: The block. + """ + ... + + def __str__(self) -> str: + """ + Returns a string representation of the block. + + Returns: + str: A string representation of the block. + """ + ... + +class PSSPC: + def __new__(cls, num_ts_per_rotation: int, ccx_magic_states: bool) -> PSSPC: ... + def transform(self, trace: Trace) -> Optional[Trace]: ... + +class LatticeSurgery: + def __new__(cls, slow_down_factor: float) -> LatticeSurgery: ... + def transform(self, trace: Trace) -> Optional[Trace]: ... + +class InstructionFrontier: + """ + Represents a Pareto frontier of instructions with space, time, and error + rates as objectives. + """ + + def __new__(cls) -> InstructionFrontier: + """ + Creates a new instruction frontier. + """ + ... + + def insert(self, point: Instruction): + """ + Inserts an instruction to the frontier. + + Args: + point (Instruction): The instruction to insert. + """ + ... + + def __len__(self) -> int: + """ + Returns the number of instructions in the frontier. + + Returns: + int: The number of instructions. + """ + ... + + def __iter__(self) -> Iterator[Instruction]: + """ + Returns an iterator over the instructions in the frontier. + + Returns: + Iterator[Instruction]: The iterator. + """ + ... + + @staticmethod + def load(filename: str) -> InstructionFrontier: + """ + Loads an instruction frontier from a file. + + Args: + filename (str): The file name. + + Returns: + InstructionFrontier: The loaded instruction frontier. + """ + ... + + def dump(self, filename: str) -> None: + """ + Dumps the instruction frontier to a file. + + Args: + filename (str): The file name. + """ + ... + +def estimate_parallel( + traces: list[Trace], isas: list[ISA], max_error: float = 1.0 +) -> EstimationCollection: + """ + Estimates resources for multiple traces and ISAs in parallel. + + Args: + traces (list[Trace]): The list of traces. + isas (list[ISA]): The list of ISAs. + max_error (float): The maximum allowed error. The default is 1.0. + + Returns: + EstimationCollection: The estimation collection. + """ + ... diff --git a/source/pip/qsharp/qre/_trace.py b/source/pip/qsharp/qre/_trace.py new file mode 100644 index 0000000000..ab1d49f6ce --- /dev/null +++ b/source/pip/qsharp/qre/_trace.py @@ -0,0 +1,89 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from abc import ABC, abstractmethod +from dataclasses import dataclass, KW_ONLY, field +from itertools import product +from typing import Any, Optional, Generator, Type +from ._application import _Context +from ._enumeration import _enumerate_instances +from ._qre import PSSPC as _PSSPC, LatticeSurgery as _LatticeSurgery, Trace + + +class TraceTransform(ABC): + @abstractmethod + def transform(self, trace: Trace) -> Optional[Trace]: ... + + @classmethod + def q(cls, **kwargs) -> TraceQuery: + return TraceQuery(cls, **kwargs) + + +@dataclass +class PSSPC(TraceTransform): + _: KW_ONLY + num_ts_per_rotation: int = field( + default=10, metadata={"domain": list(range(1, 21))} + ) + ccx_magic_states: bool = field(default=False) + + def __post_init__(self): + self._psspc = _PSSPC(self.num_ts_per_rotation, self.ccx_magic_states) + + def transform(self, trace: Trace) -> Optional[Trace]: + return self._psspc.transform(trace) + + +@dataclass +class LatticeSurgery(TraceTransform): + _: KW_ONLY + slow_down_factor: float = field(default=1.0, metadata={"domain": [1.0]}) + + def __post_init__(self): + self._lattice_surgery = _LatticeSurgery(self.slow_down_factor) + + def transform(self, trace: Trace) -> Optional[Trace]: + return self._lattice_surgery.transform(trace) + + +class _Node(ABC): + @abstractmethod + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: ... + + +class RootNode(_Node): + # NOTE: this might be redundant with TransformationNode with an empty sequence + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: + yield from ctx.application.enumerate_traces(**ctx.kwargs) + + +class TraceQuery(_Node): + sequence: list[tuple[Type, dict[str, Any]]] + + def __init__(self, t: Type, **kwargs): + self.sequence = [(t, kwargs)] + + def enumerate(self, ctx: _Context) -> Generator[Trace, None, None]: + for trace in ctx.application.enumerate_traces(**ctx.kwargs): + if not self.sequence: + yield trace + continue + + transformer_instances = [] + + for t, transformer_kwargs in self.sequence: + instances = _enumerate_instances(t, **transformer_kwargs) + transformer_instances.append(instances) + + # TODO: make parallel + for sequence in product(*transformer_instances): + transformed = trace + for transformer in sequence: + transformed = transformer.transform(transformed) + yield transformed + + def __mul__(self, other: TraceQuery) -> TraceQuery: + new_query = TraceQuery.__new__(TraceQuery) + new_query.sequence = self.sequence + other.sequence + return new_query diff --git a/source/pip/qsharp/qre/instruction_ids.py b/source/pip/qsharp/qre/instruction_ids.py index f89bcc6c5b..cec4a9c070 100644 --- a/source/pip/qsharp/qre/instruction_ids.py +++ b/source/pip/qsharp/qre/instruction_ids.py @@ -1,91 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# pyright: reportAttributeAccessIssue=false -################### -# Instruction IDs # -################### -# Paulis -PAULI_I = 0x0 -PAULI_X = 0x1 -PAULI_Y = 0x2 -PAULI_Z = 0x3 +from .._native import instruction_ids -# Clifford gates -H = H_XZ = 0x10 -H_XY = 0x11 -H_YZ = 0x12 -SQRT_X = 0x13 -SQRT_X_DAG = 0x14 -SQRT_Y = 0x15 -SQRT_Y_DAG = 0x16 -S = SQRT_Z = 0x17 -S_DAG = SQRT_Z_DAG = 0x18 -CNOT = CX = 0x19 -CY = 0x1A -CZ = 0x1B -SWAP = 0x1C - -# State preparation -PREP_X = 0x30 -PREP_Y = 0x31 -PREP_Z = 0x32 - -# Generic Cliffords -ONE_QUBIT_CLIFFORD = 0x50 -TWO_QUBIT_CLIFFORD = 0x51 -N_QUBIT_CLIFFORD = 0x52 - -# Measurements -MEAS_X = 0x100 -MEAS_Y = 0x101 -MEAS_Z = 0x102 -MEAS_RESET_X = 0x103 -MEAS_RESET_Y = 0x104 -MEAS_RESET_Z = 0x105 -MEAS_XX = 0x106 -MEAS_YY = 0x107 -MEAS_ZZ = 0x108 -MEAS_XZ = 0x109 -MEAS_XY = 0x10A -MEAS_YZ = 0x10B - -# Non-Clifford gates -SQRT_SQRT_X = 0x400 -SQRT_SQRT_X_DAG = 0x401 -SQRT_SQRT_Y = 0x402 -SQRT_SQRT_Y_DAG = 0x403 -SQRT_SQRT_Z = T = 0x404 -SQRT_SQRT_Z_DAG = T_DAG = 0x405 -CCX = 0x406 -CCY = 0x407 -CCZ = 0x408 -CSWAP = 0x409 -AND = 0x40A -AND_DAG = 0x40B -RX = 0x40C -RY = 0x40D -RZ = 0x40E -CRX = 0x40F -CRY = 0x410 -CRZ = 0x411 -RXX = 0x412 -RYY = 0x413 -RZZ = 0x414 - -# Multi-qubit Pauli measurement -MULTI_PAULI_MEAS = 0x1000 - -# Some generic logical instructions -LATTICE_SURGERY = 0x1100 - -# Memory/compute operations (used in compute parts of memory-compute layouts) -READ_FROM_MEMORY = 0x1200 -WRITE_TO_MEMORY = 0x1201 - -# Some special hardware physical instructions -CYCLIC_SHIFT = 0x1300 - -# Generic operation (for unified RE) -GENERIC = 0xFFFF +for name in instruction_ids.__all__: + globals()[name] = getattr(instruction_ids, name) diff --git a/source/pip/qsharp/qre/instruction_ids.pyi b/source/pip/qsharp/qre/instruction_ids.pyi new file mode 100644 index 0000000000..72934487f8 --- /dev/null +++ b/source/pip/qsharp/qre/instruction_ids.pyi @@ -0,0 +1,92 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# Paulis +PAULI_I: int +PAULI_X: int +PAULI_Y: int +PAULI_Z: int + +# Clifford gates +H: int +H_XZ: int +H_XY: int +H_YZ: int +SQRT_X: int +SQRT_X_DAG: int +SQRT_Y: int +SQRT_Y_DAG: int +S: int +SQRT_Z: int +S_DAG: int +SQRT_Z_DAG: int +CNOT: int +CX: int +CY: int +CZ: int +SWAP: int + +# State preparation +PREP_X: int +PREP_Y: int +PREP_Z: int + +# Generic Cliffords +ONE_QUBIT_CLIFFORD: int +TWO_QUBIT_CLIFFORD: int +N_QUBIT_CLIFFORD: int + +# Measurements +MEAS_X: int +MEAS_Y: int +MEAS_Z: int +MEAS_RESET_X: int +MEAS_RESET_Y: int +MEAS_RESET_Z: int +MEAS_XX: int +MEAS_YY: int +MEAS_ZZ: int +MEAS_XZ: int +MEAS_XY: int +MEAS_YZ: int + +# Non-Clifford gates +SQRT_SQRT_X: int +SQRT_SQRT_X_DAG: int +SQRT_SQRT_Y: int +SQRT_SQRT_Y_DAG: int +SQRT_SQRT_Z: int +T: int +SQRT_SQRT_Z_DAG: int +T_DAG: int +CCX: int +CCY: int +CCZ: int +CSWAP: int +AND: int +AND_DAG: int +RX: int +RY: int +RZ: int +CRX: int +CRY: int +CRZ: int +RXX: int +RYY: int +RZZ: int + +# Multi-qubit Pauli measurement +MULTI_PAULI_MEAS: int + +# Some generic logical instructions +LATTICE_SURGERY: int + +# Memory/compute operations (used in compute parts of memory-compute layouts) +READ_FROM_MEMORY: int +WRITE_TO_MEMORY: int + +# Some special hardware physical instructions +CYCLIC_SHIFT: int + +# Generic operation (for unified RE) +GENERIC: int diff --git a/source/pip/qsharp/qre/models/__init__.py b/source/pip/qsharp/qre/models/__init__.py new file mode 100644 index 0000000000..10a82c977e --- /dev/null +++ b/source/pip/qsharp/qre/models/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .qec import SurfaceCode +from .qubits import AQREGateBased + +__all__ = ["SurfaceCode", "AQREGateBased"] diff --git a/source/pip/qsharp/qre/models/qec/__init__.py b/source/pip/qsharp/qre/models/qec/__init__.py new file mode 100644 index 0000000000..c813df0dc4 --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._surface_code import SurfaceCode + +__all__ = ["SurfaceCode"] diff --git a/source/pip/qsharp/qre/models/qec/_surface_code.py b/source/pip/qsharp/qre/models/qec/_surface_code.py new file mode 100644 index 0000000000..52bf94439f --- /dev/null +++ b/source/pip/qsharp/qre/models/qec/_surface_code.py @@ -0,0 +1,93 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from __future__ import annotations +from dataclasses import KW_ONLY, dataclass, field +from typing import Generator +from ..._instruction import ( + ISA, + ISARequirements, + ISATransform, + instruction, + constraint, + ConstraintBound, + LOGICAL, +) +from ..._qre import linear_function +from ...instruction_ids import CNOT, GENERIC, H, LATTICE_SURGERY, MEAS_Z + + +@dataclass +class SurfaceCode(ISATransform): + """ + Attributes: + crossing_prefactor: float + The prefactor for logical error rate due to error correction + crossings. (Default is 0.03, see Eq. (11) in arXiv:1208.0928) + error_correction_threshold: float + The error correction threshold for the surface code. Default is + 0.01 (1%), see arXiv:1009.3686. + + Hyper parameters: + distance: int + The code distance of the surface code. + + References: + - [arXiv:1208.0928](https://arxiv.org/abs/1208.0928) + - [arXiv:1009.3686](https://arxiv.org/abs/1009.3686) + """ + + crossing_prefactor: float = 0.03 + error_correction_threshold: float = 0.01 + _: KW_ONLY + distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) + + @staticmethod + def required_isa() -> ISARequirements: + return ISARequirements( + constraint(H, error_rate=ConstraintBound.lt(0.01)), + constraint(CNOT, arity=2, error_rate=ConstraintBound.lt(0.01)), + constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), + ) + + def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: + cnot_time = impl_isa[CNOT].expect_time() + h_time = impl_isa[H].expect_time() + meas_time = impl_isa[MEAS_Z].expect_time() + + physical_error_rate = max( + impl_isa[CNOT].expect_error_rate(), + impl_isa[H].expect_error_rate(), + impl_isa[MEAS_Z].expect_error_rate(), + ) + + space_formula = linear_function(2 * self.distance**2) + + time_value = (h_time + meas_time + cnot_time * 4) * self.distance + + error_formula = linear_function( + self.crossing_prefactor + * ( + (physical_error_rate / self.error_correction_threshold) + ** ((self.distance + 1) // 2) + ) + ) + + yield ISA( + instruction( + GENERIC, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ), + instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + space=space_formula, + time=time_value, + error_rate=error_formula, + ), + ) diff --git a/source/pip/qsharp/qre/models/qubits/__init__.py b/source/pip/qsharp/qre/models/qubits/__init__.py new file mode 100644 index 0000000000..f9907adbc3 --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from ._aqre import AQREGateBased + +__all__ = ["AQREGateBased"] diff --git a/source/pip/qsharp/qre/models/qubits/_aqre.py b/source/pip/qsharp/qre/models/qubits/_aqre.py new file mode 100644 index 0000000000..b6add8ae2d --- /dev/null +++ b/source/pip/qsharp/qre/models/qubits/_aqre.py @@ -0,0 +1,65 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from dataclasses import KW_ONLY, dataclass, field + +from ..._architecture import Architecture +from ...instruction_ids import CNOT, CZ, MEAS_Z, PAULI_I, H, T +from ..._instruction import ISA, Encoding, instruction + + +@dataclass +class AQREGateBased(Architecture): + """ + References: + - [arXiv:2211.07629](https://arxiv.org/abs/2211.07629) + """ + + _: KW_ONLY + error_rate: float = field(default=1e-4) + + @property + def provided_isa(self) -> ISA: + return ISA( + instruction( + PAULI_I, + encoding=Encoding.PHYSICAL, + arity=1, + time=50, + error_rate=self.error_rate, + ), + instruction( + CNOT, + encoding=Encoding.PHYSICAL, + arity=2, + time=50, + error_rate=self.error_rate, + ), + instruction( + CZ, + encoding=Encoding.PHYSICAL, + arity=2, + time=50, + error_rate=self.error_rate, + ), + instruction( + H, + encoding=Encoding.PHYSICAL, + arity=1, + time=50, + error_rate=self.error_rate, + ), + instruction( + MEAS_Z, + encoding=Encoding.PHYSICAL, + arity=1, + time=100, + error_rate=self.error_rate, + ), + instruction( + T, + encoding=Encoding.PHYSICAL, + time=50, + error_rate=self.error_rate, + ), + ) diff --git a/source/pip/src/qre.rs b/source/pip/src/qre.rs index fd8e80a5cd..d9e870990c 100644 --- a/source/pip/src/qre.rs +++ b/source/pip/src/qre.rs @@ -1,22 +1,48 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -use pyo3::{IntoPyObjectExt, prelude::*, types::PyTuple}; +use std::ptr::NonNull; + +use pyo3::{ + IntoPyObjectExt, + exceptions::{PyException, PyKeyError, PyTypeError}, + prelude::*, + types::{PyDict, PyTuple}, +}; +use qre::TraceTransform; +use serde::{Deserialize, Serialize}; pub(crate) fn register_qre_submodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(constant_function, m)?)?; m.add_function(wrap_pyfunction!(linear_function, m)?)?; m.add_function(wrap_pyfunction!(block_linear_function, m)?)?; + m.add_function(wrap_pyfunction!(estimate_parallel, m)?)?; + + m.add("EstimationError", m.py().get_type::())?; + + add_instruction_ids(m)?; + Ok(()) } +pyo3::create_exception!(qsharp.qre, EstimationError, PyException); + #[allow(clippy::upper_case_acronyms)] #[pyclass] pub struct ISA(qre::ISA); @@ -63,12 +89,20 @@ impl ISA { pub fn __getitem__(&self, id: u64) -> PyResult { match self.0.get(&id) { Some(instr) => Ok(Instruction(instr.clone())), - None => Err(PyErr::new::(format!( + None => Err(PyKeyError::new_err(format!( "Instruction with id {id} not found" ))), } } + #[pyo3(signature = (id, default=None))] + pub fn get(&self, id: u64, default: Option<&Instruction>) -> Option { + match self.0.get(&id) { + Some(instr) => Some(Instruction(instr.clone())), + None => default.cloned(), + } + } + #[allow(clippy::needless_pass_by_value)] pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { let iter = ISAIterator { @@ -129,7 +163,10 @@ impl ISARequirements { } } +#[allow(clippy::unsafe_derive_deserialize)] #[pyclass] +#[derive(Clone, Serialize, Deserialize)] +#[serde(transparent)] pub struct Instruction(qre::Instruction); #[pymethods] @@ -227,6 +264,24 @@ impl Instruction { } } +impl qre::ParetoItem3D for Instruction { + type Objective1 = u64; + type Objective2 = u64; + type Objective3 = f64; + + fn objective1(&self) -> Self::Objective1 { + self.0.expect_space(None) + } + + fn objective2(&self) -> Self::Objective2 { + self.0.expect_time(None) + } + + fn objective3(&self) -> Self::Objective3 { + self.0.expect_error_rate(None) + } +} + #[pyclass] pub struct Constraint(qre::InstructionConstraint); @@ -252,9 +307,7 @@ fn convert_encoding(encoding: u64) -> PyResult { match encoding { 0 => Ok(qre::Encoding::Physical), 1 => Ok(qre::Encoding::Logical), - _ => Err(PyErr::new::( - "Invalid encoding value", - )), + _ => Err(EstimationError::new_err("Invalid encoding value")), } } @@ -289,6 +342,61 @@ impl ConstraintBound { } } +#[pyclass] +pub struct Property(qre::Property); + +#[pymethods] +impl Property { + #[new] + pub fn new(value: &Bound<'_, PyAny>) -> PyResult { + if value.is_instance_of::() { + Ok(Property(qre::Property::new_bool(value.extract()?))) + } else if let Ok(i) = value.extract::() { + Ok(Property(qre::Property::new_int(i))) + } else if let Ok(f) = value.extract::() { + Ok(Property(qre::Property::new_float(f))) + } else { + Ok(Property(qre::Property::new_str(value.to_string()))) + } + } + + fn as_bool(&self) -> Option { + self.0.as_bool() + } + + fn as_int(&self) -> Option { + self.0.as_int() + } + + fn as_float(&self) -> Option { + self.0.as_float() + } + + fn as_str(&self) -> Option { + self.0.as_str().map(String::from) + } + + fn is_bool(&self) -> bool { + self.0.is_bool() + } + + fn is_int(&self) -> bool { + self.0.is_int() + } + + fn is_float(&self) -> bool { + self.0.is_float() + } + + fn is_str(&self) -> bool { + self.0.is_str() + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + #[pyclass] pub struct IntFunction(qre::VariableArityFunction); @@ -303,7 +411,7 @@ pub fn constant_function<'py>(value: &Bound<'py, PyAny>) -> PyResult( + Err(PyTypeError::new_err( "Value must be either an integer or a float", )) } @@ -316,7 +424,7 @@ pub fn linear_function<'py>(slope: &Bound<'py, PyAny>) -> PyResult() { FloatFunction(qre::VariableArityFunction::linear(s)).into_bound_py_any(slope.py()) } else { - Err(PyErr::new::( + Err(PyTypeError::new_err( "Slope must be either an integer or a float", )) } @@ -334,8 +442,435 @@ pub fn block_linear_function<'py>( FloatFunction(qre::VariableArityFunction::block_linear(block_size, s)) .into_bound_py_any(slope.py()) } else { - Err(PyErr::new::( + Err(PyTypeError::new_err( "Slope must be either an integer or a float", )) } } + +#[derive(Default)] +#[pyclass] +pub struct EstimationCollection(qre::EstimationCollection); + +#[pymethods] +impl EstimationCollection { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, result: &EstimationResult) { + self.0.insert(result.0.clone()); + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = EstimationCollectionIterator { + iter: slf.0.iter().cloned().collect::>().into_iter(), + }; + Py::new(slf.py(), iter) + } +} + +#[pyclass] +pub struct EstimationCollectionIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl EstimationCollectionIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next().map(EstimationResult) + } +} + +#[pyclass] +pub struct EstimationResult(qre::EstimationResult); + +#[pymethods] +impl EstimationResult { + #[getter] + pub fn qubits(&self) -> u64 { + self.0.qubits() + } + + #[getter] + pub fn runtime(&self) -> u64 { + self.0.runtime() + } + + #[getter] + pub fn error(&self) -> f64 { + self.0.error() + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn factories(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + + for (id, factory) in self_.0.factories() { + dict.set_item(id, FactoryResult(factory.clone()))?; + } + + Ok(dict) + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass] +pub struct FactoryResult(qre::FactoryResult); + +#[pymethods] +impl FactoryResult { + #[getter] + pub fn copies(&self) -> u64 { + self.0.copies() + } + + #[getter] + pub fn runs(&self) -> u64 { + self.0.runs() + } + + #[getter] + pub fn states(&self) -> u64 { + self.0.states() + } + + #[getter] + pub fn error_rate(&self) -> f64 { + self.0.error_rate() + } +} + +#[pyclass] +pub struct Trace(qre::Trace); + +#[pymethods] +impl Trace { + #[new] + pub fn new(compute_qubits: u64) -> Self { + Trace(qre::Trace::new(compute_qubits)) + } + + #[pyo3(signature = (compute_qubits = None))] + pub fn clone_empty(&self, compute_qubits: Option) -> Self { + Trace(self.0.clone_empty(compute_qubits)) + } + + #[getter] + pub fn compute_qubits(&self) -> u64 { + self.0.compute_qubits() + } + + #[getter] + pub fn base_error(&self) -> f64 { + self.0.base_error() + } + + pub fn increment_base_error(&mut self, amount: f64) { + self.0.increment_base_error(amount); + } + + pub fn set_property(&mut self, key: String, value: &Property) { + self.0.set_property(key, value.0.clone()); + } + + pub fn get_property(&self, key: &str) -> Option { + self.0.get_property(key).map(|p| Property(p.clone())) + } + + #[allow(clippy::needless_pass_by_value)] + #[getter] + pub fn resource_states(self_: PyRef<'_, Self>) -> PyResult> { + let dict = PyDict::new(self_.py()); + if let Some(resource_states) = self_.0.get_resource_states() { + for (resource_id, count) in resource_states { + if *count != 0 { + dict.set_item(resource_id, *count)?; + } + } + } + Ok(dict) + } + + #[getter] + pub fn depth(&self) -> u64 { + self.0.depth() + } + + #[pyo3(signature = (isa, max_error = None))] + pub fn estimate(&self, isa: &ISA, max_error: Option) -> Option { + self.0 + .estimate(&isa.0, max_error) + .map(EstimationResult) + .ok() + } + + #[pyo3(signature = (id, qubits, params = vec![]))] + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + self.0.add_operation(id, qubits, params); + } + + #[pyo3(signature = (repetitions = 1))] + pub fn add_block(mut slf: PyRefMut<'_, Self>, repetitions: u64) -> PyResult { + let block = slf.0.add_block(repetitions); + let ptr = NonNull::from(block); + Ok(Block { + ptr, + parent: slf.into(), + }) + } + + pub fn increment_resource_state(&mut self, resource_id: u64, amount: u64) { + self.0.increment_resource_state(resource_id, amount); + } + + fn __str__(&self) -> String { + format!("{}", self.0) + } +} + +#[pyclass(unsendable)] +pub struct Block { + ptr: NonNull, + #[allow(dead_code)] + parent: Py, +} + +#[pymethods] +impl Block { + #[pyo3(signature = (id, qubits, params = vec![]))] + pub fn add_operation(&mut self, id: u64, qubits: Vec, params: Vec) { + unsafe { self.ptr.as_mut() }.add_operation(id, qubits, params); + } + + #[pyo3(signature = (repetitions = 1))] + pub fn add_block(&mut self, py: Python<'_>, repetitions: u64) -> PyResult { + let block = unsafe { self.ptr.as_mut() }.add_block(repetitions); + let ptr = NonNull::from(block); + Ok(Block { + ptr, + parent: self.parent.clone_ref(py), + }) + } + + fn __str__(&self) -> String { + format!("{}", unsafe { self.ptr.as_ref() }) + } +} + +#[allow(clippy::upper_case_acronyms)] +#[pyclass] +pub struct PSSPC(qre::PSSPC); + +#[pymethods] +impl PSSPC { + #[new] + pub fn new(num_ts_per_rotation: u64, ccx_magic_states: bool) -> Self { + PSSPC(qre::PSSPC::new(num_ts_per_rotation, ccx_magic_states)) + } + + pub fn transform(&self, trace: &Trace) -> PyResult { + self.0 + .transform(&trace.0) + .map(Trace) + .map_err(|e| EstimationError::new_err(format!("{e}"))) + } +} + +#[derive(Default)] +#[pyclass] +pub struct LatticeSurgery(qre::LatticeSurgery); + +#[pymethods] +impl LatticeSurgery { + #[new] + pub fn new(slow_down_factor: f64) -> Self { + Self(qre::LatticeSurgery::new(slow_down_factor)) + } + + pub fn transform(&self, trace: &Trace) -> PyResult { + self.0 + .transform(&trace.0) + .map(Trace) + .map_err(|e| EstimationError::new_err(format!("{e}"))) + } +} + +#[pyclass] +pub struct InstructionFrontier(qre::ParetoFrontier3D); + +impl Default for InstructionFrontier { + fn default() -> Self { + InstructionFrontier(qre::ParetoFrontier3D::new()) + } +} + +#[pymethods] +impl InstructionFrontier { + #[new] + pub fn new() -> Self { + Self::default() + } + + pub fn insert(&mut self, point: &Instruction) { + self.0.insert(point.clone()); + } + + pub fn __len__(&self) -> usize { + self.0.len() + } + + #[allow(clippy::needless_pass_by_value)] + pub fn __iter__(slf: PyRef<'_, Self>) -> PyResult> { + let iter = InstructionFrontierIterator { + iter: slf.0.iter().cloned().collect::>().into_iter(), + }; + Py::new(slf.py(), iter) + } + + #[staticmethod] + pub fn load(filename: &str) -> PyResult { + let content = std::fs::read_to_string(filename)?; + let frontier = + serde_json::from_str(&content).map_err(|e| EstimationError::new_err(format!("{e}")))?; + Ok(InstructionFrontier(frontier)) + } + + pub fn dump(&self, filename: &str) -> PyResult<()> { + let content = + serde_json::to_string(&self.0).map_err(|e| EstimationError::new_err(format!("{e}")))?; + Ok(std::fs::write(filename, content)?) + } +} + +#[pyclass] +pub struct InstructionFrontierIterator { + iter: std::vec::IntoIter, +} + +#[pymethods] +impl InstructionFrontierIterator { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + fn __next__(mut slf: PyRefMut<'_, Self>) -> Option { + slf.iter.next() + } +} + +#[allow(clippy::needless_pass_by_value)] +#[pyfunction(signature = (traces, isas, max_error = 1.0))] +pub fn estimate_parallel( + traces: Vec>, + isas: Vec>, + max_error: f64, +) -> EstimationCollection { + let traces: Vec<_> = traces.iter().map(|t| &t.0).collect(); + let isas: Vec<_> = isas.iter().map(|i| &i.0).collect(); + + let collection = qre::estimate_parallel(&traces, &isas, Some(max_error)); + EstimationCollection(collection) +} + +fn add_instruction_ids(m: &Bound<'_, PyModule>) -> PyResult<()> { + #[allow(clippy::wildcard_imports)] + use qre::instruction_ids::*; + + let instruction_ids = PyModule::new(m.py(), "instruction_ids")?; + + macro_rules! add_ids { + ($($name:ident),* $(,)?) => { + $(instruction_ids.add(stringify!($name), $name)?;)* + }; + } + + add_ids!( + PAULI_I, + PAULI_X, + PAULI_Y, + PAULI_Z, + H, + H_XZ, + H_XY, + H_YZ, + SQRT_X, + SQRT_X_DAG, + SQRT_Y, + SQRT_Y_DAG, + S, + SQRT_Z, + S_DAG, + SQRT_Z_DAG, + CNOT, + CX, + CY, + CZ, + SWAP, + PREP_X, + PREP_Y, + PREP_Z, + ONE_QUBIT_CLIFFORD, + TWO_QUBIT_CLIFFORD, + N_QUBIT_CLIFFORD, + MEAS_X, + MEAS_Y, + MEAS_Z, + MEAS_RESET_X, + MEAS_RESET_Y, + MEAS_RESET_Z, + MEAS_XX, + MEAS_YY, + MEAS_ZZ, + MEAS_XZ, + MEAS_XY, + MEAS_YZ, + SQRT_SQRT_X, + SQRT_SQRT_X_DAG, + SQRT_SQRT_Y, + SQRT_SQRT_Y_DAG, + SQRT_SQRT_Z, + T, + SQRT_SQRT_Z_DAG, + T_DAG, + CCX, + CCY, + CCZ, + CSWAP, + AND, + AND_DAG, + RX, + RY, + RZ, + CRX, + CRY, + CRZ, + RXX, + RYY, + RZZ, + MULTI_PAULI_MEAS, + LATTICE_SURGERY, + READ_FROM_MEMORY, + WRITE_TO_MEMORY, + CYCLIC_SHIFT, + GENERIC + ); + + m.add_submodule(&instruction_ids)?; + + Ok(()) +} diff --git a/source/pip/tests/test_qre.py b/source/pip/tests/test_qre.py index 90430f5167..98e1c9de59 100644 --- a/source/pip/tests/test_qre.py +++ b/source/pip/tests/test_qre.py @@ -5,33 +5,30 @@ from enum import Enum from typing import Generator +import qsharp from qsharp.qre import ( ISA, LOGICAL, - Architecture, - ConstraintBound, + PSSPC, + EstimationResult, ISARequirements, ISATransform, + LatticeSurgery, + QSharpApplication, + Trace, constraint, + estimate, instruction, linear_function, ) -from qsharp.qre._enumeration import _enumerate_instances +from qsharp.qre.models import SurfaceCode, AQREGateBased from qsharp.qre._isa_enumeration import ( - BindingNode, - Context, - ISAQuery, ISARefNode, - ProductNode, - SumNode, ) from qsharp.qre.instruction_ids import ( - CNOT, + CCX, GENERIC, LATTICE_SURGERY, - MEAS_Z, - TWO_QUBIT_CLIFFORD, - H, T, ) @@ -39,78 +36,6 @@ # pull requests and then moved out of the tests. -class ExampleArchitecture(Architecture): - @property - def provided_isa(self) -> ISA: - return ISA( - instruction(H, time=50, error_rate=1e-3), - instruction(CNOT, arity=2, time=50, error_rate=1e-3), - instruction(MEAS_Z, time=100, error_rate=1e-3), - instruction(TWO_QUBIT_CLIFFORD, arity=2, time=50, error_rate=1e-3), - instruction(GENERIC, time=50, error_rate=1e-4), - instruction(T, time=50, error_rate=1e-4), - ) - - -@dataclass -class SurfaceCode(ISATransform): - _: KW_ONLY - distance: int = field(default=3, metadata={"domain": range(3, 26, 2)}) - - @staticmethod - def required_isa() -> ISARequirements: - return ISARequirements( - constraint(H, error_rate=ConstraintBound.lt(0.01)), - constraint(CNOT, arity=2, error_rate=ConstraintBound.lt(0.01)), - constraint(MEAS_Z, error_rate=ConstraintBound.lt(0.01)), - ) - - def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: - crossing_prefactor: float = 0.03 - error_correction_threshold: float = 0.01 - - cnot_time = impl_isa[CNOT].expect_time() - h_time = impl_isa[H].expect_time() - meas_time = impl_isa[MEAS_Z].expect_time() - - physical_error_rate = max( - impl_isa[CNOT].expect_error_rate(), - impl_isa[H].expect_error_rate(), - impl_isa[MEAS_Z].expect_error_rate(), - ) - - space_formula = linear_function(2 * self.distance**2) - - time_value = (h_time + meas_time + cnot_time * 4) * self.distance - - error_formula = linear_function( - crossing_prefactor - * ( - (physical_error_rate / error_correction_threshold) - ** ((self.distance + 1) // 2) - ) - ) - - yield ISA( - instruction( - GENERIC, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ), - instruction( - LATTICE_SURGERY, - encoding=LOGICAL, - arity=None, - space=space_formula, - time=time_value, - error_rate=error_formula, - ), - ) - - @dataclass class ExampleFactory(ISATransform): _: KW_ONLY @@ -147,7 +72,7 @@ def provided_isa(self, impl_isa: ISA) -> Generator[ISA, None, None]: def test_isa_from_architecture(): - arch = ExampleArchitecture() + arch = AQREGateBased() code = SurfaceCode() # Verify that the architecture satisfies the code requirements @@ -162,6 +87,8 @@ def test_isa_from_architecture(): def test_enumerate_instances(): + from qsharp.qre._enumeration import _enumerate_instances + instances = list(_enumerate_instances(SurfaceCode)) # There are 12 instances with distances from 3 to 25 @@ -184,6 +111,8 @@ def test_enumerate_instances(): def test_enumerate_instances_bool(): + from qsharp.qre._enumeration import _enumerate_instances + @dataclass class BoolConfig: _: KW_ONLY @@ -196,6 +125,8 @@ class BoolConfig: def test_enumerate_instances_enum(): + from qsharp.qre._enumeration import _enumerate_instances + class Color(Enum): RED = 1 GREEN = 2 @@ -214,6 +145,8 @@ class EnumConfig: def test_enumerate_instances_failure(): + from qsharp.qre._enumeration import _enumerate_instances + import pytest @dataclass @@ -227,6 +160,8 @@ class InvalidConfig: def test_enumerate_instances_single(): + from qsharp.qre._enumeration import _enumerate_instances + @dataclass class SingleConfig: value: int = 42 @@ -237,6 +172,8 @@ class SingleConfig: def test_enumerate_instances_literal(): + from qsharp.qre._enumeration import _enumerate_instances + from typing import Literal @dataclass @@ -251,50 +188,32 @@ class LiteralConfig: def test_enumerate_isas(): - ctx = Context(architecture=ExampleArchitecture()) + ctx = AQREGateBased().context() # This will enumerate the 4 ISAs for the error correction code - count = sum(1 for _ in ISAQuery(SurfaceCode).enumerate(ctx)) + count = sum(1 for _ in SurfaceCode.q().enumerate(ctx)) assert count == 12 # This will enumerate the 2 ISAs for the error correction code when # restricting the domain - count = sum( - 1 for _ in ISAQuery(SurfaceCode, kwargs={"distance": [3, 5]}).enumerate(ctx) - ) + count = sum(1 for _ in SurfaceCode.q(distance=[3, 4]).enumerate(ctx)) assert count == 2 # This will enumerate the 3 ISAs for the factory - count = sum(1 for _ in ISAQuery(ExampleFactory).enumerate(ctx)) + count = sum(1 for _ in ExampleFactory.q().enumerate(ctx)) assert count == 3 # This will enumerate 36 ISAs for all products between the 12 error # correction code ISAs and the 3 factory ISAs - count = sum( - 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] - ).enumerate(ctx) - ) + count = sum(1 for _ in (SurfaceCode.q() * ExampleFactory.q()).enumerate(ctx)) assert count == 36 # When providing a list, components are chained (OR operation). This # enumerates ISAs from first factory instance OR second factory instance count = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - SumNode( - sources=[ - ISAQuery(ExampleFactory), - ISAQuery(ExampleFactory), - ] - ), - ] + for _ in ( + SurfaceCode.q() * (ExampleFactory.q() + ExampleFactory.q()) ).enumerate(ctx) ) assert count == 72 @@ -304,13 +223,9 @@ def test_enumerate_isas(): # factory instance count = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ISAQuery(ExampleFactory), - ] - ).enumerate(ctx) + for _ in (SurfaceCode.q() * ExampleFactory.q() * ExampleFactory.q()).enumerate( + ctx + ) ) assert count == 108 @@ -318,62 +233,32 @@ def test_enumerate_isas(): # from the product of other components as its source count = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery( - ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] - ), - ), - ] + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q(source=(SurfaceCode.q() * ExampleFactory.q())) ).enumerate(ctx) ) assert count == 1296 def test_binding_node(): - """Test BindingNode with ISARefNode for component bindings""" - ctx = Context(architecture=ExampleArchitecture()) + """Test binding nodes with ISARefNode for component bindings""" + ctx = AQREGateBased().context() # Test basic binding: same code used twice # Without binding: 12 codes × 12 codes = 144 combinations - count_without = sum( - 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(SurfaceCode), - ] - ).enumerate(ctx) - ) + count_without = sum(1 for _ in (SurfaceCode.q() * SurfaceCode.q()).enumerate(ctx)) assert count_without == 144 # With binding: 12 codes (same instance used twice) count_with = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx) + for _ in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx) ) assert count_with == 12 # Verify the binding works: with binding, both should use same params - for isa in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx): + for isa in SurfaceCode.bind("c", ISARefNode("c") * ISARefNode("c")).enumerate(ctx): logical_gates = [g for g in isa if g.encoding == LOGICAL] # Should have 2 logical gates (GENERIC and LATTICE_SURGERY) assert len(logical_gates) == 2 @@ -381,33 +266,19 @@ def test_binding_node(): # Test binding with factories (nested bindings) count_without = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] + for _ in ( + SurfaceCode.q() * ExampleFactory.q() * SurfaceCode.q() * ExampleFactory.q() ).enumerate(ctx) ) assert count_without == 1296 # 12 * 3 * 12 * 3 count_with = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=BindingNode( - name="f", - component=ISAQuery(ExampleFactory), - node=ProductNode( - sources=[ - ISARefNode("c"), - ISARefNode("f"), - ISARefNode("c"), - ISARefNode("f"), - ], - ), + for _ in SurfaceCode.bind( + "c", + ExampleFactory.bind( + "f", + ISARefNode("c") * ISARefNode("f") * ISARefNode("c") * ISARefNode("f"), ), ).enumerate(ctx) ) @@ -417,19 +288,11 @@ def test_binding_node(): # Without binding: 4 outer codes × (4 inner codes × 3 factories × 3 levels) count_without = sum( 1 - for _ in ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery( - ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISAQuery(SurfaceCode), - ISAQuery(ExampleFactory), - ] - ), - ), - ] + for _ in ( + SurfaceCode.q() + * ExampleLogicalFactory.q( + source=(SurfaceCode.q() * ExampleFactory.q()), + ) ).enumerate(ctx) ) assert count_without == 1296 # 12 * 12 * 3 * 3 @@ -437,22 +300,11 @@ def test_binding_node(): # With binding: 4 codes (same used twice) × 3 factories × 3 levels count_with = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode), - node=ProductNode( - sources=[ - ISARefNode("c"), - ISAQuery( - ExampleLogicalFactory, - source=ProductNode( - sources=[ - ISARefNode("c"), - ISAQuery(ExampleFactory), - ] - ), - ), - ] + for _ in SurfaceCode.bind( + "c", + ISARefNode("c") + * ExampleLogicalFactory.q( + source=(ISARefNode("c") * ExampleFactory.q()), ), ).enumerate(ctx) ) @@ -461,44 +313,32 @@ def test_binding_node(): # Test binding with kwargs count_with_kwargs = sum( 1 - for _ in BindingNode( - name="c", - component=ISAQuery(SurfaceCode, kwargs={"distance": 5}), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx) + for _ in SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) ) assert count_with_kwargs == 1 # Only distance=5 # Verify kwargs are applied - for isa in BindingNode( - name="c", - component=ISAQuery(SurfaceCode, kwargs={"distance": 5}), - node=ProductNode( - sources=[ISARefNode("c"), ISARefNode("c")], - ), - ).enumerate(ctx): + for isa in ( + SurfaceCode.q(distance=5) + .bind("c", ISARefNode("c") * ISARefNode("c")) + .enumerate(ctx) + ): logical_gates = [g for g in isa if g.encoding == LOGICAL] assert all(g.space(1) == 50 for g in logical_gates) # Test multiple independent bindings (nested) count = sum( 1 - for _ in BindingNode( - name="c1", - component=ISAQuery(SurfaceCode), - node=BindingNode( - name="c2", - component=ISAQuery(ExampleFactory), - node=ProductNode( - sources=[ - ISARefNode("c1"), - ISARefNode("c1"), - ISARefNode("c2"), - ISARefNode("c2"), - ], - ), + for _ in SurfaceCode.bind( + "c1", + ExampleFactory.bind( + "c2", + ISARefNode("c1") + * ISARefNode("c1") + * ISARefNode("c2") + * ISARefNode("c2"), ), ).enumerate(ctx) ) @@ -507,8 +347,8 @@ def test_binding_node(): def test_binding_node_errors(): - """Test error handling for BindingNode""" - ctx = Context(architecture=ExampleArchitecture()) + """Test error handling for binding nodes""" + ctx = AQREGateBased().context() # Test ISARefNode enumerate with undefined binding raises ValueError try: @@ -519,64 +359,208 @@ def test_binding_node_errors(): def test_product_isa_enumeration_nodes(): - terminal = ISAQuery(SurfaceCode) + from qsharp.qre._isa_enumeration import _ComponentQuery, _ProductNode + + terminal = SurfaceCode.q() query = terminal * terminal # Multiplication should create ProductNode - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 2 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Multiplying again should extend the sources query = query * terminal - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 3 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also from the other side query = terminal * query - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 4 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also for two ProductNodes query = query * query - assert isinstance(query, ProductNode) + assert isinstance(query, _ProductNode) assert len(query.sources) == 8 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) def test_sum_isa_enumeration_nodes(): - terminal = ISAQuery(SurfaceCode) + from qsharp.qre._isa_enumeration import _ComponentQuery, _SumNode + + terminal = SurfaceCode.q() query = terminal + terminal # Multiplication should create SumNode - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 2 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Multiplying again should extend the sources query = query + terminal - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 3 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also from the other side query = terminal + query - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 4 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) # Also for two SumNodes query = query + query - assert isinstance(query, SumNode) + assert isinstance(query, _SumNode) assert len(query.sources) == 8 for source in query.sources: - assert isinstance(source, ISAQuery) + assert isinstance(source, _ComponentQuery) + + +def test_qsharp_application(): + from qsharp.qre._enumeration import _enumerate_instances + + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + trace = app.get_trace() + + assert trace.compute_qubits == 3 + assert trace.depth == 3 + assert trace.resource_states == {} + + isa = ISA( + instruction( + LATTICE_SURGERY, + encoding=LOGICAL, + arity=None, + time=1000, + error_rate=linear_function(1e-6), + space=linear_function(50), + ), + instruction(T, encoding=LOGICAL, time=1000, error_rate=1e-8, space=400), + instruction(CCX, encoding=LOGICAL, time=2000, error_rate=1e-10, space=800), + ) + + # Properties from the program + counts = qsharp.logical_counts(code) + num_ts = counts["tCount"] + num_ccx = counts["cczCount"] + num_rotations = counts["rotationCount"] + rotation_depth = counts["rotationDepth"] + + lattice_surgery = LatticeSurgery() + + counter = 0 + for psspc in _enumerate_instances(PSSPC): + counter += 1 + trace2 = psspc.transform(trace) + assert trace2 is not None + trace2 = lattice_surgery.transform(trace2) + assert trace2 is not None + assert trace2.compute_qubits == 12 + assert ( + trace2.depth + == num_ts + + num_ccx * 3 + + num_rotations + + rotation_depth * psspc.num_ts_per_rotation + ) + if psspc.ccx_magic_states: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations, + CCX: num_ccx, + } + else: + assert trace2.resource_states == { + T: num_ts + psspc.num_ts_per_rotation * num_rotations + 4 * num_ccx + } + result = trace2.estimate(isa, max_error=float("inf")) + assert result is not None + _assert_estimation_result(trace2, result, isa) + assert counter == 40 + + +def test_trace_enumeration(): + code = """ + {{ + use (a, b, c) = (Qubit(), Qubit(), Qubit()); + T(a); + CCNOT(a, b, c); + Rz(1.2345, a); + }} + """ + + app = QSharpApplication(code) + + from qsharp.qre._trace import RootNode + + ctx = app.context() + root = RootNode() + assert sum(1 for _ in root.enumerate(ctx)) == 1 + + assert sum(1 for _ in PSSPC.q().enumerate(ctx)) == 40 + + assert sum(1 for _ in LatticeSurgery.q().enumerate(ctx)) == 1 + + q = PSSPC.q() * LatticeSurgery.q() + assert sum(1 for _ in q.enumerate(ctx)) == 40 + + +def test_estimation_max_error(): + from qsharp.estimator import LogicalCounts + + app = QSharpApplication(LogicalCounts({"numQubits": 100, "measurementCount": 100})) + arch = AQREGateBased() + + for max_error in [1e-1, 1e-2, 1e-3, 1e-4]: + results = estimate( + app, + arch, + PSSPC.q() * LatticeSurgery.q(), + SurfaceCode.q() * ExampleFactory.q(), + max_error=max_error, + ) + + assert len(results) == 1 + assert next(iter(results)).error <= max_error + + +def _assert_estimation_result(trace: Trace, result: EstimationResult, isa: ISA): + actual_qubits = ( + isa[LATTICE_SURGERY].expect_space(trace.compute_qubits) + + isa[T].expect_space() * result.factories[T].copies + ) + if CCX in trace.resource_states: + actual_qubits += isa[CCX].expect_space() * result.factories[CCX].copies + assert result.qubits == actual_qubits + + assert ( + result.runtime + == isa[LATTICE_SURGERY].expect_time(trace.compute_qubits) * trace.depth + ) + + actual_error = ( + trace.base_error + + isa[LATTICE_SURGERY].expect_error_rate(trace.compute_qubits) * trace.depth + + isa[T].expect_error_rate() * result.factories[T].states + ) + if CCX in trace.resource_states: + actual_error += isa[CCX].expect_error_rate() * result.factories[CCX].states + assert abs(result.error - actual_error) <= 1e-8 diff --git a/source/qre/src/trace.rs b/source/qre/src/trace.rs index 0193b3c9db..c557251534 100644 --- a/source/qre/src/trace.rs +++ b/source/qre/src/trace.rs @@ -552,20 +552,28 @@ fn get_error_rate_by_id(isa: &ISA, id: u64) -> Result { .ok_or(Error::CannotExtractErrorRate(id)) } -fn estimate_chunks<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> Vec { - let mut local_collection = Vec::new(); - for trace in traces { - for isa in isas { - if let Ok(estimation) = trace.estimate(isa, None) { - local_collection.push(estimation); +#[must_use] +pub fn estimate_parallel<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, +) -> EstimationCollection { + fn estimate_chunks<'a>( + traces: &[&'a Trace], + isas: &[&'a ISA], + max_error: Option, + ) -> Vec { + let mut local_collection = Vec::new(); + for trace in traces { + for isa in isas { + if let Ok(estimation) = trace.estimate(isa, max_error) { + local_collection.push(estimation); + } } } + local_collection } - local_collection -} -#[must_use] -pub fn estimate_parallel<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> EstimationCollection { let mut collection = EstimationCollection::new(); std::thread::scope(|scope| { let num_threads = std::thread::available_parallelism() @@ -577,7 +585,7 @@ pub fn estimate_parallel<'a>(traces: &[&'a Trace], isas: &[&'a ISA]) -> Estimati for chunk in traces.chunks(chunk_size) { let tx = tx.clone(); - scope.spawn(move || tx.send(estimate_chunks(chunk, isas))); + scope.spawn(move || tx.send(estimate_chunks(chunk, isas, max_error))); } drop(tx); diff --git a/source/qre/src/trace/tests.rs b/source/qre/src/trace/tests.rs index 6509b30048..57c422c8a4 100644 --- a/source/qre/src/trace/tests.rs +++ b/source/qre/src/trace/tests.rs @@ -144,7 +144,7 @@ fn test_lattice_surgery_transform() { assert_eq!(trace.depth(), 2); - let ls = LatticeSurgery::new(); + let ls = LatticeSurgery::default(); let transformed = ls.transform(&trace).expect("Transformation failed"); assert_eq!(transformed.compute_qubits(), 3); diff --git a/source/qre/src/trace/transforms/lattice_surgery.rs b/source/qre/src/trace/transforms/lattice_surgery.rs index 425606b99d..fd3ff45f72 100644 --- a/source/qre/src/trace/transforms/lattice_surgery.rs +++ b/source/qre/src/trace/transforms/lattice_surgery.rs @@ -4,21 +4,36 @@ use crate::trace::TraceTransform; use crate::{Error, Trace, instruction_ids}; -#[derive(Default)] -pub struct LatticeSurgery; +pub struct LatticeSurgery { + slow_down_factor: f64, +} + +impl Default for LatticeSurgery { + fn default() -> Self { + Self { + slow_down_factor: 1.0, + } + } +} impl LatticeSurgery { #[must_use] - pub fn new() -> Self { - Self + pub fn new(slow_down_factor: f64) -> Self { + Self { slow_down_factor } } } impl TraceTransform for LatticeSurgery { + #[allow( + clippy::cast_possible_truncation, + clippy::cast_precision_loss, + clippy::cast_sign_loss + )] fn transform(&self, trace: &Trace) -> Result { let mut transformed = trace.clone_empty(None); - let block = transformed.add_block(trace.depth()); + let block = + transformed.add_block((trace.depth() as f64 * self.slow_down_factor).ceil() as u64); block.add_operation( instruction_ids::LATTICE_SURGERY, (0..trace.compute_qubits()).collect(),