diff --git a/CHANGELOG.md b/CHANGELOG.md index dad0d3abf3..22616216d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added +- ✨ Add Dead Gate Elimination Pattern ([#1755]) ([**@DRovara**]) - 🚸 Add [CMake presets] to provide a standardized and reproducible way to configure builds ([#1660]) ([**@denialhaag**]) - ✨ Add a `quantum-loop-unroll` pass for unrolling for-loop operations containing quantum operations ([#1718]) ([**@MatthiasReumann**]) - ✨ Add a `hadamard-lifting` pass for lifting Hadamard gates above Pauli gates ([#1605]) ([**@lirem101**], [**@burgholzer**]) @@ -402,6 +403,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool +[#1755]: https://github.com/munich-quantum-toolkit/core/pull/1755 [#1749]: https://github.com/munich-quantum-toolkit/core/pull/1749 [#1748]: https://github.com/munich-quantum-toolkit/core/pull/1748 [#1737]: https://github.com/munich-quantum-toolkit/core/pull/1737 diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 772ea1eba2..92e75be2f7 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -78,7 +78,8 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Initialize the builder and prepare for program construction + * @brief Initialize the builder and prepare for program construction, with + * a default return type of i64. * * @details * Creates a main function with an entry_point attribute. Must be called @@ -86,6 +87,17 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void initialize(); + /** + * @brief Initialize the builder and prepare for program construction + * with specified return types. + * @param returnTypes The return types for the main function + * + * @details + * Creates a main function with an entry_point attribute. Must be called + * before adding operations. + */ + void initialize(TypeRange returnTypes); + //===--------------------------------------------------------------------===// // Constants //===--------------------------------------------------------------------===// @@ -105,6 +117,21 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ Value intConstant(int64_t value); + /** + * @brief Create a constant boolean value + * @param value The value to store in the constant + * @return The value produced by the constant operation + * + * @par Example: + * ```c++ + * auto c = builder.boolConstant(true); + * ``` + * ```mlir + * %c = arith.constant 1 : i1 + * ``` + */ + Value boolConstant(bool value); + //===--------------------------------------------------------------------===// // Memory Management //===--------------------------------------------------------------------===// @@ -1375,6 +1402,25 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ OwningOpRef finalize(); + /** + * @brief Finalize the program with a given exit code and return the + * constructed module + * @param returnValues Values representing the exit code to return + * + * @details + * Automatically deallocates all remaining valid qubits and tensors of qubits, + * adds a return statement with a given exit code, + * and transfers ownership of the module to the caller. The builder should not + * be used after calling this method. + * + * The return values must have the types indicated by the function signature + * of the main function, which returns an `i64` by default and can be + * modified by passing different arguments to the `initialize()` method. + * + * @return OwningOpRef containing the constructed quantum program module + */ + OwningOpRef finalize(ValueRange returnValues); + /** * @brief Convenience method for building quantum programs * @param context The MLIR context to use for building the program diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td index 62d1f5bba4..e7d98a0367 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td @@ -38,6 +38,8 @@ def QCODialect : Dialect { let cppNamespace = "::mlir::qco"; let useDefaultTypePrinterParser = 1; + + let hasCanonicalizer = 1; } #endif // MLIR_DIALECT_QCO_IR_QCODIALECT_TD diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a5bbfb7f51..4b05cd1442 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -96,7 +96,7 @@ def StaticOp : QCOOp<"static", [Pure]> { // Measurement and Reset Operations //===----------------------------------------------------------------------===// -def MeasureOp : QCOOp<"measure"> { +def MeasureOp : QCOOp<"measure", [Pure]> { let summary = "Measure a qubit in the computational basis"; let description = [{ Measures a qubit in the computational (Z) basis, collapsing the state @@ -119,8 +119,7 @@ def MeasureOp : QCOOp<"measure"> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, OptionalAttr:$register_name, OptionalAttr>:$register_size, OptionalAttr>:$register_index); @@ -136,9 +135,10 @@ def MeasureOp : QCOOp<"measure"> { }]>]; let hasVerifier = 1; + let hasCanonicalizer = 1; } -def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { +def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType, Pure]> { let summary = "Reset a qubit to |0⟩ state"; let description = [{ Resets a qubit to the |0⟩ state, regardless of its current state, @@ -150,8 +150,7 @@ def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { ``` }]; - let arguments = - (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -208,7 +207,8 @@ def GPhaseOp let hasCanonicalizer = 1; } -def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def IdOp + : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an Id gate to a qubit"; let description = [{ Applies an Id gate to a qubit and returns the transformed qubit. @@ -219,7 +219,7 @@ def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -232,7 +232,8 @@ def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def XOp + : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an X gate to a qubit"; let description = [{ Applies an X gate to a qubit and returns the transformed qubit. @@ -243,7 +244,7 @@ def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -256,7 +257,8 @@ def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def YOp + : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a Y gate to a qubit"; let description = [{ Applies a Y gate to a qubit and returns the transformed qubit. @@ -267,7 +269,7 @@ def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -280,7 +282,8 @@ def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def ZOp + : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a Z gate to a qubit"; let description = [{ Applies a Z gate to a qubit and returns the transformed qubit. @@ -291,7 +294,7 @@ def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -304,7 +307,8 @@ def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def HOp + : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a H gate to a qubit"; let description = [{ Applies a H gate to a qubit and returns the transformed qubit. @@ -315,7 +319,7 @@ def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -328,7 +332,8 @@ def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SOp + : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an S gate to a qubit"; let description = [{ Applies an S gate to a qubit and returns the transformed qubit. @@ -339,7 +344,7 @@ def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -352,8 +357,8 @@ def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SdgOp - : QCOOp<"sdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SdgOp : QCOOp<"sdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an Sdg gate to a qubit"; let description = [{ Applies an Sdg gate to a qubit and returns the transformed qubit. @@ -364,7 +369,7 @@ def SdgOp ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -377,7 +382,8 @@ def SdgOp let hasCanonicalizer = 1; } -def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TOp + : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a T gate to a qubit"; let description = [{ Applies a T gate to a qubit and returns the transformed qubit. @@ -388,7 +394,7 @@ def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -401,8 +407,8 @@ def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def TdgOp - : QCOOp<"tdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TdgOp : QCOOp<"tdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a Tdg gate to a qubit"; let description = [{ Applies a Tdg gate to a qubit and returns the transformed qubit. @@ -413,7 +419,7 @@ def TdgOp ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -426,7 +432,8 @@ def TdgOp let hasCanonicalizer = 1; } -def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SXOp + : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an SX gate to a qubit"; let description = [{ Applies an SX gate to a qubit and returns the transformed qubit. @@ -437,7 +444,7 @@ def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -450,8 +457,8 @@ def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SXdgOp - : QCOOp<"sxdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SXdgOp : QCOOp<"sxdg", traits = [UnitaryOpInterface, OneTargetZeroParameter, + Pure]> { let summary = "Apply an SXdg gate to a qubit"; let description = [{ Applies an SXdg gate to a qubit and returns the transformed qubit. @@ -462,7 +469,7 @@ def SXdgOp ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -475,7 +482,8 @@ def SXdgOp let hasCanonicalizer = 1; } -def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RXOp + : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply an RX gate to a qubit"; let description = [{ Applies an RX gate to a qubit and returns the transformed qubit. @@ -486,7 +494,7 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -504,7 +512,8 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RYOp + : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply an RY gate to a qubit"; let description = [{ Applies an RY gate to a qubit and returns the transformed qubit. @@ -515,7 +524,7 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -533,7 +542,8 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RZOp + : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply an RZ gate to a qubit"; let description = [{ Applies an RZ gate to a qubit and returns the transformed qubit. @@ -544,7 +554,7 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -562,7 +572,8 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def POp + : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply a P gate to a qubit"; let description = [{ Applies a P gate to a qubit and returns the transformed qubit. @@ -573,7 +584,7 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -591,7 +602,8 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def ROp + : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter, Pure]> { let summary = "Apply an R gate to a qubit"; let description = [{ Applies an R gate to a qubit and returns the transformed qubit. @@ -602,7 +614,7 @@ def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta, Arg:$phi); let results = (outs QubitType:$qubit_out); @@ -621,7 +633,8 @@ def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def U2Op + : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter, Pure]> { let summary = "Apply a U2 gate to a qubit"; let description = [{ Applies a U2 gate to a qubit and returns the transformed qubit. @@ -632,7 +645,7 @@ def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$phi, Arg:$lambda); let results = (outs QubitType:$qubit_out); @@ -651,7 +664,8 @@ def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { +def UOp + : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter, Pure]> { let summary = "Apply a U gate to a qubit"; let description = [{ Applies a U gate to a qubit and returns the transformed qubit. @@ -662,7 +676,7 @@ def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta, Arg:$phi, Arg:$lambda); @@ -683,8 +697,8 @@ def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { let hasCanonicalizer = 1; } -def SWAPOp - : QCOOp<"swap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def SWAPOp : QCOOp<"swap", traits = [UnitaryOpInterface, TwoTargetZeroParameter, + Pure]> { let summary = "Apply a SWAP gate to two qubits"; let description = [{ Applies a SWAP gate to two qubits and returns the transformed qubits. @@ -695,9 +709,8 @@ def SWAPOp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -711,8 +724,8 @@ def SWAPOp let hasCanonicalizer = 1; } -def iSWAPOp - : QCOOp<"iswap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def iSWAPOp : QCOOp<"iswap", traits = [UnitaryOpInterface, + TwoTargetZeroParameter, Pure]> { let summary = "Apply a iSWAP gate to two qubits"; let description = [{ Applies a iSWAP gate to two qubits and returns the transformed qubits. @@ -723,9 +736,8 @@ def iSWAPOp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -737,8 +749,8 @@ def iSWAPOp }]; } -def DCXOp - : QCOOp<"dcx", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def DCXOp : QCOOp<"dcx", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, Pure]> { let summary = "Apply a DCX gate to two qubits"; let description = [{ Applies a DCX gate to two qubits and returns the transformed qubits. @@ -749,9 +761,8 @@ def DCXOp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -765,8 +776,8 @@ def DCXOp let hasCanonicalizer = 1; } -def ECROp - : QCOOp<"ecr", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def ECROp : QCOOp<"ecr", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, Pure]> { let summary = "Apply an ECR gate to two qubits"; let description = [{ Applies an ECR gate to two qubits and returns the transformed qubits. @@ -777,9 +788,8 @@ def ECROp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -793,7 +803,8 @@ def ECROp let hasCanonicalizer = 1; } -def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RXXOp + : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RXX gate to two qubits"; let description = [{ Applies an RXX gate to two qubits and returns the transformed qubits. @@ -804,10 +815,9 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -825,7 +835,8 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RYYOp + : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RYY gate to two qubits"; let description = [{ Applies an RYY gate to two qubits and returns the transformed qubits. @@ -836,10 +847,9 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -857,7 +867,8 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZXOp + : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RZX gate to two qubits"; let description = [{ Applies an RZX gate to two qubits and returns the transformed qubits. @@ -868,10 +879,9 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -889,7 +899,8 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZZOp + : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RZZ gate to two qubits"; let description = [{ Applies an RZZ gate to two qubits and returns the transformed qubits. @@ -900,10 +911,9 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -921,8 +931,8 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def XXPlusYYOp : QCOOp<"xx_plus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { +def XXPlusYYOp : QCOOp<"xx_plus_yy", traits = [UnitaryOpInterface, + TwoTargetTwoParameter, Pure]> { let summary = "Apply an XX+YY gate to two qubits"; let description = [{ Applies an XX+YY gate to two qubits and returns the transformed qubits. @@ -933,11 +943,10 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta, - Arg:$beta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta, + Arg:$beta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `,` $beta `)` $qubit0_in `,` $qubit1_in " "attr-dict `:` type($qubit0_in) `,` type($qubit1_in) " @@ -956,8 +965,8 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", let hasCanonicalizer = 1; } -def XXMinusYYOp : QCOOp<"xx_minus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { +def XXMinusYYOp : QCOOp<"xx_minus_yy", traits = [UnitaryOpInterface, + TwoTargetTwoParameter, Pure]> { let summary = "Apply an XX-YY gate to two qubits"; let description = [{ Applies an XX-YY gate to two qubits and returns the transformed qubits. @@ -968,11 +977,10 @@ def XXMinusYYOp : QCOOp<"xx_minus_yy", ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta, - Arg:$beta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta, + Arg:$beta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `,` $beta `)` $qubit0_in `,` $qubit1_in " "attr-dict `:` type($qubit0_in) `,` type($qubit1_in) " @@ -991,7 +999,7 @@ def XXMinusYYOp : QCOOp<"xx_minus_yy", let hasCanonicalizer = 1; } -def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { +def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface, Pure]> { let summary = "Apply a barrier gate to a set of qubits"; let description = [{ Applies a barrier gate to a set of qubits and returns the transformed qubits. @@ -1003,7 +1011,7 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { }]; let arguments = - (ins Arg, "the target qubits", [MemRead]>:$qubits_in); + (ins Arg, "the target qubits">:$qubits_in); let results = (outs Variadic:$qubits_out); let assemblyFormat = "$qubits_in attr-dict `:` type($qubits_in) `->` type($qubits_out)"; @@ -1041,7 +1049,7 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { // Modifiers //===----------------------------------------------------------------------===// -def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike]> { +def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike, Pure]> { let summary = "Yield from a modifier region"; let description = [{ Terminates a modifier region, yielding the transformed target qubit and qtensor values back to the enclosing modifier operation. @@ -1066,7 +1074,7 @@ def CtrlOp AttrSizedResultSegments, SameOperandsAndResultType, SameOperandsAndResultShape, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, - RecursiveMemoryEffects]> { + Pure, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; let description = [{ A modifier operation that adds control qubits to the unitary operation @@ -1085,9 +1093,9 @@ def CtrlOp ``` }]; - let arguments = (ins Arg, - "the control qubits", [MemRead]>:$controls_in, - Arg, "the target qubits", [MemRead]>:$targets_in); + let arguments = + (ins Arg, "the control qubits">:$controls_in, + Arg, "the target qubits">:$targets_in); let results = (outs Variadic:$controls_out, Variadic:$targets_out); let regions = (region SizedRegion<1>:$region); @@ -1143,7 +1151,7 @@ def InvOp : QCOOp<"inv", traits = [UnitaryOpInterface, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, - RecursiveMemoryEffects]> { + Pure, RecursiveMemoryEffects]> { let summary = "Invert a unitary operation"; let description = [{ A modifier operation that inverts the unitary operation defined in its body @@ -1159,9 +1167,8 @@ def InvOp ``` }]; - let arguments = - (ins Arg, - "the qubits involved in the operation", [MemRead]>:$qubits_in); + let arguments = (ins Arg, + "the qubits involved in the operation">:$qubits_in); let results = (outs Variadic:$qubits_out); let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ @@ -1221,7 +1228,7 @@ def IfOp "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, SingleBlock, - SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, + SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, Pure, RecursiveMemoryEffects]> { let summary = "If-then-else operation for linear (qubit) types"; diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 489fceb00e..9c6af7a8f4 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -237,4 +237,42 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, return success(); } +/** + * @brief Check if given quantum operation is unused (i.e., only used by + * sinks) and remove it if so. + * + * @param op The operation to check. + * @param rewriter The pattern rewriter. + * @return LogicalResult Success or failure of the removal. + */ +inline LogicalResult checkAndRemoveDeadGate(Operation* op, + PatternRewriter& rewriter) { + if (llvm::all_of(op->getUsers(), + [](Operation* user) { return isa(user); })) { + // If the operation is only used by sinks, we can safely remove it. + if (auto u = dyn_cast(op)) { + // We specifically have to replace the output *qubits* with the input + // *qubits* to ignore parameters. + rewriter.replaceOp(op, u.getInputQubits()); + return success(); + } else if (auto m = dyn_cast(op)) { + // We specifically have to replace the output *qubits* with the input + // *qubits* to ignore the classical outcome. + rewriter.replaceAllUsesWith(m.getQubitOut(), m.getQubitIn()); + rewriter.eraseOp(op); + return success(); + } else if (auto i = dyn_cast(op)) { + // We specifically have to replace the output *qubits* with the input + // *qubits* to ignore the condition. + rewriter.replaceOp(op, i.getQubits()); + return success(); + } else { + // This currently only includes the `Reset` operation. + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + } + return failure(); +} + } // namespace mlir::qco diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index 32f678924e..bad2fb4f9c 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -174,4 +174,38 @@ def HadamardLifting : Pass<"hadamard-lifting", "mlir::ModuleOp"> { }]; } +def MeasurementLifting : Pass<"measurement-lifting", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qco::QCODialect", + "::mlir::arith::ArithDialect", + ]; + let summary = "This pass attempts to move measurements as far up as " + "possible, shifting them above gates that commute with them. " + "This is done to enable qubit reuse and other optimizations."; + let description = [{ + This pass applies measurement lifting, moving measurements up the code as far as possible. + Measurement lifting is a subroutine of the qubit reuse routine. The goal is to measure qubits earlier in the + circuit to reuse them and to potentially remove some quantum gates. + + Measurement lifting uses the following commutation rules: + ┌──────┐ ┌──────┐ + ──■──┤ Meas │────── ─┤ Meas ├──■─── + │ └──────┘ └──────┘ │ + ┌─┴─┐ => ┌─┴─┐ + ┤ U ├──────────────── ─────────┤ U ├─ + └───┘ └───┘ + (Where U is any (controlled) unitary gate) + + ┌───┐┌──────┐ ┌──────┐┌───┐ + ┤ P ├┤ Meas ├ => ┤ Meas ├┤ P ├ + └───┘└──────┘ └──────┘└───┘ + (Where P is any diagonal gate, e.g., `z`, `s`, ...) + + ┌───┐┌──────┐ ┌───────┐┌───┐ + ┤ X ├┤ Meas ├ => ┤ Meas* ├┤ X ├ + └───┘└──────┘ └───────┘└───┘ + (Where Meas* is a measurement after which the outcome is classically negated) + + }]; +} + #endif // MLIR_DIALECT_QCO_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index a07c52aa0f..22c45132d4 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -52,12 +52,14 @@ QCOProgramBuilder::QCOProgramBuilder(MLIRContext* context) ctx->loadDialect(); } -void QCOProgramBuilder::initialize() { +void QCOProgramBuilder::initialize() { initialize({getI64Type()}); } + +void QCOProgramBuilder::initialize(TypeRange returnTypes) { // Set insertion point to the module body setInsertionPointToStart(mlir::cast(module).getBody()); // Create main function as entry point - auto funcType = getFunctionType({}, {getI64Type()}); + auto funcType = getFunctionType({}, returnTypes); auto mainFunc = func::FuncOp::create(*this, "main", funcType); // Add entry_point attribute to identify the main function @@ -74,6 +76,11 @@ Value QCOProgramBuilder::intConstant(const int64_t value) { return arith::ConstantOp::create(*this, getI64IntegerAttr(value)).getResult(); } +Value QCOProgramBuilder::boolConstant(const bool value) { + checkFinalized(); + return arith::ConstantOp::create(*this, getBoolAttr(value)).getResult(); +} + Value& QCOProgramBuilder::QubitRegister::operator[](const size_t index) { if (index >= qubits.size()) { llvm::reportFatalUsageError("Qubit index out of bounds"); @@ -1098,6 +1105,13 @@ void QCOProgramBuilder::ensureAllocationMode( OwningOpRef QCOProgramBuilder::finalize() { checkFinalized(); + auto exitCode = intConstant(0); + return finalize({exitCode}); +} + +OwningOpRef QCOProgramBuilder::finalize(ValueRange returnValues) { + checkFinalized(); + // Ensure that main function exists and insertion point is valid auto* insertionBlock = getInsertionBlock(); func::FuncOp mainFunc = nullptr; @@ -1146,11 +1160,8 @@ OwningOpRef QCOProgramBuilder::finalize() { validQubits.clear(); validTensors.clear(); - // Create constant 0 for successful exit code - auto exitCode = intConstant(0); - // Add return statement with exit code 0 to the main function - func::ReturnOp::create(*this, exitCode); + func::ReturnOp::create(*this, returnValues); // Invalidate context to prevent use-after-finalize ctx = nullptr; diff --git a/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp index b76a2d0471..5d0462b6d7 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp @@ -9,12 +9,31 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/QCOUtils.h" +#include +#include #include using namespace mlir; using namespace mlir::qco; +namespace { + +/** + * @brief Remove dead measurements. + */ +struct DeadMeasurementRemoval final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MeasureOp op, + PatternRewriter& rewriter) const override { + return checkAndRemoveDeadGate(op, rewriter); + } +}; + +} // namespace + LogicalResult MeasureOp::verify() { const auto registerName = getRegisterName(); const auto registerSize = getRegisterSize(); @@ -37,3 +56,8 @@ LogicalResult MeasureOp::verify() { } return success(); } + +void MeasureOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 8832162354..462a28108d 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/QCOUtils.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" @@ -101,6 +102,18 @@ struct RemoveResetAfterExtract final : OpRewritePattern { } }; +/** + * @brief Remove dead resets. + */ +struct DeadResetRemoval final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ResetOp op, + PatternRewriter& rewriter) const override { + return checkAndRemoveDeadGate(op, rewriter); + } +}; + } // namespace OpFoldResult ResetOp::fold(FoldAdaptor /*adaptor*/) { @@ -114,4 +127,5 @@ OpFoldResult ResetOp::fold(FoldAdaptor /*adaptor*/) { void ResetOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index a3ce816081..e49d00da5a 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -11,14 +11,19 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/QCOUtils.h" #include #include +#include #include #include #include +#include #include #include +#include #include // The following headers are needed for some template instantiations. @@ -30,6 +35,34 @@ using namespace mlir; using namespace mlir::qco; +//===----------------------------------------------------------------------===// +// Dialect-Level Canonicalizers +//===----------------------------------------------------------------------===// + +namespace { + +/** + * @brief Remove dead gates. + */ +struct DeadGateElimination final + : public OpInterfaceRewritePattern { + + explicit DeadGateElimination(MLIRContext* context) + : OpInterfaceRewritePattern(context) {} + + LogicalResult matchAndRewrite(UnitaryOpInterface op, + PatternRewriter& rewriter) const override { + if (!isMemoryEffectFree(op)) { + // This effectively ignores the GPhase operation and variants such as its + // inverse or `ctrl` ops containing it, which should never be considered + // dead. + return failure(); + } + return checkAndRemoveDeadGate(op.getOperation(), rewriter); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Custom Parsers //===----------------------------------------------------------------------===// @@ -258,6 +291,10 @@ void QCODialect::initialize() { >(); } +void QCODialect::getCanonicalizationPatterns(RewritePatternSet& results) const { + results.add(getContext()); +} + //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp index 4ac2d98faa..6589fbbad7 100644 --- a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/QCOUtils.h" #include #include @@ -25,6 +26,7 @@ #include #include #include +#include #include #include @@ -234,11 +236,28 @@ struct ConditionPropagation : public OpRewritePattern { return success(changed); } }; + +/** + * @brief Remove dead `IfOp` instructions. + */ +struct DeadIfRemoval final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter& rewriter) const override { + if (!isMemoryEffectFree(op)) { + // This effectively ignores `IfOp`s with memory effects. + return failure(); + } + return checkAndRemoveDeadGate(op, rewriter); + } +}; } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + results.add(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, IfOp::getOperationName()); } diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp new file mode 100644 index 0000000000..1fb173a348 --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +// +// Created by damian on 5/13/26. +// + +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_MEASUREMENTLIFTING +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +/** + * @brief Checks if the given operation is an inverting gate. + * @param op The operation to check. + * @return True if the operation is an inverting gate, false otherwise. + */ +static bool isInverting(Operation* op) { return isa(op); } + +/** + * @brief Checks if the given operation is a diagonal gate. + * @param op The operation to check. + * @return True if the operation is a diagonal gate, false otherwise. + */ +static bool isDiagonal(Operation* op) { + if (auto c = dyn_cast(op)) { + return isDiagonal(c.getBodyUnitary()); + } + if (auto i = dyn_cast(op)) { + return isDiagonal(i.getBodyUnitary()); + } + return isa(op); +} + +/** + * @brief This method swaps a gate with a measurement. + * @param gate The gate to swap. + * @param measurement The measurement to swap. + * @param rewriter The used rewriter. + */ +static void swapGateWithMeasurement(UnitaryOpInterface gate, + MeasureOp measurement, + mlir::PatternRewriter& rewriter) { + auto measurementInput = measurement.getQubitIn(); + auto gateInput = gate.getInputForOutput(measurementInput); + rewriter.replaceUsesWithIf(measurementInput, gateInput, + [&](mlir::OpOperand& operand) { + // We only replace the single use by the + // measure op + return operand.getOwner() == measurement; + }); + rewriter.replaceUsesWithIf(gateInput, measurement.getQubitOut(), + [&](mlir::OpOperand& operand) { + // We only replace the single use by the + // predecessor + return operand.getOwner() == gate; + }); + rewriter.replaceUsesWithIf(measurement.getQubitOut(), measurementInput, + [&](mlir::OpOperand& operand) { + // All further uses of the measurement output now + // use the gate output + return operand.getOwner() != gate; + }); + rewriter.moveOpBefore(measurement, gate); +} + +namespace { +/** + * @brief This pattern is responsible for lifting measurements above any phase + * gates. + */ +struct LiftMeasurementsAbovePhaseGatesPattern final + : mlir::OpRewritePattern { + + explicit LiftMeasurementsAbovePhaseGatesPattern(mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(MeasureOp op, + mlir::PatternRewriter& rewriter) const override { + const auto qubitVariable = op.getQubitIn(); + auto* predecessor = qubitVariable.getDefiningOp(); + + auto predecessorUnitary = mlir::dyn_cast(predecessor); + + if (!predecessorUnitary) { + return mlir::failure(); + } + + if (isDiagonal(predecessor)) { + swapGateWithMeasurement(predecessorUnitary, op, rewriter); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +/** + * @brief This pattern is responsible for lifting measurements above any + * non-phase gates. + */ +struct LiftMeasurementsAboveInvertingGatesPattern final + : mlir::OpRewritePattern { + + explicit LiftMeasurementsAboveInvertingGatesPattern( + mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + /** + * @brief Checks if the given qubit is not used anymore. + * @param outQubit The output qubit to check. + * @return True if all users are resets/deallocs, false otherwise. + */ + static bool outputQubitRemainsUnused(mlir::Value outQubit) { + return llvm::all_of(outQubit.getUsers(), [](mlir::Operation* user) { + return mlir::isa(user) || mlir::isa(user); + }); + } + + mlir::LogicalResult + matchAndRewrite(MeasureOp op, + mlir::PatternRewriter& rewriter) const override { + const auto qubitVariable = op.getQubitIn(); + auto* predecessor = qubitVariable.getDefiningOp(); + + auto predecessorUnitary = mlir::dyn_cast(predecessor); + + if (!predecessorUnitary) { + return mlir::failure(); + } + + if (isInverting(predecessor) && + predecessorUnitary.getInputQubits().size() == 1) { + swapGateWithMeasurement(predecessorUnitary, op, rewriter); + rewriter.setInsertionPointAfter(op); + const mlir::Value trueConstant = rewriter.create( + op.getLoc(), rewriter.getBoolAttr(true)); + auto inversion = rewriter.create( + op.getLoc(), op.getResult(), trueConstant); + // We need `replaceUsesWithIf` so that we can replace all uses except for + // the one use that defines the inverted bit. + rewriter.replaceUsesWithIf(op.getResult(), inversion.getResult(), + [&](mlir::OpOperand& operand) { + return operand.getOwner() != inversion; + }); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +/** + * @brief This pattern is responsible for applying the "deferred measurement + * principle", lifting measurements above controls. + */ +struct LiftMeasurementsAboveControlsPattern final + : mlir::OpRewritePattern { + + explicit LiftMeasurementsAboveControlsPattern(mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(MeasureOp op, + mlir::PatternRewriter& rewriter) const override { + const auto qubitVariable = op.getQubitIn(); + auto* predecessor = qubitVariable.getDefiningOp(); + auto predecessorCtrl = mlir::dyn_cast(predecessor); + + if (!predecessorCtrl) { + return mlir::failure(); + } + + if (llvm::find(predecessorCtrl.getControlsOut(), qubitVariable) == + predecessorCtrl.getControlsOut().end()) { + // The measured qubit is a target, not a control of the gate. + return mlir::failure(); + } + + swapGateWithMeasurement(predecessorCtrl, op, rewriter); + + return mlir::success(); + } +}; + +/** + * @brief Pass raises Measurements above controlled and uncontrolled gates + * gates. + */ +struct MeasurementLifting final + : impl::MeasurementLiftingBase { + using MeasurementLiftingBase::MeasurementLiftingBase; + +protected: + void runOnOperation() override { + const auto op = getOperation(); + auto* ctx = &getContext(); + + // Define the set of patterns to use. + RewritePatternSet patterns(ctx); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + + // Apply patterns in an iterative and greedy manner. + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qco diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 413f29336d..3af2feb196 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -114,6 +114,114 @@ TEST_F(QCOTest, BuilderRejectsMixedStaticAndDynamicQubitAllocationModes) { "Cannot mix dynamic and static qubit allocation modes"); } +TEST_F(QCOTest, CheckDeadGateElimination) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto q0S0 = builder.allocQubit(); + auto q1S0 = builder.allocQubit(); + auto q0S1 = builder.h(q0S0); + auto [q0S2, q1S1] = builder.cx(q0S1, q1S0); + auto [q1S2, c1] = builder.measure(q1S1); + builder.sink(q0S2); + builder.sink(q1S2); + auto module = builder.finalize(); + + QCOProgramBuilder reference(context.get()); + reference.initialize(); + auto r0 = reference.allocQubit(); + auto r1 = reference.allocQubit(); + reference.sink(r0); + reference.sink(r1); + auto refModule = reference.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); + EXPECT_TRUE(verify(*module).succeeded()); + + ASSERT_TRUE(refModule); + EXPECT_TRUE(verify(*refModule).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(refModule.get()).succeeded()); + EXPECT_TRUE(verify(*refModule).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), refModule.get())); +} + +TEST_F(QCOTest, CheckIfOpDeadGateElimination) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto q0S0 = builder.allocQubit(); + auto q1S0 = builder.allocQubit(); + auto q0S1 = builder.h(q0S0); + auto [q0S2, c0] = builder.measure(q0S1); + + // This is an `if` with memory effects - it can't be removed. + auto q1S1 = builder.qcoIf( + c0, {q1S0}, + [&](ValueRange qubits) -> SmallVector { + auto q1Then = builder.x(qubits[0]); + builder.gphase(0.5); + return SmallVector{q1Then}; + }, + [&](ValueRange qubits) -> SmallVector { + auto q1Else = builder.h(qubits[0]); + return SmallVector{q1Else}; + })[0]; + + // This is an `if` without memory effects - it can be removed. + auto q1S2 = builder.qcoIf( + c0, {q1S1}, + [&](ValueRange qubits) -> SmallVector { + auto q1Then = builder.x(qubits[0]); + return SmallVector{q1Then}; + }, + [&](ValueRange qubits) -> SmallVector { + auto q1Else = builder.h(qubits[0]); + return SmallVector{q1Else}; + })[0]; + builder.sink(q0S2); + builder.sink(q1S2); + auto module = builder.finalize(); + + QCOProgramBuilder reference(context.get()); + reference.initialize(); + auto r0S0 = reference.allocQubit(); + auto r1S0 = reference.allocQubit(); + auto r0S1 = reference.h(r0S0); + auto [r0S2, cr0] = reference.measure(r0S1); + + // This is an `if` with memory effects - it can't be removed. + auto r1S1 = reference.qcoIf( + cr0, {r1S0}, + [&](ValueRange qubits) -> SmallVector { + auto q1Then = reference.x(qubits[0]); + reference.gphase(0.5); + return SmallVector{q1Then}; + }, + [&](ValueRange qubits) -> SmallVector { + auto q1Else = reference.h(qubits[0]); + return SmallVector{q1Else}; + })[0]; + + reference.sink(r0S2); + reference.sink(r1S1); + auto refModule = reference.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); + EXPECT_TRUE(verify(*module).succeeded()); + + ASSERT_TRUE(refModule); + EXPECT_TRUE(verify(*refModule).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(refModule.get()).succeeded()); + EXPECT_TRUE(verify(*refModule).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), refModule.get())); +} + TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly QCOProgramBuilder builder(context.get()); diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt index e80e957680..bdc78401c8 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt @@ -7,7 +7,7 @@ # Licensed under the MIT License set(target_name mqt-core-mlir-unittest-optimizations) -add_executable(${target_name} test_qco_hadamard_lifting.cpp +add_executable(${target_name} test_qco_hadamard_lifting.cpp test_qco_measurement_lifting.cpp test_qco_merge_single_qubit_rotation.cpp test_quantum_loop_unroll.cpp) target_link_libraries( diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp new file mode 100644 index 0000000000..f197a467f7 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp @@ -0,0 +1,448 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +// +// Created by damian on 5/21/26. +// + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Support/IRVerification.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using namespace mlir; +using namespace mlir::qco; + +class QCOMeasurementLiftingTest : public testing::Test { + +protected: + MLIRContext context; + QCOProgramBuilder programBuilder; + QCOProgramBuilder referenceBuilder; + OwningOpRef module; + OwningOpRef reference; + + QCOMeasurementLiftingTest() + : programBuilder(&context), referenceBuilder(&context) {} + + void SetUp() override { + // Register all necessary dialects + DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + } + + /** + * @brief Adds the measurementLiftingPass to the current context and runs it. + */ + static LogicalResult runMeasurementLiftingPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createMeasurementLifting()); + pm.addPass(createCanonicalizerPass()); + return pm.run(module); + } + + /** + * @brief Adds the canonicalizerPass to the current context and runs it. + */ + static LogicalResult runCanonicalizerPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createCanonicalizerPass()); + return pm.run(module); + } +}; + +} // namespace + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q1S1, q0S1] = programBuilder.cx(q1S0, q0S0); + auto [q0S2, q1S2] = programBuilder.ch(q0S1, q1S1); + auto [q0S3, q1S3] = programBuilder.cx(q0S2, q1S2); + + auto [q0S4, c0] = programBuilder.measure(q0S3); + auto [q1S4, c1] = programBuilder.measure(q1S3); + + programBuilder.sink(q0S4); + programBuilder.sink(q1S4); + module = programBuilder.finalize({c0, c1}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r1S1, r0S1] = referenceBuilder.cx(r1S0, r0S0); + auto [r0S2, cr0] = referenceBuilder.measure(r0S1); + auto [r0S3, r1S2] = referenceBuilder.ch(r0S2, r1S1); + auto [r0S4, r1S3] = referenceBuilder.cx(r0S3, r1S2); + + auto [r1S4, cr1] = referenceBuilder.measure(r1S3); + + referenceBuilder.sink(r0S4); + referenceBuilder.sink(r1S4); + reference = referenceBuilder.finalize({cr0, cr1}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { + programBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + + auto [q12_0, q0S1] = + programBuilder.ctrl({q1S0, q2S0}, {q0S0}, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + auto [q12_1, q0S2] = programBuilder.ctrl( + {q12_0[1], q12_0[0]}, q0S1, [&](const ValueRange target) { + return SmallVector{programBuilder.h(target[0])}; + }); + auto [q12_2, q0S3] = programBuilder.ctrl( + {q12_1[1], q12_1[0]}, q0S2, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + + auto [q1S4, c1] = programBuilder.measure(q12_2[0]); + + auto q0S4 = programBuilder.h(q0S3[0]); + auto q2S4 = programBuilder.h(q12_2[1]); + + auto [q0S5, c0] = programBuilder.measure(q0S4); + auto [q2S5, c2] = programBuilder.measure(q2S4); + + programBuilder.sink(q0S5); + programBuilder.sink(q1S4); + programBuilder.sink(q2S5); + + module = programBuilder.finalize({c0, c1, c2}); + + referenceBuilder.initialize({referenceBuilder.getI1Type(), + referenceBuilder.getI1Type(), + referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + + auto [r12_0, r0S1] = + referenceBuilder.ctrl({r1S1, r2S0}, {r0S0}, [&](const ValueRange target) { + return SmallVector{referenceBuilder.x(target[0])}; + }); + auto [r12_1, r0S2] = referenceBuilder.ctrl( + {r12_0[1], r12_0[0]}, r0S1, [&](const ValueRange target) { + return SmallVector{referenceBuilder.h(target[0])}; + }); + auto [r12_2, r0S3] = referenceBuilder.ctrl( + {r12_1[1], r12_1[0]}, r0S2, [&](const ValueRange target) { + return SmallVector{referenceBuilder.x(target[0])}; + }); + + auto r0S4 = referenceBuilder.h(r0S3[0]); + auto r2S4 = referenceBuilder.h(r12_2[1]); + + auto [r0S5, cr0] = referenceBuilder.measure(r0S4); + auto [r2S5, cr2] = referenceBuilder.measure(r2S4); + + referenceBuilder.sink(r0S5); + referenceBuilder.sink(r12_2[0]); + referenceBuilder.sink(r2S5); + + reference = referenceBuilder.finalize({cr0, cr1, cr2}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, + liftMeasurementMultipleOverOneControlledGate) { + + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + + auto [q12_0, q0S1] = + programBuilder.ctrl({q1S0, q2S0}, {q0S0}, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + + auto [q1S1, c1] = programBuilder.measure(q12_0[0]); + auto [q2S1, c2] = programBuilder.measure(q12_0[1]); + + programBuilder.sink(q0S1[0]); + programBuilder.sink(q1S1); + programBuilder.sink(q2S1); + module = programBuilder.finalize({c1, c2}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + auto [r2S1, cr2] = referenceBuilder.measure(r2S0); + + auto [r12_0, r0S1] = + referenceBuilder.ctrl({r1S1, r2S1}, {r0S0}, [&](const ValueRange target) { + return SmallVector{referenceBuilder.x(target[0])}; + }); + + referenceBuilder.sink(r0S1[0]); + referenceBuilder.sink(r12_0[0]); + referenceBuilder.sink(r12_0[1]); + reference = referenceBuilder.finalize({cr1, cr2}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, + liftMeasurementOverControlledParametrizedGate) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q0S1, q1S1] = programBuilder.crx(std::numbers::pi / 2, q0S0, q1S0); + + auto [q0S2, c0] = programBuilder.measure(q0S1); + auto [q1S2, c1] = programBuilder.measure(q1S1); + + programBuilder.sink(q0S2); + programBuilder.sink(q1S2); + module = programBuilder.finalize({c0, c1}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + auto [r0S2, r1S1] = referenceBuilder.crx(std::numbers::pi / 2, r0S1, r1S0); + + auto [r1S2, cr1] = referenceBuilder.measure(r1S1); + + referenceBuilder.sink(r0S2); + referenceBuilder.sink(r1S2); + reference = referenceBuilder.finalize({cr0, cr1}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { + + programBuilder.initialize({programBuilder.getI1Type()}); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.x(q0); + auto [q2, c] = programBuilder.measure(q1); + programBuilder.sink(q2); + module = programBuilder.finalize(c); + + referenceBuilder.initialize({referenceBuilder.getI1Type()}); + auto r0 = referenceBuilder.allocQubit(); + auto trueConstant = referenceBuilder.boolConstant(true); + auto [r1, cr] = referenceBuilder.measure(r0); + + auto xorOp = arith::XOrIOp::create( + referenceBuilder, referenceBuilder.getLoc(), cr, trueConstant); + referenceBuilder.sink(r1); + reference = referenceBuilder.finalize(xorOp.getResult()); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { + programBuilder.initialize({programBuilder.getI1Type()}); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.y(q0); + auto [q2, c] = programBuilder.measure(q1); + programBuilder.sink(q2); + module = programBuilder.finalize({c}); + + referenceBuilder.initialize({referenceBuilder.getI1Type()}); + auto r0 = referenceBuilder.allocQubit(); + auto trueConstant = referenceBuilder.boolConstant(true); + auto [r1, cr] = referenceBuilder.measure(r0); + auto xorOp = arith::XOrIOp::create( + referenceBuilder, referenceBuilder.getLoc(), cr, trueConstant); + referenceBuilder.sink(r1); + reference = referenceBuilder.finalize({xorOp.getResult()}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { + programBuilder.initialize({programBuilder.getI1Type()}); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.id(q0); + auto q2 = programBuilder.z(q1); + auto q3 = programBuilder.s(q2); + auto q4 = programBuilder.sdg(q3); + auto q5 = programBuilder.t(q4); + auto q6 = programBuilder.tdg(q5); + auto q7 = programBuilder.p(std::numbers::pi / 2, q6); + auto q8 = programBuilder.rz(std::numbers::pi / 2, q7); + auto [q9, c] = programBuilder.measure(q8); + programBuilder.sink(q9); + module = programBuilder.finalize({c}); + + referenceBuilder.initialize({referenceBuilder.getI1Type()}); + auto r0 = referenceBuilder.allocQubit(); + auto [r1, cr] = referenceBuilder.measure(r0); + referenceBuilder.sink(r1); + reference = referenceBuilder.finalize({cr}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverMultipleXY) { + programBuilder.initialize({programBuilder.getI1Type()}); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.x(q0); + auto q2 = programBuilder.y(q1); + auto [q3, c] = programBuilder.measure(q2); + programBuilder.sink(q3); + module = programBuilder.finalize({c}); + + referenceBuilder.initialize({referenceBuilder.getI1Type()}); + auto r0 = referenceBuilder.allocQubit(); + auto [r1, cr] = referenceBuilder.measure(r0); + referenceBuilder.sink(r1); + reference = referenceBuilder.finalize({cr}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { + programBuilder.initialize({programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q0S1, q1S1] = programBuilder.cy(q0S0, q1S0); + auto q0S2 = programBuilder.x(q0S1); + auto [q0S3, q1S2] = programBuilder.cy(q0S2, q1S1); + auto q0S4 = programBuilder.x(q0S3); + + auto [q0S5, c0] = programBuilder.measure(q0S4); + + programBuilder.sink(q0S5); + programBuilder.sink(q1S2); + module = programBuilder.finalize({c0}); + + referenceBuilder.initialize({referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + auto [r0S2, r1S1] = referenceBuilder.cx(r0S1, r1S0); + auto r0S3 = referenceBuilder.x(r0S2); + auto [r0S4, r1S2] = referenceBuilder.cx(r0S3, r1S1); + + referenceBuilder.sink(r0S4); + referenceBuilder.sink(r1S2); + reference = referenceBuilder.finalize({cr0}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverDiagonalGateInControl) { + programBuilder.initialize({programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q0S1, q1S1] = programBuilder.cz(q0S0, q1S0); + + auto [q0S2, c0] = programBuilder.measure(q0S1); + + programBuilder.sink(q0S2); + programBuilder.sink(q1S1); + module = programBuilder.finalize({c0}); + + referenceBuilder.initialize({referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S0); + reference = referenceBuilder.finalize({cr0}); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +}