From 7436f34d6073054792dce8ffd9a730ec3e218497 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Wed, 13 May 2026 16:25:30 +0200 Subject: [PATCH 01/21] feat(mlir): :sparkles: add pass and patterns for measurement lifting --- .../mlir/Dialect/QCO/Transforms/Passes.td | 34 +++ .../Optimizations/MeasurementLifting.cpp | 236 ++++++++++++++++++ 2 files changed, 270 insertions(+) create mode 100644 mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index 54fccffdc7..c697b7dad4 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -158,4 +158,38 @@ def HadamardLifting : Pass<"hadamard-lifting", "mlir::ModuleOp"> { }]; } +def MeasurementLifting : Pass<"measurement-lifting", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qco::QCODialect", + "::mlir::arith::ArithDialect", + ]; + let summary = "This pass attempts to move measurements as far up as" + "possible, shiftling them above gates that commute with them." + "This is done to enable qubit reuse and other optimizations."; + let description = [{ + This pass lifts measurements gates away from the measurements in order to apply measurement lifting more effectively. + Measurement lifting is a subroutine of the qubit reuse routine. The goal is to measure qubits earlier in the + circuit to reuse them and to potentially remove some quantum gates. + + Measurement lifting uses the following commutation rules: + ┌──────┐ ┌──────┐ + ──■──┤ Meas │────── ─┤ Meas ├──■─── + │ └──────┘ └──────┘ │ + ┌─┴─┐ => ┌─┴─┐ + ┤ U ├──────────────── ─────────┤ U ├─ + └───┘ └───┘ + (Where U is any (controlled) unitary gate) + + ┌───┐┌──────┐ ┌──────┐┌───┐ + ┤ P ├┤ Meas ├ => ┤ Meas ├┤ P ├ + └───┘└──────┘ └──────┘└───┘ + (Where P is any diagonal gate, e.g., `z`, `s`, ...) + + ┌───┐┌──────┐ ┌───────┐┌───┐ + ┤ X ├┤ Meas ├ => ┤ Meas* ├┤ X ├ + └───┘└──────┘ └───────┘└───┘ + (Where Meas* is a measurement after which the outcome is classically negated) + + }]; +} + #endif // MLIR_DIALECT_QCO_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp new file mode 100644 index 0000000000..02589d814d --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp @@ -0,0 +1,236 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +// +// Created by damian on 5/13/26. +// + +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_MEASUREMENTLIFTING +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +namespace { + +/** + * @brief Checks if the given operation is an inverting gate. + * @param op The operation to check. + * @return True if the operation is an inverting gate, false otherwise. + */ +bool isInverting(Operation* op) { return isa(op); } + +/** + * @brief Checks if the given operation is a diagonal gate. + * @param op The operation to check. + * @return True if the operation is a diagonal gate, false otherwise. + */ +bool isDiagonal(Operation* op) { return isa(op); } + +/** + * @brief This method swaps a gate with a measurement. + * @param gate The gate to swap. + * @param measurement The measurement to swap. + * @param rewriter The used rewriter. + */ +void swapGateWithMeasurement(UnitaryOpInterface gate, MeasureOp measurement, + mlir::PatternRewriter& rewriter) { + auto measurementInput = measurement.getQubitIn(); + auto gateInput = gate.getInputForOutput(measurementInput); + rewriter.replaceUsesWithIf(measurementInput, gateInput, + [&](mlir::OpOperand& operand) { + // We only replace the single use by the + // measure op + return operand.getOwner() == measurement; + }); + rewriter.replaceUsesWithIf(gateInput, measurement.getQubitOut(), + [&](mlir::OpOperand& operand) { + // We only replace the single use by the + // predecessor + return operand.getOwner() == gate; + }); + rewriter.replaceUsesWithIf(measurement.getQubitOut(), measurementInput, + [&](mlir::OpOperand& operand) { + // All further uses of the measurement output now + // use the gate output + return operand.getOwner() != gate; + }); + rewriter.moveOpBefore(measurement, gate); +} + +/** + * @brief This pattern is responsible for lifting measurements above any phase + * gates. + */ +struct LiftMeasurementsAbovePhaseGatesPattern final + : mlir::OpRewritePattern { + + explicit LiftMeasurementsAbovePhaseGatesPattern(mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(MeasureOp op, + mlir::PatternRewriter& rewriter) const override { + const auto qubitVariable = op.getQubitIn(); + auto* predecessor = qubitVariable.getDefiningOp(); + + auto predecessorUnitary = mlir::dyn_cast(predecessor); + + if (!predecessorUnitary) { + return mlir::failure(); + } + + if (isDiagonal(predecessor)) { + swapGateWithMeasurement(predecessorUnitary, op, rewriter); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +/** + * @brief This pattern is responsible for lifting measurements above any + * non-phase gates. + */ +struct LiftMeasurementsAboveInvertingGatesPattern final + : mlir::OpRewritePattern { + + explicit LiftMeasurementsAboveInvertingGatesPattern( + mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + /** + * @brief Checks if the given qubit is not used anymore. + * @param outQubit The output qubit to check. + * @return True if all users are resets/deallocs, false otherwise. + */ + static bool outputQubitRemainsUnused(mlir::Value outQubit) { + return llvm::all_of(outQubit.getUsers(), [](mlir::Operation* user) { + return mlir::isa(user) || mlir::isa(user); + }); + } + + mlir::LogicalResult + matchAndRewrite(MeasureOp op, + mlir::PatternRewriter& rewriter) const override { + if (!outputQubitRemainsUnused(op.getQubitOut())) { + return mlir::failure(); // if the qubit is still used after the + // measurement, we cannot lift it above the gate. + } + const auto qubitVariable = op.getQubitIn(); + auto* predecessor = qubitVariable.getDefiningOp(); + + auto predecessorUnitary = mlir::dyn_cast(predecessor); + + if (!predecessorUnitary) { + return mlir::failure(); + } + + if (isInverting(predecessor) && + predecessorUnitary.getInputQubits().size() == 1) { + swapGateWithMeasurement(predecessorUnitary, op, rewriter); + rewriter.setInsertionPointAfter(op); + const mlir::Value trueConstant = rewriter.create( + op.getLoc(), rewriter.getBoolAttr(true)); + auto inversion = rewriter.create( + op.getLoc(), op.getResult(), trueConstant); + // We need `replaceUsesWithIf` so that we can replace all uses except for + // the one use that defines the inverted bit. + rewriter.replaceUsesWithIf(op.getResult(), inversion.getResult(), + [&](mlir::OpOperand& operand) { + return operand.getOwner() != inversion; + }); + return mlir::success(); + } + + return mlir::failure(); + } +}; + +/** + * @brief This pattern is responsible for applying the "deferred measurement + * principle", lifting measurements above controls. + */ +struct LiftMeasurementsAboveControlsPattern final + : mlir::OpRewritePattern { + + explicit LiftMeasurementsAboveControlsPattern(mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(MeasureOp op, + mlir::PatternRewriter& rewriter) const override { + const auto qubitVariable = op.getQubitIn(); + auto* predecessor = qubitVariable.getDefiningOp(); + auto predecessorUnitary = mlir::dyn_cast(predecessor); + + if (!predecessorUnitary) { + return mlir::failure(); + } + + if (llvm::find(predecessorUnitary.getOutputQubits(), qubitVariable) != + predecessorUnitary.getOutputQubits().end()) { + // The measured qubit is a target, not a control of the gate. + return mlir::failure(); + } + + swapGateWithMeasurement(predecessorUnitary, op, rewriter); + + return mlir::success(); + } +}; + +/** + * @brief Pass raises Measurements above controlled and uncontrolled gates + * gates. + */ +struct MeasurementLifting final + : impl::MeasurementLiftingBase { + using MeasurementLiftingBase::MeasurementLiftingBase; + +protected: + void runOnOperation() override { + const auto op = getOperation(); + auto* ctx = &getContext(); + + // Define the set of patterns to use. + RewritePatternSet patterns(ctx); + patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); + patterns.add(patterns.getContext()); + + // Apply patterns in an iterative and greedy manner. + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qco From b4d8e0ea1a9fdabb94b6e1e0cbcf2c0eadba511b Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Mon, 1 Jun 2026 13:51:43 +0200 Subject: [PATCH 02/21] test(mlir/hybrid-opt): :construction: set up tests for measurement lifting --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 30 ++ .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 13 +- .../Transforms/Optimizations/CMakeLists.txt | 2 +- .../test_qco_measurement_lifting.cpp | 428 ++++++++++++++++++ 4 files changed, 469 insertions(+), 4 deletions(-) create mode 100644 mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 772ea1eba2..600e1c1ffb 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -105,6 +105,21 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ Value intConstant(int64_t value); + /** + * @brief Create a constant boolean value + * @param value The value to store in the constant + * @return The value produced by the constant operation + * + * @par Example: + * ```c++ + * auto c = builder.boolConstant(true); + * ``` + * ```mlir + * %c = arith.constant 1 : i1 + * ``` + */ + Value boolConstant(bool value); + //===--------------------------------------------------------------------===// // Memory Management //===--------------------------------------------------------------------===// @@ -1375,6 +1390,21 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ OwningOpRef finalize(); + /** + * @brief Finalize the program with a given exit code and return the + * constructed module + * @param exitCode Value representing the exit code to return + * + * @details + * Automatically deallocates all remaining valid qubits and tensors of qubits, + * adds a return statement with a given exit code, + * and transfers ownership of the module to the caller. The builder should not + * be used after calling this method. + * + * @return OwningOpRef containing the constructed quantum program module + */ + OwningOpRef finalize(Value exitCode); + /** * @brief Convenience method for building quantum programs * @param context The MLIR context to use for building the program diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index a07c52aa0f..5d9b243352 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -74,6 +74,11 @@ Value QCOProgramBuilder::intConstant(const int64_t value) { return arith::ConstantOp::create(*this, getI64IntegerAttr(value)).getResult(); } +Value QCOProgramBuilder::boolConstant(const bool value) { + checkFinalized(); + return arith::ConstantOp::create(*this, getBoolAttr(value)).getResult(); +} + Value& QCOProgramBuilder::QubitRegister::operator[](const size_t index) { if (index >= qubits.size()) { llvm::reportFatalUsageError("Qubit index out of bounds"); @@ -1096,6 +1101,11 @@ void QCOProgramBuilder::ensureAllocationMode( } OwningOpRef QCOProgramBuilder::finalize() { + auto exitCode = intConstant(0); + return finalize(exitCode); +} + +OwningOpRef QCOProgramBuilder::finalize(Value exitCode) { checkFinalized(); // Ensure that main function exists and insertion point is valid @@ -1146,9 +1156,6 @@ OwningOpRef QCOProgramBuilder::finalize() { validQubits.clear(); validTensors.clear(); - // Create constant 0 for successful exit code - auto exitCode = intConstant(0); - // Add return statement with exit code 0 to the main function func::ReturnOp::create(*this, exitCode); diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt index b785e5a400..9e49278edb 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt @@ -7,7 +7,7 @@ # Licensed under the MIT License set(target_name mqt-core-mlir-unittest-optimizations) -add_executable(${target_name} test_qco_hadamard_lifting.cpp +add_executable(${target_name} test_qco_hadamard_lifting.cpp test_qco_measurement_lifting.cpp test_qco_merge_single_qubit_rotation.cpp) target_link_libraries( diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp new file mode 100644 index 0000000000..f0373ebfd5 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp @@ -0,0 +1,428 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +// +// Created by damian on 5/21/26. +// + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Support/IRVerification.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using namespace mlir; +using namespace mlir::qco; + +class QCOMeasurementLiftingTest : public testing::Test { + +protected: + MLIRContext context; + QCOProgramBuilder programBuilder; + QCOProgramBuilder referenceBuilder; + OwningOpRef module; + OwningOpRef reference; + + QCOMeasurementLiftingTest() + : programBuilder(&context), referenceBuilder(&context) {} + + void SetUp() override { + // Register all necessary dialects + DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + + programBuilder.initialize(); + referenceBuilder.initialize(); + } + + /** + * @brief Adds the measurementLiftingPass to the current context and runs it. + */ + static LogicalResult runMeasurementLiftingPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createMeasurementLifting()); + pm.addPass(createCanonicalizerPass()); + return pm.run(module); + } + + /** + * @brief Adds the canonicalizerPass to the current context and runs it. + */ + static LogicalResult runCanonicalizerPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createCanonicalizerPass()); + return pm.run(module); + } + + static Value i1ToI64(Value i1Value, ImplicitLocOpBuilder& builder) { + return arith::ExtUIOp::create(builder, builder.getI64Type(), i1Value) + .getResult(); + } +}; + +} // namespace + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { + auto q0_0 = programBuilder.allocQubit(); + auto q1_0 = programBuilder.allocQubit(); + + auto [q1_1, q0_1] = programBuilder.cx(q1_0, q0_0); + auto [q0_2, q1_2] = programBuilder.ch(q0_1, q1_1); + auto [q0_3, q1_3] = programBuilder.cx(q0_2, q1_2); + + auto [q0_4, c0] = programBuilder.measure(q0_3); + auto [q1_4, c1] = programBuilder.measure(q1_3); + + programBuilder.sink(q0_4); + programBuilder.sink(q1_4); + module = programBuilder.finalize(); + + auto r0_0 = referenceBuilder.allocQubit(); + auto r1_0 = referenceBuilder.allocQubit(); + + auto [r1_1, r0_1] = referenceBuilder.cx(r1_0, r0_0); + auto [r0_2, cr0] = referenceBuilder.measure(r0_1); + auto [r0_3, r1_2] = referenceBuilder.ch(r0_2, r1_1); + auto [r0_4, r1_3] = referenceBuilder.cx(r0_3, r1_2); + + auto [r1_4, cr1] = referenceBuilder.measure(r1_3); + + referenceBuilder.sink(r0_4); + referenceBuilder.sink(r1_4); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { + auto q0_0 = programBuilder.allocQubit(); + auto q1_0 = programBuilder.allocQubit(); + auto q2_0 = programBuilder.allocQubit(); + + auto [q12_0, q0_1] = + programBuilder.ctrl({q1_0, q2_0}, {q0_0}, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + auto [q12_1, q0_2] = programBuilder.ctrl( + {q12_0[1], q12_0[0]}, q0_1, [&](const ValueRange target) { + return SmallVector{programBuilder.h(target[0])}; + }); + auto [q12_2, q0_3] = programBuilder.ctrl( + {q12_1[1], q12_1[0]}, q0_2, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + + auto [q1_4, c1] = programBuilder.measure(q12_2[0]); + + auto q0_4 = programBuilder.h(q0_3[0]); + auto q2_4 = programBuilder.h(q12_2[1]); + + auto [q0_5, c0] = programBuilder.measure(q0_4); + auto [q2_5, c2] = programBuilder.measure(q2_4); + + programBuilder.sink(q0_5); + programBuilder.sink(q1_4); + programBuilder.sink(q2_5); + + module = programBuilder.finalize(); + + auto r0_0 = referenceBuilder.allocQubit(); + auto r1_0 = referenceBuilder.allocQubit(); + auto r2_0 = referenceBuilder.allocQubit(); + + auto [r1_1, cr1] = referenceBuilder.measure(r1_0); + + auto [r12_0, r0_1] = + referenceBuilder.ctrl({r1_1, r2_0}, {r0_0}, [&](const ValueRange target) { + return SmallVector{referenceBuilder.x(target[0])}; + }); + auto [r12_1, r0_2] = referenceBuilder.ctrl( + {r12_0[1], r12_0[0]}, r0_1, [&](const ValueRange target) { + return SmallVector{referenceBuilder.h(target[0])}; + }); + auto [r12_2, r0_3] = referenceBuilder.ctrl( + {r12_1[1], r12_1[0]}, r0_2, [&](const ValueRange target) { + return SmallVector{referenceBuilder.x(target[0])}; + }); + + auto r0_4 = referenceBuilder.h(r0_3[0]); + auto r2_4 = referenceBuilder.h(r12_2[1]); + + auto [r0_5, cr0] = referenceBuilder.measure(r0_4); + auto [r2_5, cr2] = referenceBuilder.measure(r2_4); + + referenceBuilder.sink(r0_5); + referenceBuilder.sink(r12_2[0]); + referenceBuilder.sink(r2_5); + + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, + liftMeasurementMultipleOverOneControlledGate) { + auto q0_0 = programBuilder.allocQubit(); + auto q1_0 = programBuilder.allocQubit(); + auto q2_0 = programBuilder.allocQubit(); + + auto [q12_0, q0_1] = + programBuilder.ctrl({q1_0, q2_0}, {q0_0}, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + + auto [q1_1, c1] = programBuilder.measure(q12_0[0]); + auto [q2_1, c2] = programBuilder.measure(q12_0[1]); + + programBuilder.sink(q0_1[0]); + programBuilder.sink(q1_1); + programBuilder.sink(q2_1); + module = programBuilder.finalize(); + + auto r0_0 = referenceBuilder.allocQubit(); + auto r1_0 = referenceBuilder.allocQubit(); + auto r2_0 = referenceBuilder.allocQubit(); + + auto [r1_1, cr1] = programBuilder.measure(r1_0); + auto [r2_1, cr2] = programBuilder.measure(r2_0); + + auto [r12_0, r0_1] = + programBuilder.ctrl({r1_1, r2_1}, {r0_0}, [&](const ValueRange target) { + return SmallVector{programBuilder.x(target[0])}; + }); + + referenceBuilder.sink(r0_1[0]); + referenceBuilder.sink(r12_0[0]); + referenceBuilder.sink(r12_0[1]); + module = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, + liftMeasurementOverControlledParametrizedGate) { + auto q0_0 = programBuilder.allocQubit(); + auto q1_0 = programBuilder.allocQubit(); + + auto [q0_1, q1_1] = programBuilder.crx(std::numbers::pi / 2, q0_0, q1_0); + + auto [q0_2, c0] = programBuilder.measure(q0_1); + auto [q1_2, c1] = programBuilder.measure(q1_1); + + programBuilder.sink(q0_2); + programBuilder.sink(q1_2); + module = programBuilder.finalize(); + + auto r0_0 = referenceBuilder.allocQubit(); + auto r1_0 = referenceBuilder.allocQubit(); + + auto [r0_1, cr0] = referenceBuilder.measure(r0_0); + + auto [r0_2, r1_1] = referenceBuilder.crx(std::numbers::pi / 2, r0_1, r1_0); + + auto [r1_2, cr1] = referenceBuilder.measure(r1_1); + + referenceBuilder.sink(r0_2); + referenceBuilder.sink(r1_2); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { + auto q_0 = programBuilder.allocQubit(); + auto q_1 = programBuilder.x(q_0); + auto [q_2, c] = programBuilder.measure(q_1); + programBuilder.sink(q_2); + module = programBuilder.finalize(i1ToI64(c, programBuilder)); + + auto r_0 = referenceBuilder.allocQubit(); + auto true_constant = referenceBuilder.boolConstant(true); + auto [r_1, cr] = referenceBuilder.measure(r_0); + + auto xorOp = arith::XOrIOp::create( + referenceBuilder, referenceBuilder.getLoc(), cr, true_constant); + referenceBuilder.sink(r_1); + reference = + referenceBuilder.finalize(i1ToI64(xorOp.getResult(), referenceBuilder)); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + reference.get().dump(); + module.get().dump(); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { + auto q_0 = programBuilder.allocQubit(); + auto q_1 = programBuilder.y(q_0); + auto [q_2, c] = programBuilder.measure(q_1); + programBuilder.sink(q_2); + module = programBuilder.finalize(); + + auto r_0 = referenceBuilder.allocQubit(); + auto true_constant = referenceBuilder.boolConstant(true); + auto [r_1, cr] = referenceBuilder.measure(r_0); + referenceBuilder.insert(arith::XOrIOp::create( + referenceBuilder, referenceBuilder.getLoc(), cr, true_constant)); + referenceBuilder.sink(r_1); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { + auto q_0 = programBuilder.allocQubit(); + auto q_1 = programBuilder.id(q_0); + auto q_2 = programBuilder.z(q_1); + auto q_3 = programBuilder.s(q_2); + auto q_4 = programBuilder.sdg(q_3); + auto q_5 = programBuilder.t(q_4); + auto q_6 = programBuilder.tdg(q_5); + auto q_7 = programBuilder.p(std::numbers::pi / 2, q_6); + auto q_8 = programBuilder.rz(std::numbers::pi / 2, q_7); + auto [q_9, c] = programBuilder.measure(q_8); + programBuilder.sink(q_9); + module = programBuilder.finalize(); + + auto r_0 = referenceBuilder.allocQubit(); + auto [r_1, cr] = referenceBuilder.measure(r_0); + referenceBuilder.sink(r_1); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverMultipleXY) { + auto q_0 = programBuilder.allocQubit(); + auto q_1 = programBuilder.x(q_0); + auto q_2 = programBuilder.y(q_1); + auto [q_3, c] = programBuilder.measure(q_2); + programBuilder.sink(q_3); + module = programBuilder.finalize(); + + auto r_0 = referenceBuilder.allocQubit(); + auto [r_1, cr] = referenceBuilder.measure(r_0); + referenceBuilder.sink(r_1); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { + auto q0_0 = programBuilder.allocQubit(); + auto q1_0 = programBuilder.allocQubit(); + + auto [q0_1, q1_1] = programBuilder.cy(q0_0, q1_0); + auto q0_2 = programBuilder.x(q0_1); + auto [q0_3, q1_2] = programBuilder.cy(q0_2, q1_1); + auto q0_4 = programBuilder.x(q0_3); + + auto [q0_5, c0] = programBuilder.measure(q0_4); + + programBuilder.sink(q0_5); + programBuilder.sink(q1_2); + module = programBuilder.finalize(); + + auto r0_0 = referenceBuilder.allocQubit(); + auto r1_0 = referenceBuilder.allocQubit(); + + auto [r0_1, cr0] = referenceBuilder.measure(r0_0); + + auto [r0_2, r1_1] = referenceBuilder.cx(r0_1, r1_0); + auto r0_3 = referenceBuilder.x(r0_2); + auto [r0_4, r1_2] = referenceBuilder.cx(r0_3, r1_1); + + referenceBuilder.sink(r0_4); + referenceBuilder.sink(r1_2); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverDiagonalGateInControl) { + auto q0_0 = programBuilder.allocQubit(); + auto q1_0 = programBuilder.allocQubit(); + + auto [q0_1, q1_1] = programBuilder.cz(q0_0, q1_0); + + auto [q0_2, c0] = programBuilder.measure(q0_1); + + programBuilder.sink(q0_2); + programBuilder.sink(q1_1); + module = programBuilder.finalize(); + + auto r0_0 = referenceBuilder.allocQubit(); + auto r1_0 = referenceBuilder.allocQubit(); + + auto [r0_1, cr0] = referenceBuilder.measure(r0_0); + + referenceBuilder.sink(r0_1); + referenceBuilder.sink(r1_0); + reference = referenceBuilder.finalize(); + + ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} From 91b3ef949edc12e84f376ace12a3a2ee5b2c64ab Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Mon, 1 Jun 2026 14:22:26 +0200 Subject: [PATCH 03/21] feat(mlir): :sparkles: implement dead gate elimination canonicalization patterns --- .../include/mlir/Dialect/QCO/IR/QCODialect.td | 2 ++ mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 1 + mlir/include/mlir/Dialect/QCO/QCOUtils.h | 18 +++++++++++++++ .../Dialect/QCO/IR/Operations/MeasureOp.cpp | 22 +++++++++++++++++++ .../lib/Dialect/QCO/IR/Operations/ResetOp.cpp | 14 ++++++++++++ mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 20 +++++++++++++++++ mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp | 14 ++++++++++++ 7 files changed, 91 insertions(+) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td index 62d1f5bba4..e7d98a0367 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.td @@ -38,6 +38,8 @@ def QCODialect : Dialect { let cppNamespace = "::mlir::qco"; let useDefaultTypePrinterParser = 1; + + let hasCanonicalizer = 1; } #endif // MLIR_DIALECT_QCO_IR_QCODIALECT_TD diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a5bbfb7f51..04cda7df4a 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -136,6 +136,7 @@ def MeasureOp : QCOOp<"measure"> { }]>]; let hasVerifier = 1; + let hasCanonicalizer = 1; } def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 489fceb00e..5a21b5091c 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -237,4 +237,22 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, return success(); } +/** + * @brief Check if given quantum operation is unused (i.e., only used by + * deallocations) and remove it if so. + * + * @param op The operation to check. + * @param rewriter The pattern rewriter. + * @return LogicalResult Success or failure of the removal. + */ +LogicalResult checkAndRemoveDeadGate(Operation* op, PatternRewriter& rewriter) { + if (std::all_of(op->getUsers().begin(), op->getUsers().end(), + [](Operation* user) { return isa(user); })) { + // If the operation is only used by deallocs, we can safely remove it. + rewriter.replaceOp(op, op->getOperands()); + return success(); + } + return failure(); +} + } // namespace mlir::qco diff --git a/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp index b76a2d0471..a71412a671 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp @@ -9,12 +9,29 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/QCOUtils.h" #include using namespace mlir; using namespace mlir::qco; +namespace { + +/** + * @brief Remove dead measurements. + */ +struct DeadMeasurementRemoval final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MeasureOp op, + PatternRewriter& rewriter) const override { + return checkAndRemoveDeadGate(op, rewriter); + } +}; + +} // namespace + LogicalResult MeasureOp::verify() { const auto registerName = getRegisterName(); const auto registerSize = getRegisterSize(); @@ -37,3 +54,8 @@ LogicalResult MeasureOp::verify() { } return success(); } + +void MeasureOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp index 8832162354..462a28108d 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/ResetOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/QCOUtils.h" #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" @@ -101,6 +102,18 @@ struct RemoveResetAfterExtract final : OpRewritePattern { } }; +/** + * @brief Remove dead resets. + */ +struct DeadResetRemoval final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ResetOp op, + PatternRewriter& rewriter) const override { + return checkAndRemoveDeadGate(op, rewriter); + } +}; + } // namespace OpFoldResult ResetOp::fold(FoldAdaptor /*adaptor*/) { @@ -114,4 +127,5 @@ OpFoldResult ResetOp::fold(FoldAdaptor /*adaptor*/) { void ResetOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index a3ce816081..b2d75a7cad 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include "mlir/Dialect/QCO/QCOUtils.h" #include #include @@ -30,6 +31,21 @@ using namespace mlir; using namespace mlir::qco; +//===----------------------------------------------------------------------===// +// Dialect-Level Canonicalizers +//===----------------------------------------------------------------------===// + +namespace { +struct DeadGateElimination + : public mlir::OpInterfaceRewritePattern { + + LogicalResult matchAndRewrite(UnitaryOpInterface op, + PatternRewriter& rewriter) const override { + return checkAndRemoveDeadGate(op.getOperation(), rewriter); + } +}; +} // namespace + //===----------------------------------------------------------------------===// // Custom Parsers //===----------------------------------------------------------------------===// @@ -258,6 +274,10 @@ void QCODialect::initialize() { >(); } +void QCODialect::getCanonicalizationPatterns(RewritePatternSet& results) const { + results.add(getContext()); +} + //===----------------------------------------------------------------------===// // Types //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp index 4ac2d98faa..bb401dd9f3 100644 --- a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp @@ -9,6 +9,7 @@ */ #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/QCOUtils.h" #include #include @@ -234,11 +235,24 @@ struct ConditionPropagation : public OpRewritePattern { return success(changed); } }; + +/** + * @brief Remove dead resets. + */ +struct DeadIfRemoval final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IfOp op, + PatternRewriter& rewriter) const override { + return checkAndRemoveDeadGate(op, rewriter); + } +}; } // namespace void IfOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + results.add(context); populateRegionBranchOpInterfaceCanonicalizationPatterns( results, IfOp::getOperationName()); } From ba57aa2f28d4df7a5f873314d2126572f1b22fc1 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Mon, 1 Jun 2026 15:15:54 +0200 Subject: [PATCH 04/21] fix(mlir): :bug: fix tests --- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 15 ++++++++++++--- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 12 ++++++++++-- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 5a21b5091c..8a46a7b72c 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -245,12 +245,21 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, * @param rewriter The pattern rewriter. * @return LogicalResult Success or failure of the removal. */ -LogicalResult checkAndRemoveDeadGate(Operation* op, PatternRewriter& rewriter) { +inline LogicalResult checkAndRemoveDeadGate(Operation* op, + PatternRewriter& rewriter) { if (std::all_of(op->getUsers().begin(), op->getUsers().end(), [](Operation* user) { return isa(user); })) { // If the operation is only used by deallocs, we can safely remove it. - rewriter.replaceOp(op, op->getOperands()); - return success(); + if (auto u = dyn_cast(op)) { + // We specifically have to replace the output *qubits* with the input + // *qubits* to ignore parameters. + rewriter.replaceOp(op, u.getInputQubits()); + return success(); + } else { + // This includes the `IfOp` as well as `Reset` and `Measure`. + rewriter.replaceOp(op, op->getOperands()); + return success(); + } } return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index b2d75a7cad..61031d3258 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -36,11 +36,19 @@ using namespace mlir::qco; //===----------------------------------------------------------------------===// namespace { -struct DeadGateElimination - : public mlir::OpInterfaceRewritePattern { +struct DeadGateElimination final + : public OpInterfaceRewritePattern { + + explicit DeadGateElimination(MLIRContext* context) + : OpInterfaceRewritePattern(context) {} LogicalResult matchAndRewrite(UnitaryOpInterface op, PatternRewriter& rewriter) const override { + if (op->use_empty()) { + // This effectively ignores the GPhase operation and variants such as its + // inverse, which should never be considered dead. + return failure(); + } return checkAndRemoveDeadGate(op.getOperation(), rewriter); } }; From b83afab6110f9207b35b3a3cf67d515f9c085628 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Mon, 1 Jun 2026 15:26:59 +0200 Subject: [PATCH 05/21] test(mlir): :white_check_mark: add direct test for dead gate elimination --- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 413f29336d..1267ce553d 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -114,6 +114,39 @@ TEST_F(QCOTest, BuilderRejectsMixedStaticAndDynamicQubitAllocationModes) { "Cannot mix dynamic and static qubit allocation modes"); } +TEST_F(QCOTest, CheckDeadGateElimination) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto q0_0 = builder.allocQubit(); + auto q1_0 = builder.allocQubit(); + auto q0_1 = builder.h(q0_0); + auto [q0_2, q1_1] = builder.cx(q0_1, q1_0); + auto q1_2 = builder.h(q1_1); + builder.sink(q0_2); + builder.sink(q1_2); + auto module = builder.finalize(); + + QCOProgramBuilder reference(context.get()); + reference.initialize(); + auto r0 = reference.allocQubit(); + auto r1 = reference.allocQubit(); + reference.sink(r0); + reference.sink(r1); + auto ref = reference.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); + EXPECT_TRUE(verify(*module).succeeded()); + + ASSERT_TRUE(ref); + EXPECT_TRUE(verify(*ref).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(ref.get()).succeeded()); + EXPECT_TRUE(verify(*ref).succeeded()); + + EXPECT_TRUE(areModulesEquivalentWithPermutations(module.get(), ref.get())); +} + TEST_F(QCOTest, DirectIfBuilder) { // Test If construction directly QCOProgramBuilder builder(context.get()); From 4243ac0f78ac780cd7446b45f1411e72adc33893 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Mon, 1 Jun 2026 15:35:41 +0200 Subject: [PATCH 06/21] docs(mlir): :memo: update changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index dad0d3abf3..ed0b4dfbcc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added +- ✨ Add Dead Gate Elimination Pattern ([#1755]) ([**DRovara**]) - 🚸 Add [CMake presets] to provide a standardized and reproducible way to configure builds ([#1660]) ([**@denialhaag**]) - ✨ Add a `quantum-loop-unroll` pass for unrolling for-loop operations containing quantum operations ([#1718]) ([**@MatthiasReumann**]) - ✨ Add a `hadamard-lifting` pass for lifting Hadamard gates above Pauli gates ([#1605]) ([**@lirem101**], [**@burgholzer**]) @@ -402,6 +403,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool +[#1755]: https://github.com/munich-quantum-toolkit/core/pull/1755 [#1749]: https://github.com/munich-quantum-toolkit/core/pull/1749 [#1748]: https://github.com/munich-quantum-toolkit/core/pull/1748 [#1737]: https://github.com/munich-quantum-toolkit/core/pull/1737 From a50032b48137f23e43bc378d55afbc8795571540 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Mon, 1 Jun 2026 15:48:49 +0200 Subject: [PATCH 07/21] style(mlir): :rotating_light: fix linter issues --- mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp | 2 ++ mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 3 +++ mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 14 +++++++------- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp index a71412a671..5d0462b6d7 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/MeasureOp.cpp @@ -11,6 +11,8 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/QCOUtils.h" +#include +#include #include using namespace mlir; diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index 61031d3258..33474448d7 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -11,13 +11,16 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" #include "mlir/Dialect/QCO/QCOUtils.h" #include #include +#include #include #include #include +#include #include #include #include diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 1267ce553d..7c76b2a049 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -117,13 +117,13 @@ TEST_F(QCOTest, BuilderRejectsMixedStaticAndDynamicQubitAllocationModes) { TEST_F(QCOTest, CheckDeadGateElimination) { QCOProgramBuilder builder(context.get()); builder.initialize(); - auto q0_0 = builder.allocQubit(); - auto q1_0 = builder.allocQubit(); - auto q0_1 = builder.h(q0_0); - auto [q0_2, q1_1] = builder.cx(q0_1, q1_0); - auto q1_2 = builder.h(q1_1); - builder.sink(q0_2); - builder.sink(q1_2); + auto q0S0 = builder.allocQubit(); + auto q1S0 = builder.allocQubit(); + auto q0S1 = builder.h(q0S0); + auto [q0S2, q1S1] = builder.cx(q0S1, q1S0); + auto q1S2 = builder.h(q1S1); + builder.sink(q0S2); + builder.sink(q1S2); auto module = builder.finalize(); QCOProgramBuilder reference(context.get()); From d8c9ad5bfb3890caf957b9b4e8de206067435cf7 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 09:37:45 +0200 Subject: [PATCH 08/21] fix(mlir): :recycle: guard RegionOp removal for child oeprations with memory effects --- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 228 +++++++++--------- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 12 +- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 9 +- mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp | 2 +- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 4 +- 5 files changed, 137 insertions(+), 118 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index 04cda7df4a..4b05cd1442 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -96,7 +96,7 @@ def StaticOp : QCOOp<"static", [Pure]> { // Measurement and Reset Operations //===----------------------------------------------------------------------===// -def MeasureOp : QCOOp<"measure"> { +def MeasureOp : QCOOp<"measure", [Pure]> { let summary = "Measure a qubit in the computational basis"; let description = [{ Measures a qubit in the computational (Z) basis, collapsing the state @@ -119,8 +119,7 @@ def MeasureOp : QCOOp<"measure"> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, OptionalAttr:$register_name, OptionalAttr>:$register_size, OptionalAttr>:$register_index); @@ -139,7 +138,7 @@ def MeasureOp : QCOOp<"measure"> { let hasCanonicalizer = 1; } -def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { +def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType, Pure]> { let summary = "Reset a qubit to |0⟩ state"; let description = [{ Resets a qubit to the |0⟩ state, regardless of its current state, @@ -151,8 +150,7 @@ def ResetOp : QCOOp<"reset", [Idempotent, SameOperandsAndResultType]> { ``` }]; - let arguments = - (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -209,7 +207,8 @@ def GPhaseOp let hasCanonicalizer = 1; } -def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def IdOp + : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an Id gate to a qubit"; let description = [{ Applies an Id gate to a qubit and returns the transformed qubit. @@ -220,7 +219,7 @@ def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -233,7 +232,8 @@ def IdOp : QCOOp<"id", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def XOp + : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an X gate to a qubit"; let description = [{ Applies an X gate to a qubit and returns the transformed qubit. @@ -244,7 +244,7 @@ def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -257,7 +257,8 @@ def XOp : QCOOp<"x", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def YOp + : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a Y gate to a qubit"; let description = [{ Applies a Y gate to a qubit and returns the transformed qubit. @@ -268,7 +269,7 @@ def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -281,7 +282,8 @@ def YOp : QCOOp<"y", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def ZOp + : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a Z gate to a qubit"; let description = [{ Applies a Z gate to a qubit and returns the transformed qubit. @@ -292,7 +294,7 @@ def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -305,7 +307,8 @@ def ZOp : QCOOp<"z", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def HOp + : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a H gate to a qubit"; let description = [{ Applies a H gate to a qubit and returns the transformed qubit. @@ -316,7 +319,7 @@ def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -329,7 +332,8 @@ def HOp : QCOOp<"h", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SOp + : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an S gate to a qubit"; let description = [{ Applies an S gate to a qubit and returns the transformed qubit. @@ -340,7 +344,7 @@ def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -353,8 +357,8 @@ def SOp : QCOOp<"s", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SdgOp - : QCOOp<"sdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SdgOp : QCOOp<"sdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an Sdg gate to a qubit"; let description = [{ Applies an Sdg gate to a qubit and returns the transformed qubit. @@ -365,7 +369,7 @@ def SdgOp ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -378,7 +382,8 @@ def SdgOp let hasCanonicalizer = 1; } -def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TOp + : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a T gate to a qubit"; let description = [{ Applies a T gate to a qubit and returns the transformed qubit. @@ -389,7 +394,7 @@ def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -402,8 +407,8 @@ def TOp : QCOOp<"t", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def TdgOp - : QCOOp<"tdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def TdgOp : QCOOp<"tdg", + traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply a Tdg gate to a qubit"; let description = [{ Applies a Tdg gate to a qubit and returns the transformed qubit. @@ -414,7 +419,7 @@ def TdgOp ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -427,7 +432,8 @@ def TdgOp let hasCanonicalizer = 1; } -def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SXOp + : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter, Pure]> { let summary = "Apply an SX gate to a qubit"; let description = [{ Applies an SX gate to a qubit and returns the transformed qubit. @@ -438,7 +444,7 @@ def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -451,8 +457,8 @@ def SXOp : QCOOp<"sx", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { let hasCanonicalizer = 1; } -def SXdgOp - : QCOOp<"sxdg", traits = [UnitaryOpInterface, OneTargetZeroParameter]> { +def SXdgOp : QCOOp<"sxdg", traits = [UnitaryOpInterface, OneTargetZeroParameter, + Pure]> { let summary = "Apply an SXdg gate to a qubit"; let description = [{ Applies an SXdg gate to a qubit and returns the transformed qubit. @@ -463,7 +469,7 @@ def SXdgOp ``` }]; - let arguments = (ins Arg:$qubit_in); + let arguments = (ins Arg:$qubit_in); let results = (outs QubitType:$qubit_out); let assemblyFormat = "$qubit_in attr-dict `:` type($qubit_in) `->` type($qubit_out)"; @@ -476,7 +482,8 @@ def SXdgOp let hasCanonicalizer = 1; } -def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RXOp + : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply an RX gate to a qubit"; let description = [{ Applies an RX gate to a qubit and returns the transformed qubit. @@ -487,7 +494,7 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -505,7 +512,8 @@ def RXOp : QCOOp<"rx", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RYOp + : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply an RY gate to a qubit"; let description = [{ Applies an RY gate to a qubit and returns the transformed qubit. @@ -516,7 +524,7 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -534,7 +542,8 @@ def RYOp : QCOOp<"ry", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def RZOp + : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply an RZ gate to a qubit"; let description = [{ Applies an RZ gate to a qubit and returns the transformed qubit. @@ -545,7 +554,7 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -563,7 +572,8 @@ def RZOp : QCOOp<"rz", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { +def POp + : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter, Pure]> { let summary = "Apply a P gate to a qubit"; let description = [{ Applies a P gate to a qubit and returns the transformed qubit. @@ -574,7 +584,7 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta); let results = (outs QubitType:$qubit_out); let assemblyFormat = "`(` $theta `)` $qubit_in attr-dict `:` type($qubit_in) " @@ -592,7 +602,8 @@ def POp : QCOOp<"p", traits = [UnitaryOpInterface, OneTargetOneParameter]> { let hasCanonicalizer = 1; } -def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def ROp + : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter, Pure]> { let summary = "Apply an R gate to a qubit"; let description = [{ Applies an R gate to a qubit and returns the transformed qubit. @@ -603,7 +614,7 @@ def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta, Arg:$phi); let results = (outs QubitType:$qubit_out); @@ -622,7 +633,8 @@ def ROp : QCOOp<"r", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { +def U2Op + : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter, Pure]> { let summary = "Apply a U2 gate to a qubit"; let description = [{ Applies a U2 gate to a qubit and returns the transformed qubit. @@ -633,7 +645,7 @@ def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$phi, Arg:$lambda); let results = (outs QubitType:$qubit_out); @@ -652,7 +664,8 @@ def U2Op : QCOOp<"u2", traits = [UnitaryOpInterface, OneTargetTwoParameter]> { let hasCanonicalizer = 1; } -def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { +def UOp + : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter, Pure]> { let summary = "Apply a U gate to a qubit"; let description = [{ Applies a U gate to a qubit and returns the transformed qubit. @@ -663,7 +676,7 @@ def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { ``` }]; - let arguments = (ins Arg:$qubit_in, + let arguments = (ins Arg:$qubit_in, Arg:$theta, Arg:$phi, Arg:$lambda); @@ -684,8 +697,8 @@ def UOp : QCOOp<"u", traits = [UnitaryOpInterface, OneTargetThreeParameter]> { let hasCanonicalizer = 1; } -def SWAPOp - : QCOOp<"swap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def SWAPOp : QCOOp<"swap", traits = [UnitaryOpInterface, TwoTargetZeroParameter, + Pure]> { let summary = "Apply a SWAP gate to two qubits"; let description = [{ Applies a SWAP gate to two qubits and returns the transformed qubits. @@ -696,9 +709,8 @@ def SWAPOp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -712,8 +724,8 @@ def SWAPOp let hasCanonicalizer = 1; } -def iSWAPOp - : QCOOp<"iswap", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def iSWAPOp : QCOOp<"iswap", traits = [UnitaryOpInterface, + TwoTargetZeroParameter, Pure]> { let summary = "Apply a iSWAP gate to two qubits"; let description = [{ Applies a iSWAP gate to two qubits and returns the transformed qubits. @@ -724,9 +736,8 @@ def iSWAPOp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -738,8 +749,8 @@ def iSWAPOp }]; } -def DCXOp - : QCOOp<"dcx", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def DCXOp : QCOOp<"dcx", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, Pure]> { let summary = "Apply a DCX gate to two qubits"; let description = [{ Applies a DCX gate to two qubits and returns the transformed qubits. @@ -750,9 +761,8 @@ def DCXOp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -766,8 +776,8 @@ def DCXOp let hasCanonicalizer = 1; } -def ECROp - : QCOOp<"ecr", traits = [UnitaryOpInterface, TwoTargetZeroParameter]> { +def ECROp : QCOOp<"ecr", + traits = [UnitaryOpInterface, TwoTargetZeroParameter, Pure]> { let summary = "Apply an ECR gate to two qubits"; let description = [{ Applies an ECR gate to two qubits and returns the transformed qubits. @@ -778,9 +788,8 @@ def ECROp ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "$qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) `,` " @@ -794,7 +803,8 @@ def ECROp let hasCanonicalizer = 1; } -def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RXXOp + : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RXX gate to two qubits"; let description = [{ Applies an RXX gate to two qubits and returns the transformed qubits. @@ -805,10 +815,9 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -826,7 +835,8 @@ def RXXOp : QCOOp<"rxx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RYYOp + : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RYY gate to two qubits"; let description = [{ Applies an RYY gate to two qubits and returns the transformed qubits. @@ -837,10 +847,9 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -858,7 +867,8 @@ def RYYOp : QCOOp<"ryy", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZXOp + : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RZX gate to two qubits"; let description = [{ Applies an RZX gate to two qubits and returns the transformed qubits. @@ -869,10 +879,9 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -890,7 +899,8 @@ def RZXOp : QCOOp<"rzx", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { +def RZZOp + : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter, Pure]> { let summary = "Apply an RZZ gate to two qubits"; let description = [{ Applies an RZZ gate to two qubits and returns the transformed qubits. @@ -901,10 +911,9 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `)` $qubit0_in `,` $qubit1_in attr-dict `:` type($qubit0_in) " @@ -922,8 +931,8 @@ def RZZOp : QCOOp<"rzz", traits = [UnitaryOpInterface, TwoTargetOneParameter]> { let hasCanonicalizer = 1; } -def XXPlusYYOp : QCOOp<"xx_plus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { +def XXPlusYYOp : QCOOp<"xx_plus_yy", traits = [UnitaryOpInterface, + TwoTargetTwoParameter, Pure]> { let summary = "Apply an XX+YY gate to two qubits"; let description = [{ Applies an XX+YY gate to two qubits and returns the transformed qubits. @@ -934,11 +943,10 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta, - Arg:$beta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta, + Arg:$beta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `,` $beta `)` $qubit0_in `,` $qubit1_in " "attr-dict `:` type($qubit0_in) `,` type($qubit1_in) " @@ -957,8 +965,8 @@ def XXPlusYYOp : QCOOp<"xx_plus_yy", let hasCanonicalizer = 1; } -def XXMinusYYOp : QCOOp<"xx_minus_yy", - traits = [UnitaryOpInterface, TwoTargetTwoParameter]> { +def XXMinusYYOp : QCOOp<"xx_minus_yy", traits = [UnitaryOpInterface, + TwoTargetTwoParameter, Pure]> { let summary = "Apply an XX-YY gate to two qubits"; let description = [{ Applies an XX-YY gate to two qubits and returns the transformed qubits. @@ -969,11 +977,10 @@ def XXMinusYYOp : QCOOp<"xx_minus_yy", ``` }]; - let arguments = - (ins Arg:$qubit0_in, - Arg:$qubit1_in, - Arg:$theta, - Arg:$beta); + let arguments = (ins Arg:$qubit0_in, + Arg:$qubit1_in, + Arg:$theta, + Arg:$beta); let results = (outs QubitType:$qubit0_out, QubitType:$qubit1_out); let assemblyFormat = "`(` $theta `,` $beta `)` $qubit0_in `,` $qubit1_in " "attr-dict `:` type($qubit0_in) `,` type($qubit1_in) " @@ -992,7 +999,7 @@ def XXMinusYYOp : QCOOp<"xx_minus_yy", let hasCanonicalizer = 1; } -def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { +def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface, Pure]> { let summary = "Apply a barrier gate to a set of qubits"; let description = [{ Applies a barrier gate to a set of qubits and returns the transformed qubits. @@ -1004,7 +1011,7 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { }]; let arguments = - (ins Arg, "the target qubits", [MemRead]>:$qubits_in); + (ins Arg, "the target qubits">:$qubits_in); let results = (outs Variadic:$qubits_out); let assemblyFormat = "$qubits_in attr-dict `:` type($qubits_in) `->` type($qubits_out)"; @@ -1042,7 +1049,7 @@ def BarrierOp : QCOOp<"barrier", traits = [UnitaryOpInterface]> { // Modifiers //===----------------------------------------------------------------------===// -def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike]> { +def YieldOp : QCOOp<"yield", traits = [Terminator, ReturnLike, Pure]> { let summary = "Yield from a modifier region"; let description = [{ Terminates a modifier region, yielding the transformed target qubit and qtensor values back to the enclosing modifier operation. @@ -1067,7 +1074,7 @@ def CtrlOp AttrSizedResultSegments, SameOperandsAndResultType, SameOperandsAndResultShape, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, - RecursiveMemoryEffects]> { + Pure, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; let description = [{ A modifier operation that adds control qubits to the unitary operation @@ -1086,9 +1093,9 @@ def CtrlOp ``` }]; - let arguments = (ins Arg, - "the control qubits", [MemRead]>:$controls_in, - Arg, "the target qubits", [MemRead]>:$targets_in); + let arguments = + (ins Arg, "the control qubits">:$controls_in, + Arg, "the target qubits">:$targets_in); let results = (outs Variadic:$controls_out, Variadic:$targets_out); let regions = (region SizedRegion<1>:$region); @@ -1144,7 +1151,7 @@ def InvOp : QCOOp<"inv", traits = [UnitaryOpInterface, SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, - RecursiveMemoryEffects]> { + Pure, RecursiveMemoryEffects]> { let summary = "Invert a unitary operation"; let description = [{ A modifier operation that inverts the unitary operation defined in its body @@ -1160,9 +1167,8 @@ def InvOp ``` }]; - let arguments = - (ins Arg, - "the qubits involved in the operation", [MemRead]>:$qubits_in); + let arguments = (ins Arg, + "the qubits involved in the operation">:$qubits_in); let results = (outs Variadic:$qubits_out); let regions = (region SizedRegion<1>:$region); let assemblyFormat = [{ @@ -1222,7 +1228,7 @@ def IfOp "getRegionInvocationBounds", "getEntrySuccessorRegions"]>, SingleBlock, - SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, + SingleBlockImplicitTerminator<"::mlir::qco::YieldOp">, Pure, RecursiveMemoryEffects]> { let summary = "If-then-else operation for linear (qubit) types"; diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 8a46a7b72c..6135c416cd 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -239,7 +239,7 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, /** * @brief Check if given quantum operation is unused (i.e., only used by - * deallocations) and remove it if so. + * sinks) and remove it if so. * * @param op The operation to check. * @param rewriter The pattern rewriter. @@ -249,14 +249,20 @@ inline LogicalResult checkAndRemoveDeadGate(Operation* op, PatternRewriter& rewriter) { if (std::all_of(op->getUsers().begin(), op->getUsers().end(), [](Operation* user) { return isa(user); })) { - // If the operation is only used by deallocs, we can safely remove it. + // If the operation is only used by sinks, we can safely remove it. if (auto u = dyn_cast(op)) { // We specifically have to replace the output *qubits* with the input // *qubits* to ignore parameters. rewriter.replaceOp(op, u.getInputQubits()); return success(); + } else if (auto m = dyn_cast(op)) { + // We specifically have to replace the output *qubits* with the input + // *qubits* to ignore the classical outcome. + rewriter.replaceAllUsesWith(m.getQubitOut(), m.getQubitIn()); + rewriter.eraseOp(op); + return success(); } else { - // This includes the `IfOp` as well as `Reset` and `Measure`. + // This includes the `IfOp` as well as `Reset`. rewriter.replaceOp(op, op->getOperands()); return success(); } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index 33474448d7..fe11640e0d 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -39,6 +39,10 @@ using namespace mlir::qco; //===----------------------------------------------------------------------===// namespace { + +/** + * @brief Remove dead measurements. + */ struct DeadGateElimination final : public OpInterfaceRewritePattern { @@ -47,9 +51,10 @@ struct DeadGateElimination final LogicalResult matchAndRewrite(UnitaryOpInterface op, PatternRewriter& rewriter) const override { - if (op->use_empty()) { + if (!isMemoryEffectFree(op)) { // This effectively ignores the GPhase operation and variants such as its - // inverse, which should never be considered dead. + // inverse or `if` ops containing it, which should never be considered + // dead. return failure(); } return checkAndRemoveDeadGate(op.getOperation(), rewriter); diff --git a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp index bb401dd9f3..fc10e98157 100644 --- a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp @@ -237,7 +237,7 @@ struct ConditionPropagation : public OpRewritePattern { }; /** - * @brief Remove dead resets. + * @brief Remove dead `IfOp` instructions. */ struct DeadIfRemoval final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 7c76b2a049..f1d3f02777 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -121,7 +121,7 @@ TEST_F(QCOTest, CheckDeadGateElimination) { auto q1S0 = builder.allocQubit(); auto q0S1 = builder.h(q0S0); auto [q0S2, q1S1] = builder.cx(q0S1, q1S0); - auto q1S2 = builder.h(q1S1); + auto [q1S2, c1] = builder.measure(q1S1); builder.sink(q0S2); builder.sink(q1S2); auto module = builder.finalize(); @@ -144,6 +144,8 @@ TEST_F(QCOTest, CheckDeadGateElimination) { EXPECT_TRUE(runQCOCleanupPipeline(ref.get()).succeeded()); EXPECT_TRUE(verify(*ref).succeeded()); + module.get().dump(); + EXPECT_TRUE(areModulesEquivalentWithPermutations(module.get(), ref.get())); } From 97175472e086cdd46e8bb5393d16cc7024729dd7 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 09:54:03 +0200 Subject: [PATCH 09/21] fix(mlir): :bug: fix handling for `IfOp` removal and add specialized test --- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 7 +- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 1 + mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 87 +++++++++++++++++-- 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 6135c416cd..deda577eba 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -261,8 +261,13 @@ inline LogicalResult checkAndRemoveDeadGate(Operation* op, rewriter.replaceAllUsesWith(m.getQubitOut(), m.getQubitIn()); rewriter.eraseOp(op); return success(); + } else if (auto i = dyn_cast(op)) { + // We specifically have to replace the output *qubits* with the input + // *qubits* to ignore the condition. + rewriter.replaceOp(op, i.getQubits()); + return success(); } else { - // This includes the `IfOp` as well as `Reset`. + // This currently only includes the `Reset` operation. rewriter.replaceOp(op, op->getOperands()); return success(); } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index fe11640e0d..dc302e3028 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include // The following headers are needed for some template instantiations. diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index f1d3f02777..3af2feb196 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -132,21 +132,94 @@ TEST_F(QCOTest, CheckDeadGateElimination) { auto r1 = reference.allocQubit(); reference.sink(r0); reference.sink(r1); - auto ref = reference.finalize(); + auto refModule = reference.finalize(); ASSERT_TRUE(module); EXPECT_TRUE(verify(*module).succeeded()); EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); EXPECT_TRUE(verify(*module).succeeded()); - ASSERT_TRUE(ref); - EXPECT_TRUE(verify(*ref).succeeded()); - EXPECT_TRUE(runQCOCleanupPipeline(ref.get()).succeeded()); - EXPECT_TRUE(verify(*ref).succeeded()); + ASSERT_TRUE(refModule); + EXPECT_TRUE(verify(*refModule).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(refModule.get()).succeeded()); + EXPECT_TRUE(verify(*refModule).succeeded()); - module.get().dump(); + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), refModule.get())); +} + +TEST_F(QCOTest, CheckIfOpDeadGateElimination) { + QCOProgramBuilder builder(context.get()); + builder.initialize(); + auto q0S0 = builder.allocQubit(); + auto q1S0 = builder.allocQubit(); + auto q0S1 = builder.h(q0S0); + auto [q0S2, c0] = builder.measure(q0S1); + + // This is an `if` with memory effects - it can't be removed. + auto q1S1 = builder.qcoIf( + c0, {q1S0}, + [&](ValueRange qubits) -> SmallVector { + auto q1Then = builder.x(qubits[0]); + builder.gphase(0.5); + return SmallVector{q1Then}; + }, + [&](ValueRange qubits) -> SmallVector { + auto q1Else = builder.h(qubits[0]); + return SmallVector{q1Else}; + })[0]; + + // This is an `if` without memory effects - it can be removed. + auto q1S2 = builder.qcoIf( + c0, {q1S1}, + [&](ValueRange qubits) -> SmallVector { + auto q1Then = builder.x(qubits[0]); + return SmallVector{q1Then}; + }, + [&](ValueRange qubits) -> SmallVector { + auto q1Else = builder.h(qubits[0]); + return SmallVector{q1Else}; + })[0]; + builder.sink(q0S2); + builder.sink(q1S2); + auto module = builder.finalize(); - EXPECT_TRUE(areModulesEquivalentWithPermutations(module.get(), ref.get())); + QCOProgramBuilder reference(context.get()); + reference.initialize(); + auto r0S0 = reference.allocQubit(); + auto r1S0 = reference.allocQubit(); + auto r0S1 = reference.h(r0S0); + auto [r0S2, cr0] = reference.measure(r0S1); + + // This is an `if` with memory effects - it can't be removed. + auto r1S1 = reference.qcoIf( + cr0, {r1S0}, + [&](ValueRange qubits) -> SmallVector { + auto q1Then = reference.x(qubits[0]); + reference.gphase(0.5); + return SmallVector{q1Then}; + }, + [&](ValueRange qubits) -> SmallVector { + auto q1Else = reference.h(qubits[0]); + return SmallVector{q1Else}; + })[0]; + + reference.sink(r0S2); + reference.sink(r1S1); + auto refModule = reference.finalize(); + + ASSERT_TRUE(module); + EXPECT_TRUE(verify(*module).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(module.get()).succeeded()); + EXPECT_TRUE(verify(*module).succeeded()); + + ASSERT_TRUE(refModule); + EXPECT_TRUE(verify(*refModule).succeeded()); + EXPECT_TRUE(runQCOCleanupPipeline(refModule.get()).succeeded()); + EXPECT_TRUE(verify(*refModule).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), refModule.get())); } TEST_F(QCOTest, DirectIfBuilder) { From 2d65418e7a1e4b367f747902f5a5479571475873 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 10:10:45 +0200 Subject: [PATCH 10/21] fix(mlir): :bug: minor bug and code style fixes --- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 4 ++-- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 5 +++-- mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp | 5 +++++ mlir/lib/Support/IRVerification.cpp | 1 - 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index deda577eba..9c6af7a8f4 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -247,8 +247,8 @@ mergeTwoTargetOneParameterWithSwappedTargets(OpType op, */ inline LogicalResult checkAndRemoveDeadGate(Operation* op, PatternRewriter& rewriter) { - if (std::all_of(op->getUsers().begin(), op->getUsers().end(), - [](Operation* user) { return isa(user); })) { + if (llvm::all_of(op->getUsers(), + [](Operation* user) { return isa(user); })) { // If the operation is only used by sinks, we can safely remove it. if (auto u = dyn_cast(op)) { // We specifically have to replace the output *qubits* with the input diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index dc302e3028..b1407ee0df 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include // The following headers are needed for some template instantiations. @@ -42,7 +43,7 @@ using namespace mlir::qco; namespace { /** - * @brief Remove dead measurements. + * @brief Remove dead gates. */ struct DeadGateElimination final : public OpInterfaceRewritePattern { @@ -54,7 +55,7 @@ struct DeadGateElimination final PatternRewriter& rewriter) const override { if (!isMemoryEffectFree(op)) { // This effectively ignores the GPhase operation and variants such as its - // inverse or `if` ops containing it, which should never be considered + // inverse or `ctrl` ops containing it, which should never be considered // dead. return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp index fc10e98157..6589fbbad7 100644 --- a/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/SCF/IfOp.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -244,6 +245,10 @@ struct DeadIfRemoval final : OpRewritePattern { LogicalResult matchAndRewrite(IfOp op, PatternRewriter& rewriter) const override { + if (!isMemoryEffectFree(op)) { + // This effectively ignores `IfOp`s with memory effects. + return failure(); + } return checkAndRemoveDeadGate(op, rewriter); } }; diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index eaac426f0a..8aa982e976 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -33,7 +33,6 @@ #include #include #include -#include #include #include From 278a54e356e4417f4b9446a689c8c74835c35d5f Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 10:19:43 +0200 Subject: [PATCH 11/21] style(mlir): :rotating_light: fix linter issues on includes --- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 1 - mlir/lib/Support/IRVerification.cpp | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index b1407ee0df..e49d00da5a 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 8aa982e976..eaac426f0a 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include From a45a15085869079c1fda01e890d25f79ce77f0c8 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 10:37:47 +0200 Subject: [PATCH 12/21] fix(mlir): :white_check_mark: fix tests --- .../test_qco_measurement_lifting.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp index f0373ebfd5..61b9458c13 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp @@ -215,18 +215,18 @@ TEST_F(QCOMeasurementLiftingTest, auto r1_0 = referenceBuilder.allocQubit(); auto r2_0 = referenceBuilder.allocQubit(); - auto [r1_1, cr1] = programBuilder.measure(r1_0); - auto [r2_1, cr2] = programBuilder.measure(r2_0); + auto [r1_1, cr1] = referenceBuilder.measure(r1_0); + auto [r2_1, cr2] = referenceBuilder.measure(r2_0); auto [r12_0, r0_1] = - programBuilder.ctrl({r1_1, r2_1}, {r0_0}, [&](const ValueRange target) { - return SmallVector{programBuilder.x(target[0])}; + referenceBuilder.ctrl({r1_1, r2_1}, {r0_0}, [&](const ValueRange target) { + return SmallVector{referenceBuilder.x(target[0])}; }); referenceBuilder.sink(r0_1[0]); referenceBuilder.sink(r12_0[0]); referenceBuilder.sink(r12_0[1]); - module = referenceBuilder.finalize(); + reference = referenceBuilder.finalize(); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -289,9 +289,6 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); - reference.get().dump(); - module.get().dump(); - EXPECT_TRUE( areModulesEquivalentWithPermutations(module.get(), reference.get())); } @@ -306,8 +303,8 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { auto r_0 = referenceBuilder.allocQubit(); auto true_constant = referenceBuilder.boolConstant(true); auto [r_1, cr] = referenceBuilder.measure(r_0); - referenceBuilder.insert(arith::XOrIOp::create( - referenceBuilder, referenceBuilder.getLoc(), cr, true_constant)); + auto xorOp = arith::XOrIOp::create( + referenceBuilder, referenceBuilder.getLoc(), cr, true_constant); referenceBuilder.sink(r_1); reference = referenceBuilder.finalize(); From 0edfeb1edc5b47ea40c33612457fc507c10a50a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Jun 2026 08:38:18 +0000 Subject: [PATCH 13/21] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Dialect/QCO/Transforms/Optimizations/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt index 40d46afa20..bdc78401c8 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt @@ -7,8 +7,8 @@ # Licensed under the MIT License set(target_name mqt-core-mlir-unittest-optimizations) -add_executable(${target_name} test_qco_hadamard_lifting.cpp - test_qco_merge_single_qubit_rotation.cpp test_quantum_loop_unroll.cpp test_qco_measurement_lifting.cpp) +add_executable(${target_name} test_qco_hadamard_lifting.cpp test_qco_measurement_lifting.cpp + test_qco_merge_single_qubit_rotation.cpp test_quantum_loop_unroll.cpp) target_link_libraries( ${target_name} From c4a43c2c0fc595bd69ad94e94118d1227b67ccc5 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 10:48:50 +0200 Subject: [PATCH 14/21] style(mlir): :rotating_light: fix linter issues --- .../Optimizations/MeasurementLifting.cpp | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp index 02589d814d..5e6baea41f 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp @@ -16,9 +16,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/Transforms/Passes.h" -#include #include -#include #include #include #include @@ -26,7 +24,6 @@ #include #include -#include #include namespace mlir::qco { @@ -34,21 +31,21 @@ namespace mlir::qco { #define GEN_PASS_DEF_MEASUREMENTLIFTING #include "mlir/Dialect/QCO/Transforms/Passes.h.inc" -namespace { - /** * @brief Checks if the given operation is an inverting gate. * @param op The operation to check. * @return True if the operation is an inverting gate, false otherwise. */ -bool isInverting(Operation* op) { return isa(op); } +static bool isInverting(Operation* op) { return isa(op); } /** * @brief Checks if the given operation is a diagonal gate. * @param op The operation to check. * @return True if the operation is a diagonal gate, false otherwise. */ -bool isDiagonal(Operation* op) { return isa(op); } +static bool isDiagonal(Operation* op) { + return isa(op); +} /** * @brief This method swaps a gate with a measurement. @@ -56,8 +53,9 @@ bool isDiagonal(Operation* op) { return isa(op); } * @param measurement The measurement to swap. * @param rewriter The used rewriter. */ -void swapGateWithMeasurement(UnitaryOpInterface gate, MeasureOp measurement, - mlir::PatternRewriter& rewriter) { +static void swapGateWithMeasurement(UnitaryOpInterface gate, + MeasureOp measurement, + mlir::PatternRewriter& rewriter) { auto measurementInput = measurement.getQubitIn(); auto gateInput = gate.getInputForOutput(measurementInput); rewriter.replaceUsesWithIf(measurementInput, gateInput, @@ -81,6 +79,7 @@ void swapGateWithMeasurement(UnitaryOpInterface gate, MeasureOp measurement, rewriter.moveOpBefore(measurement, gate); } +namespace { /** * @brief This pattern is responsible for lifting measurements above any phase * gates. From 44eaf7e6bdb0274d2057a64d5b4c1690e35bbb6b Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 10:57:00 +0200 Subject: [PATCH 15/21] fix: :pencil2: fix typo in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed0b4dfbcc..22616216d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added -- ✨ Add Dead Gate Elimination Pattern ([#1755]) ([**DRovara**]) +- ✨ Add Dead Gate Elimination Pattern ([#1755]) ([**@DRovara**]) - 🚸 Add [CMake presets] to provide a standardized and reproducible way to configure builds ([#1660]) ([**@denialhaag**]) - ✨ Add a `quantum-loop-unroll` pass for unrolling for-loop operations containing quantum operations ([#1718]) ([**@MatthiasReumann**]) - ✨ Add a `hadamard-lifting` pass for lifting Hadamard gates above Pauli gates ([#1605]) ([**@lirem101**], [**@burgholzer**]) From a695579e63cf2708b4139e214820ba54334d177d Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 13:24:45 +0200 Subject: [PATCH 16/21] style(mlir): :recycle: implement coderabbit suggestions --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 22 ++++++++++++++++--- .../mlir/Dialect/QCO/Transforms/Passes.td | 6 ++--- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 14 +++++++----- .../Optimizations/MeasurementLifting.cpp | 2 +- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 600e1c1ffb..92e75be2f7 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -78,7 +78,8 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Initialize the builder and prepare for program construction + * @brief Initialize the builder and prepare for program construction, with + * a default return type of i64. * * @details * Creates a main function with an entry_point attribute. Must be called @@ -86,6 +87,17 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ void initialize(); + /** + * @brief Initialize the builder and prepare for program construction + * with specified return types. + * @param returnTypes The return types for the main function + * + * @details + * Creates a main function with an entry_point attribute. Must be called + * before adding operations. + */ + void initialize(TypeRange returnTypes); + //===--------------------------------------------------------------------===// // Constants //===--------------------------------------------------------------------===// @@ -1393,7 +1405,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Finalize the program with a given exit code and return the * constructed module - * @param exitCode Value representing the exit code to return + * @param returnValues Values representing the exit code to return * * @details * Automatically deallocates all remaining valid qubits and tensors of qubits, @@ -1401,9 +1413,13 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * and transfers ownership of the module to the caller. The builder should not * be used after calling this method. * + * The return values must have the types indicated by the function signature + * of the main function, which returns an `i64` by default and can be + * modified by passing different arguments to the `initialize()` method. + * * @return OwningOpRef containing the constructed quantum program module */ - OwningOpRef finalize(Value exitCode); + OwningOpRef finalize(ValueRange returnValues); /** * @brief Convenience method for building quantum programs diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index 9fced722db..bad2fb4f9c 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -178,11 +178,11 @@ def MeasurementLifting : Pass<"measurement-lifting", "mlir::ModuleOp"> { let dependentDialects = ["mlir::qco::QCODialect", "::mlir::arith::ArithDialect", ]; - let summary = "This pass attempts to move measurements as far up as" - "possible, shiftling them above gates that commute with them." + let summary = "This pass attempts to move measurements as far up as " + "possible, shifting them above gates that commute with them. " "This is done to enable qubit reuse and other optimizations."; let description = [{ - This pass lifts measurements gates away from the measurements in order to apply measurement lifting more effectively. + This pass applies measurement lifting, moving measurements up the code as far as possible. Measurement lifting is a subroutine of the qubit reuse routine. The goal is to measure qubits earlier in the circuit to reuse them and to potentially remove some quantum gates. diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 5d9b243352..22c45132d4 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -52,12 +52,14 @@ QCOProgramBuilder::QCOProgramBuilder(MLIRContext* context) ctx->loadDialect(); } -void QCOProgramBuilder::initialize() { +void QCOProgramBuilder::initialize() { initialize({getI64Type()}); } + +void QCOProgramBuilder::initialize(TypeRange returnTypes) { // Set insertion point to the module body setInsertionPointToStart(mlir::cast(module).getBody()); // Create main function as entry point - auto funcType = getFunctionType({}, {getI64Type()}); + auto funcType = getFunctionType({}, returnTypes); auto mainFunc = func::FuncOp::create(*this, "main", funcType); // Add entry_point attribute to identify the main function @@ -1101,11 +1103,13 @@ void QCOProgramBuilder::ensureAllocationMode( } OwningOpRef QCOProgramBuilder::finalize() { + checkFinalized(); + auto exitCode = intConstant(0); - return finalize(exitCode); + return finalize({exitCode}); } -OwningOpRef QCOProgramBuilder::finalize(Value exitCode) { +OwningOpRef QCOProgramBuilder::finalize(ValueRange returnValues) { checkFinalized(); // Ensure that main function exists and insertion point is valid @@ -1157,7 +1161,7 @@ OwningOpRef QCOProgramBuilder::finalize(Value exitCode) { validTensors.clear(); // Add return statement with exit code 0 to the main function - func::ReturnOp::create(*this, exitCode); + func::ReturnOp::create(*this, returnValues); // Invalidate context to prevent use-after-finalize ctx = nullptr; diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp index 5e6baea41f..db1b3c95cc 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp @@ -44,7 +44,7 @@ static bool isInverting(Operation* op) { return isa(op); } * @return True if the operation is a diagonal gate, false otherwise. */ static bool isDiagonal(Operation* op) { - return isa(op); + return isa(op); } /** From 51cc98218ba2f6412e6033e7c641fde2b0446ad9 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 14:15:47 +0200 Subject: [PATCH 17/21] fix(mlir): :bug: fix all measurement lifting tests --- .../Optimizations/MeasurementLifting.cpp | 20 ++--- .../test_qco_measurement_lifting.cpp | 76 +++++++++++++------ 2 files changed, 63 insertions(+), 33 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp index db1b3c95cc..1fb173a348 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/MeasurementLifting.cpp @@ -44,6 +44,12 @@ static bool isInverting(Operation* op) { return isa(op); } * @return True if the operation is a diagonal gate, false otherwise. */ static bool isDiagonal(Operation* op) { + if (auto c = dyn_cast(op)) { + return isDiagonal(c.getBodyUnitary()); + } + if (auto i = dyn_cast(op)) { + return isDiagonal(i.getBodyUnitary()); + } return isa(op); } @@ -136,10 +142,6 @@ struct LiftMeasurementsAboveInvertingGatesPattern final mlir::LogicalResult matchAndRewrite(MeasureOp op, mlir::PatternRewriter& rewriter) const override { - if (!outputQubitRemainsUnused(op.getQubitOut())) { - return mlir::failure(); // if the qubit is still used after the - // measurement, we cannot lift it above the gate. - } const auto qubitVariable = op.getQubitIn(); auto* predecessor = qubitVariable.getDefiningOp(); @@ -185,19 +187,19 @@ struct LiftMeasurementsAboveControlsPattern final mlir::PatternRewriter& rewriter) const override { const auto qubitVariable = op.getQubitIn(); auto* predecessor = qubitVariable.getDefiningOp(); - auto predecessorUnitary = mlir::dyn_cast(predecessor); + auto predecessorCtrl = mlir::dyn_cast(predecessor); - if (!predecessorUnitary) { + if (!predecessorCtrl) { return mlir::failure(); } - if (llvm::find(predecessorUnitary.getOutputQubits(), qubitVariable) != - predecessorUnitary.getOutputQubits().end()) { + if (llvm::find(predecessorCtrl.getControlsOut(), qubitVariable) == + predecessorCtrl.getControlsOut().end()) { // The measured qubit is a target, not a control of the gate. return mlir::failure(); } - swapGateWithMeasurement(predecessorUnitary, op, rewriter); + swapGateWithMeasurement(predecessorCtrl, op, rewriter); return mlir::success(); } diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp index 61b9458c13..eca6c31a16 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp @@ -54,9 +54,6 @@ class QCOMeasurementLiftingTest : public testing::Test { registry.insert(); context.appendDialectRegistry(registry); context.loadAllAvailableDialects(); - - programBuilder.initialize(); - referenceBuilder.initialize(); } /** @@ -87,6 +84,8 @@ class QCOMeasurementLiftingTest : public testing::Test { } // namespace TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); auto q0_0 = programBuilder.allocQubit(); auto q1_0 = programBuilder.allocQubit(); @@ -99,8 +98,10 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { programBuilder.sink(q0_4); programBuilder.sink(q1_4); - module = programBuilder.finalize(); + module = programBuilder.finalize({c0, c1}); + referenceBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); auto r0_0 = referenceBuilder.allocQubit(); auto r1_0 = referenceBuilder.allocQubit(); @@ -113,7 +114,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { referenceBuilder.sink(r0_4); referenceBuilder.sink(r1_4); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr0, cr1}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -123,6 +124,9 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { + programBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); auto q0_0 = programBuilder.allocQubit(); auto q1_0 = programBuilder.allocQubit(); auto q2_0 = programBuilder.allocQubit(); @@ -152,8 +156,11 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { programBuilder.sink(q1_4); programBuilder.sink(q2_5); - module = programBuilder.finalize(); + module = programBuilder.finalize({c0, c1, c2}); + referenceBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); auto r0_0 = referenceBuilder.allocQubit(); auto r1_0 = referenceBuilder.allocQubit(); auto r2_0 = referenceBuilder.allocQubit(); @@ -183,7 +190,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { referenceBuilder.sink(r12_2[0]); referenceBuilder.sink(r2_5); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr0, cr1, cr2}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -194,6 +201,9 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { TEST_F(QCOMeasurementLiftingTest, liftMeasurementMultipleOverOneControlledGate) { + + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); auto q0_0 = programBuilder.allocQubit(); auto q1_0 = programBuilder.allocQubit(); auto q2_0 = programBuilder.allocQubit(); @@ -209,8 +219,10 @@ TEST_F(QCOMeasurementLiftingTest, programBuilder.sink(q0_1[0]); programBuilder.sink(q1_1); programBuilder.sink(q2_1); - module = programBuilder.finalize(); + module = programBuilder.finalize({c1, c2}); + referenceBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); auto r0_0 = referenceBuilder.allocQubit(); auto r1_0 = referenceBuilder.allocQubit(); auto r2_0 = referenceBuilder.allocQubit(); @@ -226,7 +238,7 @@ TEST_F(QCOMeasurementLiftingTest, referenceBuilder.sink(r0_1[0]); referenceBuilder.sink(r12_0[0]); referenceBuilder.sink(r12_0[1]); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr1, cr2}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -237,6 +249,8 @@ TEST_F(QCOMeasurementLiftingTest, TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverControlledParametrizedGate) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); auto q0_0 = programBuilder.allocQubit(); auto q1_0 = programBuilder.allocQubit(); @@ -247,8 +261,10 @@ TEST_F(QCOMeasurementLiftingTest, programBuilder.sink(q0_2); programBuilder.sink(q1_2); - module = programBuilder.finalize(); + module = programBuilder.finalize({c0, c1}); + referenceBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); auto r0_0 = referenceBuilder.allocQubit(); auto r1_0 = referenceBuilder.allocQubit(); @@ -260,7 +276,7 @@ TEST_F(QCOMeasurementLiftingTest, referenceBuilder.sink(r0_2); referenceBuilder.sink(r1_2); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr0, cr1}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -270,12 +286,15 @@ TEST_F(QCOMeasurementLiftingTest, } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { + + programBuilder.initialize({programBuilder.getI1Type()}); auto q_0 = programBuilder.allocQubit(); auto q_1 = programBuilder.x(q_0); auto [q_2, c] = programBuilder.measure(q_1); programBuilder.sink(q_2); - module = programBuilder.finalize(i1ToI64(c, programBuilder)); + module = programBuilder.finalize(c); + referenceBuilder.initialize({programBuilder.getI1Type()}); auto r_0 = referenceBuilder.allocQubit(); auto true_constant = referenceBuilder.boolConstant(true); auto [r_1, cr] = referenceBuilder.measure(r_0); @@ -283,8 +302,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { auto xorOp = arith::XOrIOp::create( referenceBuilder, referenceBuilder.getLoc(), cr, true_constant); referenceBuilder.sink(r_1); - reference = - referenceBuilder.finalize(i1ToI64(xorOp.getResult(), referenceBuilder)); + reference = referenceBuilder.finalize(xorOp.getResult()); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -294,19 +312,21 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { + programBuilder.initialize({programBuilder.getI1Type()}); auto q_0 = programBuilder.allocQubit(); auto q_1 = programBuilder.y(q_0); auto [q_2, c] = programBuilder.measure(q_1); programBuilder.sink(q_2); - module = programBuilder.finalize(); + module = programBuilder.finalize({c}); + referenceBuilder.initialize({programBuilder.getI1Type()}); auto r_0 = referenceBuilder.allocQubit(); auto true_constant = referenceBuilder.boolConstant(true); auto [r_1, cr] = referenceBuilder.measure(r_0); auto xorOp = arith::XOrIOp::create( referenceBuilder, referenceBuilder.getLoc(), cr, true_constant); referenceBuilder.sink(r_1); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({xorOp.getResult()}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -316,6 +336,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { + programBuilder.initialize({programBuilder.getI1Type()}); auto q_0 = programBuilder.allocQubit(); auto q_1 = programBuilder.id(q_0); auto q_2 = programBuilder.z(q_1); @@ -327,12 +348,13 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { auto q_8 = programBuilder.rz(std::numbers::pi / 2, q_7); auto [q_9, c] = programBuilder.measure(q_8); programBuilder.sink(q_9); - module = programBuilder.finalize(); + module = programBuilder.finalize({c}); + referenceBuilder.initialize({programBuilder.getI1Type()}); auto r_0 = referenceBuilder.allocQubit(); auto [r_1, cr] = referenceBuilder.measure(r_0); referenceBuilder.sink(r_1); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -342,17 +364,19 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverMultipleXY) { + programBuilder.initialize({programBuilder.getI1Type()}); auto q_0 = programBuilder.allocQubit(); auto q_1 = programBuilder.x(q_0); auto q_2 = programBuilder.y(q_1); auto [q_3, c] = programBuilder.measure(q_2); programBuilder.sink(q_3); - module = programBuilder.finalize(); + module = programBuilder.finalize({c}); + referenceBuilder.initialize({programBuilder.getI1Type()}); auto r_0 = referenceBuilder.allocQubit(); auto [r_1, cr] = referenceBuilder.measure(r_0); referenceBuilder.sink(r_1); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -362,6 +386,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverMultipleXY) { } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { + programBuilder.initialize({programBuilder.getI1Type()}); auto q0_0 = programBuilder.allocQubit(); auto q1_0 = programBuilder.allocQubit(); @@ -374,8 +399,9 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { programBuilder.sink(q0_5); programBuilder.sink(q1_2); - module = programBuilder.finalize(); + module = programBuilder.finalize({c0}); + referenceBuilder.initialize({programBuilder.getI1Type()}); auto r0_0 = referenceBuilder.allocQubit(); auto r1_0 = referenceBuilder.allocQubit(); @@ -387,7 +413,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { referenceBuilder.sink(r0_4); referenceBuilder.sink(r1_2); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr0}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); @@ -397,6 +423,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { } TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverDiagonalGateInControl) { + programBuilder.initialize({programBuilder.getI1Type()}); auto q0_0 = programBuilder.allocQubit(); auto q1_0 = programBuilder.allocQubit(); @@ -406,8 +433,9 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverDiagonalGateInControl) { programBuilder.sink(q0_2); programBuilder.sink(q1_1); - module = programBuilder.finalize(); + module = programBuilder.finalize({c0}); + referenceBuilder.initialize({programBuilder.getI1Type()}); auto r0_0 = referenceBuilder.allocQubit(); auto r1_0 = referenceBuilder.allocQubit(); @@ -415,7 +443,7 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverDiagonalGateInControl) { referenceBuilder.sink(r0_1); referenceBuilder.sink(r1_0); - reference = referenceBuilder.finalize(); + reference = referenceBuilder.finalize({cr0}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); From f6d5e3e1ff56f272575722809ec8c4682f986dd4 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 14:21:20 +0200 Subject: [PATCH 18/21] style(mlir): :rotating_light: fix linter issues --- .../test_qco_measurement_lifting.cpp | 303 +++++++++--------- 1 file changed, 152 insertions(+), 151 deletions(-) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp index eca6c31a16..4847e1b7cf 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_measurement_lifting.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -86,34 +87,34 @@ class QCOMeasurementLiftingTest : public testing::Test { TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPositiveControl) { programBuilder.initialize( {programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto q0_0 = programBuilder.allocQubit(); - auto q1_0 = programBuilder.allocQubit(); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); - auto [q1_1, q0_1] = programBuilder.cx(q1_0, q0_0); - auto [q0_2, q1_2] = programBuilder.ch(q0_1, q1_1); - auto [q0_3, q1_3] = programBuilder.cx(q0_2, q1_2); + auto [q1S1, q0S1] = programBuilder.cx(q1S0, q0S0); + auto [q0S2, q1S2] = programBuilder.ch(q0S1, q1S1); + auto [q0S3, q1S3] = programBuilder.cx(q0S2, q1S2); - auto [q0_4, c0] = programBuilder.measure(q0_3); - auto [q1_4, c1] = programBuilder.measure(q1_3); + auto [q0S4, c0] = programBuilder.measure(q0S3); + auto [q1S4, c1] = programBuilder.measure(q1S3); - programBuilder.sink(q0_4); - programBuilder.sink(q1_4); + programBuilder.sink(q0S4); + programBuilder.sink(q1S4); module = programBuilder.finalize({c0, c1}); referenceBuilder.initialize( {programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto r0_0 = referenceBuilder.allocQubit(); - auto r1_0 = referenceBuilder.allocQubit(); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); - auto [r1_1, r0_1] = referenceBuilder.cx(r1_0, r0_0); - auto [r0_2, cr0] = referenceBuilder.measure(r0_1); - auto [r0_3, r1_2] = referenceBuilder.ch(r0_2, r1_1); - auto [r0_4, r1_3] = referenceBuilder.cx(r0_3, r1_2); + auto [r1S1, r0S1] = referenceBuilder.cx(r1S0, r0S0); + auto [r0S2, cr0] = referenceBuilder.measure(r0S1); + auto [r0S3, r1S2] = referenceBuilder.ch(r0S2, r1S1); + auto [r0S4, r1S3] = referenceBuilder.cx(r0S3, r1S2); - auto [r1_4, cr1] = referenceBuilder.measure(r1_3); + auto [r1S4, cr1] = referenceBuilder.measure(r1S3); - referenceBuilder.sink(r0_4); - referenceBuilder.sink(r1_4); + referenceBuilder.sink(r0S4); + referenceBuilder.sink(r1S4); reference = referenceBuilder.finalize({cr0, cr1}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -127,68 +128,68 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverOneOfMultipleControls) { programBuilder.initialize({programBuilder.getI1Type(), programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto q0_0 = programBuilder.allocQubit(); - auto q1_0 = programBuilder.allocQubit(); - auto q2_0 = programBuilder.allocQubit(); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); - auto [q12_0, q0_1] = - programBuilder.ctrl({q1_0, q2_0}, {q0_0}, [&](const ValueRange target) { + auto [q12_0, q0S1] = + programBuilder.ctrl({q1S0, q2S0}, {q0S0}, [&](const ValueRange target) { return SmallVector{programBuilder.x(target[0])}; }); - auto [q12_1, q0_2] = programBuilder.ctrl( - {q12_0[1], q12_0[0]}, q0_1, [&](const ValueRange target) { + auto [q12_1, q0S2] = programBuilder.ctrl( + {q12_0[1], q12_0[0]}, q0S1, [&](const ValueRange target) { return SmallVector{programBuilder.h(target[0])}; }); - auto [q12_2, q0_3] = programBuilder.ctrl( - {q12_1[1], q12_1[0]}, q0_2, [&](const ValueRange target) { + auto [q12_2, q0S3] = programBuilder.ctrl( + {q12_1[1], q12_1[0]}, q0S2, [&](const ValueRange target) { return SmallVector{programBuilder.x(target[0])}; }); - auto [q1_4, c1] = programBuilder.measure(q12_2[0]); + auto [q1S4, c1] = programBuilder.measure(q12_2[0]); - auto q0_4 = programBuilder.h(q0_3[0]); - auto q2_4 = programBuilder.h(q12_2[1]); + auto q0S4 = programBuilder.h(q0S3[0]); + auto q2S4 = programBuilder.h(q12_2[1]); - auto [q0_5, c0] = programBuilder.measure(q0_4); - auto [q2_5, c2] = programBuilder.measure(q2_4); + auto [q0S5, c0] = programBuilder.measure(q0S4); + auto [q2S5, c2] = programBuilder.measure(q2S4); - programBuilder.sink(q0_5); - programBuilder.sink(q1_4); - programBuilder.sink(q2_5); + programBuilder.sink(q0S5); + programBuilder.sink(q1S4); + programBuilder.sink(q2S5); module = programBuilder.finalize({c0, c1, c2}); referenceBuilder.initialize({programBuilder.getI1Type(), programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto r0_0 = referenceBuilder.allocQubit(); - auto r1_0 = referenceBuilder.allocQubit(); - auto r2_0 = referenceBuilder.allocQubit(); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); - auto [r1_1, cr1] = referenceBuilder.measure(r1_0); + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); - auto [r12_0, r0_1] = - referenceBuilder.ctrl({r1_1, r2_0}, {r0_0}, [&](const ValueRange target) { + auto [r12_0, r0S1] = + referenceBuilder.ctrl({r1S1, r2S0}, {r0S0}, [&](const ValueRange target) { return SmallVector{referenceBuilder.x(target[0])}; }); - auto [r12_1, r0_2] = referenceBuilder.ctrl( - {r12_0[1], r12_0[0]}, r0_1, [&](const ValueRange target) { + auto [r12_1, r0S2] = referenceBuilder.ctrl( + {r12_0[1], r12_0[0]}, r0S1, [&](const ValueRange target) { return SmallVector{referenceBuilder.h(target[0])}; }); - auto [r12_2, r0_3] = referenceBuilder.ctrl( - {r12_1[1], r12_1[0]}, r0_2, [&](const ValueRange target) { + auto [r12_2, r0S3] = referenceBuilder.ctrl( + {r12_1[1], r12_1[0]}, r0S2, [&](const ValueRange target) { return SmallVector{referenceBuilder.x(target[0])}; }); - auto r0_4 = referenceBuilder.h(r0_3[0]); - auto r2_4 = referenceBuilder.h(r12_2[1]); + auto r0S4 = referenceBuilder.h(r0S3[0]); + auto r2S4 = referenceBuilder.h(r12_2[1]); - auto [r0_5, cr0] = referenceBuilder.measure(r0_4); - auto [r2_5, cr2] = referenceBuilder.measure(r2_4); + auto [r0S5, cr0] = referenceBuilder.measure(r0S4); + auto [r2S5, cr2] = referenceBuilder.measure(r2S4); - referenceBuilder.sink(r0_5); + referenceBuilder.sink(r0S5); referenceBuilder.sink(r12_2[0]); - referenceBuilder.sink(r2_5); + referenceBuilder.sink(r2S5); reference = referenceBuilder.finalize({cr0, cr1, cr2}); @@ -204,38 +205,38 @@ TEST_F(QCOMeasurementLiftingTest, programBuilder.initialize( {programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto q0_0 = programBuilder.allocQubit(); - auto q1_0 = programBuilder.allocQubit(); - auto q2_0 = programBuilder.allocQubit(); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); - auto [q12_0, q0_1] = - programBuilder.ctrl({q1_0, q2_0}, {q0_0}, [&](const ValueRange target) { + auto [q12_0, q0S1] = + programBuilder.ctrl({q1S0, q2S0}, {q0S0}, [&](const ValueRange target) { return SmallVector{programBuilder.x(target[0])}; }); - auto [q1_1, c1] = programBuilder.measure(q12_0[0]); - auto [q2_1, c2] = programBuilder.measure(q12_0[1]); + auto [q1S1, c1] = programBuilder.measure(q12_0[0]); + auto [q2S1, c2] = programBuilder.measure(q12_0[1]); - programBuilder.sink(q0_1[0]); - programBuilder.sink(q1_1); - programBuilder.sink(q2_1); + programBuilder.sink(q0S1[0]); + programBuilder.sink(q1S1); + programBuilder.sink(q2S1); module = programBuilder.finalize({c1, c2}); referenceBuilder.initialize( {programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto r0_0 = referenceBuilder.allocQubit(); - auto r1_0 = referenceBuilder.allocQubit(); - auto r2_0 = referenceBuilder.allocQubit(); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); - auto [r1_1, cr1] = referenceBuilder.measure(r1_0); - auto [r2_1, cr2] = referenceBuilder.measure(r2_0); + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + auto [r2S1, cr2] = referenceBuilder.measure(r2S0); - auto [r12_0, r0_1] = - referenceBuilder.ctrl({r1_1, r2_1}, {r0_0}, [&](const ValueRange target) { + auto [r12_0, r0S1] = + referenceBuilder.ctrl({r1S1, r2S1}, {r0S0}, [&](const ValueRange target) { return SmallVector{referenceBuilder.x(target[0])}; }); - referenceBuilder.sink(r0_1[0]); + referenceBuilder.sink(r0S1[0]); referenceBuilder.sink(r12_0[0]); referenceBuilder.sink(r12_0[1]); reference = referenceBuilder.finalize({cr1, cr2}); @@ -251,31 +252,31 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverControlledParametrizedGate) { programBuilder.initialize( {programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto q0_0 = programBuilder.allocQubit(); - auto q1_0 = programBuilder.allocQubit(); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); - auto [q0_1, q1_1] = programBuilder.crx(std::numbers::pi / 2, q0_0, q1_0); + auto [q0S1, q1S1] = programBuilder.crx(std::numbers::pi / 2, q0S0, q1S0); - auto [q0_2, c0] = programBuilder.measure(q0_1); - auto [q1_2, c1] = programBuilder.measure(q1_1); + auto [q0S2, c0] = programBuilder.measure(q0S1); + auto [q1S2, c1] = programBuilder.measure(q1S1); - programBuilder.sink(q0_2); - programBuilder.sink(q1_2); + programBuilder.sink(q0S2); + programBuilder.sink(q1S2); module = programBuilder.finalize({c0, c1}); referenceBuilder.initialize( {programBuilder.getI1Type(), programBuilder.getI1Type()}); - auto r0_0 = referenceBuilder.allocQubit(); - auto r1_0 = referenceBuilder.allocQubit(); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); - auto [r0_1, cr0] = referenceBuilder.measure(r0_0); + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); - auto [r0_2, r1_1] = referenceBuilder.crx(std::numbers::pi / 2, r0_1, r1_0); + auto [r0S2, r1S1] = referenceBuilder.crx(std::numbers::pi / 2, r0S1, r1S0); - auto [r1_2, cr1] = referenceBuilder.measure(r1_1); + auto [r1S2, cr1] = referenceBuilder.measure(r1S1); - referenceBuilder.sink(r0_2); - referenceBuilder.sink(r1_2); + referenceBuilder.sink(r0S2); + referenceBuilder.sink(r1S2); reference = referenceBuilder.finalize({cr0, cr1}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -288,20 +289,20 @@ TEST_F(QCOMeasurementLiftingTest, TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { programBuilder.initialize({programBuilder.getI1Type()}); - auto q_0 = programBuilder.allocQubit(); - auto q_1 = programBuilder.x(q_0); - auto [q_2, c] = programBuilder.measure(q_1); - programBuilder.sink(q_2); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.x(q0); + auto [q2, c] = programBuilder.measure(q1); + programBuilder.sink(q2); module = programBuilder.finalize(c); referenceBuilder.initialize({programBuilder.getI1Type()}); - auto r_0 = referenceBuilder.allocQubit(); - auto true_constant = referenceBuilder.boolConstant(true); - auto [r_1, cr] = referenceBuilder.measure(r_0); + auto r0 = referenceBuilder.allocQubit(); + auto trueConstant = referenceBuilder.boolConstant(true); + auto [r1, cr] = referenceBuilder.measure(r0); auto xorOp = arith::XOrIOp::create( - referenceBuilder, referenceBuilder.getLoc(), cr, true_constant); - referenceBuilder.sink(r_1); + referenceBuilder, referenceBuilder.getLoc(), cr, trueConstant); + referenceBuilder.sink(r1); reference = referenceBuilder.finalize(xorOp.getResult()); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -313,19 +314,19 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleX) { TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { programBuilder.initialize({programBuilder.getI1Type()}); - auto q_0 = programBuilder.allocQubit(); - auto q_1 = programBuilder.y(q_0); - auto [q_2, c] = programBuilder.measure(q_1); - programBuilder.sink(q_2); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.y(q0); + auto [q2, c] = programBuilder.measure(q1); + programBuilder.sink(q2); module = programBuilder.finalize({c}); referenceBuilder.initialize({programBuilder.getI1Type()}); - auto r_0 = referenceBuilder.allocQubit(); - auto true_constant = referenceBuilder.boolConstant(true); - auto [r_1, cr] = referenceBuilder.measure(r_0); + auto r0 = referenceBuilder.allocQubit(); + auto trueConstant = referenceBuilder.boolConstant(true); + auto [r1, cr] = referenceBuilder.measure(r0); auto xorOp = arith::XOrIOp::create( - referenceBuilder, referenceBuilder.getLoc(), cr, true_constant); - referenceBuilder.sink(r_1); + referenceBuilder, referenceBuilder.getLoc(), cr, trueConstant); + referenceBuilder.sink(r1); reference = referenceBuilder.finalize({xorOp.getResult()}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -337,23 +338,23 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverSingleY) { TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { programBuilder.initialize({programBuilder.getI1Type()}); - auto q_0 = programBuilder.allocQubit(); - auto q_1 = programBuilder.id(q_0); - auto q_2 = programBuilder.z(q_1); - auto q_3 = programBuilder.s(q_2); - auto q_4 = programBuilder.sdg(q_3); - auto q_5 = programBuilder.t(q_4); - auto q_6 = programBuilder.tdg(q_5); - auto q_7 = programBuilder.p(std::numbers::pi / 2, q_6); - auto q_8 = programBuilder.rz(std::numbers::pi / 2, q_7); - auto [q_9, c] = programBuilder.measure(q_8); - programBuilder.sink(q_9); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.id(q0); + auto q2 = programBuilder.z(q1); + auto q3 = programBuilder.s(q2); + auto q4 = programBuilder.sdg(q3); + auto q5 = programBuilder.t(q4); + auto q6 = programBuilder.tdg(q5); + auto q7 = programBuilder.p(std::numbers::pi / 2, q6); + auto q8 = programBuilder.rz(std::numbers::pi / 2, q7); + auto [q9, c] = programBuilder.measure(q8); + programBuilder.sink(q9); module = programBuilder.finalize({c}); referenceBuilder.initialize({programBuilder.getI1Type()}); - auto r_0 = referenceBuilder.allocQubit(); - auto [r_1, cr] = referenceBuilder.measure(r_0); - referenceBuilder.sink(r_1); + auto r0 = referenceBuilder.allocQubit(); + auto [r1, cr] = referenceBuilder.measure(r0); + referenceBuilder.sink(r1); reference = referenceBuilder.finalize({cr}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -365,17 +366,17 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverPhaseGates) { TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverMultipleXY) { programBuilder.initialize({programBuilder.getI1Type()}); - auto q_0 = programBuilder.allocQubit(); - auto q_1 = programBuilder.x(q_0); - auto q_2 = programBuilder.y(q_1); - auto [q_3, c] = programBuilder.measure(q_2); - programBuilder.sink(q_3); + auto q0 = programBuilder.allocQubit(); + auto q1 = programBuilder.x(q0); + auto q2 = programBuilder.y(q1); + auto [q3, c] = programBuilder.measure(q2); + programBuilder.sink(q3); module = programBuilder.finalize({c}); referenceBuilder.initialize({programBuilder.getI1Type()}); - auto r_0 = referenceBuilder.allocQubit(); - auto [r_1, cr] = referenceBuilder.measure(r_0); - referenceBuilder.sink(r_1); + auto r0 = referenceBuilder.allocQubit(); + auto [r1, cr] = referenceBuilder.measure(r0); + referenceBuilder.sink(r1); reference = referenceBuilder.finalize({cr}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -387,32 +388,32 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverMultipleXY) { TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { programBuilder.initialize({programBuilder.getI1Type()}); - auto q0_0 = programBuilder.allocQubit(); - auto q1_0 = programBuilder.allocQubit(); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); - auto [q0_1, q1_1] = programBuilder.cy(q0_0, q1_0); - auto q0_2 = programBuilder.x(q0_1); - auto [q0_3, q1_2] = programBuilder.cy(q0_2, q1_1); - auto q0_4 = programBuilder.x(q0_3); + auto [q0S1, q1S1] = programBuilder.cy(q0S0, q1S0); + auto q0S2 = programBuilder.x(q0S1); + auto [q0S3, q1S2] = programBuilder.cy(q0S2, q1S1); + auto q0S4 = programBuilder.x(q0S3); - auto [q0_5, c0] = programBuilder.measure(q0_4); + auto [q0S5, c0] = programBuilder.measure(q0S4); - programBuilder.sink(q0_5); - programBuilder.sink(q1_2); + programBuilder.sink(q0S5); + programBuilder.sink(q1S2); module = programBuilder.finalize({c0}); referenceBuilder.initialize({programBuilder.getI1Type()}); - auto r0_0 = referenceBuilder.allocQubit(); - auto r1_0 = referenceBuilder.allocQubit(); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); - auto [r0_1, cr0] = referenceBuilder.measure(r0_0); + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); - auto [r0_2, r1_1] = referenceBuilder.cx(r0_1, r1_0); - auto r0_3 = referenceBuilder.x(r0_2); - auto [r0_4, r1_2] = referenceBuilder.cx(r0_3, r1_1); + auto [r0S2, r1S1] = referenceBuilder.cx(r0S1, r1S0); + auto r0S3 = referenceBuilder.x(r0S2); + auto [r0S4, r1S2] = referenceBuilder.cx(r0S3, r1S1); - referenceBuilder.sink(r0_4); - referenceBuilder.sink(r1_2); + referenceBuilder.sink(r0S4); + referenceBuilder.sink(r1S2); reference = referenceBuilder.finalize({cr0}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); @@ -424,25 +425,25 @@ TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverXAndControlledGates) { TEST_F(QCOMeasurementLiftingTest, liftMeasurementOverDiagonalGateInControl) { programBuilder.initialize({programBuilder.getI1Type()}); - auto q0_0 = programBuilder.allocQubit(); - auto q1_0 = programBuilder.allocQubit(); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); - auto [q0_1, q1_1] = programBuilder.cz(q0_0, q1_0); + auto [q0S1, q1S1] = programBuilder.cz(q0S0, q1S0); - auto [q0_2, c0] = programBuilder.measure(q0_1); + auto [q0S2, c0] = programBuilder.measure(q0S1); - programBuilder.sink(q0_2); - programBuilder.sink(q1_1); + programBuilder.sink(q0S2); + programBuilder.sink(q1S1); module = programBuilder.finalize({c0}); referenceBuilder.initialize({programBuilder.getI1Type()}); - auto r0_0 = referenceBuilder.allocQubit(); - auto r1_0 = referenceBuilder.allocQubit(); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); - auto [r0_1, cr0] = referenceBuilder.measure(r0_0); + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); - referenceBuilder.sink(r0_1); - referenceBuilder.sink(r1_0); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S0); reference = referenceBuilder.finalize({cr0}); ASSERT_TRUE(runMeasurementLiftingPass(module.get()).succeeded()); From 90706126f497da5c718caa8b7bf678641f160d62 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 17:29:08 +0200 Subject: [PATCH 19/21] feat(mlir): :sparkles: implement classical control replacement pattern and add first tests --- .../mlir/Dialect/QCO/Transforms/Passes.td | 18 + .../ReplaceClassicalControls.cpp | 193 +++++++ .../Transforms/Optimizations/CMakeLists.txt | 7 +- .../test_qco_replace_classical_controls.cpp | 491 ++++++++++++++++++ 4 files changed, 707 insertions(+), 2 deletions(-) create mode 100644 mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp create mode 100644 mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td index bad2fb4f9c..c23e18f988 100644 --- a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -208,4 +208,22 @@ def MeasurementLifting : Pass<"measurement-lifting", "mlir::ModuleOp"> { }]; } +def ReplaceClassicalControls + : Pass<"replace-classical-controls", "mlir::ModuleOp"> { + let dependentDialects = ["mlir::qco::QCODialect", + "::mlir::arith::ArithDialect", + ]; + let summary = + "This pass attempts to replace controlled gates that are controlled by " + "values that are in a computational basis state and available as " + "measurement outcome by a classical " + "`qco.if` instruction. This is done to enable qubit reuse and other " + "optimizations."; + let description = [{ + This pass searches for control operations that immediately follow measurements and replaces them + with classically controlled operations, represented by `qco.if` operations. + This reduces quantum interactions and allows for more optimizations, e.g., by enabling qubit reuse. + }]; +} + #endif // MLIR_DIALECT_QCO_TRANSFORMS_PASSES_TD diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp new file mode 100644 index 0000000000..d9aa504513 --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp @@ -0,0 +1,193 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +// +// Created by damian on 5/13/26. +// + +#include "mlir/Dialect/QCO/IR/QCOInterfaces.h" +#include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_REPLACECLASSICALCONTROLS +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +/** + * @brief Retrieves the measurement outcome that directly precedes the given + * qubit, if it exists. + * @param qubit The qubit for which to find the predecessor measurement outcome + * @return The measurement outcome if a predecessor measurement exists, nullptr + * otherwise + */ +static Value getPredecessorMeasurementOutcome(Value qubit) { + auto* definingOp = qubit.getDefiningOp(); + if (auto measureOp = dyn_cast_or_null(definingOp)) { + return measureOp.getResult(); + } + return nullptr; +} + +/** + * @brief Checks if the given operation is a diagonal gate, i.e., it only + * applies a phase to the target qubit(s) and does not change their state. + * @param op The operation to check + * @return true if the operation is a diagonal gate, false otherwise + */ +static bool isDiagonal(Operation* op) { + if (auto i = dyn_cast(op)) { + return isDiagonal(i.getBodyUnitary()); + } + return isa(op); +} + +/** + * @brief For a diagonal gate with a control that has a predecessor measurement, + * swaps the control with the target. + * @param op The control operation containing the diagonal gate + * @param rewriter The pattern rewriter used to perform the transformation + */ +static void trySwapControlsOfDiagonalGate(CtrlOp op, + mlir::PatternRewriter& rewriter) { + assert(op.getBodyUnitary().getNumQubits() == 1 && + "Only single-qubit gates can be swapped around controls"); + auto target = op.getTargetsIn()[0]; + auto predecessorOutcome = getPredecessorMeasurementOutcome(target); + if (!predecessorOutcome) { + // No advantage gained from swapping. + return; + } + for (auto control : op.getControlsIn()) { + auto controlOutcome = getPredecessorMeasurementOutcome(control); + if (controlOutcome) { + continue; + } + rewriter.replaceAllUsesWith(control, target); + rewriter.modifyOpInPlace( + op, [&]() { op.getTargetsInMutable()[0].set(control); }); + auto dummyTarget = AllocOp::create(rewriter, op->getLoc()); + rewriter.replaceAllUsesWith(op.getOutputForInput(target), dummyTarget); + rewriter.replaceAllUsesWith(op.getOutputForInput(control), + op.getOutputForInput(target)); + rewriter.replaceAllUsesWith(dummyTarget, op.getOutputForInput(control)); + rewriter.eraseOp(dummyTarget); + break; + } +} + +namespace { +/** + * @brief This pattern is responsible for replacing controls after measurements + * with `if` constructs. + */ +struct ReplaceBasisStateControlsWithIfPattern final + : mlir::OpRewritePattern { + + explicit ReplaceBasisStateControlsWithIfPattern(mlir::MLIRContext* context) + : OpRewritePattern(context) {} + + mlir::LogicalResult + matchAndRewrite(CtrlOp op, mlir::PatternRewriter& rewriter) const override { + if (isDiagonal(op.getBodyUnitary())) { + trySwapControlsOfDiagonalGate(op, rewriter); + } + + SmallVector> toReplace; + SmallVector toKeep; + + for (const auto& operand : op.getControlsIn()) { + auto outcome = getPredecessorMeasurementOutcome(operand); + if (outcome) { + toReplace.emplace_back(operand, outcome); + } else { + toKeep.push_back(operand); + } + } + + if (toReplace.empty()) { + return mlir::failure(); + } + + auto condition = std::accumulate( + std::next(toReplace.begin()), toReplace.end(), + toReplace.begin()->second, [&](const auto& acc, const auto& pair) { + auto conjunction = + arith::AndIOp::create(rewriter, op.getLoc(), acc, pair.second); + return conjunction.getResult(); + }); + + auto allQubits = toKeep; + llvm::append_range(allQubits, op.getTargetsIn()); + + auto ifOp = IfOp::create( + rewriter, op->getLoc(), condition, allQubits, + [&](ValueRange qubits) -> SmallVector { + auto newControls = qubits.slice(0, toKeep.size()); + auto newTargets = + qubits.slice(toKeep.size(), qubits.size() - toKeep.size()); + + auto newCtrl = + CtrlOp::create(rewriter, op->getLoc(), newControls, newTargets); + rewriter.inlineRegionBefore(op.getRegion(), newCtrl.getRegion(), + newCtrl.getRegion().begin()); + return newCtrl.getOutputQubits(); + }); + + for (auto replace : toReplace) { + rewriter.replaceAllUsesWith(op.getOutputForInput(replace.first), + replace.first); + } + for (auto [oldInput, result] : llvm::zip(allQubits, ifOp.getResults())) { + rewriter.replaceAllUsesWith(op.getOutputForInput(oldInput), result); + } + rewriter.eraseOp(op); + + return mlir::success(); + } +}; + +/** + * @brief Pass replaces controls with `IfOp` operations if the qubits' + * control values are available classically. + */ +struct ReplaceClassicalControls final + : impl::ReplaceClassicalControlsBase { + using ReplaceClassicalControlsBase::ReplaceClassicalControlsBase; + +protected: + void runOnOperation() override { + const auto op = getOperation(); + auto* ctx = &getContext(); + + // Define the set of patterns to use. + RewritePatternSet patterns(ctx); + patterns.add(patterns.getContext()); + + // Apply patterns in an iterative and greedy manner. + if (failed(applyPatternsGreedily(op, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace mlir::qco diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt index bdc78401c8..5a4180196a 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/CMakeLists.txt @@ -7,8 +7,11 @@ # Licensed under the MIT License set(target_name mqt-core-mlir-unittest-optimizations) -add_executable(${target_name} test_qco_hadamard_lifting.cpp test_qco_measurement_lifting.cpp - test_qco_merge_single_qubit_rotation.cpp test_quantum_loop_unroll.cpp) +add_executable( + ${target_name} + test_qco_hadamard_lifting.cpp test_qco_measurement_lifting.cpp + test_qco_merge_single_qubit_rotation.cpp test_qco_replace_classical_controls.cpp + test_quantum_loop_unroll.cpp) target_link_libraries( ${target_name} diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp new file mode 100644 index 0000000000..129a887f0d --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp @@ -0,0 +1,491 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +// +// Created by damian on 5/21/26. +// + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" +#include "mlir/Support/IRVerification.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace { + +using namespace mlir; +using namespace mlir::qco; + +class QCOReplaceClassicalControlsTest : public testing::Test { + +protected: + MLIRContext context; + QCOProgramBuilder programBuilder; + QCOProgramBuilder referenceBuilder; + OwningOpRef module; + OwningOpRef reference; + + QCOReplaceClassicalControlsTest() + : programBuilder(&context), referenceBuilder(&context) {} + + void SetUp() override { + // Register all necessary dialects + DialectRegistry registry; + registry.insert(); + context.appendDialectRegistry(registry); + context.loadAllAvailableDialects(); + } + + /** + * @brief Adds the replaceClassicalControls pass to the current context and + * runs it. + */ + static LogicalResult + runReplaceClassicalControlsPass(ModuleOp module, + bool liftMeasurements = false) { + PassManager pm(module.getContext()); + pm.addPass(createReplaceClassicalControls()); + if (liftMeasurements) { + pm.addPass(createMeasurementLifting()); + } + pm.addPass(createCanonicalizerPass()); + return pm.run(module); + } + + /** + * @brief Adds the canonicalizerPass to the current context and runs it. + */ + static LogicalResult runCanonicalizerPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createCanonicalizerPass()); + return pm.run(module); + } +}; + +} // namespace + +TEST_F(QCOReplaceClassicalControlsTest, replaceClassicalControlsOnlyControl) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q0S2, q1S1] = programBuilder.cx(q0S1, q1S0); + auto [q1S2, c1] = programBuilder.measure(q1S1); + + programBuilder.sink(q0S2); + programBuilder.sink(q1S2); + module = programBuilder.finalize({c0, c1}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + auto r1S1 = referenceBuilder.qcoIf( + cr0, {r1S0}, [&](ValueRange qubits) -> SmallVector { + auto q1Then = referenceBuilder.x(qubits[0]); + return SmallVector{q1Then}; + })[0]; + auto [r1S2, cr1] = referenceBuilder.measure(r1S1); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S2); + + reference = referenceBuilder.finalize({cr0, cr1}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, + replaceClassicalControlsOneOfTwoControls) { + programBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + + auto [q01S1, q2S1] = + programBuilder.ctrl({q0S1, q1S0}, {q2S0}, + [&](const ValueRange targets) -> SmallVector { + auto q = programBuilder.x(targets[0]); + return SmallVector{q}; + }); + + auto [q1S2, c1] = programBuilder.measure(q01S1[1]); + auto [q2S2, c2] = programBuilder.measure(q2S1[0]); + + programBuilder.sink(q01S1[0]); + programBuilder.sink(q1S2); + programBuilder.sink(q2S2); + module = programBuilder.finalize({c0, c1, c2}); + + referenceBuilder.initialize({referenceBuilder.getI1Type(), + referenceBuilder.getI1Type(), + referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + auto r12 = referenceBuilder.qcoIf( + cr0, {r1S0, r2S0}, [&](ValueRange qubits) -> SmallVector { + auto [r1, r2] = referenceBuilder.cx(qubits[0], qubits[1]); + return SmallVector{r1, r2}; + }); + auto [r1S2, cr1] = referenceBuilder.measure(r12[0]); + auto [r2S2, cr2] = referenceBuilder.measure(r12[1]); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S2); + referenceBuilder.sink(r2S2); + + reference = referenceBuilder.finalize({cr0, cr1, cr2}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, + replaceClassicalControlsTwoOfTwoControls) { + programBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q1S1, c1] = programBuilder.measure(q1S0); + + auto [q01S1, q2S1] = + programBuilder.ctrl({q0S1, q1S1}, {q2S0}, + [&](const ValueRange targets) -> SmallVector { + auto q = programBuilder.x(targets[0]); + return SmallVector{q}; + }); + + auto [q2S2, c2] = programBuilder.measure(q2S1[0]); + + programBuilder.sink(q01S1[0]); + programBuilder.sink(q01S1[1]); + programBuilder.sink(q2S2); + module = programBuilder.finalize({c0, c1, c2}); + + referenceBuilder.initialize({referenceBuilder.getI1Type(), + referenceBuilder.getI1Type(), + referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + + auto andOp = arith::AndIOp::create(referenceBuilder, cr0, cr1); + + auto r2S1 = referenceBuilder.qcoIf( + andOp.getResult(), {r2S0}, [&](ValueRange qubits) -> SmallVector { + auto r = referenceBuilder.x(qubits[0]); + return SmallVector{r}; + })[0]; + auto [r2S2, cr2] = referenceBuilder.measure(r2S1); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S1); + referenceBuilder.sink(r2S2); + + reference = referenceBuilder.finalize({cr0, cr1, cr2}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, + replaceClassicalControlsTwoOfThreeControls) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type(), + programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + auto q3S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q1S1, c1] = programBuilder.measure(q1S0); + + auto [q012S1, q3S1] = + programBuilder.ctrl({q0S1, q1S1, q2S0}, {q3S0}, + [&](const ValueRange targets) -> SmallVector { + auto q = programBuilder.x(targets[0]); + return SmallVector{q}; + }); + + auto [q2S2, c2] = programBuilder.measure(q012S1[2]); + auto [q3S2, c3] = programBuilder.measure(q3S1[0]); + + programBuilder.sink(q012S1[0]); + programBuilder.sink(q012S1[1]); + programBuilder.sink(q2S2); + programBuilder.sink(q3S2); + module = programBuilder.finalize({c0, c1, c2, c3}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type(), + referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + auto r3S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + + auto andOp = arith::AndIOp::create(referenceBuilder, cr0, cr1); + + auto r23S1 = + referenceBuilder.qcoIf(andOp.getResult(), {r2S0, r3S0}, + [&](ValueRange qubits) -> SmallVector { + auto [r2, r3] = + referenceBuilder.cx(qubits[0], qubits[1]); + return SmallVector{r2, r3}; + }); + auto [r2S2, cr2] = referenceBuilder.measure(r23S1[0]); + auto [r3S2, cr3] = referenceBuilder.measure(r23S1[1]); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S1); + referenceBuilder.sink(r2S2); + referenceBuilder.sink(r3S2); + + reference = referenceBuilder.finalize({cr0, cr1, cr2, cr3}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, replaceClassicalControlsSwapDiagonal) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q1S1, q0S2] = programBuilder.cz(q1S0, q0S1); + auto [q1S2, c1] = programBuilder.measure(q1S1); + + programBuilder.sink(q0S2); + programBuilder.sink(q1S2); + module = programBuilder.finalize({c0, c1}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + auto r1S1 = referenceBuilder.qcoIf( + cr0, {r1S0}, [&](ValueRange qubits) -> SmallVector { + auto q1Then = referenceBuilder.z(qubits[0]); + return SmallVector{q1Then}; + })[0]; + auto [r1S2, cr1] = referenceBuilder.measure(r1S1); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S2); + + reference = referenceBuilder.finalize({cr0, cr1}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, + replaceClassicalControlsDontSwapDiagonalIfNotNecessary) { + programBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q1S1, c1] = programBuilder.measure(q1S0); + auto [q1S2, q0S2] = programBuilder.cz(q1S1, q0S1); + auto [q0S3, c0_] = programBuilder.measure(q0S2); + + programBuilder.sink(q0S3); + programBuilder.sink(q1S2); + module = programBuilder.finalize({c0, c1, c0_}); + + referenceBuilder.initialize({referenceBuilder.getI1Type(), + referenceBuilder.getI1Type(), + programBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + + auto r0S2 = referenceBuilder.qcoIf( + cr1, {r0S1}, [&](ValueRange qubits) -> SmallVector { + auto r0Then = referenceBuilder.z(qubits[0]); + return SmallVector{r0Then}; + })[0]; + auto [r0S3, cr0_] = referenceBuilder.measure(r0S2); + referenceBuilder.sink(r0S3); + referenceBuilder.sink(r1S1); + + reference = referenceBuilder.finalize({cr0, cr1, cr0_}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + module->dump(); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, + replaceClassicalControlsSwapOneOfTwoDiagonal) { + programBuilder.initialize( + {programBuilder.getI1Type(), programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q12, q0S2] = + programBuilder.ctrl({q1S0, q2S0}, {q0S1}, + [&](const ValueRange targets) -> SmallVector { + auto q = programBuilder.z(targets[0]); + return SmallVector{q}; + }); + auto [q1S2, c1] = programBuilder.measure(q12[0]); + + programBuilder.sink(q0S2[0]); + programBuilder.sink(q1S2); + programBuilder.sink(q12[1]); + module = programBuilder.finalize({c0, c1}); + + referenceBuilder.initialize( + {referenceBuilder.getI1Type(), referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + + auto r21 = referenceBuilder.qcoIf( + cr0, {r2S0, r1S0}, [&](ValueRange qubits) -> SmallVector { + auto [r2, r1] = referenceBuilder.cz(qubits[0], qubits[1]); + return SmallVector{r2, r1}; + }); + auto [r1S2, cr1] = referenceBuilder.measure(r21[1]); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S2); + referenceBuilder.sink(r21[0]); + + reference = referenceBuilder.finalize({cr0, cr1}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} + +TEST_F(QCOReplaceClassicalControlsTest, + replaceClassicalControlsSwapOnlyPossibleDiagonal) { + programBuilder.initialize({programBuilder.getI1Type(), + programBuilder.getI1Type(), + programBuilder.getI1Type()}); + auto q0S0 = programBuilder.allocQubit(); + auto q1S0 = programBuilder.allocQubit(); + auto q2S0 = programBuilder.allocQubit(); + + auto [q0S1, c0] = programBuilder.measure(q0S0); + auto [q1S1, c1] = programBuilder.measure(q1S0); + auto [q12, q0S2] = + programBuilder.ctrl({q1S1, q2S0}, {q0S1}, + [&](const ValueRange targets) -> SmallVector { + auto q = programBuilder.z(targets[0]); + return SmallVector{q}; + }); + auto [q2S2, c2] = programBuilder.measure(q12[1]); + + programBuilder.sink(q0S2[0]); + programBuilder.sink(q12[0]); + programBuilder.sink(q2S2); + module = programBuilder.finalize({c0, c1, c2}); + + referenceBuilder.initialize({referenceBuilder.getI1Type(), + referenceBuilder.getI1Type(), + referenceBuilder.getI1Type()}); + auto r0S0 = referenceBuilder.allocQubit(); + auto r1S0 = referenceBuilder.allocQubit(); + auto r2S0 = referenceBuilder.allocQubit(); + + auto [r0S1, cr0] = referenceBuilder.measure(r0S0); + auto [r1S1, cr1] = referenceBuilder.measure(r1S0); + + auto andOp = arith::AndIOp::create(referenceBuilder, cr1, cr0); + + auto r2S1 = referenceBuilder.qcoIf( + andOp.getResult(), {r2S0}, [&](ValueRange qubits) -> SmallVector { + auto r = referenceBuilder.z(qubits[0]); + return SmallVector{r}; + })[0]; + auto [r2S2, cr2] = referenceBuilder.measure(r2S1); + referenceBuilder.sink(r0S1); + referenceBuilder.sink(r1S1); + referenceBuilder.sink(r2S2); + + reference = referenceBuilder.finalize({cr0, cr1, cr2}); + + ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); + ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); + + EXPECT_TRUE( + areModulesEquivalentWithPermutations(module.get(), reference.get())); +} From c1fda941ec958c2ce928531f8e60043b713c821c Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Tue, 2 Jun 2026 18:15:23 +0200 Subject: [PATCH 20/21] style(mlir): :rotating_light: fix includes --- .../QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp | 3 +++ .../Optimizations/test_qco_replace_classical_controls.cpp | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp index d9aa504513..353d36e0d6 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp @@ -24,6 +24,9 @@ #include #include +#include +#include +#include #include namespace mlir::qco { diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp index 129a887f0d..11d5304cc5 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp @@ -20,7 +20,6 @@ #include #include #include -#include #include #include #include @@ -30,8 +29,6 @@ #include #include -#include - namespace { using namespace mlir; From ad1be84e1ab39bcea17b63416ee6e0c85b3cb375 Mon Sep 17 00:00:00 2001 From: Damian Rovara Date: Wed, 3 Jun 2026 12:50:34 +0200 Subject: [PATCH 21/21] style(mlir): :recycle: apply coderabbit suggestions --- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 2 ++ .../Optimizations/ReplaceClassicalControls.cpp | 10 +++++----- .../test_qco_replace_classical_controls.cpp | 2 -- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 9c6af7a8f4..53639b3dc0 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -10,6 +10,8 @@ #pragma once +#include "mlir/Dialect/QCO/IR/QCOOps.h" + #include #include #include diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp index 353d36e0d6..49651d7424 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/ReplaceClassicalControls.cpp @@ -50,14 +50,14 @@ static Value getPredecessorMeasurementOutcome(Value qubit) { } /** - * @brief Checks if the given operation is a diagonal gate, i.e., it only - * applies a phase to the target qubit(s) and does not change their state. + * @brief Checks if the given operation is a phase gate, i.e., it only + * applies a phase to the target qubit(s) in the 1 state. * @param op The operation to check * @return true if the operation is a diagonal gate, false otherwise */ -static bool isDiagonal(Operation* op) { +static bool isPhaseGate(Operation* op) { if (auto i = dyn_cast(op)) { - return isDiagonal(i.getBodyUnitary()); + return isPhaseGate(i.getBodyUnitary()); } return isa(op); } @@ -109,7 +109,7 @@ struct ReplaceBasisStateControlsWithIfPattern final mlir::LogicalResult matchAndRewrite(CtrlOp op, mlir::PatternRewriter& rewriter) const override { - if (isDiagonal(op.getBodyUnitary())) { + if (isPhaseGate(op.getBodyUnitary())) { trySwapControlsOfDiagonalGate(op, rewriter); } diff --git a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp index 11d5304cc5..4ff48c57a1 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp +++ b/mlir/unittests/Dialect/QCO/Transforms/Optimizations/test_qco_replace_classical_controls.cpp @@ -377,8 +377,6 @@ TEST_F(QCOReplaceClassicalControlsTest, ASSERT_TRUE(runReplaceClassicalControlsPass(module.get()).succeeded()); ASSERT_TRUE(runCanonicalizerPass(reference.get()).succeeded()); - module->dump(); - EXPECT_TRUE( areModulesEquivalentWithPermutations(module.get(), reference.get())); }