From 9070fdf867f94e3d0e6874808cc771a27cd9f8ef Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 11 Dec 2025 15:43:22 +0100 Subject: [PATCH 001/108] add support for if else conversion --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 53 ++++- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 193 ++++++++++++++---- 2 files changed, 203 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 97e5448ed2..9ed723163e 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Flux/IR/FluxDialect.h" #include "mlir/Dialect/Quartz/IR/QuartzDialect.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include #include @@ -820,6 +821,44 @@ struct ConvertFluxYieldOp final : OpConversionPattern { return success(); } }; +struct ConvertFluxScfIfOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto newIf = + rewriter.create(op.getLoc(), ValueRange{}, op.getCondition(), + op.getElseRegion().empty()); + // inline the regions + rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), + newIf.getThenRegion().end()); + if (!op.getElseRegion().empty()) { + rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), + newIf.getElseRegion().end()); + + } + rewriter.eraseBlock(&newIf.getThenRegion().front()); + + auto yield = + dyn_cast(newIf.getThenRegion().back().getTerminator()); + + rewriter.replaceOp(op, yield->getOperands()); + + return success(); + } +}; + +struct ConvertFluxScfYieldOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; /** * @brief Pass implementation for Flux-to-Quartz conversion @@ -862,7 +901,17 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { // Configure conversion target: Flux illegal, Quartz legal target.addIllegalDialect(); target.addLegalDialect(); + target.addDynamicallyLegalOp([&](scf::IfOp op) { + return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](scf::YieldOp op) { + return !llvm::any_of(op.getOperandTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns @@ -876,8 +925,8 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { ConvertFluxiSWAPOp, ConvertFluxDCXOp, ConvertFluxECROp, ConvertFluxRXXOp, ConvertFluxRYYOp, ConvertFluxRZXOp, ConvertFluxRZZOp, ConvertFluxXXPlusYYOp, ConvertFluxXXMinusYYOp, - ConvertFluxBarrierOp, ConvertFluxCtrlOp, ConvertFluxYieldOp>( - typeConverter, context); + ConvertFluxBarrierOp, ConvertFluxCtrlOp, ConvertFluxYieldOp, + ConvertFluxScfIfOp, ConvertFluxScfYieldOp>(typeConverter, context); // Conversion of flux types in func.func signatures // Note: This currently has limitations with signature changes diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index a7a4b93297..865ce9765d 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include @@ -69,7 +70,9 @@ namespace { */ struct LoweringState { /// Map from original Quartz qubit references to their latest Flux SSA values - llvm::DenseMap qubitMap; + llvm::DenseMap> qubitMap; + /// Map each initial op to its refQubits. + llvm::DenseMap> regionMap; /// Modifier information int64_t inCtrlOp = 0; @@ -107,6 +110,55 @@ class StatefulOpConversionPattern : public OpConversionPattern { LoweringState* state_; }; +llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, + MLIRContext* ctx) { + + // get the regions of the current operation + const auto regions = op->getRegions(); + SetVector uniqueQubits; + for (auto& region : regions) { + // skip empty regions e.g. empty else region of an If operation + if (region.empty()) { + continue; + } + // iterate over all operations inside the region + // currently assumes that each region only has one block + for (auto& operation : region.front().getOperations()) { + // check if the operation has an region, if yes recursively collect the + // qubits + if (operation.getNumRegions() > 0) { + auto qubits = collectRegionQubits(&operation, state, ctx); + for (auto qubit : qubits) { + uniqueQubits.insert(qubit); + } + } + // collect qubits form the operands + for (auto operand : operation.getOperands()) { + if (operand.getType() == quartz::QubitType::get(ctx)) { + uniqueQubits.insert(operand); + } + } + // collect qubits from the results + for (auto result : operation.getResults()) { + if (result.getType() == quartz::QubitType::get(ctx)) { + uniqueQubits.insert(result); + } + } + if (llvm::isa(operation) && uniqueQubits.size() > 0) { + operation.setAttr("needChange", StringAttr::get(ctx, "yes")); + } + } + } + if (!uniqueQubits.empty() && + (llvm::isa(op) || (llvm::isa(op)) || + llvm::isa(op))) { + state->regionMap[op] = uniqueQubits; + // mark operations that need to be changed afterwards + op->setAttr("needChange", StringAttr::get(ctx, "yes")); + } + return uniqueQubits; +} + /** * @brief Converts a zero-target, one-parameter Quartz operation to Flux * @@ -151,7 +203,7 @@ template LogicalResult convertOneTargetZeroParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit @@ -194,7 +246,7 @@ template LogicalResult convertOneTargetOneParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit @@ -238,7 +290,7 @@ template LogicalResult convertOneTargetTwoParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit @@ -283,7 +335,7 @@ LogicalResult convertOneTargetThreeParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit @@ -328,7 +380,7 @@ template LogicalResult convertTwoTargetZeroParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubits @@ -379,7 +431,7 @@ template LogicalResult convertTwoTargetOneParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubits @@ -430,7 +482,7 @@ template LogicalResult convertTwoTargetTwoParameter(QuartzOpType& op, ConversionPatternRewriter& rewriter, LoweringState& state) { - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubits @@ -519,7 +571,7 @@ struct ConvertQuartzAllocOp final LogicalResult matchAndRewrite(quartz::AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; const auto& quartzQubit = op.getResult(); // Create the flux.alloc operation with preserved register metadata @@ -559,7 +611,7 @@ struct ConvertQuartzDeallocOp final LogicalResult matchAndRewrite(quartz::DeallocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; const auto& quartzQubit = op.getQubit(); // Look up the latest Flux value for this Quartz qubit @@ -597,7 +649,7 @@ struct ConvertQuartzStaticOp final LogicalResult matchAndRewrite(quartz::StaticOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; const auto& quartzQubit = op.getQubit(); // Create new flux.static operation with the same index @@ -646,7 +698,7 @@ struct ConvertQuartzMeasureOp final LogicalResult matchAndRewrite(quartz::MeasureOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; const auto& quartzQubit = op.getQubit(); // Get the latest Flux qubit value from the state map @@ -697,7 +749,7 @@ struct ConvertQuartzResetOp final LogicalResult matchAndRewrite(quartz::ResetOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap; + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; const auto& quartzQubit = op.getQubit(); // Get the latest Flux qubit value from the state map @@ -1014,7 +1066,7 @@ struct ConvertQuartzBarrierOp final matchAndRewrite(quartz::BarrierOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; // Get Flux qubits from state map const auto& quartzQubits = op.getQubits(); @@ -1063,7 +1115,7 @@ struct ConvertQuartzCtrlOp final : StatefulOpConversionPattern { matchAndRewrite(quartz::CtrlOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - auto& qubitMap = state.qubitMap; + auto& qubitMap = state.qubitMap[op->getParentRegion()]; // Get Flux controls from state map const auto& quartzControls = op.getControls(); @@ -1142,6 +1194,82 @@ struct ConvertQuartzYieldOp final } }; +struct ConvertScfIfOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto quartzQubits = getState().regionMap[op]; + SmallVector values; + values.reserve(quartzQubits.size()); + for (auto qubit : quartzQubits) { + values.push_back(qubit); + } + // create result typerange + auto const optType = flux::QubitType::get(rewriter.getContext()); + SmallVector resultTypes; + resultTypes.assign(quartzQubits.size(), optType); + + // create new if operation + auto newIf = rewriter.create( + op->getLoc(), TypeRange{resultTypes}, op.getCondition(), true); + auto& thenRegion = newIf.getThenRegion(); + auto& elseRegion = newIf.getElseRegion(); + // move the regions of the old operations inside the new operation + rewriter.inlineRegionBefore(op.getThenRegion(), thenRegion, + thenRegion.end()); + rewriter.eraseBlock(&thenRegion.front()); + + if (!op.getElseRegion().empty()) { + rewriter.inlineRegionBefore(op.getElseRegion(), elseRegion, + elseRegion.end()); + rewriter.eraseBlock(&elseRegion.front()); + } else { + rewriter.setInsertionPointToEnd(&elseRegion.front()); + const auto elseYield = + rewriter.create(op->getLoc(), values); + elseYield->setAttr("needChange", + StringAttr::get(rewriter.getContext(), "yes")); + } + + auto& thenRegionQubitMap = getState().qubitMap[&thenRegion]; + auto& elseRegionQubitMap = getState().qubitMap[&elseRegion]; + for (const auto& refQubit : quartzQubits) { + thenRegionQubitMap.try_emplace( + refQubit, getState().qubitMap[op->getParentRegion()][refQubit]); + elseRegionQubitMap.try_emplace( + refQubit, getState().qubitMap[op->getParentRegion()][refQubit]); + } + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + for (size_t i = 0; i < newIf->getResults().size(); i++) { + qubitMap[quartzQubits[i]] = newIf->getResult(i); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertScfYieldOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + + auto* region = op->getParentRegion(); + auto& qubitMap = getState().qubitMap[region]; + + SmallVector optQubits; + for (auto [refQubit, optQubit] : qubitMap) { + optQubits.push_back(optQubit); + } + rewriter.replaceOpWithNewOp(op, optQubits); + return success(); + } +}; + /** * @brief Pass implementation for Quartz-to-Flux conversion * @@ -1179,11 +1307,18 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { RewritePatternSet patterns(context); QuartzToFluxTypeConverter typeConverter(context); + collectRegionQubits(module, &state, context); // Configure conversion target: Quartz illegal, Flux // legal target.addIllegalDialect(); target.addLegalDialect(); + target.addDynamicallyLegalOp([&](scf::YieldOp op) { + return !(op->getAttrOfType("needChange")); + }); + target.addDynamicallyLegalOp([&](scf::IfOp op) { + return !(op->getAttrOfType("needChange")); + }); // Register operation conversion patterns with state // tracking patterns.add< @@ -1198,32 +1333,8 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { ConvertQuartziSWAPOp, ConvertQuartzDCXOp, ConvertQuartzECROp, ConvertQuartzRXXOp, ConvertQuartzRYYOp, ConvertQuartzRZXOp, ConvertQuartzRZZOp, ConvertQuartzXXPlusYYOp, ConvertQuartzXXMinusYYOp, - ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp>( - typeConverter, context, &state); - - // Conversion of quartz types in func.func signatures - // Note: This currently has limitations with signature - // changes - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - - // Conversion of quartz types in func.return - populateReturnOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](const func::ReturnOp op) { return typeConverter.isLegal(op); }); - - // Conversion of quartz types in func.call - populateCallOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](const func::CallOp op) { return typeConverter.isLegal(op); }); - - // Conversion of quartz types in control-flow ops (e.g., - // cf.br, cf.cond_br) - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); + ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp, + ConvertScfIfOp, ConvertScfYieldOp>(typeConverter, context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { From 4916ea3a1a9c041111fa24b5f9e63ab3fccf2e72 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 11 Dec 2025 15:59:27 +0100 Subject: [PATCH 002/108] add initial support for scf while conversion --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 83 +++++++++++++--- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 96 +++++++++++++++++-- 2 files changed, 159 insertions(+), 20 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 9ed723163e..37b7810e22 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -836,7 +836,6 @@ struct ConvertFluxScfIfOp final : OpConversionPattern { if (!op.getElseRegion().empty()) { rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); - } rewriter.eraseBlock(&newIf.getThenRegion().front()); @@ -848,6 +847,40 @@ struct ConvertFluxScfIfOp final : OpConversionPattern { return success(); } }; +struct ConvertFluxScfWhileOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + + // replace the uses of the blockarguments with the init values + const auto& inits = adaptor.getInits(); + const auto beforeArgs = op.getBeforeArguments(); + const auto afterArgs = op.getAfterArguments(); + for (size_t i = 0; i < beforeArgs.size(); i++) { + beforeArgs[i].replaceAllUsesWith(inits[i]); + afterArgs[i].replaceAllUsesWith(inits[i]); + } + // create the bew while operation + auto newWhileOp = + rewriter.create(op->getLoc(), ValueRange{}, ValueRange{}); + + // create the blocks of the new operation and move the operations to them + auto* newBeforeBlock = + rewriter.createBlock(&newWhileOp.getBefore(), {}, {}, {}); + auto* newAfterBlock = + rewriter.createBlock(&newWhileOp.getAfter(), {}, {}, {}); + newBeforeBlock->getOperations().splice(newBeforeBlock->end(), + op.getBeforeBody()->getOperations()); + newAfterBlock->getOperations().splice(newAfterBlock->end(), + op.getAfterBody()->getOperations()); + + // replace the result values with the init values + rewriter.replaceOp(op, adaptor.getInits()); + return success(); + } +}; struct ConvertFluxScfYieldOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -859,7 +892,18 @@ struct ConvertFluxScfYieldOp final : OpConversionPattern { return success(); } }; +struct ConvertFluxScfConditionOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + + rewriter.replaceOpWithNewOp(op, op.getCondition(), + ValueRange{}); + return success(); + } +}; /** * @brief Pass implementation for Flux-to-Quartz conversion * @@ -912,21 +956,32 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { return type == flux::QubitType::get(context); }); }); + target.addDynamicallyLegalOp([&](scf::WhileOp op) { + return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](scf::ConditionOp op) { + return !llvm::any_of(op.getOperandTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion - patterns - .add(typeConverter, context); + patterns.add< + ConvertFluxAllocOp, ConvertFluxDeallocOp, ConvertFluxStaticOp, + ConvertFluxMeasureOp, ConvertFluxResetOp, ConvertFluxGPhaseOp, + ConvertFluxIdOp, ConvertFluxXOp, ConvertFluxYOp, ConvertFluxZOp, + ConvertFluxHOp, ConvertFluxSOp, ConvertFluxSdgOp, ConvertFluxTOp, + ConvertFluxTdgOp, ConvertFluxSXOp, ConvertFluxSXdgOp, ConvertFluxRXOp, + ConvertFluxRYOp, ConvertFluxRZOp, ConvertFluxPOp, ConvertFluxROp, + ConvertFluxU2Op, ConvertFluxUOp, ConvertFluxSWAPOp, ConvertFluxiSWAPOp, + ConvertFluxDCXOp, ConvertFluxECROp, ConvertFluxRXXOp, ConvertFluxRYYOp, + ConvertFluxRZXOp, ConvertFluxRZZOp, ConvertFluxXXPlusYYOp, + ConvertFluxXXMinusYYOp, ConvertFluxBarrierOp, ConvertFluxCtrlOp, + ConvertFluxYieldOp, ConvertFluxScfIfOp, ConvertFluxScfYieldOp, + ConvertFluxScfWhileOp, ConvertFluxScfConditionOp>(typeConverter, + context); // Conversion of flux types in func.func signatures // Note: This currently has limitations with signature changes diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 865ce9765d..f32e5e9fff 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -144,7 +144,9 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, uniqueQubits.insert(result); } } - if (llvm::isa(operation) && uniqueQubits.size() > 0) { + if ((llvm::isa(operation) || + llvm::isa(operation)) && + uniqueQubits.size() > 0) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); } } @@ -1194,7 +1196,7 @@ struct ConvertQuartzYieldOp final } }; -struct ConvertScfIfOp final : StatefulOpConversionPattern { +struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult @@ -1250,16 +1252,70 @@ struct ConvertScfIfOp final : StatefulOpConversionPattern { return success(); } }; +struct ConvertQuartzScfWhileOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; -struct ConvertScfYieldOp final : StatefulOpConversionPattern { + LogicalResult + matchAndRewrite(scf::WhileOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + auto refQubits = getState().regionMap[op]; + + SmallVector values; + values.reserve(refQubits.size()); + for (auto qubit : refQubits) { + values.push_back(qubit); + } + SmallVector optQubits; + SmallVector types(refQubits.size(), + flux::QubitType::get(rewriter.getContext())); + for (auto [refQubit, optQubit] : qubitMap) { + optQubits.push_back(optQubit); + } + auto newWhileOp = rewriter.create( + op.getLoc(), TypeRange(types), ValueRange(optQubits)); + auto& newBeforeRegion = newWhileOp.getBefore(); + auto& newAfterRegion = newWhileOp.getAfter(); + SmallVector locs(refQubits.size(), op->getLoc()); + auto* newBeforeBlock = + rewriter.createBlock(&newBeforeRegion, {}, types, locs); + auto* newAfterBlock = + rewriter.createBlock(&newAfterRegion, {}, types, locs); + + newBeforeBlock->getOperations().splice(newBeforeBlock->end(), + op.getBeforeBody()->getOperations()); + newAfterBlock->getOperations().splice(newAfterBlock->end(), + op.getAfterBody()->getOperations()); + auto& newBeforeRegionMap = getState().qubitMap[&newWhileOp.getBefore()]; + auto& newAfterRegionMap = getState().qubitMap[&newWhileOp.getAfter()]; + for (size_t i = 0; i < refQubits.size(); i++) { + newBeforeRegionMap.try_emplace(refQubits[i], + newWhileOp.getBeforeArguments()[i]); + } + for (size_t i = 0; i < refQubits.size(); i++) { + newAfterRegionMap.try_emplace(refQubits[i], + newWhileOp.getAfterArguments()[i]); + } + + for (size_t i = 0; i < newWhileOp->getResults().size(); i++) { + qubitMap[refQubits[i]] = newWhileOp->getResult(i); + } + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertQuartzScfYieldOp final + : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto* region = op->getParentRegion(); - auto& qubitMap = getState().qubitMap[region]; + const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; SmallVector optQubits; for (auto [refQubit, optQubit] : qubitMap) { @@ -1269,6 +1325,27 @@ struct ConvertScfYieldOp final : StatefulOpConversionPattern { return success(); } }; +struct ConvertQuartzScfCondtionOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern< + scf::ConditionOp>::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ConditionOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + + const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + + SmallVector optQubits; + for (auto [refQubit, optQubit] : qubitMap) { + optQubits.push_back(optQubit); + } + rewriter.replaceOpWithNewOp(op, op.getCondition(), + optQubits); + + return success(); + } +}; /** * @brief Pass implementation for Quartz-to-Flux conversion @@ -1319,6 +1396,12 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { target.addDynamicallyLegalOp([&](scf::IfOp op) { return !(op->getAttrOfType("needChange")); }); + target.addDynamicallyLegalOp([&](scf::WhileOp op) { + return !(op->getAttrOfType("needChange")); + }); + target.addDynamicallyLegalOp([&](scf::ConditionOp op) { + return !(op->getAttrOfType("needChange")); + }); // Register operation conversion patterns with state // tracking patterns.add< @@ -1334,7 +1417,8 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { ConvertQuartzRXXOp, ConvertQuartzRYYOp, ConvertQuartzRZXOp, ConvertQuartzRZZOp, ConvertQuartzXXPlusYYOp, ConvertQuartzXXMinusYYOp, ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp, - ConvertScfIfOp, ConvertScfYieldOp>(typeConverter, context, &state); + ConvertQuartzScfIfOp, ConvertQuartzScfYieldOp, ConvertQuartzScfWhileOp, + ConvertQuartzScfCondtionOp>(typeConverter, context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { From ff54901e806b9b93b88d20ae8fc19f315a2dce56 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 11 Dec 2025 16:33:14 +0100 Subject: [PATCH 003/108] add initial support for scf forOp conversion --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 43 +++++++++++++-- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 53 ++++++++++++++++++- 2 files changed, 90 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 37b7810e22..5fc5339a36 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -853,6 +853,9 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { LogicalResult matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + // create the bew while operation + auto newWhileOp = + rewriter.create(op->getLoc(), ValueRange{}, ValueRange{}); // replace the uses of the blockarguments with the init values const auto& inits = adaptor.getInits(); @@ -862,9 +865,6 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { beforeArgs[i].replaceAllUsesWith(inits[i]); afterArgs[i].replaceAllUsesWith(inits[i]); } - // create the bew while operation - auto newWhileOp = - rewriter.create(op->getLoc(), ValueRange{}, ValueRange{}); // create the blocks of the new operation and move the operations to them auto* newBeforeBlock = @@ -881,6 +881,34 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { return success(); } }; +struct ConvertFluxScfForOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + // Create a new for-loop with no iter_args + auto newFor = rewriter.create( + op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), ValueRange{}); + + for (const auto& [fluxQubit, quartzQubit] : + llvm::zip_equal(op.getRegionIterArgs(), adaptor.getInitArgs())) { + fluxQubit.replaceAllUsesWith(quartzQubit); + } + + // move all the operations from the old block to the new block + auto* newBlock = newFor.getBody(); + // erase the existing yield operation + rewriter.eraseOp(newBlock->getTerminator()); + newBlock->getOperations().splice(newBlock->end(), + op.getBody()->getOperations()); + + rewriter.replaceOp(op, adaptor.getInitArgs()); + + return success(); + } +}; struct ConvertFluxScfYieldOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -966,6 +994,11 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { return type == flux::QubitType::get(context); }); }); + target.addDynamicallyLegalOp([&](scf::ForOp op) { + return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns.add< @@ -980,8 +1013,8 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { ConvertFluxRZXOp, ConvertFluxRZZOp, ConvertFluxXXPlusYYOp, ConvertFluxXXMinusYYOp, ConvertFluxBarrierOp, ConvertFluxCtrlOp, ConvertFluxYieldOp, ConvertFluxScfIfOp, ConvertFluxScfYieldOp, - ConvertFluxScfWhileOp, ConvertFluxScfConditionOp>(typeConverter, - context); + ConvertFluxScfWhileOp, ConvertFluxScfConditionOp, ConvertFluxScfForOp>( + typeConverter, context); // Conversion of flux types in func.func signatures // Note: This currently has limitations with signature changes diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index f32e5e9fff..309bd1a7ac 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -1307,6 +1307,53 @@ struct ConvertQuartzScfWhileOp final } }; +struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + + auto refQubits = getState().regionMap[op]; + SmallVector values; + values.reserve(refQubits.size()); + for (auto qubit : refQubits) { + values.push_back(qubit); + } + + SmallVector optQubits; + for (auto [quartQubit, fluxQubit] : qubitMap) { + optQubits.push_back(fluxQubit); + } + // Create a new for-loop with flux qubits as iter_args + auto newFor = rewriter.create( + op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + adaptor.getStep(), ValueRange(optQubits)); + auto& srcBlock = op.getRegion().front(); + auto& dstBlock = newFor.getRegion().front(); + dstBlock.getOperations().splice(dstBlock.end(), srcBlock.getOperations()); + + auto& newRegion = newFor.getRegion(); + + auto& regionQubitMap = getState().qubitMap[&newRegion]; + + + for (const auto& [refQubit, optQubit] : + llvm::zip_equal(refQubits, newFor.getRegionIterArgs())) { + regionQubitMap.try_emplace(refQubit, optQubit); + } + + auto& map = getState().qubitMap[op->getParentRegion()]; + for (size_t i = 0; i < newFor->getResults().size(); i++) { + map[refQubits[i]] = newFor->getResult(i); + } + rewriter.eraseOp(op); + return success(); + } +}; + struct ConvertQuartzScfYieldOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1402,6 +1449,9 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { target.addDynamicallyLegalOp([&](scf::ConditionOp op) { return !(op->getAttrOfType("needChange")); }); + target.addDynamicallyLegalOp([&](scf::ForOp op) { + return !(op->getAttrOfType("needChange")); + }); // Register operation conversion patterns with state // tracking patterns.add< @@ -1418,7 +1468,8 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { ConvertQuartzRZZOp, ConvertQuartzXXPlusYYOp, ConvertQuartzXXMinusYYOp, ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp, ConvertQuartzScfIfOp, ConvertQuartzScfYieldOp, ConvertQuartzScfWhileOp, - ConvertQuartzScfCondtionOp>(typeConverter, context, &state); + ConvertQuartzScfCondtionOp, ConvertQuartzScfForOp>(typeConverter, + context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { From e968570efa435101e4738b96d5b31549124d7366 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 11:14:08 +0100 Subject: [PATCH 004/108] add docstrings to QuartzToFlux conversion --- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 271 +++++++++++++----- 1 file changed, 197 insertions(+), 74 deletions(-) diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 309bd1a7ac..2b9bbe4cd9 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -70,8 +70,9 @@ namespace { */ struct LoweringState { /// Map from original Quartz qubit references to their latest Flux SSA values + /// for each region llvm::DenseMap> qubitMap; - /// Map each initial op to its refQubits. + /// Map each operation to its Set of Quartz qubit references llvm::DenseMap> regionMap; /// Modifier information @@ -110,9 +111,17 @@ class StatefulOpConversionPattern : public OpConversionPattern { LoweringState* state_; }; +/** + * @brief Recursively collects all the Quartz qubit references used by an + * operation and store them in map + * + * @param Operation The operation that is currently traversed + * @param state The lowering state + * @param ctx The MLIRContext of the current program + * @return llvm::Setvector The set of unique Quartz qubit references + */ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { - // get the regions of the current operation const auto regions = op->getRegions(); SetVector uniqueQubits; @@ -127,23 +136,25 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, // check if the operation has an region, if yes recursively collect the // qubits if (operation.getNumRegions() > 0) { - auto qubits = collectRegionQubits(&operation, state, ctx); - for (auto qubit : qubits) { + const auto& qubits = collectRegionQubits(&operation, state, ctx); + for (const auto& qubit : qubits) { uniqueQubits.insert(qubit); } } // collect qubits form the operands - for (auto operand : operation.getOperands()) { + for (const auto& operand : operation.getOperands()) { if (operand.getType() == quartz::QubitType::get(ctx)) { uniqueQubits.insert(operand); } } // collect qubits from the results - for (auto result : operation.getResults()) { + for (const auto& result : operation.getResults()) { if (result.getType() == quartz::QubitType::get(ctx)) { uniqueQubits.insert(result); } } + // mark scf terminator operations if they need to return a value after the + // conversion if ((llvm::isa(operation) || llvm::isa(operation)) && uniqueQubits.size() > 0) { @@ -155,7 +166,7 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, (llvm::isa(op) || (llvm::isa(op)) || llvm::isa(op))) { state->regionMap[op] = uniqueQubits; - // mark operations that need to be changed afterwards + // mark scf operations that need to be changed afterwards op->setAttr("needChange", StringAttr::get(ctx, "yes")); } return uniqueQubits; @@ -1196,31 +1207,53 @@ struct ConvertQuartzYieldOp final } }; +/** + * @brief Converts scf.if with memory semantics to scf.if with value semantics + * + * @par Example: + * ```mlir + * scf.if %cond { + * quartz.x %q0 + * scf.yield + * } + * ``` + * is converted to + * ```mlir + * %targets_out = scf.if %cond -> (!flux.qubit) { + * %q1 = flux.h %q0 : !flux.qubit -> !flux.qubit + * scf.yield %q1 : !flux.qubit + * } else { + * scf.yield %q0 : !flux.qubit + * } + * ``` + */ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto quartzQubits = getState().regionMap[op]; - SmallVector values; - values.reserve(quartzQubits.size()); + const auto& quartzQubits = getState().regionMap[op]; + SmallVector quartzValues; + quartzValues.reserve(quartzQubits.size()); for (auto qubit : quartzQubits) { - values.push_back(qubit); + quartzValues.push_back(qubit); } // create result typerange - auto const optType = flux::QubitType::get(rewriter.getContext()); SmallVector resultTypes; - resultTypes.assign(quartzQubits.size(), optType); + resultTypes.assign(quartzQubits.size(), + flux::QubitType::get(rewriter.getContext())); // create new if operation - auto newIf = rewriter.create( + auto newIfOp = rewriter.create( op->getLoc(), TypeRange{resultTypes}, op.getCondition(), true); - auto& thenRegion = newIf.getThenRegion(); - auto& elseRegion = newIf.getElseRegion(); + auto& thenRegion = newIfOp.getThenRegion(); + auto& elseRegion = newIfOp.getElseRegion(); + // move the regions of the old operations inside the new operation rewriter.inlineRegionBefore(op.getThenRegion(), thenRegion, thenRegion.end()); + // eliminate the empty block that was created during the initialization rewriter.eraseBlock(&thenRegion.front()); if (!op.getElseRegion().empty()) { @@ -1228,30 +1261,63 @@ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { elseRegion.end()); rewriter.eraseBlock(&elseRegion.front()); } else { + // create the yield operation if it does not exist yet rewriter.setInsertionPointToEnd(&elseRegion.front()); const auto elseYield = - rewriter.create(op->getLoc(), values); + rewriter.create(op->getLoc(), quartzValues); + // mark the yield operation for conversion elseYield->setAttr("needChange", StringAttr::get(rewriter.getContext(), "yes")); } + // create the qubit map for the regions auto& thenRegionQubitMap = getState().qubitMap[&thenRegion]; auto& elseRegionQubitMap = getState().qubitMap[&elseRegion]; - for (const auto& refQubit : quartzQubits) { + for (const auto& quartzQubit : quartzQubits) { thenRegionQubitMap.try_emplace( - refQubit, getState().qubitMap[op->getParentRegion()][refQubit]); + quartzQubit, getState().qubitMap[op->getParentRegion()][quartzQubit]); elseRegionQubitMap.try_emplace( - refQubit, getState().qubitMap[op->getParentRegion()][refQubit]); + quartzQubit, getState().qubitMap[op->getParentRegion()][quartzQubit]); } + + // update the qubit map in the current region auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - for (size_t i = 0; i < newIf->getResults().size(); i++) { - qubitMap[quartzQubits[i]] = newIf->getResult(i); + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, newIfOp->getResults())) { + qubitMap[quartzQubit] = fluxQubit; } rewriter.eraseOp(op); return success(); } }; + +/** + * @brief Converts scf.while with memory semantics to scf.while with value + * semantics + * + * @par Example: + * ```mlir + * scf.while : () -> () { + * quartz.x %q0 + * scf.condition(%cond) + * } do { + * quartz.x %q0 + * scf.yield + * } + * ``` + * is converted to + * ```mlir + * %targets_out = scf.while (%arg0 = %q0) : (!flux.qubit) -> !flux.qubit { + * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit + * scf.condition(%cond) %q1 : !flux.qubit + * } do { + * ^bb0(%arg0: !flux.qubit): + * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit + * scf.yield %q1 : !flux.qubit + * } + * ``` + */ struct ConvertQuartzScfWhileOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1259,101 +1325,145 @@ struct ConvertQuartzScfWhileOp final LogicalResult matchAndRewrite(scf::WhileOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto refQubits = getState().regionMap[op]; + auto quartzQubits = getState().regionMap[op]; - SmallVector values; - values.reserve(refQubits.size()); - for (auto qubit : refQubits) { - values.push_back(qubit); + SmallVector quartzValues; + quartzValues.reserve(quartzQubits.size()); + for (auto qubit : quartzQubits) { + quartzValues.push_back(qubit); } - SmallVector optQubits; - SmallVector types(refQubits.size(), - flux::QubitType::get(rewriter.getContext())); - for (auto [refQubit, optQubit] : qubitMap) { - optQubits.push_back(optQubit); + SmallVector fluxQubits; + fluxQubits.reserve(quartzQubits.size()); + for (const auto& [quartzQubit, fluxQubit] : qubitMap) { + fluxQubits.push_back(fluxQubit); } + // create the result typerange + SmallVector fluxTypes(quartzQubits.size(), + flux::QubitType::get(rewriter.getContext())); + + // create the new while operation auto newWhileOp = rewriter.create( - op.getLoc(), TypeRange(types), ValueRange(optQubits)); + op.getLoc(), TypeRange(fluxTypes), ValueRange(fluxQubits)); auto& newBeforeRegion = newWhileOp.getBefore(); auto& newAfterRegion = newWhileOp.getAfter(); - SmallVector locs(refQubits.size(), op->getLoc()); + SmallVector locs(quartzQubits.size(), op->getLoc()); + // create the new blocks auto* newBeforeBlock = - rewriter.createBlock(&newBeforeRegion, {}, types, locs); + rewriter.createBlock(&newBeforeRegion, {}, fluxTypes, locs); auto* newAfterBlock = - rewriter.createBlock(&newAfterRegion, {}, types, locs); + rewriter.createBlock(&newAfterRegion, {}, fluxTypes, locs); + // move the operations to the new blocks newBeforeBlock->getOperations().splice(newBeforeBlock->end(), op.getBeforeBody()->getOperations()); newAfterBlock->getOperations().splice(newAfterBlock->end(), op.getAfterBody()->getOperations()); + + // create the qubit map for the new regions auto& newBeforeRegionMap = getState().qubitMap[&newWhileOp.getBefore()]; auto& newAfterRegionMap = getState().qubitMap[&newWhileOp.getAfter()]; - for (size_t i = 0; i < refQubits.size(); i++) { - newBeforeRegionMap.try_emplace(refQubits[i], - newWhileOp.getBeforeArguments()[i]); + + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, newWhileOp.getBeforeArguments())) { + newBeforeRegionMap.try_emplace(quartzQubit, fluxQubit); } - for (size_t i = 0; i < refQubits.size(); i++) { - newAfterRegionMap.try_emplace(refQubits[i], - newWhileOp.getAfterArguments()[i]); + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, newWhileOp.getAfterArguments())) { + newAfterRegionMap.try_emplace(quartzQubit, fluxQubit); } - for (size_t i = 0; i < newWhileOp->getResults().size(); i++) { - qubitMap[refQubits[i]] = newWhileOp->getResult(i); + // update the qubit map in the current region + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, newWhileOp->getResults())) { + qubitMap[quartzQubit] = fluxQubit; } + rewriter.eraseOp(op); return success(); } }; +/** + * @brief Converts scf.for with memory semantics to scf.while with value + * semantics + * + * @par Example: + * ```mlir + * scf.for %iv = %lb to %ub step %step { + * quartz.x %q0 + * scf.yield + * } + * ``` + * is converted to + * ```mlir + * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> + * (!flux.qubit) { %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit scf.yield + * %q1 : !flux.qubit + * } + * ``` + */ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto refQubits = getState().regionMap[op]; + auto quartzQubits = getState().regionMap[op]; SmallVector values; - values.reserve(refQubits.size()); - for (auto qubit : refQubits) { + values.reserve(quartzQubits.size()); + for (auto qubit : quartzQubits) { values.push_back(qubit); } - SmallVector optQubits; + SmallVector fluxQubits; for (auto [quartQubit, fluxQubit] : qubitMap) { - optQubits.push_back(fluxQubit); + fluxQubits.push_back(fluxQubit); } // Create a new for-loop with flux qubits as iter_args auto newFor = rewriter.create( op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), - adaptor.getStep(), ValueRange(optQubits)); + adaptor.getStep(), ValueRange(fluxQubits)); + + // move the operations to the new block auto& srcBlock = op.getRegion().front(); auto& dstBlock = newFor.getRegion().front(); dstBlock.getOperations().splice(dstBlock.end(), srcBlock.getOperations()); auto& newRegion = newFor.getRegion(); - auto& regionQubitMap = getState().qubitMap[&newRegion]; - - for (const auto& [refQubit, optQubit] : - llvm::zip_equal(refQubits, newFor.getRegionIterArgs())) { - regionQubitMap.try_emplace(refQubit, optQubit); + // create the qubitmap for the new region + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, newFor.getRegionIterArgs())) { + regionQubitMap.try_emplace(quartzQubit, fluxQubit); } - - auto& map = getState().qubitMap[op->getParentRegion()]; - for (size_t i = 0; i < newFor->getResults().size(); i++) { - map[refQubits[i]] = newFor->getResult(i); + // update the qubitmap in the current region + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, newFor->getResults())) { + qubitMap[quartzQubit] = fluxQubit; } + rewriter.eraseOp(op); return success(); } }; +/** + * @brief Converts scf.yield with memory semantics to scf.yield with value + * semantics + * + * @par Example: + * ```mlir + * scf.yield + * ``` + * is converted to + * ```mlir + * scf.yield %targets + * ``` + */ struct ConvertQuartzScfYieldOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1361,17 +1471,31 @@ struct ConvertQuartzScfYieldOp final LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - - SmallVector optQubits; - for (auto [refQubit, optQubit] : qubitMap) { - optQubits.push_back(optQubit); + SmallVector fluxQubits; + fluxQubits.reserve(qubitMap.size()); + for (auto [quartzQubit, fluxQubit] : qubitMap) { + fluxQubits.push_back(fluxQubit); } - rewriter.replaceOpWithNewOp(op, optQubits); + + rewriter.replaceOpWithNewOp(op, fluxQubits); return success(); } }; + +/** + * @brief Converts scf.condition with memory semantics to scf.condition with + * value semantics + * + * @par Example: + * ```mlir + * scf.condition(%cond) + * ``` + * is converted to + * ```mlir + * scf.condition(%cond) %targets + * ``` + */ struct ConvertQuartzScfCondtionOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern< @@ -1380,16 +1504,15 @@ struct ConvertQuartzScfCondtionOp final LogicalResult matchAndRewrite(scf::ConditionOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - - SmallVector optQubits; - for (auto [refQubit, optQubit] : qubitMap) { - optQubits.push_back(optQubit); + SmallVector fluxQubits; + fluxQubits.reserve(qubitMap.size()); + for (auto [quartzQubit, fluxQubit] : qubitMap) { + fluxQubits.push_back(fluxQubit); } - rewriter.replaceOpWithNewOp(op, op.getCondition(), - optQubits); + rewriter.replaceOpWithNewOp(op, op.getCondition(), + fluxQubits); return success(); } }; From 27909af62dc4129d5f30d4ead1512b9c8913ae71 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 11:36:01 +0100 Subject: [PATCH 005/108] add docstrings to FluxToQuartz conversion --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 114 ++++++++++++++++-- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 39 ++---- 2 files changed, 119 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 5fc5339a36..a8c0b535fc 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -821,12 +821,34 @@ struct ConvertFluxYieldOp final : OpConversionPattern { return success(); } }; + +/** + * @brief Converts scf.if with value semantics to scf.if with memory semantics + * + * @par Example: + * ```mlir + * %targets_out = scf.if %cond -> (!flux.qubit) { + * %q1 = flux.h %q0 : !flux.qubit -> !flux.qubit + * scf.yield %q1 : !flux.qubit + * } else { + * scf.yield %q0 : !flux.qubit + * } + * ``` + * is converted to + * ```mlir + * scf.if %cond { + * quartz.x %q0 + * scf.yield + * } + * ``` + */ struct ConvertFluxScfIfOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + // create the new if operation auto newIf = rewriter.create(op.getLoc(), ValueRange{}, op.getCondition(), op.getElseRegion().empty()); @@ -837,23 +859,50 @@ struct ConvertFluxScfIfOp final : OpConversionPattern { rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); } + // erase the empty block that was created during the initialization rewriter.eraseBlock(&newIf.getThenRegion().front()); - auto yield = - dyn_cast(newIf.getThenRegion().back().getTerminator()); + const auto& yield = + dyn_cast(newIf.getThenRegion().front().getTerminator()); rewriter.replaceOp(op, yield->getOperands()); - return success(); } }; + +/** + * @brief Converts scf.while with value semantics to scf.while with memory + * semantics + * + * @par Example: + * ```mlir + * %targets_out = scf.while (%arg0 = %q0) : (!flux.qubit) -> !flux.qubit { + * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit + * scf.condition(%cond) %q1 : !flux.qubit + * } do { + * ^bb0(%arg0: !flux.qubit): + * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit + * scf.yield %q1 : !flux.qubit + * } + * ``` + * is converted to + * ```mlir + * scf.while : () -> () { + * quartz.x %q0 + * scf.condition(%cond) + * } do { + * quartz.x %q0 + * scf.yield + * } + * ``` + */ struct ConvertFluxScfWhileOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // create the bew while operation + // create the new while operation auto newWhileOp = rewriter.create(op->getLoc(), ValueRange{}, ValueRange{}); @@ -881,6 +930,27 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { return success(); } }; + +/** + * @brief Converts scf.for with value semantics to scf.while with memory + * semantics + * + * @par Example: + * ```mlir + * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> + * (!flux.qubit) { + * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit + * scf.yield %q1 : !flux.qubit + * } + * ``` + * is converted to + * ```mlir + * scf.for %iv = %lb to %ub step %step { + * quartz.x %q0 + * scf.yield + * } + * ``` + */ struct ConvertFluxScfForOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -892,6 +962,7 @@ struct ConvertFluxScfForOp final : OpConversionPattern { op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), ValueRange{}); + // replace the uses of the previous iter_args for (const auto& [fluxQubit, quartzQubit] : llvm::zip_equal(op.getRegionIterArgs(), adaptor.getInitArgs())) { fluxQubit.replaceAllUsesWith(quartzQubit); @@ -903,13 +974,26 @@ struct ConvertFluxScfForOp final : OpConversionPattern { rewriter.eraseOp(newBlock->getTerminator()); newBlock->getOperations().splice(newBlock->end(), op.getBody()->getOperations()); - - rewriter.replaceOp(op, adaptor.getInitArgs()); + // replace the result values with the init values + rewriter.replaceOp(op, adaptor.getInitArgs()); return success(); } }; +/** + * @brief Converts scf.yield with value semantics to scf.yield with memory + * semantics + * + * @par Example: + * ```mlir + * scf.yield %targets + * ``` + * is converted to + * ```mlir + * scf.yield + * ``` + */ struct ConvertFluxScfYieldOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -920,13 +1004,27 @@ struct ConvertFluxScfYieldOp final : OpConversionPattern { return success(); } }; + +/** + * @brief Converts scf.condition with value semantics to scf.condition with + * memory semantics + * + * @par Example: + * ```mlir + * scf.condition(%cond) %targets + * ``` + * is converted to + * ```mlir + * scf.condition(%cond) + + * ``` + */ struct ConvertFluxScfConditionOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(scf::ConditionOp op, OpAdaptor adaptor, + matchAndRewrite(scf::ConditionOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOpWithNewOp(op, op.getCondition(), ValueRange{}); return success(); diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 2b9bbe4cd9..1f2da0ca11 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -1234,19 +1234,16 @@ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { const auto& quartzQubits = getState().regionMap[op]; - SmallVector quartzValues; - quartzValues.reserve(quartzQubits.size()); - for (auto qubit : quartzQubits) { - quartzValues.push_back(qubit); - } + const SmallVector quartzValues(quartzQubits.begin(), + quartzQubits.end()); + // create result typerange - SmallVector resultTypes; - resultTypes.assign(quartzQubits.size(), - flux::QubitType::get(rewriter.getContext())); + const SmallVector fluxTypes( + quartzQubits.size(), flux::QubitType::get(rewriter.getContext())); // create new if operation auto newIfOp = rewriter.create( - op->getLoc(), TypeRange{resultTypes}, op.getCondition(), true); + op->getLoc(), TypeRange{fluxTypes}, op.getCondition(), true); auto& thenRegion = newIfOp.getThenRegion(); auto& elseRegion = newIfOp.getElseRegion(); @@ -1326,21 +1323,16 @@ struct ConvertQuartzScfWhileOp final matchAndRewrite(scf::WhileOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto quartzQubits = getState().regionMap[op]; + auto& quartzQubits = getState().regionMap[op]; - SmallVector quartzValues; - quartzValues.reserve(quartzQubits.size()); - for (auto qubit : quartzQubits) { - quartzValues.push_back(qubit); - } SmallVector fluxQubits; fluxQubits.reserve(quartzQubits.size()); for (const auto& [quartzQubit, fluxQubit] : qubitMap) { fluxQubits.push_back(fluxQubit); } // create the result typerange - SmallVector fluxTypes(quartzQubits.size(), - flux::QubitType::get(rewriter.getContext())); + const SmallVector fluxTypes( + quartzQubits.size(), flux::QubitType::get(rewriter.getContext())); // create the new while operation auto newWhileOp = rewriter.create( @@ -1398,8 +1390,9 @@ struct ConvertQuartzScfWhileOp final * is converted to * ```mlir * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> - * (!flux.qubit) { %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit scf.yield - * %q1 : !flux.qubit + * (!flux.qubit) { + * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit + * scf.yield %q1 : !flux.qubit * } * ``` */ @@ -1410,13 +1403,7 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - - auto quartzQubits = getState().regionMap[op]; - SmallVector values; - values.reserve(quartzQubits.size()); - for (auto qubit : quartzQubits) { - values.push_back(qubit); - } + auto& quartzQubits = getState().regionMap[op]; SmallVector fluxQubits; for (auto [quartQubit, fluxQubit] : qubitMap) { From b2ab37cd85d99591b9acb4bbe3104f4c868472fc Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 11:46:18 +0100 Subject: [PATCH 006/108] fix typo --- mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 1f2da0ca11..db6104bde5 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -1483,7 +1483,7 @@ struct ConvertQuartzScfYieldOp final * scf.condition(%cond) %targets * ``` */ -struct ConvertQuartzScfCondtionOp final +struct ConvertQuartzScfConditionOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern< scf::ConditionOp>::StatefulOpConversionPattern; @@ -1578,8 +1578,8 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { ConvertQuartzRZZOp, ConvertQuartzXXPlusYYOp, ConvertQuartzXXMinusYYOp, ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp, ConvertQuartzScfIfOp, ConvertQuartzScfYieldOp, ConvertQuartzScfWhileOp, - ConvertQuartzScfCondtionOp, ConvertQuartzScfForOp>(typeConverter, - context, &state); + ConvertQuartzScfConditionOp, ConvertQuartzScfForOp>(typeConverter, + context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { From 0fc6bc40699b2b006409a72127a407c962510958 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 13:04:36 +0100 Subject: [PATCH 007/108] add support for multiple functions --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 71 +++++++++++- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 108 +++++++++++++++++- 2 files changed, 176 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index a8c0b535fc..28a14b70c0 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -1030,6 +1030,59 @@ struct ConvertFluxScfConditionOp final : OpConversionPattern { return success(); } }; + +struct ConvertFluxFuncCallOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + rewriter.create(op->getLoc(), adaptor.getCallee(), + TypeRange{}, adaptor.getOperands()); + rewriter.replaceOp(op, adaptor.getOperands()); + return success(); + } +}; + +struct ConvertFluxFuncFuncOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::FuncOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + const SmallVector argumentTypes( + op.front().getNumArguments(), + quartz::QubitType::get(rewriter.getContext())); + + auto newFuncType = rewriter.getFunctionType(argumentTypes, {}); + op.setFunctionType(newFuncType); + return success(); + } +}; + +/** + * @brief Converts func.return for fluxQubits to a trivial func.return + * + * @par Example: + * ```mlir + * scf.condition(%cond) %targets + * ``` + * is converted to + * ```mlir + * scf.condition(%cond) + + * ``` + */ +struct ConvertFluxFuncReturnOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op); + return success(); + } +}; /** * @brief Pass implementation for Flux-to-Quartz conversion * @@ -1097,6 +1150,21 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { return type == flux::QubitType::get(context); }); }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return !llvm::any_of(op->getOperandTypes(), [&](Type type) { + return type == flux::QubitType::get(context); + }); + }); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns.add< @@ -1111,7 +1179,8 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { ConvertFluxRZXOp, ConvertFluxRZZOp, ConvertFluxXXPlusYYOp, ConvertFluxXXMinusYYOp, ConvertFluxBarrierOp, ConvertFluxCtrlOp, ConvertFluxYieldOp, ConvertFluxScfIfOp, ConvertFluxScfYieldOp, - ConvertFluxScfWhileOp, ConvertFluxScfConditionOp, ConvertFluxScfForOp>( + ConvertFluxScfWhileOp, ConvertFluxScfConditionOp, ConvertFluxScfForOp, + ConvertFluxFuncCallOp, ConvertFluxFuncFuncOp, ConvertFluxFuncReturnOp>( typeConverter, context); // Conversion of flux types in func.func signatures diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index db6104bde5..65637d8701 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Flux/IR/FluxDialect.h" #include "mlir/Dialect/Quartz/IR/QuartzDialect.h" +#include "mlir/IR/Block.h" #include #include @@ -160,6 +161,16 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, uniqueQubits.size() > 0) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); } + // mark func.return operation for functions that need to return a qubit + // value + if (llvm::isa(operation)) { + if (auto func = operation.getParentOfType()) { + if (!func.getArgumentTypes().empty() && + func.getArgumentTypes().front() == quartz::QubitType::get(ctx)) { + operation.setAttr("needChange", StringAttr::get(ctx, "yes")); + } + } + } } } if (!uniqueQubits.empty() && @@ -221,9 +232,12 @@ LogicalResult convertOneTargetZeroParameter(QuartzOpType& op, // Get the latest Flux qubit const auto& quartzQubit = op.getQubitIn(); + Value fluxQubit; if (inCtrlOp == 0) { + fluxQubit = qubitMap[quartzQubit]; + } else { fluxQubit = state.targetsIn[inCtrlOp].front(); } @@ -1406,6 +1420,7 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { auto& quartzQubits = getState().regionMap[op]; SmallVector fluxQubits; + fluxQubits.reserve(qubitMap.size()); for (auto [quartQubit, fluxQubit] : qubitMap) { fluxQubits.push_back(fluxQubit); } @@ -1503,7 +1518,82 @@ struct ConvertQuartzScfConditionOp final return success(); } }; +struct ConvertQuartzFuncCallOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + LogicalResult + matchAndRewrite(func::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + auto& quartzQubits = getState().regionMap[op]; + + SmallVector fluxQubits; + fluxQubits.reserve(qubitMap.size()); + for (auto [quartQubit, fluxQubit] : qubitMap) { + fluxQubits.push_back(fluxQubit); + } + // create the result typerange + const SmallVector fluxTypes( + quartzQubits.size(), flux::QubitType::get(rewriter.getContext())); + + const auto callOp = rewriter.create( + op->getLoc(), adaptor.getCallee(), fluxTypes, fluxQubits); + + for (const auto& [quartzQubit, fluxQubit] : + llvm::zip_equal(quartzQubits, callOp->getResults())) { + qubitMap[quartzQubit] = fluxQubit; + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertQuartzFuncFuncOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(func::FuncOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap[&op->getRegion(0)]; + const SmallVector fluxTypes( + op.front().getNumArguments(), + flux::QubitType::get(rewriter.getContext())); + + // set the arguments to flux qubit type + for (auto blockArg : op.front().getArguments()) { + blockArg.setType(flux::QubitType::get(rewriter.getContext())); + qubitMap.try_emplace(blockArg, blockArg); + } + + // change the function signature to return the same number of flux Qubits as + // it gets as input + auto newFuncType = rewriter.getFunctionType(fluxTypes, fluxTypes); // + op.setFunctionType(newFuncType); + return success(); + } +}; + +struct ConvertQuartzFuncReturnOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern< + func::ReturnOp>::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + SmallVector fluxQubits; + fluxQubits.reserve(qubitMap.size()); + for (auto [quartzQubit, fluxQubit] : qubitMap) { + fluxQubits.push_back(fluxQubit); + } + rewriter.replaceOpWithNewOp(op, fluxQubits); + return success(); + } +}; /** * @brief Pass implementation for Quartz-to-Flux conversion * @@ -1562,6 +1652,19 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { target.addDynamicallyLegalOp([&](scf::ForOp op) { return !(op->getAttrOfType("needChange")); }); + target.addDynamicallyLegalOp([&](func::FuncOp op) { + return !llvm::any_of(op.front().getArgumentTypes(), [&](Type type) { + return type == quartz::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](func::CallOp op) { + return !llvm::any_of(op->getOperandTypes(), [&](Type type) { + return type == quartz::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](func::ReturnOp op) { + return !op->getAttrOfType("needChange"); + }); // Register operation conversion patterns with state // tracking patterns.add< @@ -1578,8 +1681,9 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { ConvertQuartzRZZOp, ConvertQuartzXXPlusYYOp, ConvertQuartzXXMinusYYOp, ConvertQuartzBarrierOp, ConvertQuartzCtrlOp, ConvertQuartzYieldOp, ConvertQuartzScfIfOp, ConvertQuartzScfYieldOp, ConvertQuartzScfWhileOp, - ConvertQuartzScfConditionOp, ConvertQuartzScfForOp>(typeConverter, - context, &state); + ConvertQuartzScfConditionOp, ConvertQuartzScfForOp, + ConvertQuartzFuncCallOp, ConvertQuartzFuncFuncOp, + ConvertQuartzFuncReturnOp>(typeConverter, context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { From 7c323c622b85e4fb2dcd40905674eb9006253e80 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 14:29:34 +0100 Subject: [PATCH 008/108] add more docstrings --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 82 +++++++++++-------- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 58 +++++++++++-- 2 files changed, 100 insertions(+), 40 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 28a14b70c0..f47d95ba70 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -824,6 +824,7 @@ struct ConvertFluxYieldOp final : OpConversionPattern { /** * @brief Converts scf.if with value semantics to scf.if with memory semantics + * for qubit values * * @par Example: * ```mlir @@ -872,7 +873,7 @@ struct ConvertFluxScfIfOp final : OpConversionPattern { /** * @brief Converts scf.while with value semantics to scf.while with memory - * semantics + * semantics for qubit values * * @par Example: * ```mlir @@ -933,7 +934,7 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { /** * @brief Converts scf.for with value semantics to scf.while with memory - * semantics + * semantics for qubit values * * @par Example: * ```mlir @@ -983,7 +984,7 @@ struct ConvertFluxScfForOp final : OpConversionPattern { /** * @brief Converts scf.yield with value semantics to scf.yield with memory - * semantics + * semantics for qubit values * * @par Example: * ```mlir @@ -1007,7 +1008,7 @@ struct ConvertFluxScfYieldOp final : OpConversionPattern { /** * @brief Converts scf.condition with value semantics to scf.condition with - * memory semantics + * memory semantics for qubit values * * @par Example: * ```mlir @@ -1031,6 +1032,20 @@ struct ConvertFluxScfConditionOp final : OpConversionPattern { } }; +/** + * @brief Converts func.call with value semantics to func.call with + * memory semantics for qubit values + * + * @par Example: + * ```mlir + * %q1 = call @test(%q1) : (!flux.qubit) -> !flux.qubit + * } + * ``` + * is converted to + * ```mlir + * call @test(%q0) : (!quartz.qubit) -> () + * ``` + */ struct ConvertFluxFuncCallOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1044,6 +1059,23 @@ struct ConvertFluxFuncCallOp final : OpConversionPattern { } }; +/** + * @brief Converts func.func with memory semantics to func.func with + * value semantics for qubit values + * + * @par Example: + * ```mlir + * func.func @test(%arg0: !flux.qubit) -> !flux.qubit { + * ... + * } + * ``` + * is converted to + * ```mlir + * func.func @test(%arg0: !quartz.qubit){ + * ... + * } + * ``` + */ struct ConvertFluxFuncFuncOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -1053,7 +1085,9 @@ struct ConvertFluxFuncFuncOp final : OpConversionPattern { const SmallVector argumentTypes( op.front().getNumArguments(), quartz::QubitType::get(rewriter.getContext())); - + for (auto blockArg : op.front().getArguments()) { + blockArg.setType(quartz::QubitType::get(rewriter.getContext())); + } auto newFuncType = rewriter.getFunctionType(argumentTypes, {}); op.setFunctionType(newFuncType); return success(); @@ -1061,16 +1095,16 @@ struct ConvertFluxFuncFuncOp final : OpConversionPattern { }; /** - * @brief Converts func.return for fluxQubits to a trivial func.return + * @brief Converts func.return with value semantics to func.return with + * memory semantics for qubit values * * @par Example: * ```mlir - * scf.condition(%cond) %targets + * func.return %targets : !flux.qubit * ``` * is converted to * ```mlir - * scf.condition(%cond) - + * func.return * ``` */ struct ConvertFluxFuncReturnOp final : OpConversionPattern { @@ -1079,7 +1113,8 @@ struct ConvertFluxFuncReturnOp final : OpConversionPattern { LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOpWithNewOp(op); + rewriter.create(op->getLoc()); + rewriter.eraseOp(op); return success(); } }; @@ -1156,13 +1191,14 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { }); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { - return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return !llvm::any_of(op.getArgumentTypes(), [&](Type type) { return type == flux::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == flux::QubitType::get(context); + return type == quartz::QubitType::get(context) || + type == flux::QubitType::get(context); }); }); // Register operation conversion patterns @@ -1183,28 +1219,6 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { ConvertFluxFuncCallOp, ConvertFluxFuncFuncOp, ConvertFluxFuncReturnOp>( typeConverter, context); - // Conversion of flux types in func.func signatures - // Note: This currently has limitations with signature changes - populateFunctionOpInterfaceTypeConversionPattern( - patterns, typeConverter); - target.addDynamicallyLegalOp([&](func::FuncOp op) { - return typeConverter.isSignatureLegal(op.getFunctionType()) && - typeConverter.isLegal(&op.getBody()); - }); - - // Conversion of flux types in func.return - populateReturnOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](const func::ReturnOp op) { return typeConverter.isLegal(op); }); - - // Conversion of flux types in func.call - populateCallOpTypeConversionPattern(patterns, typeConverter); - target.addDynamicallyLegalOp( - [&](const func::CallOp op) { return typeConverter.isLegal(op); }); - - // Conversion of flux types in control-flow ops (e.g., cf.br, cf.cond_br) - populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter); - // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { signalPassFailure(); diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 65637d8701..4410be5539 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -1223,6 +1223,7 @@ struct ConvertQuartzYieldOp final /** * @brief Converts scf.if with memory semantics to scf.if with value semantics + * for qubit values * * @par Example: * ```mlir @@ -1305,7 +1306,7 @@ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { /** * @brief Converts scf.while with memory semantics to scf.while with value - * semantics + * semantics for qubit values * * @par Example: * ```mlir @@ -1392,7 +1393,7 @@ struct ConvertQuartzScfWhileOp final /** * @brief Converts scf.for with memory semantics to scf.while with value - * semantics + * semantics for qubit values * * @par Example: * ```mlir @@ -1455,7 +1456,7 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { /** * @brief Converts scf.yield with memory semantics to scf.yield with value - * semantics + * semantics for qubit values * * @par Example: * ```mlir @@ -1487,7 +1488,7 @@ struct ConvertQuartzScfYieldOp final /** * @brief Converts scf.condition with memory semantics to scf.condition with - * value semantics + * value semantics for qubit values * * @par Example: * ```mlir @@ -1518,6 +1519,21 @@ struct ConvertQuartzScfConditionOp final return success(); } }; + +/** + * @brief Converts func.call with memory semantics to func.call with + * value semantics for qubit values + * + * @par Example: + * ```mlir + * call @test(%q0) : (!quartz.qubit) -> () + * } + * ``` + * is converted to + * ```mlir + * %q1 = call @test(%q1) : (!flux.qubit) -> !flux.qubit + * ``` + */ struct ConvertQuartzFuncCallOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1526,11 +1542,11 @@ struct ConvertQuartzFuncCallOp final matchAndRewrite(func::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto& quartzQubits = getState().regionMap[op]; + auto quartzQubits = op->getOperands(); SmallVector fluxQubits; fluxQubits.reserve(qubitMap.size()); - for (auto [quartQubit, fluxQubit] : qubitMap) { + for (auto [quartzQubit, fluxQubit] : qubitMap) { fluxQubits.push_back(fluxQubit); } // create the result typerange @@ -1550,6 +1566,23 @@ struct ConvertQuartzFuncCallOp final } }; +/** + * @brief Converts func.func with memory semantics to func.func with + * value semantics for qubit values + * + * @par Example: + * ```mlir + * func.func @test(%arg0: !quartz.qubit){ + * ... + * } + * ``` + * is converted to + * ```mlir + * func.func @test(%arg0: !flux.qubit) -> !flux.qubit{ + * ... + * } + * ``` + */ struct ConvertQuartzFuncFuncOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1576,6 +1609,19 @@ struct ConvertQuartzFuncFuncOp final } }; +/** + * @brief Converts func.return with memory semantics to func.return with + * value semantics for qubit values + * + * @par Example: + * ```mlir + * func.return + * ``` + * is converted to + * ```mlir + * func.return %targets + * ``` + */ struct ConvertQuartzFuncReturnOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern< From f62fbcd54aeec253a2b03bb18f829f5dfdeaa18d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 14:30:55 +0100 Subject: [PATCH 009/108] fix header --- mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp | 2 +- mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index f47d95ba70..d8d9064970 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -12,10 +12,10 @@ #include "mlir/Dialect/Flux/IR/FluxDialect.h" #include "mlir/Dialect/Quartz/IR/QuartzDialect.h" -#include "mlir/Dialect/SCF/IR/SCF.h" #include #include +#include #include #include #include diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 4410be5539..a23a47b2ba 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -12,7 +12,6 @@ #include "mlir/Dialect/Flux/IR/FluxDialect.h" #include "mlir/Dialect/Quartz/IR/QuartzDialect.h" -#include "mlir/IR/Block.h" #include #include @@ -22,6 +21,7 @@ #include #include #include +#include #include #include #include From 500431e6fa504e2041a30b674e5536d58299e78d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 17:15:02 +0100 Subject: [PATCH 010/108] fix bug with multiple qubits --- .../Conversion/FluxToQuartz/FluxToQuartz.cpp | 9 +++---- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 24 +++++++++---------- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index d8d9064970..7b5f177386 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -915,7 +915,6 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { beforeArgs[i].replaceAllUsesWith(inits[i]); afterArgs[i].replaceAllUsesWith(inits[i]); } - // create the blocks of the new operation and move the operations to them auto* newBeforeBlock = rewriter.createBlock(&newWhileOp.getBefore(), {}, {}, {}); @@ -927,7 +926,7 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { op.getAfterBody()->getOperations()); // replace the result values with the init values - rewriter.replaceOp(op, adaptor.getInits()); + rewriter.replaceOp(op, inits); return success(); } }; @@ -1028,6 +1027,7 @@ struct ConvertFluxScfConditionOp final : OpConversionPattern { ConversionPatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp(op, op.getCondition(), ValueRange{}); + return success(); } }; @@ -1166,8 +1166,9 @@ struct FluxToQuartz final : impl::FluxToQuartzBase { }); target.addDynamicallyLegalOp([&](scf::YieldOp op) { - return !llvm::any_of(op.getOperandTypes(), [&](Type type) { - return type == flux::QubitType::get(context); + return !llvm::any_of(op->getOperandTypes(), [&](Type type) { + return type == quartz::QubitType::get(context) || + type == flux::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](scf::WhileOp op) { diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index a23a47b2ba..142c249839 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -138,9 +138,7 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, // qubits if (operation.getNumRegions() > 0) { const auto& qubits = collectRegionQubits(&operation, state, ctx); - for (const auto& qubit : qubits) { - uniqueQubits.insert(qubit); - } + uniqueQubits.set_union(qubits); } // collect qubits form the operands for (const auto& operand : operation.getOperands()) { @@ -1338,12 +1336,12 @@ struct ConvertQuartzScfWhileOp final matchAndRewrite(scf::WhileOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto& quartzQubits = getState().regionMap[op]; + const auto& quartzQubits = getState().regionMap[op]; SmallVector fluxQubits; fluxQubits.reserve(quartzQubits.size()); - for (const auto& [quartzQubit, fluxQubit] : qubitMap) { - fluxQubits.push_back(fluxQubit); + for (const auto& quartzQubit : quartzQubits) { + fluxQubits.push_back(qubitMap[quartzQubit]); } // create the result typerange const SmallVector fluxTypes( @@ -1370,7 +1368,6 @@ struct ConvertQuartzScfWhileOp final // create the qubit map for the new regions auto& newBeforeRegionMap = getState().qubitMap[&newWhileOp.getBefore()]; auto& newAfterRegionMap = getState().qubitMap[&newWhileOp.getAfter()]; - for (const auto& [quartzQubit, fluxQubit] : llvm::zip_equal(quartzQubits, newWhileOp.getBeforeArguments())) { newBeforeRegionMap.try_emplace(quartzQubit, fluxQubit); @@ -1385,7 +1382,6 @@ struct ConvertQuartzScfWhileOp final llvm::zip_equal(quartzQubits, newWhileOp->getResults())) { qubitMap[quartzQubit] = fluxQubit; } - rewriter.eraseOp(op); return success(); } @@ -1418,13 +1414,14 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto& quartzQubits = getState().regionMap[op]; + const auto& quartzQubits = getState().regionMap[op]; SmallVector fluxQubits; fluxQubits.reserve(qubitMap.size()); - for (auto [quartQubit, fluxQubit] : qubitMap) { - fluxQubits.push_back(fluxQubit); + for (const auto& quartzQubit : quartzQubits) { + fluxQubits.push_back(qubitMap[quartzQubit]); } + // Create a new for-loop with flux qubits as iter_args auto newFor = rewriter.create( op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), @@ -1443,6 +1440,7 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { llvm::zip_equal(quartzQubits, newFor.getRegionIterArgs())) { regionQubitMap.try_emplace(quartzQubit, fluxQubit); } + // update the qubitmap in the current region for (const auto& [quartzQubit, fluxQubit] : llvm::zip_equal(quartzQubits, newFor->getResults())) { @@ -1546,8 +1544,8 @@ struct ConvertQuartzFuncCallOp final SmallVector fluxQubits; fluxQubits.reserve(qubitMap.size()); - for (auto [quartzQubit, fluxQubit] : qubitMap) { - fluxQubits.push_back(fluxQubit); + for (const auto& quartzQubit : quartzQubits) { + fluxQubits.push_back(qubitMap[quartzQubit]); } // create the result typerange const SmallVector fluxTypes( From 31ee7a176475f4a19e1c6643d236ffbfa535d22f Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 12 Dec 2025 17:32:07 +0100 Subject: [PATCH 011/108] minor fixes --- mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp | 2 +- mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 7b5f177386..745e03de1b 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -1038,7 +1038,7 @@ struct ConvertFluxScfConditionOp final : OpConversionPattern { * * @par Example: * ```mlir - * %q1 = call @test(%q1) : (!flux.qubit) -> !flux.qubit + * %q1 = call @test(%q0) : (!flux.qubit) -> !flux.qubit * } * ``` * is converted to diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 142c249839..4e4323d7a1 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -121,10 +121,10 @@ class StatefulOpConversionPattern : public OpConversionPattern { * @param ctx The MLIRContext of the current program * @return llvm::Setvector The set of unique Quartz qubit references */ -llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, +llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // get the regions of the current operation - const auto regions = op->getRegions(); + const auto& regions = op->getRegions(); SetVector uniqueQubits; for (auto& region : regions) { // skip empty regions e.g. empty else region of an If operation @@ -137,7 +137,7 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, // check if the operation has an region, if yes recursively collect the // qubits if (operation.getNumRegions() > 0) { - const auto& qubits = collectRegionQubits(&operation, state, ctx); + const auto& qubits = collectUniqueQubits(&operation, state, ctx); uniqueQubits.set_union(qubits); } // collect qubits form the operands @@ -171,11 +171,11 @@ llvm::SetVector collectRegionQubits(Operation* op, LoweringState* state, } } } + // mark scf operations that need to be changed afterwards if (!uniqueQubits.empty() && (llvm::isa(op) || (llvm::isa(op)) || llvm::isa(op))) { state->regionMap[op] = uniqueQubits; - // mark scf operations that need to be changed afterwards op->setAttr("needChange", StringAttr::get(ctx, "yes")); } return uniqueQubits; @@ -1675,7 +1675,7 @@ struct QuartzToFlux final : impl::QuartzToFluxBase { RewritePatternSet patterns(context); QuartzToFluxTypeConverter typeConverter(context); - collectRegionQubits(module, &state, context); + collectUniqueQubits(module, &state, context); // Configure conversion target: Quartz illegal, Flux // legal target.addIllegalDialect(); From 183ffa772c3b63aab170b9a65e53bb3825c78f58 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sun, 14 Dec 2025 12:15:13 +0100 Subject: [PATCH 012/108] fix linter issues --- mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 4e4323d7a1..6408d379a2 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -18,10 +18,12 @@ #include #include #include +#include #include #include #include #include +#include #include #include #include @@ -156,7 +158,7 @@ llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, // conversion if ((llvm::isa(operation) || llvm::isa(operation)) && - uniqueQubits.size() > 0) { + !uniqueQubits.empty()) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); } // mark func.return operation for functions that need to return a qubit @@ -1352,7 +1354,7 @@ struct ConvertQuartzScfWhileOp final op.getLoc(), TypeRange(fluxTypes), ValueRange(fluxQubits)); auto& newBeforeRegion = newWhileOp.getBefore(); auto& newAfterRegion = newWhileOp.getAfter(); - SmallVector locs(quartzQubits.size(), op->getLoc()); + const SmallVector locs(quartzQubits.size(), op->getLoc()); // create the new blocks auto* newBeforeBlock = rewriter.createBlock(&newBeforeRegion, {}, fluxTypes, locs); From 69269f73803496677c8de42587c1bebcc930a3bb Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 15 Dec 2025 14:40:02 +0100 Subject: [PATCH 013/108] prepare programBuilders --- .../Dialect/Flux/Builder/FluxProgramBuilder.h | 14 +++- .../Quartz/Builder/QuartzProgramBuilder.h | 7 ++ .../Flux/Builder/FluxProgramBuilder.cpp | 81 ++++++++++++------- .../Quartz/Builder/QuartzProgramBuilder.cpp | 26 ++++++ 4 files changed, 98 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h b/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h index a88015a54b..fecf1af2ef 100644 --- a/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h +++ b/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h @@ -1036,6 +1036,7 @@ class FluxProgramBuilder final : public OpBuilder { MLIRContext* ctx{}; Location loc; ModuleOp module; + Region* funcRegion; //===--------------------------------------------------------------------===// // Linear Type Tracking Helpers @@ -1046,22 +1047,29 @@ class FluxProgramBuilder final : public OpBuilder { * @param qubit Qubit value to validate * @throws Aborts if qubit is not tracked (consumed or never created) */ - void validateQubitValue(Value qubit) const; + void validateQubitValue(Value qubit); /** * @brief Update tracking when an operation consumes and produces a qubit * @param inputQubit Input qubit being consumed (must be valid) * @param outputQubit New output qubit being produced */ - void updateQubitTracking(Value inputQubit, Value outputQubit); + void updateQubitTracking(Value inputQubit, Value outputQubit, Region* region); /// Track valid (unconsumed) qubit SSA values for linear type enforcement. /// Only values present in this set are valid for use in operations. /// When an operation consumes a qubit and produces a new one, the old value /// is removed and the new output is added. - llvm::DenseSet validQubits; + llvm::DenseMap> validQubits; /// Track allocated classical Registers SmallVector allocatedClassicalRegisters; + + Value arithConstantIndex(int i); + + Value arithConstantBool(bool b); + + ValueRange scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body); }; } // namespace mlir::flux diff --git a/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h b/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h index f91370f95d..8dd95b66d4 100644 --- a/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h +++ b/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h @@ -862,6 +862,13 @@ class QuartzProgramBuilder final : public OpBuilder { */ OwningOpRef finalize(); + QuartzProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body); + + Value arithConstantIndex(int i); + + Value arithConstantBool(bool b); + private: MLIRContext* ctx{}; Location loc; diff --git a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp index e487aed6b2..f5abe114c0 100644 --- a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp +++ b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -50,7 +51,7 @@ void FluxProgramBuilder::initialize() { // Add entry_point attribute to identify the main function auto entryPointAttr = getStringAttr("entry_point"); mainFunc->setAttr("passthrough", getArrayAttr({entryPointAttr})); - + funcRegion = &mainFunc->getRegion(0); // Create entry block and set insertion point auto& entryBlock = mainFunc.getBody().emplaceBlock(); setInsertionPointToStart(&entryBlock); @@ -61,7 +62,7 @@ Value FluxProgramBuilder::allocQubit() { const auto qubit = allocOp.getResult(); // Track the allocated qubit as valid - validQubits.insert(qubit); + validQubits[allocOp->getParentRegion()].insert(qubit); return qubit; } @@ -76,7 +77,7 @@ Value FluxProgramBuilder::staticQubit(const int64_t index) { const auto qubit = staticOp.getQubit(); // Track the static qubit as valid - validQubits.insert(qubit); + validQubits[staticOp->getParentRegion()].insert(qubit); return qubit; } @@ -99,7 +100,7 @@ FluxProgramBuilder::allocQubitRegister(const int64_t size, auto allocOp = create(loc, nameAttr, sizeAttr, indexAttr); const auto& qubit = qubits.emplace_back(allocOp.getResult()); // Track the allocated qubit as valid - validQubits.insert(qubit); + validQubits[allocOp->getParentRegion()].insert(qubit); } return qubits; @@ -118,8 +119,8 @@ FluxProgramBuilder::allocClassicalBitRegister(int64_t size, StringRef name) { // Linear Type Tracking Helpers //===----------------------------------------------------------------------===// -void FluxProgramBuilder::validateQubitValue(Value qubit) const { - if (!validQubits.contains(qubit)) { +void FluxProgramBuilder::validateQubitValue(Value qubit) { + if (!validQubits[qubit.getParentRegion()].contains(qubit)) { llvm::errs() << "Attempting to use an invalid qubit SSA value. " << "The value may have been consumed by a previous operation " << "or was never created through this builder.\n"; @@ -129,15 +130,16 @@ void FluxProgramBuilder::validateQubitValue(Value qubit) const { } void FluxProgramBuilder::updateQubitTracking(Value inputQubit, - Value outputQubit) { + Value outputQubit, + Region* region) { // Validate the input qubit validateQubitValue(inputQubit); // Remove the input (consumed) value from tracking - validQubits.erase(inputQubit); + validQubits[region].erase(inputQubit); // Add the output (new) value to tracking - validQubits.insert(outputQubit); + validQubits[region].insert(outputQubit); } //===----------------------------------------------------------------------===// @@ -150,7 +152,7 @@ std::pair FluxProgramBuilder::measure(Value qubit) { auto result = measureOp.getResult(); // Update tracking - updateQubitTracking(qubit, qubitOut); + updateQubitTracking(qubit, qubitOut, measureOp->getParentRegion()); return {qubitOut, result}; } @@ -163,7 +165,7 @@ Value FluxProgramBuilder::measure(Value qubit, const Bit& bit) { const auto qubitOut = measureOp.getQubitOut(); // Update tracking - updateQubitTracking(qubit, qubitOut); + updateQubitTracking(qubit, qubitOut, measureOp->getParentRegion()); return qubitOut; } @@ -173,7 +175,7 @@ Value FluxProgramBuilder::reset(Value qubit) { const auto qubitOut = resetOp.getQubitOut(); // Update tracking - updateQubitTracking(qubit, qubitOut); + updateQubitTracking(qubit, qubitOut, resetOp->getParentRegion()); return qubitOut; } @@ -219,7 +221,7 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) Value FluxProgramBuilder::OP_NAME(Value qubit) { \ auto op = create(loc, qubit); \ const auto& qubitOut = op.getQubitOut(); \ - updateQubitTracking(qubit, qubitOut); \ + updateQubitTracking(qubit, qubitOut, op->getParentRegion()); \ return qubitOut; \ } \ std::pair FluxProgramBuilder::c##OP_NAME(Value control, \ @@ -263,7 +265,7 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value qubit) { \ auto op = create(loc, qubit, PARAM); \ const auto& qubitOut = op.getQubitOut(); \ - updateQubitTracking(qubit, qubitOut); \ + updateQubitTracking(qubit, qubitOut, op->getParentRegion()); \ return qubitOut; \ } \ std::pair FluxProgramBuilder::c##OP_NAME( \ @@ -303,7 +305,7 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, phi) const std::variant&(PARAM2), Value qubit) { \ auto op = create(loc, qubit, PARAM1, PARAM2); \ const auto& qubitOut = op.getQubitOut(); \ - updateQubitTracking(qubit, qubitOut); \ + updateQubitTracking(qubit, qubitOut, op->getParentRegion()); \ return qubitOut; \ } \ std::pair FluxProgramBuilder::c##OP_NAME( \ @@ -346,7 +348,7 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) const std::variant&(PARAM3), Value qubit) { \ auto op = create(loc, qubit, PARAM1, PARAM2, PARAM3); \ const auto& qubitOut = op.getQubitOut(); \ - updateQubitTracking(qubit, qubitOut); \ + updateQubitTracking(qubit, qubitOut, op->getParentRegion()); \ return qubitOut; \ } \ std::pair FluxProgramBuilder::c##OP_NAME( \ @@ -389,8 +391,8 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) auto op = create(loc, qubit0, qubit1); \ const auto& qubit0Out = op.getQubit0Out(); \ const auto& qubit1Out = op.getQubit1Out(); \ - updateQubitTracking(qubit0, qubit0Out); \ - updateQubitTracking(qubit1, qubit1Out); \ + updateQubitTracking(qubit0, qubit0Out, op->getParentRegion()); \ + updateQubitTracking(qubit1, qubit1Out, op->getParentRegion()); \ return {qubit0Out, qubit1Out}; \ } \ std::pair> FluxProgramBuilder::c##OP_NAME( \ @@ -432,8 +434,8 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) auto op = create(loc, qubit0, qubit1, PARAM); \ const auto& qubit0Out = op.getQubit0Out(); \ const auto& qubit1Out = op.getQubit1Out(); \ - updateQubitTracking(qubit0, qubit0Out); \ - updateQubitTracking(qubit1, qubit1Out); \ + updateQubitTracking(qubit0, qubit0Out, op->getParentRegion()); \ + updateQubitTracking(qubit1, qubit1Out, op->getParentRegion()); \ return {qubit0Out, qubit1Out}; \ } \ std::pair> FluxProgramBuilder::c##OP_NAME( \ @@ -479,8 +481,8 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) auto op = create(loc, qubit0, qubit1, PARAM1, PARAM2); \ const auto& qubit0Out = op.getQubit0Out(); \ const auto& qubit1Out = op.getQubit1Out(); \ - updateQubitTracking(qubit0, qubit0Out); \ - updateQubitTracking(qubit1, qubit1Out); \ + updateQubitTracking(qubit0, qubit0Out, op->getParentRegion()); \ + updateQubitTracking(qubit1, qubit1Out, op->getParentRegion()); \ return {qubit0Out, qubit1Out}; \ } \ std::pair> FluxProgramBuilder::c##OP_NAME( \ @@ -522,7 +524,7 @@ ValueRange FluxProgramBuilder::barrier(ValueRange qubits) { auto op = create(loc, qubits); const auto& qubitsOut = op.getQubitsOut(); for (const auto& [inputQubit, outputQubit] : llvm::zip(qubits, qubitsOut)) { - updateQubitTracking(inputQubit, outputQubit); + updateQubitTracking(inputQubit, outputQubit, op->getParentRegion()); } return qubitsOut; } @@ -539,11 +541,11 @@ std::pair FluxProgramBuilder::ctrl( // Update tracking const auto& controlsOut = ctrlOp.getControlsOut(); for (const auto& [control, controlOut] : llvm::zip(controls, controlsOut)) { - updateQubitTracking(control, controlOut); + updateQubitTracking(control, controlOut, ctrlOp->getParentRegion()); } const auto& targetsOut = ctrlOp.getTargetsOut(); for (const auto& [target, targetOut] : llvm::zip(targets, targetsOut)) { - updateQubitTracking(target, targetOut); + updateQubitTracking(target, targetOut, ctrlOp->getParentRegion()); } return {controlsOut, targetsOut}; @@ -555,7 +557,7 @@ std::pair FluxProgramBuilder::ctrl( FluxProgramBuilder& FluxProgramBuilder::dealloc(Value qubit) { validateQubitValue(qubit); - validQubits.erase(qubit); + validQubits[qubit.getParentRegion()].erase(qubit); create(loc, qubit); @@ -568,7 +570,7 @@ FluxProgramBuilder& FluxProgramBuilder::dealloc(Value qubit) { OwningOpRef FluxProgramBuilder::finalize() { // Automatically deallocate all remaining valid qubits - for (const auto qubit : validQubits) { + for (const auto qubit : validQubits[funcRegion]) { create(loc, qubit); } @@ -582,5 +584,30 @@ OwningOpRef FluxProgramBuilder::finalize() { return module; } +Value FluxProgramBuilder::arithConstantIndex(int i) { + + const auto op = + create(loc, getIndexType(), getIndexAttr(i)); + return op->getResult(0); +} +Value FluxProgramBuilder::arithConstantBool(bool b) { + const auto i1Type = getI1Type(); + const auto op = + b ? create(loc, i1Type, getIntegerAttr(i1Type, 1)) + : create(loc, i1Type, getIntegerAttr(i1Type, 0)); + return op->getResult(0); +} + +ValueRange +FluxProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body) { + auto op = create(loc, lowerbound, upperbound, step, ValueRange{}, + [&](OpBuilder& b, Location, Value, ValueRange) { + body(b); // adapt + b.create(loc); + }); + + return op->getResults(); +} } // namespace mlir::flux diff --git a/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp b/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp index d54cb6af71..48297d9aea 100644 --- a/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp +++ b/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -429,4 +430,29 @@ OwningOpRef QuartzProgramBuilder::finalize() { return module; } +QuartzProgramBuilder& +QuartzProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body) { + create(loc, lowerbound, upperbound, step, ValueRange{}, + [&](OpBuilder& b, Location, Value, ValueRange) { + body(b); // adapt + b.create(loc); + }); + + return *this; +} +Value QuartzProgramBuilder::arithConstantIndex(int i) { + + const auto op = + create(loc, getIndexType(), getIndexAttr(i)); + return op->getResult(0); +} +Value QuartzProgramBuilder::arithConstantBool(bool b) { + const auto i1Type = getI1Type(); + const auto op = + b ? create(loc, i1Type, getIntegerAttr(i1Type, 1)) + : create(loc, i1Type, getIntegerAttr(i1Type, 0)); + return op->getResult(0); +} + } // namespace mlir::quartz From b38a9ce528b4f7598a1e142a78252d1696d5195a Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 16 Dec 2025 17:22:02 +0100 Subject: [PATCH 014/108] add quartzProgramBuilder for scf operations --- .../Quartz/Builder/QuartzProgramBuilder.h | 148 +++++++++++++++++- .../Quartz/Builder/QuartzProgramBuilder.cpp | 101 +++++++++--- 2 files changed, 217 insertions(+), 32 deletions(-) diff --git a/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h b/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h index 8dd95b66d4..af9f630de0 100644 --- a/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h +++ b/mlir/include/mlir/Dialect/Quartz/Builder/QuartzProgramBuilder.h @@ -845,6 +845,147 @@ class QuartzProgramBuilder final : public OpBuilder { */ QuartzProgramBuilder& dealloc(Value qubit); + //===--------------------------------------------------------------------===// + // SCF operations + //===--------------------------------------------------------------------===// + + /** + * @brief Constructs a scf.for operation without iter args + * + * @param lowerbound Lowerbound of the loop + * @param upperbound Upperbound of the loop + * @param step Stepsize of the loop + * @param body Function that builds the body of the for operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scfFor(lb, ub, step, [&](auto& b) { b.x(q0); }); + * ``` + * ```mlir + * scf.for %iv = %lb to %ub step %step { + * quartz.x %q0 : !quartz.qubit + * } + * ``` + */ + QuartzProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body); + + /** + * @brief Constructs a scf.while operation without return values + * + * @param beforeBody Function that builds the before body of the while + * operation + * @param afterBody Function that builds the after body of the while operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scfWhile([&](auto& b) { + * b.h(q0); + * auto res = b.measure(q0) + * b.condition(res) + * }, [&](auto& b) { + * b.x(q0); + * b.yield() + * }); + * ``` + * ```mlir + * scf.while : () -> () { + * quartz.h %q0 : !quartz.qubit + * %res = quartz.measure %q0 : !quartz.qubit -> i1 + * scf.condition(%tres) + * } do { + * quartz.x %q0 : !quartz.qubit + * scf.yield + * } + * ``` + */ + QuartzProgramBuilder& + scfWhile(const std::function& beforeBody, + const std::function& afterBody); + + /** + * @brief Constructs a scf.if operation without return values + * + * @param condition Condition for the if operation + * @param thenBody Function that builds the then body of the if + * operation + * @param elseBody Function that builds the else body of the if operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scf.if(condition, [&](auto& b) { + * b.h(q0); + * }, [&](auto& b) { + * b.x(q0); + * }); + * ``` + * ```mlir + * scf.if %condition { + * quartz.h %q0 : !quartz.qubit + * } else { + * quartz.x %q0 : !quartz.qubit + * } + * ``` + */ + QuartzProgramBuilder& + scfIf(Value condition, const std::function& thenBody, + const std::function& elseBody = nullptr); + + /** + * @brief Constructs a scf.condition operation without any additional Values + * + * @param condition Condition for condition operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.condition(condition); + * ``` + * ```mlir + * scf.condition(%condition) + * ``` + */ + QuartzProgramBuilder& scfCondition(Value condition); + + //===--------------------------------------------------------------------===// + // Arith operations + //===--------------------------------------------------------------------===// + + /** + * @brief Constructs a arith.constant of type Index with a given value + * + * @param index Value of the constant operation + * @return Result of the constant operation + * + * @par Example: + * ```c++ + * builder.arithConstantIndex(4); + * ``` + * ```mlir + * arith.constant 4 : index + * ``` + */ + Value arithConstantIndex(int index); + + /** + * @brief Constructs a arith.constant of type i1 with a given bool value + * + * @param b Bool value of the constant operation + * @return Result of the constant operation + * + * @par Example: + * ```c++ + * builder.arithConstantBool(true); + * ``` + * ```mlir + * arith.constant 1 : i1 + * ``` + */ + Value arithConstantBool(bool b); + //===--------------------------------------------------------------------===// // Finalization //===--------------------------------------------------------------------===// @@ -862,13 +1003,6 @@ class QuartzProgramBuilder final : public OpBuilder { */ OwningOpRef finalize(); - QuartzProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body); - - Value arithConstantIndex(int i); - - Value arithConstantBool(bool b); - private: MLIRContext* ctx{}; Location loc; diff --git a/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp b/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp index 48297d9aea..c481287155 100644 --- a/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp +++ b/mlir/lib/Dialect/Quartz/Builder/QuartzProgramBuilder.cpp @@ -408,51 +408,102 @@ QuartzProgramBuilder& QuartzProgramBuilder::dealloc(Value qubit) { } //===----------------------------------------------------------------------===// -// Finalization +// SCF operations //===----------------------------------------------------------------------===// -OwningOpRef QuartzProgramBuilder::finalize() { - // Automatically deallocate all remaining allocated qubits - for (Value qubit : allocatedQubits) { - create(loc, qubit); - } - - // Clear the tracking set - allocatedQubits.clear(); - - // Create constant 0 for successful exit code - auto exitCode = create(loc, getI64IntegerAttr(0)); - - // Add return statement with exit code 0 to the main function - create(loc, ValueRange{exitCode}); - - // Transfer ownership to the caller - return module; -} - QuartzProgramBuilder& QuartzProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, const std::function& body) { create(loc, lowerbound, upperbound, step, ValueRange{}, [&](OpBuilder& b, Location, Value, ValueRange) { - body(b); // adapt + body(b); b.create(loc); }); return *this; } -Value QuartzProgramBuilder::arithConstantIndex(int i) { +QuartzProgramBuilder& QuartzProgramBuilder::scfWhile( + const std::function& beforeBody, + const std::function& afterBody) { + create( + loc, TypeRange{}, ValueRange{}, + [&](OpBuilder& b, Location, ValueRange) { beforeBody(b); }, + [&](OpBuilder& b, Location loc, ValueRange) { + afterBody(b); + b.create(loc); + }); + + return *this; +} + +QuartzProgramBuilder& +QuartzProgramBuilder::scfIf(Value cond, + const std::function& thenBody, + const std::function& elseBody) { + if (!elseBody) { + create(loc, cond, [&](OpBuilder& b, Location loc) { + thenBody(b); + b.create(loc); + }); + } else { + create( + loc, cond, + [&](OpBuilder& b, Location loc) { + thenBody(b); + b.create(loc); + }, + [&](OpBuilder& b, Location loc) { + elseBody(b); + b.create(loc); + }); + } + return *this; +} + +QuartzProgramBuilder& QuartzProgramBuilder::scfCondition(Value condition) { + create(loc, condition, ValueRange{}); + return *this; +} + +//===----------------------------------------------------------------------===// +// Arith operations +//===----------------------------------------------------------------------===// + +Value QuartzProgramBuilder::arithConstantIndex(int index) { const auto op = - create(loc, getIndexType(), getIndexAttr(i)); + create(loc, getIndexType(), getIndexAttr(index)); return op->getResult(0); } + Value QuartzProgramBuilder::arithConstantBool(bool b) { const auto i1Type = getI1Type(); const auto op = - b ? create(loc, i1Type, getIntegerAttr(i1Type, 1)) - : create(loc, i1Type, getIntegerAttr(i1Type, 0)); + create(loc, i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); return op->getResult(0); } +//===----------------------------------------------------------------------===// +// Finalization +//===----------------------------------------------------------------------===// + +OwningOpRef QuartzProgramBuilder::finalize() { + // Automatically deallocate all remaining allocated qubits + for (Value qubit : allocatedQubits) { + create(loc, qubit); + } + + // Clear the tracking set + allocatedQubits.clear(); + + // Create constant 0 for successful exit code + auto exitCode = create(loc, getI64IntegerAttr(0)); + + // Add return statement with exit code 0 to the main function + create(loc, ValueRange{exitCode}); + + // Transfer ownership to the caller + return module; +} + } // namespace mlir::quartz From 738daaee79d8e05d5fd93a078de4ace942080977 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 16 Dec 2025 17:43:16 +0100 Subject: [PATCH 015/108] fix typing issue with converted function arguments --- .../Conversion/QuartzToFlux/QuartzToFlux.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp index 6408d379a2..6208b46a52 100644 --- a/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp +++ b/mlir/lib/Conversion/QuartzToFlux/QuartzToFlux.cpp @@ -231,7 +231,7 @@ LogicalResult convertOneTargetZeroParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit - const auto& quartzQubit = op.getQubitIn(); + const auto& quartzQubit = op.getOperand(); Value fluxQubit; if (inCtrlOp == 0) { @@ -277,7 +277,7 @@ LogicalResult convertOneTargetOneParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit - const auto& quartzQubit = op.getQubitIn(); + const auto& quartzQubit = op.getOperand(0); Value fluxQubit; if (inCtrlOp == 0) { fluxQubit = qubitMap[quartzQubit]; @@ -321,7 +321,7 @@ LogicalResult convertOneTargetTwoParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit - const auto& quartzQubit = op.getQubitIn(); + const auto& quartzQubit = op.getOperand(0); Value fluxQubit; if (inCtrlOp == 0) { fluxQubit = qubitMap[quartzQubit]; @@ -366,7 +366,7 @@ convertOneTargetThreeParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubit - const auto& quartzQubit = op.getQubitIn(); + const auto& quartzQubit = op.getOperand(0); Value fluxQubit; if (inCtrlOp == 0) { fluxQubit = qubitMap[quartzQubit]; @@ -411,8 +411,8 @@ LogicalResult convertTwoTargetZeroParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubits - const auto& quartzQubit0 = op.getQubit0In(); - const auto& quartzQubit1 = op.getQubit1In(); + const auto& quartzQubit0 = op.getOperand(0); + const auto& quartzQubit1 = op.getOperand(1); Value fluxQubit0; Value fluxQubit1; if (inCtrlOp == 0) { @@ -462,8 +462,8 @@ LogicalResult convertTwoTargetOneParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubits - const auto& quartzQubit0 = op.getQubit0In(); - const auto& quartzQubit1 = op.getQubit1In(); + const auto& quartzQubit0 = op.getOperand(0); + const auto& quartzQubit1 = op.getOperand(1); Value fluxQubit0; Value fluxQubit1; if (inCtrlOp == 0) { @@ -513,8 +513,8 @@ LogicalResult convertTwoTargetTwoParameter(QuartzOpType& op, const auto inCtrlOp = state.inCtrlOp; // Get the latest Flux qubits - const auto& quartzQubit0 = op.getQubit0In(); - const auto& quartzQubit1 = op.getQubit1In(); + const auto& quartzQubit0 = op.getOperand(0); + const auto& quartzQubit1 = op.getOperand(1); Value fluxQubit0; Value fluxQubit1; if (inCtrlOp == 0) { From 4e3306eae1bca1565d0e2b87e430dcd58a85ed2c Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 16 Dec 2025 17:45:35 +0100 Subject: [PATCH 016/108] add missing headers --- mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp index 745e03de1b..d63e96c320 100644 --- a/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp +++ b/mlir/lib/Conversion/FluxToQuartz/FluxToQuartz.cpp @@ -13,6 +13,8 @@ #include "mlir/Dialect/Flux/IR/FluxDialect.h" #include "mlir/Dialect/Quartz/IR/QuartzDialect.h" +#include +#include #include #include #include From e58701f50ddc2fcaafad8634e3e2f292345c083d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 16 Dec 2025 17:49:27 +0100 Subject: [PATCH 017/108] start adding FluxProgramBuilders for scf operations --- .../Dialect/Flux/Builder/FluxProgramBuilder.h | 14 +++++++------- .../Flux/Builder/FluxProgramBuilder.cpp | 18 ++++++++---------- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h b/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h index fecf1af2ef..f7ee99c5fe 100644 --- a/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h +++ b/mlir/include/mlir/Dialect/Flux/Builder/FluxProgramBuilder.h @@ -1015,6 +1015,13 @@ class FluxProgramBuilder final : public OpBuilder { */ FluxProgramBuilder& dealloc(Value qubit); + Value arithConstantIndex(int i); + + Value arithConstantBool(bool b); + + ValueRange + scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, + const std::function& body); //===--------------------------------------------------------------------===// // Finalization //===--------------------------------------------------------------------===// @@ -1064,12 +1071,5 @@ class FluxProgramBuilder final : public OpBuilder { /// Track allocated classical Registers SmallVector allocatedClassicalRegisters; - - Value arithConstantIndex(int i); - - Value arithConstantBool(bool b); - - ValueRange scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body); }; } // namespace mlir::flux diff --git a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp index f5abe114c0..3bc6b49348 100644 --- a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp +++ b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp @@ -599,15 +599,13 @@ Value FluxProgramBuilder::arithConstantBool(bool b) { return op->getResult(0); } -ValueRange -FluxProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body) { - auto op = create(loc, lowerbound, upperbound, step, ValueRange{}, - [&](OpBuilder& b, Location, Value, ValueRange) { - body(b); // adapt - b.create(loc); - }); - - return op->getResults(); +ValueRange FluxProgramBuilder::scfFor( + Value lowerbound, Value upperbound, Value step, ValueRange initArgs, + const std::function& body) { + auto op = create(loc, lowerbound, upperbound, step, + initArgs // iter_args + ); + + return ValueRange{op->getResults()}; } } // namespace mlir::flux From 92b556ee28a2ad3e9ecee5c0d56c4a1f195aec28 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 16 Dec 2025 18:28:49 +0100 Subject: [PATCH 018/108] use region as key for validQubis mapt --- mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp index 3a56be6212..9a149fb781 100644 --- a/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp +++ b/mlir/lib/Dialect/Flux/Builder/FluxProgramBuilder.cpp @@ -645,7 +645,8 @@ OwningOpRef FluxProgramBuilder::finalize() { // Automatically deallocate all still-allocated qubits // Sort qubits for deterministic output - SmallVector sortedQubits(validQubits.begin(), validQubits.end()); + SmallVector sortedQubits(validQubits[&mainFunc->getRegion(0)].begin(), + validQubits[&mainFunc->getRegion(0)].end()); llvm::sort(sortedQubits, [](Value a, Value b) { auto* opA = a.getDefiningOp(); auto* opB = b.getDefiningOp(); From 4f0a6c7a35764e298738e554ff8fb0160ccc04ab Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 18 Dec 2025 11:31:30 +0100 Subject: [PATCH 019/108] resolve merge conflicts --- .../Dialect/QC/Builder/QCProgramBuilder.h | 25 +- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 96 ++++--- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 266 +++++++++--------- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 26 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 6 +- 5 files changed, 208 insertions(+), 211 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 2fbf7b6724..e9534e3cd0 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -887,12 +887,12 @@ class QCProgramBuilder final : public OpBuilder { * ``` * ```mlir * scf.for %iv = %lb to %ub step %step { - * quartz.x %q0 : !quartz.qubit + * qc.x %q0 : !qc.qubit * } * ``` */ - QuartzProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body); + QCProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body); /** * @brief Constructs a scf.while operation without return values @@ -915,18 +915,17 @@ class QCProgramBuilder final : public OpBuilder { * ``` * ```mlir * scf.while : () -> () { - * quartz.h %q0 : !quartz.qubit - * %res = quartz.measure %q0 : !quartz.qubit -> i1 + * qc.h %q0 : !qc.qubit + * %res = qc.measure %q0 : !qc.qubit -> i1 * scf.condition(%tres) * } do { - * quartz.x %q0 : !quartz.qubit + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` */ - QuartzProgramBuilder& - scfWhile(const std::function& beforeBody, - const std::function& afterBody); + QCProgramBuilder& scfWhile(const std::function& beforeBody, + const std::function& afterBody); /** * @brief Constructs a scf.if operation without return values @@ -947,13 +946,13 @@ class QCProgramBuilder final : public OpBuilder { * ``` * ```mlir * scf.if %condition { - * quartz.h %q0 : !quartz.qubit + * qc.h %q0 : !qc.qubit * } else { - * quartz.x %q0 : !quartz.qubit + * qc.x %q0 : !qc.qubit * } * ``` */ - QuartzProgramBuilder& + QCProgramBuilder& scfIf(Value condition, const std::function& thenBody, const std::function& elseBody = nullptr); @@ -971,7 +970,7 @@ class QCProgramBuilder final : public OpBuilder { * scf.condition(%condition) * ``` */ - QuartzProgramBuilder& scfCondition(Value condition); + QCProgramBuilder& scfCondition(Value condition); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 22cc4512bb..7805080b26 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -819,22 +819,22 @@ struct ConvertQCOYieldOp final : OpConversionPattern { * * @par Example: * ```mlir - * %targets_out = scf.if %cond -> (!flux.qubit) { - * %q1 = flux.h %q0 : !flux.qubit -> !flux.qubit - * scf.yield %q1 : !flux.qubit + * %targets_out = scf.if %cond -> (!qco.qubit) { + * %q1 = qco.h %q0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } else { - * scf.yield %q0 : !flux.qubit + * scf.yield %q0 : !qco.qubit * } * ``` * is converted to * ```mlir * scf.if %cond { - * quartz.x %q0 + * qc.x %q0 * scf.yield * } * ``` */ -struct ConvertFluxScfIfOp final : OpConversionPattern { +struct ConvertQCOScfIfOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -868,27 +868,27 @@ struct ConvertFluxScfIfOp final : OpConversionPattern { * * @par Example: * ```mlir - * %targets_out = scf.while (%arg0 = %q0) : (!flux.qubit) -> !flux.qubit { - * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit - * scf.condition(%cond) %q1 : !flux.qubit + * %targets_out = scf.while (%arg0 = %q0) : (!qco.qubit) -> !qco.qubit { + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.condition(%cond) %q1 : !qco.qubit * } do { - * ^bb0(%arg0: !flux.qubit): - * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit - * scf.yield %q1 : !flux.qubit + * ^bb0(%arg0: !qco.qubit): + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } * ``` * is converted to * ```mlir * scf.while : () -> () { - * quartz.x %q0 + * qc.x %q0 * scf.condition(%cond) * } do { - * quartz.x %q0 + * qc.x %q0 * scf.yield * } * ``` */ -struct ConvertFluxScfWhileOp final : OpConversionPattern { +struct ConvertQCOScfWhileOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -929,20 +929,20 @@ struct ConvertFluxScfWhileOp final : OpConversionPattern { * @par Example: * ```mlir * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> - * (!flux.qubit) { - * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit - * scf.yield %q1 : !flux.qubit + * (!qco.qubit) { + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } * ``` * is converted to * ```mlir * scf.for %iv = %lb to %ub step %step { - * quartz.x %q0 + * qc.x %q0 * scf.yield * } * ``` */ -struct ConvertFluxScfForOp final : OpConversionPattern { +struct ConvertQCOScfForOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -954,9 +954,9 @@ struct ConvertFluxScfForOp final : OpConversionPattern { adaptor.getStep(), ValueRange{}); // replace the uses of the previous iter_args - for (const auto& [fluxQubit, quartzQubit] : + for (const auto& [qcoQubit, qcQubit] : llvm::zip_equal(op.getRegionIterArgs(), adaptor.getInitArgs())) { - fluxQubit.replaceAllUsesWith(quartzQubit); + qcoQubit.replaceAllUsesWith(qcQubit); } // move all the operations from the old block to the new block @@ -985,7 +985,7 @@ struct ConvertFluxScfForOp final : OpConversionPattern { * scf.yield * ``` */ -struct ConvertFluxScfYieldOp final : OpConversionPattern { +struct ConvertQCOScfYieldOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -1010,7 +1010,7 @@ struct ConvertFluxScfYieldOp final : OpConversionPattern { * ``` */ -struct ConvertFluxScfConditionOp final : OpConversionPattern { +struct ConvertQCOScfConditionOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -1029,15 +1029,15 @@ struct ConvertFluxScfConditionOp final : OpConversionPattern { * * @par Example: * ```mlir - * %q1 = call @test(%q0) : (!flux.qubit) -> !flux.qubit + * %q1 = call @test(%q0) : (!qco.qubit) -> !qco.qubit * } * ``` * is converted to * ```mlir - * call @test(%q0) : (!quartz.qubit) -> () + * call @test(%q0) : (!qc.qubit) -> () * ``` */ -struct ConvertFluxFuncCallOp final : OpConversionPattern { +struct ConvertQCOFuncCallOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -1056,18 +1056,18 @@ struct ConvertFluxFuncCallOp final : OpConversionPattern { * * @par Example: * ```mlir - * func.func @test(%arg0: !flux.qubit) -> !flux.qubit { + * func.func @test(%arg0: !qco.qubit) -> !qco.qubit { * ... * } * ``` * is converted to * ```mlir - * func.func @test(%arg0: !quartz.qubit){ + * func.func @test(%arg0: !qc.qubit){ * ... * } * ``` */ -struct ConvertFluxFuncFuncOp final : OpConversionPattern { +struct ConvertQCOFuncFuncOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -1075,9 +1075,9 @@ struct ConvertFluxFuncFuncOp final : OpConversionPattern { ConversionPatternRewriter& rewriter) const override { const SmallVector argumentTypes( op.front().getNumArguments(), - quartz::QubitType::get(rewriter.getContext())); + qc::QubitType::get(rewriter.getContext())); for (auto blockArg : op.front().getArguments()) { - blockArg.setType(quartz::QubitType::get(rewriter.getContext())); + blockArg.setType(qc::QubitType::get(rewriter.getContext())); } auto newFuncType = rewriter.getFunctionType(argumentTypes, {}); op.setFunctionType(newFuncType); @@ -1091,14 +1091,14 @@ struct ConvertFluxFuncFuncOp final : OpConversionPattern { * * @par Example: * ```mlir - * func.return %targets : !flux.qubit + * func.return %targets : !qco.qubit * ``` * is converted to * ```mlir * func.return * ``` */ -struct ConvertFluxFuncReturnOp final : OpConversionPattern { +struct ConvertQCOFuncReturnOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -1197,18 +1197,22 @@ struct QCOToQC final : impl::QCOToQCBase { // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion - patterns.add( - typeConverter, context); + patterns + .add(typeConverter, + context); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index bd9345d01a..d4ea2b95ac 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -72,10 +72,10 @@ namespace { * - %q2 after the X gate */ struct LoweringState { - /// Map from original Quartz qubit references to their latest Flux SSA values + /// Map from original QC qubit references to their latest Flux SSA values /// for each region llvm::DenseMap> qubitMap; - /// Map each operation to its Set of Quartz qubit references + /// Map each operation to its Set of QC qubit references llvm::DenseMap> regionMap; /// Modifier information @@ -117,13 +117,13 @@ class StatefulOpConversionPattern : public OpConversionPattern { } // namespace /** - * @brief Recursively collects all the Quartz qubit references used by an + * @brief Recursively collects all the QC qubit references used by an * operation and store them in map * * @param Operation The operation that is currently traversed * @param state The lowering state * @param ctx The MLIRContext of the current program - * @return llvm::Setvector The set of unique Quartz qubit references + * @return llvm::Setvector The set of unique QC qubit references */ llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { @@ -146,13 +146,13 @@ llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, } // collect qubits form the operands for (const auto& operand : operation.getOperands()) { - if (operand.getType() == quartz::QubitType::get(ctx)) { + if (operand.getType() == qc::QubitType::get(ctx)) { uniqueQubits.insert(operand); } } // collect qubits from the results for (const auto& result : operation.getResults()) { - if (result.getType() == quartz::QubitType::get(ctx)) { + if (result.getType() == qc::QubitType::get(ctx)) { uniqueQubits.insert(result); } } @@ -168,7 +168,7 @@ llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, if (llvm::isa(operation)) { if (auto func = operation.getParentOfType()) { if (!func.getArgumentTypes().empty() && - func.getArgumentTypes().front() == quartz::QubitType::get(ctx)) { + func.getArgumentTypes().front() == qc::QubitType::get(ctx)) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); } } @@ -1245,37 +1245,36 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { * @par Example: * ```mlir * scf.if %cond { - * quartz.x %q0 + * qc.x %q0 * scf.yield * } * ``` * is converted to * ```mlir - * %targets_out = scf.if %cond -> (!flux.qubit) { - * %q1 = flux.h %q0 : !flux.qubit -> !flux.qubit - * scf.yield %q1 : !flux.qubit + * %targets_out = scf.if %cond -> (!qco.qubit) { + * %q1 = qco.h %q0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } else { - * scf.yield %q0 : !flux.qubit + * scf.yield %q0 : !qco.qubit * } * ``` */ -struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { +struct ConvertQCScfIfOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& quartzQubits = getState().regionMap[op]; - const SmallVector quartzValues(quartzQubits.begin(), - quartzQubits.end()); + const auto& qcQubits = getState().regionMap[op]; + const SmallVector qcValues(qcQubits.begin(), qcQubits.end()); // create result typerange - const SmallVector fluxTypes( - quartzQubits.size(), flux::QubitType::get(rewriter.getContext())); + const SmallVector qcoTypes( + qcQubits.size(), qco::QubitType::get(rewriter.getContext())); // create new if operation - auto newIfOp = rewriter.create( - op->getLoc(), TypeRange{fluxTypes}, op.getCondition(), true); + auto newIfOp = rewriter.create(op->getLoc(), TypeRange{qcoTypes}, + op.getCondition(), true); auto& thenRegion = newIfOp.getThenRegion(); auto& elseRegion = newIfOp.getElseRegion(); @@ -1293,7 +1292,7 @@ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { // create the yield operation if it does not exist yet rewriter.setInsertionPointToEnd(&elseRegion.front()); const auto elseYield = - rewriter.create(op->getLoc(), quartzValues); + rewriter.create(op->getLoc(), qcValues); // mark the yield operation for conversion elseYield->setAttr("needChange", StringAttr::get(rewriter.getContext(), "yes")); @@ -1302,18 +1301,18 @@ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { // create the qubit map for the regions auto& thenRegionQubitMap = getState().qubitMap[&thenRegion]; auto& elseRegionQubitMap = getState().qubitMap[&elseRegion]; - for (const auto& quartzQubit : quartzQubits) { + for (const auto& qcQubit : qcQubits) { thenRegionQubitMap.try_emplace( - quartzQubit, getState().qubitMap[op->getParentRegion()][quartzQubit]); + qcQubit, getState().qubitMap[op->getParentRegion()][qcQubit]); elseRegionQubitMap.try_emplace( - quartzQubit, getState().qubitMap[op->getParentRegion()][quartzQubit]); + qcQubit, getState().qubitMap[op->getParentRegion()][qcQubit]); } // update the qubit map in the current region auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, newIfOp->getResults())) { - qubitMap[quartzQubit] = fluxQubit; + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newIfOp->getResults())) { + qubitMap[qcQubit] = qcoQubit; } rewriter.eraseOp(op); @@ -1328,55 +1327,54 @@ struct ConvertQuartzScfIfOp final : StatefulOpConversionPattern { * @par Example: * ```mlir * scf.while : () -> () { - * quartz.x %q0 + * qc.x %q0 * scf.condition(%cond) * } do { - * quartz.x %q0 + * qc.x %q0 * scf.yield * } * ``` * is converted to * ```mlir - * %targets_out = scf.while (%arg0 = %q0) : (!flux.qubit) -> !flux.qubit { - * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit - * scf.condition(%cond) %q1 : !flux.qubit + * %targets_out = scf.while (%arg0 = %q0) : (!qco.qubit) -> !qco.qubit { + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.condition(%cond) %q1 : !qco.qubit * } do { - * ^bb0(%arg0: !flux.qubit): - * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit - * scf.yield %q1 : !flux.qubit + * ^bb0(%arg0: !qco.qubit): + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } * ``` */ -struct ConvertQuartzScfWhileOp final - : StatefulOpConversionPattern { +struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::WhileOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - const auto& quartzQubits = getState().regionMap[op]; + const auto& qcQubits = getState().regionMap[op]; - SmallVector fluxQubits; - fluxQubits.reserve(quartzQubits.size()); - for (const auto& quartzQubit : quartzQubits) { - fluxQubits.push_back(qubitMap[quartzQubit]); + SmallVector qcoQubits; + qcoQubits.reserve(qcQubits.size()); + for (const auto& qcQubit : qcQubits) { + qcoQubits.push_back(qubitMap[qcQubit]); } // create the result typerange - const SmallVector fluxTypes( - quartzQubits.size(), flux::QubitType::get(rewriter.getContext())); + const SmallVector qcoTypes( + qcQubits.size(), qco::QubitType::get(rewriter.getContext())); // create the new while operation auto newWhileOp = rewriter.create( - op.getLoc(), TypeRange(fluxTypes), ValueRange(fluxQubits)); + op.getLoc(), TypeRange(qcoTypes), ValueRange(qcoQubits)); auto& newBeforeRegion = newWhileOp.getBefore(); auto& newAfterRegion = newWhileOp.getAfter(); - const SmallVector locs(quartzQubits.size(), op->getLoc()); + const SmallVector locs(qcQubits.size(), op->getLoc()); // create the new blocks auto* newBeforeBlock = - rewriter.createBlock(&newBeforeRegion, {}, fluxTypes, locs); + rewriter.createBlock(&newBeforeRegion, {}, qcoTypes, locs); auto* newAfterBlock = - rewriter.createBlock(&newAfterRegion, {}, fluxTypes, locs); + rewriter.createBlock(&newAfterRegion, {}, qcoTypes, locs); // move the operations to the new blocks newBeforeBlock->getOperations().splice(newBeforeBlock->end(), @@ -1387,19 +1385,19 @@ struct ConvertQuartzScfWhileOp final // create the qubit map for the new regions auto& newBeforeRegionMap = getState().qubitMap[&newWhileOp.getBefore()]; auto& newAfterRegionMap = getState().qubitMap[&newWhileOp.getAfter()]; - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, newWhileOp.getBeforeArguments())) { - newBeforeRegionMap.try_emplace(quartzQubit, fluxQubit); + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newWhileOp.getBeforeArguments())) { + newBeforeRegionMap.try_emplace(qcQubit, qcoQubit); } - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, newWhileOp.getAfterArguments())) { - newAfterRegionMap.try_emplace(quartzQubit, fluxQubit); + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newWhileOp.getAfterArguments())) { + newAfterRegionMap.try_emplace(qcQubit, qcoQubit); } // update the qubit map in the current region - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, newWhileOp->getResults())) { - qubitMap[quartzQubit] = fluxQubit; + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newWhileOp->getResults())) { + qubitMap[qcQubit] = qcoQubit; } rewriter.eraseOp(op); return success(); @@ -1413,38 +1411,38 @@ struct ConvertQuartzScfWhileOp final * @par Example: * ```mlir * scf.for %iv = %lb to %ub step %step { - * quartz.x %q0 + * qc.x %q0 * scf.yield * } * ``` * is converted to * ```mlir * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> - * (!flux.qubit) { - * %q1 = quartz.x %arg0 : !flux.qubit -> !flux.qubit - * scf.yield %q1 : !flux.qubit + * (!qco.qubit) { + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } * ``` */ -struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { +struct ConvertQCScfForOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - const auto& quartzQubits = getState().regionMap[op]; + const auto& qcQubits = getState().regionMap[op]; - SmallVector fluxQubits; - fluxQubits.reserve(qubitMap.size()); - for (const auto& quartzQubit : quartzQubits) { - fluxQubits.push_back(qubitMap[quartzQubit]); + SmallVector qcoQubits; + qcoQubits.reserve(qubitMap.size()); + for (const auto& qcQubit : qcQubits) { + qcoQubits.push_back(qubitMap[qcQubit]); } - // Create a new for-loop with flux qubits as iter_args + // Create a new for-loop with qco qubits as iter_args auto newFor = rewriter.create( op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), - adaptor.getStep(), ValueRange(fluxQubits)); + adaptor.getStep(), ValueRange(qcoQubits)); // move the operations to the new block auto& srcBlock = op.getRegion().front(); @@ -1455,15 +1453,15 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { auto& regionQubitMap = getState().qubitMap[&newRegion]; // create the qubitmap for the new region - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, newFor.getRegionIterArgs())) { - regionQubitMap.try_emplace(quartzQubit, fluxQubit); + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newFor.getRegionIterArgs())) { + regionQubitMap.try_emplace(qcQubit, qcoQubit); } // update the qubitmap in the current region - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, newFor->getResults())) { - qubitMap[quartzQubit] = fluxQubit; + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newFor->getResults())) { + qubitMap[qcQubit] = qcoQubit; } rewriter.eraseOp(op); @@ -1484,21 +1482,20 @@ struct ConvertQuartzScfForOp final : StatefulOpConversionPattern { * scf.yield %targets * ``` */ -struct ConvertQuartzScfYieldOp final - : StatefulOpConversionPattern { +struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - SmallVector fluxQubits; - fluxQubits.reserve(qubitMap.size()); - for (auto [quartzQubit, fluxQubit] : qubitMap) { - fluxQubits.push_back(fluxQubit); + SmallVector qcoQubits; + qcoQubits.reserve(qubitMap.size()); + for (auto [qcQubit, qcoQubit] : qubitMap) { + qcoQubits.push_back(qcoQubit); } - rewriter.replaceOpWithNewOp(op, fluxQubits); + rewriter.replaceOpWithNewOp(op, qcoQubits); return success(); } }; @@ -1516,7 +1513,7 @@ struct ConvertQuartzScfYieldOp final * scf.condition(%cond) %targets * ``` */ -struct ConvertQuartzScfConditionOp final +struct ConvertQCScfConditionOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern< scf::ConditionOp>::StatefulOpConversionPattern; @@ -1525,14 +1522,14 @@ struct ConvertQuartzScfConditionOp final matchAndRewrite(scf::ConditionOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - SmallVector fluxQubits; - fluxQubits.reserve(qubitMap.size()); - for (auto [quartzQubit, fluxQubit] : qubitMap) { - fluxQubits.push_back(fluxQubit); + SmallVector qcoQubits; + qcoQubits.reserve(qubitMap.size()); + for (auto [qcQubit, qcoQubit] : qubitMap) { + qcoQubits.push_back(qcoQubit); } rewriter.replaceOpWithNewOp(op, op.getCondition(), - fluxQubits); + qcoQubits); return success(); } }; @@ -1543,39 +1540,38 @@ struct ConvertQuartzScfConditionOp final * * @par Example: * ```mlir - * call @test(%q0) : (!quartz.qubit) -> () + * call @test(%q0) : (!qc.qubit) -> () * } * ``` * is converted to * ```mlir - * %q1 = call @test(%q1) : (!flux.qubit) -> !flux.qubit + * %q1 = call @test(%q1) : (!qco.qubit) -> !qco.qubit * ``` */ -struct ConvertQuartzFuncCallOp final - : StatefulOpConversionPattern { +struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto quartzQubits = op->getOperands(); + auto qcQubits = op->getOperands(); - SmallVector fluxQubits; - fluxQubits.reserve(qubitMap.size()); - for (const auto& quartzQubit : quartzQubits) { - fluxQubits.push_back(qubitMap[quartzQubit]); + SmallVector qcoQubits; + qcoQubits.reserve(qubitMap.size()); + for (const auto& qcQubit : qcQubits) { + qcoQubits.push_back(qubitMap[qcQubit]); } // create the result typerange - const SmallVector fluxTypes( - quartzQubits.size(), flux::QubitType::get(rewriter.getContext())); + const SmallVector qcoTypes( + qcQubits.size(), qco::QubitType::get(rewriter.getContext())); const auto callOp = rewriter.create( - op->getLoc(), adaptor.getCallee(), fluxTypes, fluxQubits); + op->getLoc(), adaptor.getCallee(), qcoTypes, qcoQubits); - for (const auto& [quartzQubit, fluxQubit] : - llvm::zip_equal(quartzQubits, callOp->getResults())) { - qubitMap[quartzQubit] = fluxQubit; + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, callOp->getResults())) { + qubitMap[qcQubit] = qcoQubit; } rewriter.eraseOp(op); @@ -1589,38 +1585,37 @@ struct ConvertQuartzFuncCallOp final * * @par Example: * ```mlir - * func.func @test(%arg0: !quartz.qubit){ + * func.func @test(%arg0: !qc.qubit){ * ... * } * ``` * is converted to * ```mlir - * func.func @test(%arg0: !flux.qubit) -> !flux.qubit{ + * func.func @test(%arg0: !qco.qubit) -> !qco.qubit{ * ... * } * ``` */ -struct ConvertQuartzFuncFuncOp final - : StatefulOpConversionPattern { +struct ConvertQCFuncFuncOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[&op->getRegion(0)]; - const SmallVector fluxTypes( + const SmallVector qcoTypes( op.front().getNumArguments(), - flux::QubitType::get(rewriter.getContext())); + qco::QubitType::get(rewriter.getContext())); - // set the arguments to flux qubit type + // set the arguments to qco qubit type for (auto blockArg : op.front().getArguments()) { - blockArg.setType(flux::QubitType::get(rewriter.getContext())); + blockArg.setType(qco::QubitType::get(rewriter.getContext())); qubitMap.try_emplace(blockArg, blockArg); } - // change the function signature to return the same number of flux Qubits as + // change the function signature to return the same number of qco Qubits as // it gets as input - auto newFuncType = rewriter.getFunctionType(fluxTypes, fluxTypes); // + auto newFuncType = rewriter.getFunctionType(qcoTypes, qcoTypes); // op.setFunctionType(newFuncType); return success(); } @@ -1639,7 +1634,7 @@ struct ConvertQuartzFuncFuncOp final * func.return %targets * ``` */ -struct ConvertQuartzFuncReturnOp final +struct ConvertQCFuncReturnOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern< func::ReturnOp>::StatefulOpConversionPattern; @@ -1648,12 +1643,12 @@ struct ConvertQuartzFuncReturnOp final matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - SmallVector fluxQubits; - fluxQubits.reserve(qubitMap.size()); - for (auto [quartzQubit, fluxQubit] : qubitMap) { - fluxQubits.push_back(fluxQubit); + SmallVector qcoQubits; + qcoQubits.reserve(qubitMap.size()); + for (auto [qcQubit, qcoQubit] : qubitMap) { + qcoQubits.push_back(qcoQubit); } - rewriter.replaceOpWithNewOp(op, fluxQubits); + rewriter.replaceOpWithNewOp(op, qcoQubits); return success(); } }; @@ -1717,12 +1712,12 @@ struct QCToQCO final : impl::QCToQCOBase { }); target.addDynamicallyLegalOp([&](func::FuncOp op) { return !llvm::any_of(op.front().getArgumentTypes(), [&](Type type) { - return type == quartz::QubitType::get(context); + return type == qc::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](func::CallOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == quartz::QubitType::get(context); + return type == qc::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { @@ -1731,21 +1726,20 @@ struct QCToQCO final : impl::QCToQCOBase { // Register operation conversion patterns with state // tracking - patterns.add(typeConverter, context, &state); + patterns.add< + ConvertQCAllocOp, ConvertQCDeallocOp, ConvertQCStaticOp, + ConvertQCMeasureOp, ConvertQCResetOp, ConvertQCGPhaseOp, ConvertQCIdOp, + ConvertQCXOp, ConvertQCYOp, ConvertQCZOp, ConvertQCHOp, ConvertQCSOp, + ConvertQCSdgOp, ConvertQCTOp, ConvertQCTdgOp, ConvertQCSXOp, + ConvertQCSXdgOp, ConvertQCRXOp, ConvertQCRYOp, ConvertQCRZOp, + ConvertQCPOp, ConvertQCROp, ConvertQCU2Op, ConvertQCUOp, + ConvertQCSWAPOp, ConvertQCiSWAPOp, ConvertQCDCXOp, ConvertQCECROp, + ConvertQCRXXOp, ConvertQCRYYOp, ConvertQCRZXOp, ConvertQCRZZOp, + ConvertQCXXPlusYYOp, ConvertQCXXMinusYYOp, ConvertQCBarrierOp, + ConvertQCCtrlOp, ConvertQCYieldOp, ConvertQCYieldOp, ConvertQCScfIfOp, + ConvertQCScfYieldOp, ConvertQCScfWhileOp, ConvertQCScfConditionOp, + ConvertQCScfForOp, ConvertQCFuncCallOp, ConvertQCFuncFuncOp, + ConvertQCFuncReturnOp>(typeConverter, context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index cb7098656a..202c6fb274 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -451,9 +451,9 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { // SCF operations //===----------------------------------------------------------------------===// -QuartzProgramBuilder& -QuartzProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body) { +QCProgramBuilder& +QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body) { create(loc, lowerbound, upperbound, step, ValueRange{}, [&](OpBuilder& b, Location, Value, ValueRange) { body(b); @@ -463,9 +463,9 @@ QuartzProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, return *this; } -QuartzProgramBuilder& QuartzProgramBuilder::scfWhile( - const std::function& beforeBody, - const std::function& afterBody) { +QCProgramBuilder& +QCProgramBuilder::scfWhile(const std::function& beforeBody, + const std::function& afterBody) { create( loc, TypeRange{}, ValueRange{}, [&](OpBuilder& b, Location, ValueRange) { beforeBody(b); }, @@ -477,10 +477,10 @@ QuartzProgramBuilder& QuartzProgramBuilder::scfWhile( return *this; } -QuartzProgramBuilder& -QuartzProgramBuilder::scfIf(Value cond, - const std::function& thenBody, - const std::function& elseBody) { +QCProgramBuilder& +QCProgramBuilder::scfIf(Value cond, + const std::function& thenBody, + const std::function& elseBody) { if (!elseBody) { create(loc, cond, [&](OpBuilder& b, Location loc) { thenBody(b); @@ -501,7 +501,7 @@ QuartzProgramBuilder::scfIf(Value cond, return *this; } -QuartzProgramBuilder& QuartzProgramBuilder::scfCondition(Value condition) { +QCProgramBuilder& QCProgramBuilder::scfCondition(Value condition) { create(loc, condition, ValueRange{}); return *this; } @@ -510,13 +510,13 @@ QuartzProgramBuilder& QuartzProgramBuilder::scfCondition(Value condition) { // Arith operations //===----------------------------------------------------------------------===// -Value QuartzProgramBuilder::arithConstantIndex(int index) { +Value QCProgramBuilder::arithConstantIndex(int index) { const auto op = create(loc, getIndexType(), getIndexAttr(index)); return op->getResult(0); } -Value QuartzProgramBuilder::arithConstantBool(bool b) { +Value QCProgramBuilder::arithConstantBool(bool b) { const auto i1Type = getI1Type(); const auto op = create(loc, i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 80571f3604..d55d74b3de 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -675,14 +675,14 @@ OwningOpRef QCOProgramBuilder::finalize() { return module; } -Value FluxProgramBuilder::arithConstantIndex(int i) { +Value QCOProgramBuilder::arithConstantIndex(int i) { const auto op = create(loc, getIndexType(), getIndexAttr(i)); return op->getResult(0); } -Value FluxProgramBuilder::arithConstantBool(bool b) { +Value QCOProgramBuilder::arithConstantBool(bool b) { const auto i1Type = getI1Type(); const auto op = b ? create(loc, i1Type, getIntegerAttr(i1Type, 1)) @@ -690,7 +690,7 @@ Value FluxProgramBuilder::arithConstantBool(bool b) { return op->getResult(0); } -ValueRange FluxProgramBuilder::scfFor( +ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, const std::function& body) { auto op = create(loc, lowerbound, upperbound, step, From 597f83849c40317c7510c5884cd526074415bc35 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 18 Dec 2025 19:29:42 +0100 Subject: [PATCH 020/108] add qco builder for scfFor operation --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 8 ++-- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 41 +++++++++++++++---- 2 files changed, 38 insertions(+), 11 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 8b45fd330b..1850566acc 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1031,9 +1031,11 @@ class QCOProgramBuilder final : public OpBuilder { Value arithConstantBool(bool b); - ValueRange - scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, - const std::function& body); + ValueRange scfFor(Value lowerbound, Value upperbound, Value step, + ValueRange initArgs, + const std::function& body); + QCOProgramBuilder& scfYield(Location loc, ValueRange yieldedValues); //===--------------------------------------------------------------------===// // Finalization //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index d55d74b3de..9b9e1df605 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -142,10 +142,8 @@ void QCOProgramBuilder::updateQubitTracking(Value inputQubit, Value outputQubit, Region* region) { // Validate the input qubit validateQubitValue(inputQubit); - // Remove the input (consumed) value from tracking validQubits[region].erase(inputQubit); - // Add the output (new) value to tracking validQubits[region].insert(outputQubit); } @@ -625,7 +623,6 @@ void QCOProgramBuilder::checkFinalized() const { OwningOpRef QCOProgramBuilder::finalize() { checkFinalized(); - // Ensure that main function exists and insertion point is valid auto* insertionBlock = getInsertionBlock(); func::FuncOp mainFunc = nullptr; @@ -690,14 +687,42 @@ Value QCOProgramBuilder::arithConstantBool(bool b) { return op->getResult(0); } +QCOProgramBuilder& QCOProgramBuilder::scfYield(Location loc, + ValueRange yieldedValues) { + create(loc, yieldedValues); + return *this; +} + ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, - const std::function& body) { - auto op = create(loc, lowerbound, upperbound, step, - initArgs // iter_args - ); + const std::function& + body) { + + auto forOp = create(loc, lowerbound, upperbound, step, initArgs); + Block* block = forOp.getBody(); + + // Block arguments: + // - arg 0 : induction variable + // - arg 1..n : iter_args + Value iv = block->getArgument(0); + ValueRange loopArgs = block->getArguments().drop_front(); + + // Set insertion point into the loop body + OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(block); + + // Register iter_args as valid qubits in this region + Region* bodyRegion = block->getParent(); + for (Value arg : loopArgs) { + validQubits[bodyRegion].insert(arg); + } - return ValueRange{op->getResults()}; + // Build user body + body(*this, loc, iv, loopArgs); + for (auto [initArg, result] : llvm::zip_equal(initArgs, forOp.getResults())) { + updateQubitTracking(initArg, result, forOp->getParentRegion()); + } + return forOp->getResults(); } } // namespace mlir::qco From 7ad53d3bc3e3b377bf1e73b1cd8b21336fbbd37f Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 18 Dec 2025 23:13:26 +0100 Subject: [PATCH 021/108] add qco builders for scf while and if --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 17 +++- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 87 +++++++++++++++++-- 2 files changed, 98 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 1850566acc..ed17f38445 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1035,7 +1035,22 @@ class QCOProgramBuilder final : public OpBuilder { ValueRange initArgs, const std::function& body); - QCOProgramBuilder& scfYield(Location loc, ValueRange yieldedValues); + + ValueRange + scfWhile(ValueRange args, + const std::function& + beforeBody, + const std::function& + afterBody); + ValueRange + scfIf(Value condition, ValueRange args, + const std::function& thenBody, + const std::function& elseBody); + + QCOProgramBuilder& scfYield(ValueRange yieldedValues); + + QCOProgramBuilder& scfCondition(Value condition, ValueRange yieldedValues); + //===--------------------------------------------------------------------===// // Finalization //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 9b9e1df605..4940eb4416 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -687,27 +687,29 @@ Value QCOProgramBuilder::arithConstantBool(bool b) { return op->getResult(0); } -QCOProgramBuilder& QCOProgramBuilder::scfYield(Location loc, - ValueRange yieldedValues) { +QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { create(loc, yieldedValues); return *this; } +QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, + ValueRange yieldedValues) { + create(loc, condition, yieldedValues); + return *this; +} ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, const std::function& body) { auto forOp = create(loc, lowerbound, upperbound, step, initArgs); - Block* block = forOp.getBody(); + auto* block = forOp.getBody(); // Block arguments: // - arg 0 : induction variable // - arg 1..n : iter_args Value iv = block->getArgument(0); ValueRange loopArgs = block->getArguments().drop_front(); - - // Set insertion point into the loop body OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(block); @@ -725,4 +727,79 @@ ValueRange QCOProgramBuilder::scfFor( return forOp->getResults(); } +ValueRange QCOProgramBuilder::scfWhile( + ValueRange initArgs, + const std::function& + beforeBody, + const std::function& + afterBody) { + auto whileOp = create(loc, initArgs.getTypes(), initArgs); + const SmallVector locs(initArgs.size(), loc); + // Before region (condition) + { + Block* block = + createBlock(&whileOp.getBefore(), {}, initArgs.getTypes(), locs); + ValueRange args = block->getArguments(); + + OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(block); + + Region* region = block->getParent(); + for (Value arg : args) { + validQubits[region].insert(arg); + } + + beforeBody(*this, loc, args); + } + + // After region (body) + { + Block* block = + createBlock(&whileOp.getAfter(), {}, initArgs.getTypes(), locs); + ValueRange args = block->getArguments(); + + OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(block); + + Region* region = block->getParent(); + for (Value arg : args) { + validQubits[region].insert(arg); + } + + ValueRange yields = afterBody(*this, loc, args); + } + for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { + updateQubitTracking(arg, result, whileOp->getParentRegion()); + } + setInsertionPointAfter(whileOp); + return whileOp->getResults(); +} +ValueRange QCOProgramBuilder::scfIf( + Value condition, ValueRange initArgs, + const std::function& thenBody, + const std::function& elseBody) { + auto ifOp = create(loc, initArgs.getTypes(), condition, + /*withElseRegion=*/true); + auto& thenBlock = ifOp.getThenRegion().front(); + auto& elseBlock = ifOp.getElseRegion().front(); + OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(&thenBlock); + for (Value arg : initArgs) { + validQubits[thenBlock.getParent()].insert(arg); + } + thenBody(*this, loc); + + OpBuilder::InsertionGuard guardElse(*this); + setInsertionPointToStart(&elseBlock); + for (Value arg : initArgs) { + validQubits[elseBlock.getParent()].insert(arg); + } + elseBody(*this, loc); + for (auto [arg, result] : llvm::zip_equal(initArgs, ifOp.getResults())) { + updateQubitTracking(arg, result, ifOp->getParentRegion()); + } + setInsertionPointAfter(ifOp); + return ifOp->getResults(); +} + } // namespace mlir::qco From 614f2f1ac65bf28c25f15c400851e48669a6931c Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 19 Dec 2025 10:45:23 +0100 Subject: [PATCH 022/108] add stubs for docstrings --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 138 ++++++++- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 272 ++++++++++-------- 2 files changed, 278 insertions(+), 132 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index ed17f38445..6af6bd2a06 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1027,29 +1027,155 @@ class QCOProgramBuilder final : public OpBuilder { */ QCOProgramBuilder& dealloc(Value qubit); - Value arithConstantIndex(int i); - - Value arithConstantBool(bool b); + //===--------------------------------------------------------------------===// + // SCF operations + //===--------------------------------------------------------------------===// + /** + * @brief Constructs a scf.for operation with iterArgs + * + * @param lowerbound Lowerbound of the loop + * @param upperbound Upperbound of the loop + * @param step Stepsize of the loop + * @param initArgs Initial arguments for the iterArgs + * @param body Function that builds the body of the for operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scfFor(lb, ub, step, [&](auto& b) { b.x(q0); }); + * ``` + * ```mlir + * scf.for %iv = %lb to %ub step %step { + * qc.x %q0 : !qc.qubit + * } + * ``` + */ ValueRange scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, const std::function& body); - + /** + * @brief Constructs a scf.while operation with return values + * + * @param args Arguments for the while loop + * @param beforeBody Function that builds the before body of the while + * operation + * @param afterBody Function that builds the after body of the while operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scfWhile([&](auto& b) { + * b.h(q0); + * auto res = b.measure(q0) + * b.condition(res) + * }, [&](auto& b) { + * b.x(q0); + * b.yield() + * }); + * ``` + * ```mlir + * scf.while : () -> () { + * qc.h %q0 : !qc.qubit + * %res = qc.measure %q0 : !qc.qubit -> i1 + * scf.condition(%tres) + * } do { + * qc.x %q0 : !qc.qubit + * scf.yield + * } + * ``` + */ ValueRange scfWhile(ValueRange args, const std::function& beforeBody, const std::function& afterBody); + + /** + * @brief Constructs a scf.if operation with return values + * + * @param condition Condition for the if operation + * @param qubits Qubits used in the if/else body + * @param thenBody Function that builds the then body of the if + * operation + * @param elseBody Function that builds the else body of the if operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scf.if(condition, [&](auto& b) { + * b.h(q0); + * }, [&](auto& b) { + * b.x(q0); + * }); + * ``` + * ```mlir + * scf.if %condition { + * qc.h %q0 : !qc.qubit + * } else { + * qc.x %q0 : !qc.qubit + * } + * ``` + */ ValueRange - scfIf(Value condition, ValueRange args, + scfIf(Value condition, ValueRange qubits, const std::function& thenBody, const std::function& elseBody); + /** + * @brief Constructs a scf.condition operation without any additional Values + * + * @param condition Condition for condition operation + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.condition(condition); + * ``` + * ```mlir + * scf.condition(%condition) + * ``` + */ + QCOProgramBuilder& scfCondition(Value condition, ValueRange yieldedValues); + QCOProgramBuilder& scfYield(ValueRange yieldedValues); + //===--------------------------------------------------------------------===// + // Arith operations + //===--------------------------------------------------------------------===// - QCOProgramBuilder& scfCondition(Value condition, ValueRange yieldedValues); + /** + * @brief Constructs a arith.constant of type Index with a given value + * + * @param index Value of the constant operation + * @return Result of the constant operation + * + * @par Example: + * ```c++ + * builder.arithConstantIndex(4); + * ``` + * ```mlir + * arith.constant 4 : index + * ``` + */ + Value arithConstantIndex(int i); + + /** + * @brief Constructs a arith.constant of type i1 with a given bool value + * + * @param b Bool value of the constant operation + * @return Result of the constant operation + * + * @par Example: + * ```c++ + * builder.arithConstantBool(true); + * ``` + * ```mlir + * arith.constant 1 : i1 + * ``` + */ + Value arithConstantBool(bool b); //===--------------------------------------------------------------------===// // Finalization diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 4940eb4416..9bd2bd3618 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -611,117 +611,36 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { } //===----------------------------------------------------------------------===// -// Finalization +// SCF operations //===----------------------------------------------------------------------===// -void QCOProgramBuilder::checkFinalized() const { - if (ctx == nullptr) { - llvm::reportFatalUsageError( - "QCOProgramBuilder instance has been finalized"); - } -} - -OwningOpRef QCOProgramBuilder::finalize() { - checkFinalized(); - // Ensure that main function exists and insertion point is valid - auto* insertionBlock = getInsertionBlock(); - func::FuncOp mainFunc = nullptr; - for (auto op : module.getOps()) { - if (op.getName() == "main") { - mainFunc = op; - break; - } - } - if (!mainFunc) { - llvm::reportFatalUsageError("Could not find main function"); - } - if ((insertionBlock == nullptr) || - insertionBlock != &mainFunc.getBody().front()) { - llvm::reportFatalUsageError( - "Insertion point is not in entry block of main function"); - } - - // Automatically deallocate all still-allocated qubits - // Sort qubits for deterministic output - llvm::SmallVector sortedQubits( - validQubits[&mainFunc->getRegion(0)].begin(), - validQubits[&mainFunc->getRegion(0)].end()); - llvm::sort(sortedQubits, [](Value a, Value b) { - auto* opA = a.getDefiningOp(); - auto* opB = b.getDefiningOp(); - if (!opA || !opB || opA->getBlock() != opB->getBlock()) { - return a.getAsOpaquePointer() < b.getAsOpaquePointer(); - } - return opA->isBeforeInBlock(opB); - }); - for (auto qubit : sortedQubits) { - DeallocOp::create(*this, loc, qubit); - } - - validQubits.clear(); - - // Create constant 0 for successful exit code - auto exitCode = arith::ConstantOp::create(*this, loc, getI64IntegerAttr(0)); - - // Add return statement with exit code 0 to the main function - func::ReturnOp::create(*this, loc, ValueRange{exitCode}); - - // Invalidate context to prevent use-after-finalize - ctx = nullptr; - - return module; -} - -Value QCOProgramBuilder::arithConstantIndex(int i) { - - const auto op = - create(loc, getIndexType(), getIndexAttr(i)); - return op->getResult(0); -} - -Value QCOProgramBuilder::arithConstantBool(bool b) { - const auto i1Type = getI1Type(); - const auto op = - b ? create(loc, i1Type, getIntegerAttr(i1Type, 1)) - : create(loc, i1Type, getIntegerAttr(i1Type, 0)); - return op->getResult(0); -} - -QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { - create(loc, yieldedValues); - return *this; -} - -QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, - ValueRange yieldedValues) { - create(loc, condition, yieldedValues); - return *this; -} ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, const std::function& body) { + // Create the empty for operation auto forOp = create(loc, lowerbound, upperbound, step, initArgs); - auto* block = forOp.getBody(); + auto* forBody = forOp.getBody(); + auto iv = forBody->getArgument(0); + auto loopArgs = forBody->getArguments().drop_front(); - // Block arguments: - // - arg 0 : induction variable - // - arg 1..n : iter_args - Value iv = block->getArgument(0); - ValueRange loopArgs = block->getArguments().drop_front(); + // Set the insertionpoint OpBuilder::InsertionGuard guard(*this); - setInsertionPointToStart(block); + setInsertionPointToStart(forBody); - // Register iter_args as valid qubits in this region - Region* bodyRegion = block->getParent(); + // Add the iterArgs to the validQubits + auto* bodyRegion = forBody->getParent(); for (Value arg : loopArgs) { validQubits[bodyRegion].insert(arg); } - // Build user body + // Build the body body(*this, loc, iv, loopArgs); - for (auto [initArg, result] : llvm::zip_equal(initArgs, forOp.getResults())) { + + // Update the qubit tracking + for (const auto& [initArg, result] : + llvm::zip_equal(initArgs, forOp.getResults())) { updateQubitTracking(initArg, result, forOp->getParentRegion()); } return forOp->getResults(); @@ -735,71 +654,172 @@ ValueRange QCOProgramBuilder::scfWhile( afterBody) { auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); - // Before region (condition) { - Block* block = + auto* beforeBlock = createBlock(&whileOp.getBefore(), {}, initArgs.getTypes(), locs); - ValueRange args = block->getArguments(); - - OpBuilder::InsertionGuard guard(*this); - setInsertionPointToStart(block); + auto beforeArgs = beforeBlock->getArguments(); + auto* beforeRegion = beforeBlock->getParent(); + OpBuilder::InsertionGuard beforeGuard(*this); + setInsertionPointToStart(beforeBlock); - Region* region = block->getParent(); - for (Value arg : args) { - validQubits[region].insert(arg); + for (Value arg : beforeArgs) { + validQubits[beforeRegion].insert(arg); } - beforeBody(*this, loc, args); + beforeBody(*this, loc, beforeArgs); } - - // After region (body) { - Block* block = + auto* afterBlock = createBlock(&whileOp.getAfter(), {}, initArgs.getTypes(), locs); - ValueRange args = block->getArguments(); + auto afterArgs = afterBlock->getArguments(); - OpBuilder::InsertionGuard guard(*this); - setInsertionPointToStart(block); + OpBuilder::InsertionGuard afterGuard(*this); + setInsertionPointToStart(afterBlock); - Region* region = block->getParent(); - for (Value arg : args) { - validQubits[region].insert(arg); + auto* afterRegion = afterBlock->getParent(); + for (Value arg : afterArgs) { + validQubits[afterRegion].insert(arg); } - ValueRange yields = afterBody(*this, loc, args); - } - for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { - updateQubitTracking(arg, result, whileOp->getParentRegion()); + afterBody(*this, loc, afterArgs); + + for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { + updateQubitTracking(arg, result, whileOp->getParentRegion()); + } } + setInsertionPointAfter(whileOp); return whileOp->getResults(); } + ValueRange QCOProgramBuilder::scfIf( - Value condition, ValueRange initArgs, + Value condition, ValueRange qubits, const std::function& thenBody, const std::function& elseBody) { - auto ifOp = create(loc, initArgs.getTypes(), condition, + // Create the empty while operation + auto ifOp = create(loc, qubits.getTypes(), condition, /*withElseRegion=*/true); auto& thenBlock = ifOp.getThenRegion().front(); auto& elseBlock = ifOp.getElseRegion().front(); + auto* thenRegion = thenBlock.getParent(); + auto* elseRegion = elseBlock.getParent(); + + // Set the insertionpoint OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(&thenBlock); - for (Value arg : initArgs) { - validQubits[thenBlock.getParent()].insert(arg); + + // add the qubits to the validQubits of the then and else region + for (Value arg : qubits) { + validQubits[thenRegion].insert(arg); + validQubits[elseRegion].insert(arg); } + + // Build the then body thenBody(*this, loc); + // Set the insertionpoint OpBuilder::InsertionGuard guardElse(*this); setInsertionPointToStart(&elseBlock); - for (Value arg : initArgs) { - validQubits[elseBlock.getParent()].insert(arg); - } + + // Build the else body elseBody(*this, loc); - for (auto [arg, result] : llvm::zip_equal(initArgs, ifOp.getResults())) { + + // Update the qubit tracking + for (auto [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { updateQubitTracking(arg, result, ifOp->getParentRegion()); } + setInsertionPointAfter(ifOp); return ifOp->getResults(); } +QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, + ValueRange yieldedValues) { + create(loc, condition, yieldedValues); + return *this; +} + +QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { + create(loc, yieldedValues); + return *this; +} + +//===----------------------------------------------------------------------===// +// Arith operations +//===----------------------------------------------------------------------===// + +Value QCOProgramBuilder::arithConstantIndex(int i) { + const auto op = + create(loc, getIndexType(), getIndexAttr(i)); + return op->getResult(0); +} + +Value QCOProgramBuilder::arithConstantBool(bool b) { + const auto i1Type = getI1Type(); + const auto op = + create(loc, i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); + return op->getResult(0); +} +//===----------------------------------------------------------------------===// +// Finalization +//===----------------------------------------------------------------------===// + +void QCOProgramBuilder::checkFinalized() const { + if (ctx == nullptr) { + llvm::reportFatalUsageError( + "QCOProgramBuilder instance has been finalized"); + } +} + +OwningOpRef QCOProgramBuilder::finalize() { + checkFinalized(); + // Ensure that main function exists and insertion point is valid + auto* insertionBlock = getInsertionBlock(); + func::FuncOp mainFunc = nullptr; + for (auto op : module.getOps()) { + if (op.getName() == "main") { + mainFunc = op; + break; + } + } + if (!mainFunc) { + llvm::reportFatalUsageError("Could not find main function"); + } + if ((insertionBlock == nullptr) || + insertionBlock != &mainFunc.getBody().front()) { + llvm::reportFatalUsageError( + "Insertion point is not in entry block of main function"); + } + + // Automatically deallocate all still-allocated qubits + // Sort qubits for deterministic output + llvm::SmallVector sortedQubits( + validQubits[&mainFunc->getRegion(0)].begin(), + validQubits[&mainFunc->getRegion(0)].end()); + llvm::sort(sortedQubits, [](Value a, Value b) { + auto* opA = a.getDefiningOp(); + auto* opB = b.getDefiningOp(); + if (!opA || !opB || opA->getBlock() != opB->getBlock()) { + return a.getAsOpaquePointer() < b.getAsOpaquePointer(); + } + return opA->isBeforeInBlock(opB); + }); + for (auto qubit : sortedQubits) { + DeallocOp::create(*this, loc, qubit); + } + + validQubits.clear(); + + // Create constant 0 for successful exit code + auto exitCode = arith::ConstantOp::create(*this, loc, getI64IntegerAttr(0)); + + // Add return statement with exit code 0 to the main function + func::ReturnOp::create(*this, loc, ValueRange{exitCode}); + + // Invalidate context to prevent use-after-finalize + ctx = nullptr; + + return module; +} + } // namespace mlir::qco From 6ada6048fcc385f6923fc4d86b47694d10f71b88 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 19 Dec 2025 16:15:23 +0100 Subject: [PATCH 023/108] add builders for func operations --- .../Dialect/QC/Builder/QCProgramBuilder.h | 11 +++ .../Dialect/QCO/Builder/QCOProgramBuilder.h | 13 +++ .../Dialect/QC/Builder/QCProgramBuilder.cpp | 31 +++++++ .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 89 +++++++++++++------ 4 files changed, 117 insertions(+), 27 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index e9534e3cd0..3e26ee66c1 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -972,6 +972,17 @@ class QCProgramBuilder final : public OpBuilder { */ QCProgramBuilder& scfCondition(Value condition); + //===--------------------------------------------------------------------===// + // Func operations + //===--------------------------------------------------------------------===// + QCProgramBuilder& funcReturn(); + + QCProgramBuilder& funcCall(StringRef name, ValueRange operands); + + QCProgramBuilder& + funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, + const std::function& body); + //===--------------------------------------------------------------------===// // Arith operations //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 6af6bd2a06..391e02ceef 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1141,6 +1141,19 @@ class QCOProgramBuilder final : public OpBuilder { QCOProgramBuilder& scfCondition(Value condition, ValueRange yieldedValues); QCOProgramBuilder& scfYield(ValueRange yieldedValues); + + //===--------------------------------------------------------------------===// + // Func operations + //===--------------------------------------------------------------------===// + + QCOProgramBuilder& funcReturn(ValueRange yieldedValues); + + ValueRange funcCall(StringRef name, ValueRange operands); + + QCOProgramBuilder& + funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, + const std::function& body); + //===--------------------------------------------------------------------===// // Arith operations //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 202c6fb274..bd0f30cf07 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -506,6 +506,37 @@ QCProgramBuilder& QCProgramBuilder::scfCondition(Value condition) { return *this; } +//===----------------------------------------------------------------------===// +// Func operations +//===----------------------------------------------------------------------===// + +QCProgramBuilder& QCProgramBuilder::funcCall(StringRef name, + ValueRange operands) { + create(loc, name, TypeRange{}, operands); + return *this; +} + +QCProgramBuilder& QCProgramBuilder::funcReturn() { + create(loc); + return *this; +} +QCProgramBuilder& QCProgramBuilder::funcFunc( + StringRef name, TypeRange argTypes, TypeRange resultTypes, + const std::function& body) { + const auto funcType = getFunctionType(argTypes, resultTypes); + OpBuilder::InsertionGuard guard(*this); + setInsertionPointToEnd(module.getBody()); + auto funcOp = create(loc, name, funcType); + + auto* entryBlock = funcOp.addEntryBlock(); + + setInsertionPointToStart(entryBlock); + + // Build function body + body(*this, loc, entryBlock->getArguments()); + return *this; +} + //===----------------------------------------------------------------------===// // Arith operations //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 9bd2bd3618..2748aa75f0 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -634,7 +634,6 @@ ValueRange QCOProgramBuilder::scfFor( for (Value arg : loopArgs) { validQubits[bodyRegion].insert(arg); } - // Build the body body(*this, loc, iv, loopArgs); @@ -654,38 +653,36 @@ ValueRange QCOProgramBuilder::scfWhile( afterBody) { auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); - { - auto* beforeBlock = - createBlock(&whileOp.getBefore(), {}, initArgs.getTypes(), locs); - auto beforeArgs = beforeBlock->getArguments(); - auto* beforeRegion = beforeBlock->getParent(); - OpBuilder::InsertionGuard beforeGuard(*this); - setInsertionPointToStart(beforeBlock); - - for (Value arg : beforeArgs) { - validQubits[beforeRegion].insert(arg); - } - beforeBody(*this, loc, beforeArgs); + OpBuilder::InsertionGuard guard(*this); + auto* beforeBlock = + createBlock(&whileOp.getBefore(), {}, initArgs.getTypes(), locs); + auto beforeArgs = beforeBlock->getArguments(); + auto* beforeRegion = beforeBlock->getParent(); + + setInsertionPointToStart(beforeBlock); + + for (Value arg : beforeArgs) { + validQubits[beforeRegion].insert(arg); } - { - auto* afterBlock = - createBlock(&whileOp.getAfter(), {}, initArgs.getTypes(), locs); - auto afterArgs = afterBlock->getArguments(); - OpBuilder::InsertionGuard afterGuard(*this); - setInsertionPointToStart(afterBlock); + beforeBody(*this, loc, beforeArgs); - auto* afterRegion = afterBlock->getParent(); - for (Value arg : afterArgs) { - validQubits[afterRegion].insert(arg); - } + auto* afterBlock = + createBlock(&whileOp.getAfter(), {}, initArgs.getTypes(), locs); + auto afterArgs = afterBlock->getArguments(); - afterBody(*this, loc, afterArgs); + setInsertionPointToStart(afterBlock); - for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { - updateQubitTracking(arg, result, whileOp->getParentRegion()); - } + auto* afterRegion = afterBlock->getParent(); + for (Value arg : afterArgs) { + validQubits[afterRegion].insert(arg); + } + + afterBody(*this, loc, afterArgs); + + for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { + updateQubitTracking(arg, result, whileOp->getParentRegion()); } setInsertionPointAfter(whileOp); @@ -744,6 +741,44 @@ QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { return *this; } +//===----------------------------------------------------------------------===// +// Func operations +//===----------------------------------------------------------------------===// + +ValueRange QCOProgramBuilder::funcCall(StringRef name, ValueRange operands) { + const auto callOp = + create(loc, name, operands.getTypes(), operands); + for (auto [arg, result] : llvm::zip_equal(operands, callOp->getResults())) { + updateQubitTracking(arg, result, callOp->getParentRegion()); + } + return callOp->getResults(); +} + +QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange yieldedValues) { + create(loc, yieldedValues); + return *this; +} +QCOProgramBuilder& QCOProgramBuilder::funcFunc( + StringRef name, TypeRange argTypes, TypeRange resultTypes, + const std::function& body) { + const auto funcType = getFunctionType(argTypes, resultTypes); + OpBuilder::InsertionGuard guard(*this); + setInsertionPointToEnd(module.getBody()); + auto funcOp = create(loc, name, funcType); + + auto* entryBlock = funcOp.addEntryBlock(); + + for (Value arg : entryBlock->getArguments()) { + validQubits[entryBlock->getParent()].insert(arg); + } + + setInsertionPointToStart(entryBlock); + + // Build function body + body(*this, loc, entryBlock->getArguments()); + return *this; +} + //===----------------------------------------------------------------------===// // Arith operations //===----------------------------------------------------------------------===// From 7358d3b6562f3faaa67695cf727df720c0732b58 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 15:05:13 +0100 Subject: [PATCH 024/108] correctly change the induction variable in scf for conversion --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 24 ++++++++++++------------ mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 6 ++++-- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 7805080b26..b3ba10e90b 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -840,18 +840,18 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - // create the new if operation + // Create the new if operation auto newIf = rewriter.create(op.getLoc(), ValueRange{}, op.getCondition(), op.getElseRegion().empty()); - // inline the regions + // Inline the regions rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); if (!op.getElseRegion().empty()) { rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); } - // erase the empty block that was created during the initialization + // Erase the empty block that was created during the initialization rewriter.eraseBlock(&newIf.getThenRegion().front()); const auto& yield = @@ -894,11 +894,11 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { LogicalResult matchAndRewrite(scf::WhileOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // create the new while operation + // Create the new while operation auto newWhileOp = rewriter.create(op->getLoc(), ValueRange{}, ValueRange{}); - // replace the uses of the blockarguments with the init values + // Replace the uses of the blockarguments with the init values const auto& inits = adaptor.getInits(); const auto beforeArgs = op.getBeforeArguments(); const auto afterArgs = op.getAfterArguments(); @@ -906,7 +906,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { beforeArgs[i].replaceAllUsesWith(inits[i]); afterArgs[i].replaceAllUsesWith(inits[i]); } - // create the blocks of the new operation and move the operations to them + // Create the blocks of the new operation and move the operations to them auto* newBeforeBlock = rewriter.createBlock(&newWhileOp.getBefore(), {}, {}, {}); auto* newAfterBlock = @@ -953,20 +953,20 @@ struct ConvertQCOScfForOp final : OpConversionPattern { op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), ValueRange{}); - // replace the uses of the previous iter_args + // Replace the uses of the previous iter_args for (const auto& [qcoQubit, qcQubit] : llvm::zip_equal(op.getRegionIterArgs(), adaptor.getInitArgs())) { qcoQubit.replaceAllUsesWith(qcQubit); } - // move all the operations from the old block to the new block + // Move all the operations from the old block to the new block auto* newBlock = newFor.getBody(); - // erase the existing yield operation + // Erase the existing yield operation rewriter.eraseOp(newBlock->getTerminator()); newBlock->getOperations().splice(newBlock->end(), op.getBody()->getOperations()); - - // replace the result values with the init values + rewriter.replaceAllUsesWith(op.getInductionVar(), newFor.getInductionVar()); + // Replace the result values with the init values rewriter.replaceOp(op, adaptor.getInitArgs()); return success(); } @@ -1194,7 +1194,7 @@ struct QCOToQC final : impl::QCOToQCBase { // Configure conversion target: QCO illegal, QC legal target.addIllegalDialect(); target.addLegalDialect(); - + target.addLegalDialect(); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index d4ea2b95ac..10f4cfa925 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1434,7 +1434,7 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { const auto& qcQubits = getState().regionMap[op]; SmallVector qcoQubits; - qcoQubits.reserve(qubitMap.size()); + qcoQubits.reserve(qcoQubits.size()); for (const auto& qcQubit : qcQubits) { qcoQubits.push_back(qubitMap[qcQubit]); } @@ -1447,7 +1447,9 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { // move the operations to the new block auto& srcBlock = op.getRegion().front(); auto& dstBlock = newFor.getRegion().front(); + dstBlock.getOperations().splice(dstBlock.end(), srcBlock.getOperations()); + rewriter.replaceAllUsesWith(op.getInductionVar(), newFor.getInductionVar()); auto& newRegion = newFor.getRegion(); auto& regionQubitMap = getState().qubitMap[&newRegion]; @@ -1694,7 +1696,7 @@ struct QCToQCO final : impl::QCToQCOBase { // legal target.addIllegalDialect(); target.addLegalDialect(); - + target.addLegalDialect(); target.addDynamicallyLegalOp([&](scf::YieldOp op) { return !(op->getAttrOfType("needChange")); }); From 010892b0d412d3438e5352a72e496f572d54760e Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 16:18:19 +0100 Subject: [PATCH 025/108] add simple conversion tests --- mlir/unittests/CMakeLists.txt | 6 +- mlir/unittests/conversion/CMakeLists.txt | 30 ++ mlir/unittests/conversion/test_conversion.cpp | 352 ++++++++++++++++++ 3 files changed, 386 insertions(+), 2 deletions(-) create mode 100644 mlir/unittests/conversion/CMakeLists.txt create mode 100644 mlir/unittests/conversion/test_conversion.cpp diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 43ffbbd241..3060e2f424 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -9,8 +9,10 @@ add_subdirectory(dialect) add_subdirectory(pipeline) add_subdirectory(translation) +add_subdirectory(conversion) add_custom_target(mqt-core-mlir-unittests) -add_dependencies(mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test - mqt-core-mlir-translation-test mqt-core-mlir-wireiterator-test) +add_dependencies( + mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test mqt-core-mlir-translation-test + mqt-core-mlir-wireiterator-test mqt-core-mlir-conversion-test) diff --git a/mlir/unittests/conversion/CMakeLists.txt b/mlir/unittests/conversion/CMakeLists.txt new file mode 100644 index 0000000000..833dc2371b --- /dev/null +++ b/mlir/unittests/conversion/CMakeLists.txt @@ -0,0 +1,30 @@ +# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM +# Copyright (c) 2025 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(testname "mqt-core-mlir-conversion-test") +file(GLOB_RECURSE CONVERSION_TEST_SOURCES *.cpp) + +if(NOT TARGET ${testname}) + # create an executable in which the tests will be stored + add_executable(${testname} ${CONVERSION_TEST_SOURCES}) + # link the Google test infrastructure and a default main function to the test executable. + target_link_libraries( + ${testname} + PRIVATE GTest::gtest_main + MLIRParser + MLIRQCProgramBuilder + QCToQCO + MLIRPass + MLIRTransforms + MLIRLLVMDialect + QCOToQC + MLIRQCOProgramBuilder) + # discover tests + gtest_discover_tests(${testname} DISCOVERY_TIMEOUT 60) + set_target_properties(${testname} PROPERTIES FOLDER unittests) +endif() diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp new file mode 100644 index 0000000000..2b69ef8699 --- /dev/null +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -0,0 +1,352 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Conversion/QCOToQC/QCOToQC.h" +#include "mlir/Conversion/QCToQCO/QCToQCO.h" +#include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" +#include "mlir/Dialect/QC/IR/QCDialect.h" +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/IR/Builders.h" + +#include +#include +#include +#include +#include +#include + +using namespace mlir; + +class ConversionTest : public ::testing::Test { +protected: + std::unique_ptr context; + void SetUp() override { + // Register all dialects needed for the full compilation pipeline + DialectRegistry registry; + registry.insert(); + + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } + + [[nodiscard]] OwningOpRef buildQCIR( + const std::function& buildFunc) const { + mlir::qc::QCProgramBuilder builder(context.get()); + builder.initialize(); + buildFunc(builder); + auto module = builder.finalize(); + return module; + } + [[nodiscard]] OwningOpRef buildQCOIR( + const std::function& buildFunc) const { + qco::QCOProgramBuilder builder(context.get()); + builder.initialize(); + buildFunc(builder); + auto module = builder.finalize(); + return module; + } +}; + +static std::string getOutputString(mlir::OwningOpRef* module) { + std::string outputString; + llvm::raw_string_ostream os(outputString); + (*module)->print(os); + os.flush(); + return outputString; +} + +TEST_F(ConversionTest, ScfForTest) { + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + b.scfFor(c0, c2, c1, [&](OpBuilder& b) { + static_cast(b).h(q0); + static_cast(b).x(q0); + static_cast(b).h(q0); + }); + b.h(q0); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + auto scfForRes = b.scfFor( + c0, c2, c1, ValueRange{q0}, + [&](OpBuilder& b, Location, Value, ValueRange iterArgs) -> ValueRange { + auto q1 = + static_cast(b).h(iterArgs[0]); + auto q2 = static_cast(b).x(q1); + auto q3 = static_cast(b).h(q2); + static_cast(b).scfYield( + ValueRange{q3}); + return {q3}; + }); + b.h(scfForRes[0]); + }); + + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfForTest2) { + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + auto scfForRes = b.scfFor( + c0, c2, c1, ValueRange{q0}, + [&](OpBuilder& b, Location, Value, ValueRange iterArgs) -> ValueRange { + auto q1 = + static_cast(b).h(iterArgs[0]); + auto q2 = static_cast(b).x(q1); + auto q3 = static_cast(b).h(q2); + static_cast(b).scfYield( + ValueRange{q3}); + return {q3}; + }); + b.h(scfForRes[0]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + b.scfFor(c0, c2, c1, [&](OpBuilder& b) { + static_cast(b).h(q0); + static_cast(b).x(q0); + static_cast(b).h(q0); + }); + b.h(q0); + }); + + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfWhileTest) { + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + + b.scfWhile( + [&](OpBuilder& b) { + auto measure = + static_cast(b).measure(q0); + static_cast(b).scfCondition(measure); + }, + [&](OpBuilder& b) { + static_cast(b).h(q0); + static_cast(b).y(q0); + }); + b.h(q0); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto scfWhileResult = b.scfWhile( + ValueRange{q0}, + [&](OpBuilder& b, Location, ValueRange iterArgs) { + auto measure = static_cast(b).measure( + iterArgs[0]); + static_cast(b).scfCondition( + measure.second, ValueRange{measure.first}); + return ValueRange{measure.first}; + }, + [&](OpBuilder& b, Location, ValueRange iterArgs) { + auto q1 = + static_cast(b).h(iterArgs[0]); + auto q2 = static_cast(b).y(q1); + static_cast(b).scfYield({q2}); + return ValueRange{q2}; + }); + b.h(scfWhileResult[0]); + }); + + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfWhileTest2) { + + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto scfWhileResult = b.scfWhile( + ValueRange{q0}, + [&](OpBuilder& b, Location, ValueRange iterArgs) { + auto measure = static_cast(b).measure( + iterArgs[0]); + static_cast(b).scfCondition( + measure.second, ValueRange{measure.first}); + return ValueRange{measure.first}; + }, + [&](OpBuilder& b, Location, ValueRange iterArgs) { + auto q1 = + static_cast(b).h(iterArgs[0]); + auto q2 = static_cast(b).y(q1); + static_cast(b).scfYield({q2}); + return ValueRange{q2}; + }); + b.h(scfWhileResult[0]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + b.scfWhile( + [&](OpBuilder& b) { + auto measure = + static_cast(b).measure(q0); + static_cast(b).scfCondition(measure); + }, + [&](OpBuilder& b) { + static_cast(b).h(q0); + static_cast(b).y(q0); + }); + b.h(q0); + }); + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfIfTest) { + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto measure = b.measure(q0); + b.scfIf( + measure, + [&](OpBuilder& b) { + static_cast(b).h(q0); + static_cast(b).y(q0); + }, + [&](OpBuilder& b) { + static_cast(b).y(q0); + static_cast(b).h(q0); + }); + b.h(q0); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto measure = b.measure(q0); + auto scfIfResult = b.scfIf( + measure.second, ValueRange{measure.first}, + [&](OpBuilder& b, Location) -> ValueRange { + auto q1 = + static_cast(b).h(measure.first); + auto q2 = static_cast(b).y(q1); + static_cast(b).scfYield(q2); + return {q2}; + }, + [&](OpBuilder& b, Location) -> ValueRange { + auto q1 = + static_cast(b).y(measure.first); + auto q2 = static_cast(b).h(q1); + static_cast(b).scfYield(q2); + return {q2}; + }); + b.h(scfIfResult[0]); + }); + + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfIfTest2) { + + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto measure = b.measure(q0); + auto scfIfResult = b.scfIf( + measure.second, ValueRange{measure.first}, + [&](OpBuilder& b, Location) -> ValueRange { + auto q1 = + static_cast(b).h(measure.first); + auto q2 = static_cast(b).y(q1); + static_cast(b).scfYield(q2); + return {q2}; + }, + [&](OpBuilder& b, Location) -> ValueRange { + auto q1 = + static_cast(b).y(measure.first); + auto q2 = static_cast(b).h(q1); + static_cast(b).scfYield(q2); + return {q2}; + }); + b.h(scfIfResult[0]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto measure = b.measure(q0); + b.scfIf( + measure, + [&](OpBuilder& b) { + static_cast(b).h(q0); + static_cast(b).y(q0); + }, + [&](OpBuilder& b) { + static_cast(b).y(q0); + static_cast(b).h(q0); + }); + b.h(q0); + }); + + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} From 5a42a8db098f9e3e58456ef7b9dd4bd9e060fea9 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 17:09:50 +0100 Subject: [PATCH 026/108] fixing some tests --- mlir/unittests/conversion/test_conversion.cpp | 180 ++++++++++++------ 1 file changed, 124 insertions(+), 56 deletions(-) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 2b69ef8699..faecd932a7 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -14,14 +14,16 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" -#include "mlir/IR/Builders.h" #include #include #include #include #include +#include +#include #include +#include using namespace mlir; @@ -66,14 +68,18 @@ static std::string getOutputString(mlir::OwningOpRef* module) { } TEST_F(ConversionTest, ScfForTest) { + // Test conversion from qc to qco for scf.for operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); b.scfFor(c0, c2, c1, [&](OpBuilder& b) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).x(q0); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); }); b.h(q0); @@ -82,6 +88,7 @@ TEST_F(ConversionTest, ScfForTest) { PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf.for"; } auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { @@ -91,14 +98,18 @@ TEST_F(ConversionTest, ScfForTest) { auto c2 = b.arithConstantIndex(2); auto scfForRes = b.scfFor( c0, c2, c1, ValueRange{q0}, - [&](OpBuilder& b, Location, Value, ValueRange iterArgs) -> ValueRange { - auto q1 = + [&](OpBuilder& b, Location, Value, ValueRange iterArgs) { + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(iterArgs[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto q2 = static_cast(b).x(q1); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto q3 = static_cast(b).h(q2); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield( ValueRange{q3}); - return {q3}; + return q3; }); b.h(scfForRes[0]); }); @@ -110,6 +121,7 @@ TEST_F(ConversionTest, ScfForTest) { } TEST_F(ConversionTest, ScfForTest2) { + // Test conversion from qco to qc for scf.for operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto c0 = b.arithConstantIndex(0); @@ -117,14 +129,18 @@ TEST_F(ConversionTest, ScfForTest2) { auto c2 = b.arithConstantIndex(2); auto scfForRes = b.scfFor( c0, c2, c1, ValueRange{q0}, - [&](OpBuilder& b, Location, Value, ValueRange iterArgs) -> ValueRange { - auto q1 = + [&](OpBuilder& b, Location, Value, ValueRange iterArgs) { + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(iterArgs[0]); - auto q2 = static_cast(b).x(q1); - auto q3 = static_cast(b).h(q2); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).x( + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q3 = static_cast(b).h( + q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield( ValueRange{q3}); - return {q3}; + return q3; }); b.h(scfForRes[0]); }); @@ -132,6 +148,7 @@ TEST_F(ConversionTest, ScfForTest2) { PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf.for"; } auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { @@ -140,8 +157,11 @@ TEST_F(ConversionTest, ScfForTest2) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); b.scfFor(c0, c2, c1, [&](OpBuilder& b) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).x(q0); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); }); b.h(q0); @@ -154,17 +174,21 @@ TEST_F(ConversionTest, ScfForTest2) { } TEST_F(ConversionTest, ScfWhileTest) { + // Test conversion from qc to qco for scf.while operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); - b.scfWhile( [&](OpBuilder& b) { - auto measure = + auto + measure = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).measure(q0); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition(measure); }, - [&](OpBuilder& b) { + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y(q0); }); b.h(q0); @@ -173,6 +197,7 @@ TEST_F(ConversionTest, ScfWhileTest) { PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf.while"; } auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { @@ -180,18 +205,24 @@ TEST_F(ConversionTest, ScfWhileTest) { auto scfWhileResult = b.scfWhile( ValueRange{q0}, [&](OpBuilder& b, Location, ValueRange iterArgs) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto measure = static_cast(b).measure( - iterArgs[0]); + iterArgs + [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition( measure.second, ValueRange{measure.first}); - return ValueRange{measure.first}; + return measure.first; }, [&](OpBuilder& b, Location, ValueRange iterArgs) { - auto q1 = - static_cast(b).h(iterArgs[0]); - auto q2 = static_cast(b).y(q1); + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h( + iterArgs + [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).y( + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield({q2}); - return ValueRange{q2}; + return q2; }); b.h(scfWhileResult[0]); }); @@ -203,24 +234,31 @@ TEST_F(ConversionTest, ScfWhileTest) { } TEST_F(ConversionTest, ScfWhileTest2) { - + // Test conversion from qco to qc for scf.while operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](OpBuilder& b, Location, ValueRange iterArgs) { + [&](OpBuilder& b, Location, + ValueRange + iterArgs) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto measure = static_cast(b).measure( - iterArgs[0]); + iterArgs + [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition( measure.second, ValueRange{measure.first}); - return ValueRange{measure.first}; + return measure.first; }, [&](OpBuilder& b, Location, ValueRange iterArgs) { - auto q1 = - static_cast(b).h(iterArgs[0]); - auto q2 = static_cast(b).y(q1); + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h( + iterArgs + [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).y( + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield({q2}); - return ValueRange{q2}; + return q2; }); b.h(scfWhileResult[0]); }); @@ -228,18 +266,22 @@ TEST_F(ConversionTest, ScfWhileTest2) { PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf.while"; } auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); b.scfWhile( - [&](OpBuilder& b) { - auto measure = - static_cast(b).measure(q0); + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto measure = static_cast(b).measure( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition(measure); }, - [&](OpBuilder& b) { - static_cast(b).h(q0); + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y(q0); }); b.h(q0); @@ -251,17 +293,22 @@ TEST_F(ConversionTest, ScfWhileTest2) { } TEST_F(ConversionTest, ScfIfTest) { + // Test conversion from qc to qco for scf.if operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); auto measure = b.measure(q0); b.scfIf( measure, - [&](OpBuilder& b) { - static_cast(b).h(q0); + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y(q0); }, - [&](OpBuilder& b) { - static_cast(b).y(q0); + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).y( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); }); b.h(q0); @@ -270,6 +317,7 @@ TEST_F(ConversionTest, ScfIfTest) { PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf.if"; } auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { @@ -277,19 +325,26 @@ TEST_F(ConversionTest, ScfIfTest) { auto measure = b.measure(q0); auto scfIfResult = b.scfIf( measure.second, ValueRange{measure.first}, - [&](OpBuilder& b, Location) -> ValueRange { - auto q1 = + [&](OpBuilder& b, Location) { + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(measure.first); - auto q2 = static_cast(b).y(q1); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).y( + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield(q2); - return {q2}; + return q2; }, - [&](OpBuilder& b, Location) -> ValueRange { - auto q1 = - static_cast(b).y(measure.first); - auto q2 = static_cast(b).h(q1); + [&](OpBuilder& b, Location) { + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).y( + measure + .first); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).h( + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield(q2); - return {q2}; + return q2; }); b.h(scfIfResult[0]); }); @@ -301,25 +356,33 @@ TEST_F(ConversionTest, ScfIfTest) { } TEST_F(ConversionTest, ScfIfTest2) { - + // Test conversion from qco to qc for scf.if operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto measure = b.measure(q0); auto scfIfResult = b.scfIf( measure.second, ValueRange{measure.first}, [&](OpBuilder& b, Location) -> ValueRange { - auto q1 = - static_cast(b).h(measure.first); - auto q2 = static_cast(b).y(q1); + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h( + measure + .first); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).y( + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield(q2); - return {q2}; + return q2; // NOLINT }, [&](OpBuilder& b, Location) -> ValueRange { - auto q1 = - static_cast(b).y(measure.first); + auto + q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).y( + measure + .first); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto q2 = static_cast(b).h(q1); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfYield(q2); - return {q2}; + return q2; // NOLINT }); b.h(scfIfResult[0]); }); @@ -327,6 +390,7 @@ TEST_F(ConversionTest, ScfIfTest2) { PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf.if"; } auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { @@ -334,12 +398,16 @@ TEST_F(ConversionTest, ScfIfTest2) { auto measure = b.measure(q0); b.scfIf( measure, - [&](OpBuilder& b) { - static_cast(b).h(q0); + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y(q0); }, - [&](OpBuilder& b) { - static_cast(b).y(q0); + [&](OpBuilder& + b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).y( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(q0); }); b.h(q0); From ad9acd2847c739dd4d36875d10f227f29306b130 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 17:26:11 +0100 Subject: [PATCH 027/108] remove redundant location --- .../Dialect/QC/Builder/QCProgramBuilder.h | 2 +- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 22 +++++++-------- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 4 +-- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 27 +++++++++---------- 4 files changed, 24 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 7b64bc0a09..85d065ddda 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -981,7 +981,7 @@ class QCProgramBuilder final : public OpBuilder { QCProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body); + const std::function& body); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index a832cff874..a3ccca4745 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1052,10 +1052,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange scfFor(Value lowerbound, Value upperbound, Value step, - ValueRange initArgs, - const std::function& body); + ValueRange + scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, + const std::function& body); /** * @brief Constructs a scf.while operation with return values * @@ -1089,10 +1088,8 @@ class QCOProgramBuilder final : public OpBuilder { */ ValueRange scfWhile(ValueRange args, - const std::function& - beforeBody, - const std::function& - afterBody); + const std::function& beforeBody, + const std::function& afterBody); /** * @brief Constructs a scf.if operation with return values @@ -1120,10 +1117,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange - scfIf(Value condition, ValueRange qubits, - const std::function& thenBody, - const std::function& elseBody); + ValueRange scfIf(Value condition, ValueRange qubits, + const std::function& thenBody, + const std::function& elseBody); /** * @brief Constructs a scf.condition operation without any additional Values @@ -1153,7 +1149,7 @@ class QCOProgramBuilder final : public OpBuilder { QCOProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body); + const std::function& body); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index ccba506088..7fa15e5502 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -519,7 +519,7 @@ QCProgramBuilder& QCProgramBuilder::funcReturn() { } QCProgramBuilder& QCProgramBuilder::funcFunc( StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body) { + const std::function& body) { const auto funcType = getFunctionType(argTypes, resultTypes); OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); @@ -530,7 +530,7 @@ QCProgramBuilder& QCProgramBuilder::funcFunc( setInsertionPointToStart(entryBlock); // Build function body - body(*this, loc, entryBlock->getArguments()); + body(*this, entryBlock->getArguments()); return *this; } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index cf0dbf5edf..4aec0c74d7 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -602,8 +602,7 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, - const std::function& - body) { + const std::function& body) { // Create the empty for operation auto forOp = create(loc, lowerbound, upperbound, step, initArgs); @@ -621,7 +620,7 @@ ValueRange QCOProgramBuilder::scfFor( validQubits[bodyRegion].insert(arg); } // Build the body - body(*this, loc, iv, loopArgs); + body(*this, iv, loopArgs); // Update the qubit tracking for (const auto& [initArg, result] : @@ -633,10 +632,8 @@ ValueRange QCOProgramBuilder::scfFor( ValueRange QCOProgramBuilder::scfWhile( ValueRange initArgs, - const std::function& - beforeBody, - const std::function& - afterBody) { + const std::function& beforeBody, + const std::function& afterBody) { auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); @@ -652,7 +649,7 @@ ValueRange QCOProgramBuilder::scfWhile( validQubits[beforeRegion].insert(arg); } - beforeBody(*this, loc, beforeArgs); + beforeBody(*this, beforeArgs); auto* afterBlock = createBlock(&whileOp.getAfter(), {}, initArgs.getTypes(), locs); @@ -665,7 +662,7 @@ ValueRange QCOProgramBuilder::scfWhile( validQubits[afterRegion].insert(arg); } - afterBody(*this, loc, afterArgs); + afterBody(*this, afterArgs); for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { updateQubitTracking(arg, result, whileOp->getParentRegion()); @@ -677,8 +674,8 @@ ValueRange QCOProgramBuilder::scfWhile( ValueRange QCOProgramBuilder::scfIf( Value condition, ValueRange qubits, - const std::function& thenBody, - const std::function& elseBody) { + const std::function& thenBody, + const std::function& elseBody) { // Create the empty while operation auto ifOp = create(loc, qubits.getTypes(), condition, /*withElseRegion=*/true); @@ -698,14 +695,14 @@ ValueRange QCOProgramBuilder::scfIf( } // Build the then body - thenBody(*this, loc); + thenBody(*this); // Set the insertionpoint OpBuilder::InsertionGuard guardElse(*this); setInsertionPointToStart(&elseBlock); // Build the else body - elseBody(*this, loc); + elseBody(*this); // Update the qubit tracking for (auto [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { @@ -746,7 +743,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange yieldedValues) { } QCOProgramBuilder& QCOProgramBuilder::funcFunc( StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body) { + const std::function& body) { const auto funcType = getFunctionType(argTypes, resultTypes); OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); @@ -761,7 +758,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( setInsertionPointToStart(entryBlock); // Build function body - body(*this, loc, entryBlock->getArguments()); + body(*this, entryBlock->getArguments()); return *this; } From 537f19f788179fb704e79263ccb0c003e7f441c7 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 17:26:35 +0100 Subject: [PATCH 028/108] add func conversion tests --- mlir/unittests/conversion/test_conversion.cpp | 111 ++++++++++++++++-- 1 file changed, 101 insertions(+), 10 deletions(-) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index faecd932a7..02ec82197f 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -98,7 +98,7 @@ TEST_F(ConversionTest, ScfForTest) { auto c2 = b.arithConstantIndex(2); auto scfForRes = b.scfFor( c0, c2, c1, ValueRange{q0}, - [&](OpBuilder& b, Location, Value, ValueRange iterArgs) { + [&](OpBuilder& b, Value, ValueRange iterArgs) { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(iterArgs[0]); @@ -129,7 +129,7 @@ TEST_F(ConversionTest, ScfForTest2) { auto c2 = b.arithConstantIndex(2); auto scfForRes = b.scfFor( c0, c2, c1, ValueRange{q0}, - [&](OpBuilder& b, Location, Value, ValueRange iterArgs) { + [&](OpBuilder& b, Value, ValueRange iterArgs) { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(iterArgs[0]); @@ -204,7 +204,7 @@ TEST_F(ConversionTest, ScfWhileTest) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](OpBuilder& b, Location, ValueRange iterArgs) { + [&](OpBuilder& b, ValueRange iterArgs) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto measure = static_cast(b).measure( iterArgs @@ -213,7 +213,7 @@ TEST_F(ConversionTest, ScfWhileTest) { measure.second, ValueRange{measure.first}); return measure.first; }, - [&](OpBuilder& b, Location, ValueRange iterArgs) { + [&](OpBuilder& b, ValueRange iterArgs) { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h( @@ -239,7 +239,7 @@ TEST_F(ConversionTest, ScfWhileTest2) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](OpBuilder& b, Location, + [&](OpBuilder& b, ValueRange iterArgs) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) auto measure = static_cast(b).measure( @@ -249,7 +249,7 @@ TEST_F(ConversionTest, ScfWhileTest2) { measure.second, ValueRange{measure.first}); return measure.first; }, - [&](OpBuilder& b, Location, ValueRange iterArgs) { + [&](OpBuilder& b, ValueRange iterArgs) { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h( @@ -325,7 +325,7 @@ TEST_F(ConversionTest, ScfIfTest) { auto measure = b.measure(q0); auto scfIfResult = b.scfIf( measure.second, ValueRange{measure.first}, - [&](OpBuilder& b, Location) { + [&](OpBuilder& b) { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h(measure.first); @@ -335,7 +335,7 @@ TEST_F(ConversionTest, ScfIfTest) { static_cast(b).scfYield(q2); return q2; }, - [&](OpBuilder& b, Location) { + [&](OpBuilder& b) { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y( @@ -362,7 +362,7 @@ TEST_F(ConversionTest, ScfIfTest2) { auto measure = b.measure(q0); auto scfIfResult = b.scfIf( measure.second, ValueRange{measure.first}, - [&](OpBuilder& b, Location) -> ValueRange { + [&](OpBuilder& b) -> ValueRange { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).h( @@ -373,7 +373,7 @@ TEST_F(ConversionTest, ScfIfTest2) { static_cast(b).scfYield(q2); return q2; // NOLINT }, - [&](OpBuilder& b, Location) -> ValueRange { + [&](OpBuilder& b) -> ValueRange { auto q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y( @@ -418,3 +418,94 @@ TEST_F(ConversionTest, ScfIfTest2) { ASSERT_EQ(outputString, checkString); } + +TEST_F(ConversionTest, FuncFuncTest) { + // Test conversion from qc to qco for func.func operation + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + b.funcCall("test", q0); + b.h(q0); + b.funcFunc( + "test", q0.getType(), {}, + [&](OpBuilder& b, + ValueRange + args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h(args[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).y(args[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).funcReturn(); + }); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for func.func"; + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto q1 = b.funcCall("test", q0); + b.h(q1[0]); + b.funcFunc( + "test", q0.getType(), q0.getType(), [&](OpBuilder& b, ValueRange args) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).h(args[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q3 = static_cast(b).y(q2); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).funcReturn(q3); + }); + }); + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, FuncFuncTest2) { + // Test conversion from qco to qc for func.func operation + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto q1 = b.funcCall("test", q0); + b.h(q1[0]); + b.funcFunc( + "test", q0.getType(), q0.getType(), [&](OpBuilder& b, ValueRange args) { + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q2 = static_cast(b).h(args[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q3 = static_cast(b).y(q2); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).funcReturn(q3); + }); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for func.func"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + b.funcCall("test", q0); + b.h(q0); + b.funcFunc( + "test", q0.getType(), {}, + [&](OpBuilder& b, + ValueRange + args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h(args[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).y(args[0]); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).funcReturn(); + }); + }); + + const auto outputString = getOutputString(&input); + const auto checkString = getOutputString(&expectedOutput); + + ASSERT_EQ(outputString, checkString); +} From bc278b9c2c86c37aa84649592a9f5608e9049ef3 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 18:03:21 +0100 Subject: [PATCH 029/108] add more docstrings --- .../Dialect/QC/Builder/QCProgramBuilder.h | 53 ++++++- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 138 ++++++++++++++---- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 4 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 4 +- mlir/unittests/conversion/test_conversion.cpp | 20 +-- 5 files changed, 174 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 85d065ddda..dad8b23fbf 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -975,12 +975,63 @@ class QCProgramBuilder final : public OpBuilder { //===--------------------------------------------------------------------===// // Func operations //===--------------------------------------------------------------------===// + + /** + * @brief Constructs a func.return operation without return values + * + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.funcReturn(); + * ``` + * ```mlir + * func.return + * ``` + */ QCProgramBuilder& funcReturn(); + /** + * @brief Constructs a func.call operation without return values + * + * @param name Name of the function that is called + * @param operands ValueRange of the used operands + * + * @par Example: + * ```c++ + * builder.funcCall("test", {q0}); + * ``` + * ```mlir + * func.call @test(%q0) : (!qco.qubit) -> () + * ``` + */ QCProgramBuilder& funcCall(StringRef name, ValueRange operands); + /** + * @brief Constructs a func.func operation with return values + * + * @param name Name of the function that is called + * @param argTypes TypeRange of the arguments + * @param body Body of the function + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.funcFunc("test", argTypes, [&](OpBuilder& b, + * ValueRange args) { + * b.h(args[0]); + * b.funcReturn(); + * }) + * ``` + * ```mlir + * func.func @test(%arg0 : !qco.qubit) { + * qc.h %arg0 : !qc.qubit + * func.return + * } + * ``` + */ QCProgramBuilder& - funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, + funcFunc(StringRef name, TypeRange argTypes, const std::function& body); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index a3ccca4745..1e893f4a27 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1040,15 +1040,20 @@ class QCOProgramBuilder final : public OpBuilder { * @param step Stepsize of the loop * @param initArgs Initial arguments for the iterArgs * @param body Function that builds the body of the for operation - * @return Reference to this builder for method chaining + * @return ValueRange of the results * * @par Example: * ```c++ - * builder.scfFor(lb, ub, step, [&](auto& b) { b.x(q0); }); + * builder.scfFor(lb, ub, step, initArgs, [&](auto& b) { + * auto q1 = b.x(initArgs[0]); + * b.scfYield(q1); + }); * ``` * ```mlir - * scf.for %iv = %lb to %ub step %step { - * qc.x %q0 : !qc.qubit + * %q1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %q0) -> + !qco.qubit { + * %q2 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q2 : !qco.qubit * } * ``` */ @@ -1062,27 +1067,28 @@ class QCOProgramBuilder final : public OpBuilder { * @param beforeBody Function that builds the before body of the while * operation * @param afterBody Function that builds the after body of the while operation - * @return Reference to this builder for method chaining + * @return ValueRange of the results * * @par Example: * ```c++ - * builder.scfWhile([&](auto& b) { - * b.h(q0); - * auto res = b.measure(q0) - * b.condition(res) - * }, [&](auto& b) { - * b.x(q0); - * b.yield() + * builder.scfWhile(args, [&](auto& b, ValueRange iterArgs) { + * auto q1 = b.h(iterArgs[0]); + * auto [q2, measureRes] = b.measure(q1); + * b.condition(measureRes); + * }, [&](auto& b, ValueRange iterArgs) { + * auto q1 = b.x(iterArgs[0]); + * b.scfYield(q1); * }); * ``` * ```mlir - * scf.while : () -> () { - * qc.h %q0 : !qc.qubit - * %res = qc.measure %q0 : !qc.qubit -> i1 - * scf.condition(%tres) - * } do { - * qc.x %q0 : !qc.qubit - * scf.yield + * %q1 = scf.while (%arg0 = %q0): (!qco.qubit) -> (!qco.qubit) { + * %q2 = qco.h(%arg0) + * %q3, %result = qco.measure %q2 : !qco.qubit + * scf.condition(%result) %q3 : !qco.qubit + * } do { + * ^bb0(%arg0 : !qco.qubit): + * %q4 = qco.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q4 : !qco.qubit * } * ``` */ @@ -1099,21 +1105,25 @@ class QCOProgramBuilder final : public OpBuilder { * @param thenBody Function that builds the then body of the if * operation * @param elseBody Function that builds the else body of the if operation - * @return Reference to this builder for method chaining + * @return ValueRange of the results * * @par Example: * ```c++ - * builder.scf.if(condition, [&](auto& b) { - * b.h(q0); + * builder.scf.if(condition, qubits, [&](auto& b) { + * auto q1 = b.h(q0); + * b.scfYield(q1); * }, [&](auto& b) { - * b.x(q0); + * auto q1 = b.x(q0); + * b.scfYield(q1); * }); * ``` * ```mlir - * scf.if %condition { - * qc.h %q0 : !qc.qubit + * %q1 = scf.if %condition -> (!qco.qubit) { + * %q2 = qco.h %q0 : !qco.qubit -> !qco.qubit + * scf.yield %q2 : !qco.qubit * } else { - * qc.x %q0 : !qc.qubit + * %q2 = qco.x %q0 : !qco.qubit -> !qco.qubit + * scf.yield %q2 : !qco.qubit * } * ``` */ @@ -1122,31 +1132,99 @@ class QCOProgramBuilder final : public OpBuilder { const std::function& elseBody); /** - * @brief Constructs a scf.condition operation without any additional Values + * @brief Constructs a scf.condition operation with yielded values * * @param condition Condition for condition operation + * @param yieldedValues ValueRange of the yieldedValues * @return Reference to this builder for method chaining * * @par Example: * ```c++ - * builder.condition(condition); + * builder.scfCondition(condition, yieldedValues); * ``` * ```mlir - * scf.condition(%condition) + * scf.condition(%condition) %q0 : !qco.qubit * ``` */ QCOProgramBuilder& scfCondition(Value condition, ValueRange yieldedValues); + /** + * @brief Constructs a scf.yield operation with yielded values + * + * @param yieldedValues ValueRange of the yieldedValues + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.scfYield( yieldedValues); + * ``` + * ```mlir + * scf.yield %q0 : !qco.qubit + * ``` + */ QCOProgramBuilder& scfYield(ValueRange yieldedValues); //===--------------------------------------------------------------------===// // Func operations //===--------------------------------------------------------------------===// - QCOProgramBuilder& funcReturn(ValueRange yieldedValues); + /** + * @brief Constructs a func.return operation with return values + * + * @param returnValues ValueRange of the returned values + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.funcReturn( yieldedValues); + * ``` + * ```mlir + * func.return %q0 : !qco.qubit + * ``` + */ + QCOProgramBuilder& funcReturn(ValueRange returnValues); + /** + * @brief Constructs a func.call operation with return values + * + * @param name Name of the function that is called + * @param operands ValueRange of the used operands + * @return ValueRange of the results + * + * @par Example: + * ```c++ + * auto q1 = builder.funcCall("test", {q0}); + * ``` + * ```mlir + * %q1 = func.call @test(%q0) : (!qco.qubit) -> !qco.qubit + * ``` + */ ValueRange funcCall(StringRef name, ValueRange operands); + /** + * @brief Constructs a func.func operation with return values + * + * @param name Name of the function that is called + * @param argTypes TypeRange of the arguments + * @param resultTypes TypeRange of the results + * @param body Body of the function + * @return Reference to this builder for method chaining + * + * @par Example: + * ```c++ + * builder.funcFunc("test", argTypes, resultTypes, [&](OpBuilder& b, + * ValueRange args) { + * auto q1 = b.h(args[0]); + * b.funcReturn({q1}); + * }) + * ``` + * ```mlir + * func.func @test(%arg0 : !qco.qubit) -> !qco.qubit { + * %q1 = qco.h %arg0 : !qco.qubit -> !qco.qubit + * func.return %q1 : !qco.qubit + * } + * ``` + */ QCOProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, const std::function& body); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 7fa15e5502..aaa3a2adc0 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -518,9 +518,9 @@ QCProgramBuilder& QCProgramBuilder::funcReturn() { return *this; } QCProgramBuilder& QCProgramBuilder::funcFunc( - StringRef name, TypeRange argTypes, TypeRange resultTypes, + StringRef name, TypeRange argTypes, const std::function& body) { - const auto funcType = getFunctionType(argTypes, resultTypes); + const auto funcType = getFunctionType(argTypes, {}); OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); auto funcOp = create(loc, name, funcType); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 4aec0c74d7..1a83cca3df 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -737,8 +737,8 @@ ValueRange QCOProgramBuilder::funcCall(StringRef name, ValueRange operands) { return callOp->getResults(); } -QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange yieldedValues) { - create(loc, yieldedValues); +QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange returnValues) { + create(loc, returnValues); return *this; } QCOProgramBuilder& QCOProgramBuilder::funcFunc( diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 02ec82197f..0e85fa27af 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -144,13 +144,13 @@ TEST_F(ConversionTest, ScfForTest2) { }); b.h(scfForRes[0]); }); - + input->print(llvm::outs()); PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QCO-QC conversion for scf.for"; } - + input->print(llvm::outs()); auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); auto c0 = b.arithConstantIndex(0); @@ -193,13 +193,13 @@ TEST_F(ConversionTest, ScfWhileTest) { }); b.h(q0); }); - + input->print(llvm::outs()); PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QC-QCO conversion for scf.while"; } - + input->print(llvm::outs()); auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( @@ -313,13 +313,13 @@ TEST_F(ConversionTest, ScfIfTest) { }); b.h(q0); }); - + input->print(llvm::outs()); PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QC-QCO conversion for scf.if"; } - + input->print(llvm::outs()); auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto measure = b.measure(q0); @@ -426,7 +426,7 @@ TEST_F(ConversionTest, FuncFuncTest) { b.funcCall("test", q0); b.h(q0); b.funcFunc( - "test", q0.getType(), {}, + "test", q0.getType(), [&](OpBuilder& b, ValueRange args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) @@ -437,13 +437,13 @@ TEST_F(ConversionTest, FuncFuncTest) { static_cast(b).funcReturn(); }); }); - + input->print(llvm::outs()); PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QC-QCO conversion for func.func"; } - + input->print(llvm::outs()); auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto q1 = b.funcCall("test", q0); @@ -492,7 +492,7 @@ TEST_F(ConversionTest, FuncFuncTest2) { b.funcCall("test", q0); b.h(q0); b.funcFunc( - "test", q0.getType(), {}, + "test", q0.getType(), [&](OpBuilder& b, ValueRange args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) From 59edb8eb3b593f1528f1e5e31ddb599486b812d4 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 18:09:12 +0100 Subject: [PATCH 030/108] remove unnecessary print statements --- mlir/unittests/conversion/test_conversion.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 0e85fa27af..d8ac0889ca 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -144,13 +144,13 @@ TEST_F(ConversionTest, ScfForTest2) { }); b.h(scfForRes[0]); }); - input->print(llvm::outs()); + PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QCO-QC conversion for scf.for"; } - input->print(llvm::outs()); + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); auto c0 = b.arithConstantIndex(0); @@ -193,13 +193,13 @@ TEST_F(ConversionTest, ScfWhileTest) { }); b.h(q0); }); - input->print(llvm::outs()); + PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QC-QCO conversion for scf.while"; } - input->print(llvm::outs()); + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( @@ -313,13 +313,13 @@ TEST_F(ConversionTest, ScfIfTest) { }); b.h(q0); }); - input->print(llvm::outs()); + PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QC-QCO conversion for scf.if"; } - input->print(llvm::outs()); + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto measure = b.measure(q0); @@ -437,13 +437,13 @@ TEST_F(ConversionTest, FuncFuncTest) { static_cast(b).funcReturn(); }); }); - input->print(llvm::outs()); + PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { FAIL() << "Conversion error during QC-QCO conversion for func.func"; } - input->print(llvm::outs()); + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto q1 = b.funcCall("test", q0); From b1e77ecc4b171e46249ccac865063525531d7a4d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 18:18:19 +0100 Subject: [PATCH 031/108] fix comments --- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 8 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 26 +++-- mlir/unittests/conversion/test_conversion.cpp | 109 +++++++++--------- 3 files changed, 80 insertions(+), 63 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index aaa3a2adc0..b29710f8ad 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -517,20 +517,24 @@ QCProgramBuilder& QCProgramBuilder::funcReturn() { create(loc); return *this; } + QCProgramBuilder& QCProgramBuilder::funcFunc( StringRef name, TypeRange argTypes, const std::function& body) { - const auto funcType = getFunctionType(argTypes, {}); + // Set the insertionPoint OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); - auto funcOp = create(loc, name, funcType); + // Create the empty func operation + const auto funcType = getFunctionType(argTypes, {}); + auto funcOp = create(loc, name, funcType); auto* entryBlock = funcOp.addEntryBlock(); setInsertionPointToStart(entryBlock); // Build function body body(*this, entryBlock->getArguments()); + return *this; } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 1a83cca3df..adf4c9fb00 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -603,7 +603,6 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, const std::function& body) { - // Create the empty for operation auto forOp = create(loc, lowerbound, upperbound, step, initArgs); auto* forBody = forOp.getBody(); @@ -627,6 +626,7 @@ ValueRange QCOProgramBuilder::scfFor( llvm::zip_equal(initArgs, forOp.getResults())) { updateQubitTracking(initArg, result, forOp->getParentRegion()); } + return forOp->getResults(); } @@ -634,41 +634,49 @@ ValueRange QCOProgramBuilder::scfWhile( ValueRange initArgs, const std::function& beforeBody, const std::function& afterBody) { + // Create the empty while operation auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); OpBuilder::InsertionGuard guard(*this); + + // Construct the before block auto* beforeBlock = createBlock(&whileOp.getBefore(), {}, initArgs.getTypes(), locs); auto beforeArgs = beforeBlock->getArguments(); auto* beforeRegion = beforeBlock->getParent(); + // Set the insertionpoint setInsertionPointToStart(beforeBlock); + // Add the beforeArgs to the validQubits for (Value arg : beforeArgs) { validQubits[beforeRegion].insert(arg); } beforeBody(*this, beforeArgs); + // Construct the after block auto* afterBlock = createBlock(&whileOp.getAfter(), {}, initArgs.getTypes(), locs); auto afterArgs = afterBlock->getArguments(); + auto* afterRegion = afterBlock->getParent(); + // Set the insertionpoint setInsertionPointToStart(afterBlock); - auto* afterRegion = afterBlock->getParent(); + // Add the afterArgs to the validQubits for (Value arg : afterArgs) { validQubits[afterRegion].insert(arg); } afterBody(*this, afterArgs); + // Update the qubit tracking for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { updateQubitTracking(arg, result, whileOp->getParentRegion()); } - setInsertionPointAfter(whileOp); return whileOp->getResults(); } @@ -688,7 +696,7 @@ ValueRange QCOProgramBuilder::scfIf( OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(&thenBlock); - // add the qubits to the validQubits of the then and else region + // Add the qubits to the validQubits of the then and else region for (Value arg : qubits) { validQubits[thenRegion].insert(arg); validQubits[elseRegion].insert(arg); @@ -698,7 +706,6 @@ ValueRange QCOProgramBuilder::scfIf( thenBody(*this); // Set the insertionpoint - OpBuilder::InsertionGuard guardElse(*this); setInsertionPointToStart(&elseBlock); // Build the else body @@ -709,7 +716,6 @@ ValueRange QCOProgramBuilder::scfIf( updateQubitTracking(arg, result, ifOp->getParentRegion()); } - setInsertionPointAfter(ifOp); return ifOp->getResults(); } @@ -744,13 +750,16 @@ QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange returnValues) { QCOProgramBuilder& QCOProgramBuilder::funcFunc( StringRef name, TypeRange argTypes, TypeRange resultTypes, const std::function& body) { - const auto funcType = getFunctionType(argTypes, resultTypes); + // Set the insertionPoint OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); - auto funcOp = create(loc, name, funcType); + // Create the empty func operation + const auto funcType = getFunctionType(argTypes, resultTypes); + auto funcOp = create(loc, name, funcType); auto* entryBlock = funcOp.addEntryBlock(); + // Add the arguments to the validQubits for (Value arg : entryBlock->getArguments()) { validQubits[entryBlock->getParent()].insert(arg); } @@ -759,6 +768,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( // Build function body body(*this, entryBlock->getArguments()); + return *this; } diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index d8ac0889ca..8777c50295 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -180,10 +180,11 @@ TEST_F(ConversionTest, ScfWhileTest) { b.scfWhile( [&](OpBuilder& b) { auto - measure = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + measureResult = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).measure(q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfCondition(measure); + static_cast(b).scfCondition( + measureResult); }, [&](OpBuilder& b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) @@ -205,13 +206,15 @@ TEST_F(ConversionTest, ScfWhileTest) { auto scfWhileResult = b.scfWhile( ValueRange{q0}, [&](OpBuilder& b, ValueRange iterArgs) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto measure = static_cast(b).measure( - iterArgs - [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto + [q1, + measureResult] = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).measure( + iterArgs + [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition( - measure.second, ValueRange{measure.first}); - return measure.first; + measureResult, ValueRange{q1}); + return q1; }, [&](OpBuilder& b, ValueRange iterArgs) { auto @@ -239,15 +242,16 @@ TEST_F(ConversionTest, ScfWhileTest2) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](OpBuilder& b, - ValueRange - iterArgs) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto measure = static_cast(b).measure( - iterArgs - [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + [&](OpBuilder& b, ValueRange iterArgs) { + auto + [q1, + measureResult] = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).measure( + iterArgs + [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition( - measure.second, ValueRange{measure.first}); - return measure.first; + measureResult, ValueRange{q1}); + return q1; }, [&](OpBuilder& b, ValueRange iterArgs) { auto @@ -274,9 +278,11 @@ TEST_F(ConversionTest, ScfWhileTest2) { b.scfWhile( [&](OpBuilder& b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto measure = static_cast(b).measure( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfCondition(measure); + auto measureResult = + static_cast(b).measure( + q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).scfCondition( + measureResult); }, [&](OpBuilder& b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) @@ -322,29 +328,28 @@ TEST_F(ConversionTest, ScfIfTest) { auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); - auto measure = b.measure(q0); + auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( - measure.second, ValueRange{measure.first}, + measureResult, ValueRange{q1}, [&](OpBuilder& b) { auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(measure.first); + q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h(q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).y( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q2); - return q2; + auto q3 = static_cast(b).y( + q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).scfYield(q3); + return q3; }, [&](OpBuilder& b) { auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y( - measure - .first); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).h( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q2); - return q2; + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q3 = static_cast(b).h( + q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).scfYield(q3); + return q3; }); b.h(scfIfResult[0]); }); @@ -359,30 +364,28 @@ TEST_F(ConversionTest, ScfIfTest2) { // Test conversion from qco to qc for scf.if operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); - auto measure = b.measure(q0); + auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( - measure.second, ValueRange{measure.first}, - [&](OpBuilder& b) -> ValueRange { + measureResult, ValueRange{q1}, + [&](OpBuilder& b) { auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h( - measure - .first); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).y( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q2); - return q2; // NOLINT + q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).h(q1); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q3 = static_cast(b).y( + q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).scfYield(q3); + return q3; }, - [&](OpBuilder& b) -> ValueRange { + [&](OpBuilder& b) { auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).y( - measure - .first); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).h(q1); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q2); - return q2; // NOLINT + q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + auto q3 = static_cast(b).h( + q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) + static_cast(b).scfYield(q3); + return q3; }); b.h(scfIfResult[0]); }); From abe1eef6b4aa26ae64a66af4c7894cf5b9987e6c Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 18:47:57 +0100 Subject: [PATCH 032/108] fix linter issues --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 3 +- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 7 ++--- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 2 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 30 ++++++++++--------- mlir/unittests/conversion/test_conversion.cpp | 16 ++++++---- 5 files changed, 32 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index b3ba10e90b..0f6a758e14 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1145,7 +1145,7 @@ struct QCOToQC final : impl::QCOToQCBase { ConversionTarget target(*context); RewritePatternSet patterns(context); - QCOToQCTypeConverter typeConverter(context); + const QCOToQCTypeConverter typeConverter(context); target.addDynamicallyLegalOp([&](scf::IfOp op) { return !llvm::any_of(op->getResultTypes(), [&](Type type) { @@ -1194,7 +1194,6 @@ struct QCOToQC final : impl::QCOToQCBase { // Configure conversion target: QCO illegal, QC legal target.addIllegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 10f4cfa925..745259c0a0 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -125,8 +125,8 @@ class StatefulOpConversionPattern : public OpConversionPattern { * @param ctx The MLIRContext of the current program * @return llvm::Setvector The set of unique QC qubit references */ -llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, - MLIRContext* ctx) { +static llvm::SetVector +collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // get the regions of the current operation const auto& regions = op->getRegions(); SetVector uniqueQubits; @@ -1696,7 +1696,7 @@ struct QCToQCO final : impl::QCToQCOBase { // legal target.addIllegalDialect(); target.addLegalDialect(); - target.addLegalDialect(); + target.addDynamicallyLegalOp([&](scf::YieldOp op) { return !(op->getAttrOfType("needChange")); }); @@ -1727,7 +1727,6 @@ struct QCToQCO final : impl::QCToQCOBase { }); // Register operation conversion patterns with state // tracking - patterns.add< ConvertQCAllocOp, ConvertQCDeallocOp, ConvertQCStaticOp, ConvertQCMeasureOp, ConvertQCResetOp, ConvertQCGPhaseOp, ConvertQCIdOp, diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index b29710f8ad..9e22bffd30 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -522,7 +522,7 @@ QCProgramBuilder& QCProgramBuilder::funcFunc( StringRef name, TypeRange argTypes, const std::function& body) { // Set the insertionPoint - OpBuilder::InsertionGuard guard(*this); + const OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); // Create the empty func operation diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index adf4c9fb00..0f1ba02c16 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -36,7 +37,7 @@ namespace mlir::qco { QCOProgramBuilder::QCOProgramBuilder(MLIRContext* context) : OpBuilder(context), ctx(context), loc(getUnknownLoc()), - module(ModuleOp::create(loc)) { + module(ModuleOp::create(loc)), funcRegion(nullptr) { ctx->loadDialect(); } @@ -606,16 +607,16 @@ ValueRange QCOProgramBuilder::scfFor( // Create the empty for operation auto forOp = create(loc, lowerbound, upperbound, step, initArgs); auto* forBody = forOp.getBody(); - auto iv = forBody->getArgument(0); - auto loopArgs = forBody->getArguments().drop_front(); + const auto iv = forBody->getArgument(0); + const auto loopArgs = forBody->getArguments().drop_front(); // Set the insertionpoint - OpBuilder::InsertionGuard guard(*this); + const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(forBody); // Add the iterArgs to the validQubits auto* bodyRegion = forBody->getParent(); - for (Value arg : loopArgs) { + for (const auto& arg : loopArgs) { validQubits[bodyRegion].insert(arg); } // Build the body @@ -638,7 +639,7 @@ ValueRange QCOProgramBuilder::scfWhile( auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); - OpBuilder::InsertionGuard guard(*this); + const OpBuilder::InsertionGuard guard(*this); // Construct the before block auto* beforeBlock = @@ -650,7 +651,7 @@ ValueRange QCOProgramBuilder::scfWhile( setInsertionPointToStart(beforeBlock); // Add the beforeArgs to the validQubits - for (Value arg : beforeArgs) { + for (const auto& arg : beforeArgs) { validQubits[beforeRegion].insert(arg); } @@ -666,14 +667,15 @@ ValueRange QCOProgramBuilder::scfWhile( setInsertionPointToStart(afterBlock); // Add the afterArgs to the validQubits - for (Value arg : afterArgs) { + for (const auto& arg : afterArgs) { validQubits[afterRegion].insert(arg); } afterBody(*this, afterArgs); // Update the qubit tracking - for (auto [arg, result] : llvm::zip_equal(initArgs, whileOp.getResults())) { + for (const auto& [arg, result] : + llvm::zip_equal(initArgs, whileOp.getResults())) { updateQubitTracking(arg, result, whileOp->getParentRegion()); } @@ -693,11 +695,11 @@ ValueRange QCOProgramBuilder::scfIf( auto* elseRegion = elseBlock.getParent(); // Set the insertionpoint - OpBuilder::InsertionGuard guard(*this); + const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(&thenBlock); // Add the qubits to the validQubits of the then and else region - for (Value arg : qubits) { + for (const auto& arg : qubits) { validQubits[thenRegion].insert(arg); validQubits[elseRegion].insert(arg); } @@ -712,7 +714,7 @@ ValueRange QCOProgramBuilder::scfIf( elseBody(*this); // Update the qubit tracking - for (auto [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { + for (const auto& [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { updateQubitTracking(arg, result, ifOp->getParentRegion()); } @@ -751,7 +753,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( StringRef name, TypeRange argTypes, TypeRange resultTypes, const std::function& body) { // Set the insertionPoint - OpBuilder::InsertionGuard guard(*this); + const OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); // Create the empty func operation @@ -760,7 +762,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( auto* entryBlock = funcOp.addEntryBlock(); // Add the arguments to the validQubits - for (Value arg : entryBlock->getArguments()) { + for (const auto& arg : entryBlock->getArguments()) { validQubits[entryBlock->getParent()].insert(arg); } diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 8777c50295..c389232dd9 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -16,14 +16,20 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include -#include +#include +#include +#include +#include #include #include #include #include +#include +#include +#include #include #include -#include +#include using namespace mlir; @@ -276,9 +282,9 @@ TEST_F(ConversionTest, ScfWhileTest2) { auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); b.scfWhile( - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto measureResult = + [&](OpBuilder& b) { + auto + measureResult = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).measure( q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) static_cast(b).scfCondition( From 981c6e0b1ffa91aa84c04393003438df993cbf03 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 22 Dec 2025 18:58:32 +0100 Subject: [PATCH 033/108] fix headers --- mlir/unittests/conversion/test_conversion.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index c389232dd9..833367959c 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -15,8 +15,8 @@ #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include #include -#include #include #include #include @@ -29,6 +29,7 @@ #include #include #include +#include #include using namespace mlir; From a8d71fa29a9de268aeba6c36a2b3eca4415ceb31 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 3 Jan 2026 16:45:09 +0100 Subject: [PATCH 034/108] apply codeRabbit feedback --- .../Dialect/QC/Builder/QCProgramBuilder.h | 4 +- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 4 +- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 12 +-- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 98 +++++++++++++------ mlir/unittests/conversion/test_conversion.cpp | 36 +++---- 5 files changed, 94 insertions(+), 60 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index dad8b23fbf..ac9d76af74 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -938,7 +938,7 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scf.if(condition, [&](auto& b) { + * builder.scfIf(condition, [&](auto& b) { * b.h(q0); * }, [&](auto& b) { * b.x(q0); @@ -1002,7 +1002,7 @@ class QCProgramBuilder final : public OpBuilder { * builder.funcCall("test", {q0}); * ``` * ```mlir - * func.call @test(%q0) : (!qco.qubit) -> () + * func.call @test(%q0) : (!qc.qubit) -> () * ``` */ QCProgramBuilder& funcCall(StringRef name, ValueRange operands); diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 1e893f4a27..3924077601 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1109,7 +1109,7 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scf.if(condition, qubits, [&](auto& b) { + * builder.scfIf(condition, qubits, [&](auto& b) { * auto q1 = b.h(q0); * b.scfYield(q1); * }, [&](auto& b) { @@ -1156,7 +1156,7 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfYield( yieldedValues); + * builder.scfYield(yieldedValues); * ``` * ```mlir * scf.yield %q0 : !qco.qubit diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 0f6a758e14..8e69d26a9f 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1104,8 +1104,7 @@ struct ConvertQCOFuncReturnOp final : OpConversionPattern { LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - rewriter.create(op->getLoc()); - rewriter.eraseOp(op); + rewriter.replaceOpWithNewOp(op); return success(); } }; @@ -1207,11 +1206,10 @@ struct QCOToQC final : impl::QCOToQCBase { ConvertQCODCXOp, ConvertQCOECROp, ConvertQCORXXOp, ConvertQCORYYOp, ConvertQCORZXOp, ConvertQCORZZOp, ConvertQCOXXPlusYYOp, ConvertQCOXXMinusYYOp, ConvertQCOBarrierOp, ConvertQCOCtrlOp, - ConvertQCOYieldOp, ConvertQCOYieldOp, ConvertQCOScfIfOp, - ConvertQCOScfYieldOp, ConvertQCOScfWhileOp, - ConvertQCOScfConditionOp, ConvertQCOScfForOp, ConvertQCOFuncCallOp, - ConvertQCOFuncFuncOp, ConvertQCOFuncReturnOp>(typeConverter, - context); + ConvertQCOYieldOp, ConvertQCOScfIfOp, ConvertQCOScfYieldOp, + ConvertQCOScfWhileOp, ConvertQCOScfConditionOp, ConvertQCOScfForOp, + ConvertQCOFuncCallOp, ConvertQCOFuncFuncOp, + ConvertQCOFuncReturnOp>(typeConverter, context); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 745259c0a0..98e45f2cf8 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -170,6 +170,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (!func.getArgumentTypes().empty() && func.getArgumentTypes().front() == qc::QubitType::get(ctx)) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); + state->regionMap[func] = uniqueQubits; } } } @@ -1265,7 +1266,8 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& qcQubits = getState().regionMap[op]; + auto& regionMap = getState().regionMap; + const auto& qcQubits = regionMap[op]; const SmallVector qcValues(qcQubits.begin(), qcQubits.end()); // create result typerange @@ -1315,6 +1317,12 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { qubitMap[qcQubit] = qcoQubit; } + // replace the old entry in the regionMap with the new operation + const auto& it = regionMap.find(op); + const auto values = std::move(it->second); + regionMap.erase(op); + regionMap.try_emplace(newIfOp, values); + rewriter.eraseOp(op); return success(); } @@ -1353,7 +1361,8 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { matchAndRewrite(scf::WhileOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - const auto& qcQubits = getState().regionMap[op]; + auto& regionMap = getState().regionMap; + const auto& qcQubits = regionMap[op]; SmallVector qcoQubits; qcoQubits.reserve(qcQubits.size()); @@ -1399,7 +1408,14 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { llvm::zip_equal(qcQubits, newWhileOp->getResults())) { qubitMap[qcQubit] = qcoQubit; } + + // replace the old entry in the regionMap with the new operation + const auto& it = regionMap.find(op); + const auto values = std::move(it->second); + regionMap.erase(op); + regionMap.try_emplace(newWhileOp, values); rewriter.eraseOp(op); + return success(); } }; @@ -1431,10 +1447,11 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - const auto& qcQubits = getState().regionMap[op]; + auto& regionMap = getState().regionMap; + const auto& qcQubits = regionMap[op]; SmallVector qcoQubits; - qcoQubits.reserve(qcoQubits.size()); + qcoQubits.reserve(qcQubits.size()); for (const auto& qcQubit : qcQubits) { qcoQubits.push_back(qubitMap[qcQubit]); } @@ -1466,6 +1483,12 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { qubitMap[qcQubit] = qcoQubit; } + // replace the old entry in the regionMap with the new operation + const auto& it = regionMap.find(op); + const auto values = std::move(it->second); + regionMap.erase(op); + regionMap.try_emplace(newFor, values); + rewriter.eraseOp(op); return success(); } @@ -1490,11 +1513,15 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + const auto& parentRegion = op->getParentRegion(); + const auto& qubitMap = getState().qubitMap[parentRegion]; + const auto& orderedQubits = + getState().regionMap[parentRegion->getParentOp()]; + SmallVector qcoQubits; - qcoQubits.reserve(qubitMap.size()); - for (auto [qcQubit, qcoQubit] : qubitMap) { - qcoQubits.push_back(qcoQubit); + qcoQubits.reserve(orderedQubits.size()); + for (const auto& qcQubit : orderedQubits) { + qcoQubits.push_back(qubitMap.lookup(qcQubit)); } rewriter.replaceOpWithNewOp(op, qcoQubits); @@ -1523,11 +1550,15 @@ struct ConvertQCScfConditionOp final LogicalResult matchAndRewrite(scf::ConditionOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + const auto& parentRegion = op->getParentRegion(); + const auto& qubitMap = getState().qubitMap[parentRegion]; + const auto& orderedQubits = + getState().regionMap[parentRegion->getParentOp()]; + SmallVector qcoQubits; - qcoQubits.reserve(qubitMap.size()); - for (auto [qcQubit, qcoQubit] : qubitMap) { - qcoQubits.push_back(qcoQubit); + qcoQubits.reserve(orderedQubits.size()); + for (const auto& qcQubit : orderedQubits) { + qcoQubits.push_back(qubitMap.lookup(qcQubit)); } rewriter.replaceOpWithNewOp(op, op.getCondition(), @@ -1560,7 +1591,7 @@ struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { auto qcQubits = op->getOperands(); SmallVector qcoQubits; - qcoQubits.reserve(qubitMap.size()); + qcoQubits.reserve(qcQubits.size()); for (const auto& qcQubit : qcQubits) { qcoQubits.push_back(qubitMap[qcQubit]); } @@ -1644,11 +1675,15 @@ struct ConvertQCFuncReturnOp final LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + const auto& parentRegion = op->getParentRegion(); + const auto& qubitMap = getState().qubitMap[parentRegion]; + const auto& orderedQubits = + getState().regionMap[parentRegion->getParentOp()]; + SmallVector qcoQubits; - qcoQubits.reserve(qubitMap.size()); - for (auto [qcQubit, qcoQubit] : qubitMap) { - qcoQubits.push_back(qcoQubit); + qcoQubits.reserve(orderedQubits.size()); + for (const auto& qcQubit : orderedQubits) { + qcoQubits.push_back(qubitMap.lookup(qcQubit)); } rewriter.replaceOpWithNewOp(op, qcoQubits); return success(); @@ -1727,20 +1762,21 @@ struct QCToQCO final : impl::QCToQCOBase { }); // Register operation conversion patterns with state // tracking - patterns.add< - ConvertQCAllocOp, ConvertQCDeallocOp, ConvertQCStaticOp, - ConvertQCMeasureOp, ConvertQCResetOp, ConvertQCGPhaseOp, ConvertQCIdOp, - ConvertQCXOp, ConvertQCYOp, ConvertQCZOp, ConvertQCHOp, ConvertQCSOp, - ConvertQCSdgOp, ConvertQCTOp, ConvertQCTdgOp, ConvertQCSXOp, - ConvertQCSXdgOp, ConvertQCRXOp, ConvertQCRYOp, ConvertQCRZOp, - ConvertQCPOp, ConvertQCROp, ConvertQCU2Op, ConvertQCUOp, - ConvertQCSWAPOp, ConvertQCiSWAPOp, ConvertQCDCXOp, ConvertQCECROp, - ConvertQCRXXOp, ConvertQCRYYOp, ConvertQCRZXOp, ConvertQCRZZOp, - ConvertQCXXPlusYYOp, ConvertQCXXMinusYYOp, ConvertQCBarrierOp, - ConvertQCCtrlOp, ConvertQCYieldOp, ConvertQCYieldOp, ConvertQCScfIfOp, - ConvertQCScfYieldOp, ConvertQCScfWhileOp, ConvertQCScfConditionOp, - ConvertQCScfForOp, ConvertQCFuncCallOp, ConvertQCFuncFuncOp, - ConvertQCFuncReturnOp>(typeConverter, context, &state); + patterns + .add( + typeConverter, context, &state); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 833367959c..529a7a2480 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -66,10 +66,10 @@ class ConversionTest : public ::testing::Test { } }; -static std::string getOutputString(mlir::OwningOpRef* module) { +static std::string getOutputString(mlir::OwningOpRef& module) { std::string outputString; llvm::raw_string_ostream os(outputString); - (*module)->print(os); + module->print(os); os.flush(); return outputString; } @@ -121,8 +121,8 @@ TEST_F(ConversionTest, ScfForTest) { b.h(scfForRes[0]); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -174,8 +174,8 @@ TEST_F(ConversionTest, ScfForTest2) { b.h(q0); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -237,8 +237,8 @@ TEST_F(ConversionTest, ScfWhileTest) { b.h(scfWhileResult[0]); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -299,8 +299,8 @@ TEST_F(ConversionTest, ScfWhileTest2) { }); b.h(q0); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -361,8 +361,8 @@ TEST_F(ConversionTest, ScfIfTest) { b.h(scfIfResult[0]); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -423,8 +423,8 @@ TEST_F(ConversionTest, ScfIfTest2) { b.h(q0); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -468,8 +468,8 @@ TEST_F(ConversionTest, FuncFuncTest) { static_cast(b).funcReturn(q3); }); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } @@ -514,8 +514,8 @@ TEST_F(ConversionTest, FuncFuncTest2) { }); }); - const auto outputString = getOutputString(&input); - const auto checkString = getOutputString(&expectedOutput); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } From 67ecea72a9025b4ba6bd43b69536592874a78c53 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 13:55:13 +0100 Subject: [PATCH 035/108] small fixes --- .../include/mlir/Dialect/QC/Builder/QCProgramBuilder.h | 10 +++++----- .../mlir/Dialect/QCO/Builder/QCOProgramBuilder.h | 2 +- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 +- mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 2 +- mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 2 +- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index ac9d76af74..8b6ccf76c0 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -906,11 +906,11 @@ class QCProgramBuilder final : public OpBuilder { * ```c++ * builder.scfWhile([&](auto& b) { * b.h(q0); - * auto res = b.measure(q0) - * b.condition(res) + * auto res = b.measure(q0); + * b.condition(res); * }, [&](auto& b) { * b.x(q0); - * b.yield() + * b.yield(); * }); * ``` * ```mlir @@ -1024,7 +1024,7 @@ class QCProgramBuilder final : public OpBuilder { * }) * ``` * ```mlir - * func.func @test(%arg0 : !qco.qubit) { + * func.func @test(%arg0 : !qc.qubit) { * qc.h %arg0 : !qc.qubit * func.return * } @@ -1052,7 +1052,7 @@ class QCProgramBuilder final : public OpBuilder { * arith.constant 4 : index * ``` */ - Value arithConstantIndex(int index); + Value arithConstantIndex(int64_t index); /** * @brief Constructs a arith.constant of type i1 with a given bool value diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 3924077601..d5a4445364 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1247,7 +1247,7 @@ class QCOProgramBuilder final : public OpBuilder { * arith.constant 4 : index * ``` */ - Value arithConstantIndex(int i); + Value arithConstantIndex(int64_t i); /** * @brief Constructs a arith.constant of type i1 with a given bool value diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 98e45f2cf8..26de1e7933 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -72,7 +72,7 @@ namespace { * - %q2 after the X gate */ struct LoweringState { - /// Map from original QC qubit references to their latest Flux SSA values + /// Map from original QC qubit references to their latest QCO SSA values /// for each region llvm::DenseMap> qubitMap; /// Map each operation to its Set of QC qubit references diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 9e22bffd30..fe6004668d 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -542,7 +542,7 @@ QCProgramBuilder& QCProgramBuilder::funcFunc( // Arith operations //===----------------------------------------------------------------------===// -Value QCProgramBuilder::arithConstantIndex(int index) { +Value QCProgramBuilder::arithConstantIndex(int64_t index) { const auto op = create(loc, getIndexType(), getIndexAttr(index)); return op->getResult(0); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 0f1ba02c16..9df66090cb 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -778,7 +778,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( // Arith operations //===----------------------------------------------------------------------===// -Value QCOProgramBuilder::arithConstantIndex(int i) { +Value QCOProgramBuilder::arithConstantIndex(int64_t i) { const auto op = create(loc, getIndexType(), getIndexAttr(i)); return op->getResult(0); From 338b7eeb783efb2eac4046e67292778b832bb4c3 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 15:05:32 +0100 Subject: [PATCH 036/108] simplify body builders --- .../Dialect/QC/Builder/QCProgramBuilder.h | 17 +- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 23 +- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 35 +-- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 33 +- mlir/unittests/CMakeLists.txt | 1 + mlir/unittests/conversion/test_conversion.cpp | 295 ++++++------------ 6 files changed, 155 insertions(+), 249 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index c4d3874b0d..0212c01cd2 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -892,7 +892,7 @@ class QCProgramBuilder final : public OpBuilder { * ``` */ QCProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body); + const std::function& body); /** * @brief Constructs a scf.while operation without return values @@ -924,8 +924,8 @@ class QCProgramBuilder final : public OpBuilder { * } * ``` */ - QCProgramBuilder& scfWhile(const std::function& beforeBody, - const std::function& afterBody); + QCProgramBuilder& scfWhile(const std::function& beforeBody, + const std::function& afterBody); /** * @brief Constructs a scf.if operation without return values @@ -952,9 +952,9 @@ class QCProgramBuilder final : public OpBuilder { * } * ``` */ - QCProgramBuilder& - scfIf(Value condition, const std::function& thenBody, - const std::function& elseBody = nullptr); + QCProgramBuilder& scfIf(Value condition, + const std::function& thenBody, + const std::function& elseBody = nullptr); /** * @brief Constructs a scf.condition operation without any additional Values @@ -1030,9 +1030,8 @@ class QCProgramBuilder final : public OpBuilder { * } * ``` */ - QCProgramBuilder& - funcFunc(StringRef name, TypeRange argTypes, - const std::function& body); + QCProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, + const std::function& body); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 5812089561..b450bf2668 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1057,9 +1057,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange - scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, - const std::function& body); + ValueRange scfFor(Value lowerbound, Value upperbound, Value step, + ValueRange initArgs, + const std::function& body); /** * @brief Constructs a scf.while operation with return values * @@ -1092,10 +1092,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange - scfWhile(ValueRange args, - const std::function& beforeBody, - const std::function& afterBody); + ValueRange scfWhile(ValueRange args, + const std::function& beforeBody, + const std::function& afterBody); /** * @brief Constructs a scf.if operation with return values @@ -1128,8 +1127,8 @@ class QCOProgramBuilder final : public OpBuilder { * ``` */ ValueRange scfIf(Value condition, ValueRange qubits, - const std::function& thenBody, - const std::function& elseBody); + const std::function& thenBody, + const std::function& elseBody); /** * @brief Constructs a scf.condition operation with yielded values @@ -1225,9 +1224,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - QCOProgramBuilder& - funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body); + QCOProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, + TypeRange resultTypes, + const std::function& body); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 6b4191609c..0cbdb28c66 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -448,12 +448,12 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { // SCF operations //===----------------------------------------------------------------------===// -QCProgramBuilder& -QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body) { +QCProgramBuilder& QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, + Value step, + const std::function& body) { create(loc, lowerbound, upperbound, step, ValueRange{}, [&](OpBuilder& b, Location, Value, ValueRange) { - body(b); + body(); b.create(loc); }); @@ -461,13 +461,13 @@ QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, } QCProgramBuilder& -QCProgramBuilder::scfWhile(const std::function& beforeBody, - const std::function& afterBody) { +QCProgramBuilder::scfWhile(const std::function& beforeBody, + const std::function& afterBody) { create( loc, TypeRange{}, ValueRange{}, - [&](OpBuilder& b, Location, ValueRange) { beforeBody(b); }, + [&](OpBuilder& /*b*/, Location, ValueRange) { beforeBody(); }, [&](OpBuilder& b, Location loc, ValueRange) { - afterBody(b); + afterBody(); b.create(loc); }); @@ -475,23 +475,22 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, } QCProgramBuilder& -QCProgramBuilder::scfIf(Value cond, - const std::function& thenBody, - const std::function& elseBody) { +QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, + const std::function& elseBody) { if (!elseBody) { create(loc, cond, [&](OpBuilder& b, Location loc) { - thenBody(b); + thenBody(); b.create(loc); }); } else { create( loc, cond, [&](OpBuilder& b, Location loc) { - thenBody(b); + thenBody(); b.create(loc); }, [&](OpBuilder& b, Location loc) { - elseBody(b); + elseBody(); b.create(loc); }); } @@ -518,9 +517,9 @@ QCProgramBuilder& QCProgramBuilder::funcReturn() { return *this; } -QCProgramBuilder& QCProgramBuilder::funcFunc( - StringRef name, TypeRange argTypes, - const std::function& body) { +QCProgramBuilder& +QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, + const std::function& body) { // Set the insertionPoint const OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); @@ -533,7 +532,7 @@ QCProgramBuilder& QCProgramBuilder::funcFunc( setInsertionPointToStart(entryBlock); // Build function body - body(*this, entryBlock->getArguments()); + body(entryBlock->getArguments()); return *this; } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index c51f5f01c6..b1e1eea417 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -603,7 +603,7 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, - const std::function& body) { + const std::function& body) { // Create the empty for operation auto forOp = create(loc, lowerbound, upperbound, step, initArgs); auto* forBody = forOp.getBody(); @@ -620,7 +620,7 @@ ValueRange QCOProgramBuilder::scfFor( validQubits[bodyRegion].insert(arg); } // Build the body - body(*this, iv, loopArgs); + body(iv, loopArgs); // Update the qubit tracking for (const auto& [initArg, result] : @@ -633,8 +633,8 @@ ValueRange QCOProgramBuilder::scfFor( ValueRange QCOProgramBuilder::scfWhile( ValueRange initArgs, - const std::function& beforeBody, - const std::function& afterBody) { + const std::function& beforeBody, + const std::function& afterBody) { // Create the empty while operation auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); @@ -655,7 +655,7 @@ ValueRange QCOProgramBuilder::scfWhile( validQubits[beforeRegion].insert(arg); } - beforeBody(*this, beforeArgs); + beforeBody(beforeArgs); // Construct the after block auto* afterBlock = @@ -671,7 +671,7 @@ ValueRange QCOProgramBuilder::scfWhile( validQubits[afterRegion].insert(arg); } - afterBody(*this, afterArgs); + afterBody(afterArgs); // Update the qubit tracking for (const auto& [arg, result] : @@ -682,10 +682,10 @@ ValueRange QCOProgramBuilder::scfWhile( return whileOp->getResults(); } -ValueRange QCOProgramBuilder::scfIf( - Value condition, ValueRange qubits, - const std::function& thenBody, - const std::function& elseBody) { +ValueRange +QCOProgramBuilder::scfIf(Value condition, ValueRange qubits, + const std::function& thenBody, + const std::function& elseBody) { // Create the empty while operation auto ifOp = create(loc, qubits.getTypes(), condition, /*withElseRegion=*/true); @@ -705,13 +705,13 @@ ValueRange QCOProgramBuilder::scfIf( } // Build the then body - thenBody(*this); + thenBody(); // Set the insertionpoint setInsertionPointToStart(&elseBlock); // Build the else body - elseBody(*this); + elseBody(); // Update the qubit tracking for (const auto& [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { @@ -749,9 +749,10 @@ QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange returnValues) { create(loc, returnValues); return *this; } -QCOProgramBuilder& QCOProgramBuilder::funcFunc( - StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body) { +QCOProgramBuilder& +QCOProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, + TypeRange resultTypes, + const std::function& body) { // Set the insertionPoint const OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); @@ -769,7 +770,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( setInsertionPointToStart(entryBlock); // Build function body - body(*this, entryBlock->getArguments()); + body(entryBlock->getArguments()); return *this; } diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 18e2d296de..e054c049e5 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -7,6 +7,7 @@ # Licensed under the MIT License add_subdirectory(pipeline) +add_subdirectory(conversion) add_custom_target(mqt-core-mlir-unittests) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 529a7a2480..ce506e26ca 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -1,6 +1,6 @@ /* - * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM - * Copyright (c) 2025 Munich Quantum Software Company GmbH + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH * All rights reserved. * * SPDX-License-Identifier: MIT @@ -81,13 +81,10 @@ TEST_F(ConversionTest, ScfForTest) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&](OpBuilder& b) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).x(q0); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); + b.scfFor(c0, c2, c1, [&] { + b.h(q0); + b.x(q0); + b.h(q0); }); b.h(q0); }); @@ -103,19 +100,12 @@ TEST_F(ConversionTest, ScfForTest) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - auto scfForRes = b.scfFor( - c0, c2, c1, ValueRange{q0}, - [&](OpBuilder& b, Value, ValueRange iterArgs) { - auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(iterArgs[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).x(q1); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).h(q2); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield( - ValueRange{q3}); + auto scfForRes = + b.scfFor(c0, c2, c1, ValueRange{q0}, [&](Value, ValueRange iterArgs) { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + b.scfYield(ValueRange{q3}); return q3; }); b.h(scfForRes[0]); @@ -134,19 +124,12 @@ TEST_F(ConversionTest, ScfForTest2) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - auto scfForRes = b.scfFor( - c0, c2, c1, ValueRange{q0}, - [&](OpBuilder& b, Value, ValueRange iterArgs) { - auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(iterArgs[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).x( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).h( - q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield( - ValueRange{q3}); + auto scfForRes = + b.scfFor(c0, c2, c1, ValueRange{q0}, [&](Value, ValueRange iterArgs) { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + b.scfYield(ValueRange{q3}); return q3; }); b.h(scfForRes[0]); @@ -163,13 +146,10 @@ TEST_F(ConversionTest, ScfForTest2) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&](OpBuilder& b) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).x(q0); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); + b.scfFor(c0, c2, c1, [&] { + b.h(q0); + b.x(q0); + b.h(q0); }); b.h(q0); }); @@ -185,19 +165,14 @@ TEST_F(ConversionTest, ScfWhileTest) { auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); b.scfWhile( - [&](OpBuilder& b) { - auto - measureResult = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).measure(q0); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfCondition( - measureResult); + [&] { + auto measureResult = b.measure(q0); + + b.scfCondition(measureResult); }, - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y(q0); + [&] { + b.h(q0); + b.y(q0); }); b.h(q0); }); @@ -212,26 +187,15 @@ TEST_F(ConversionTest, ScfWhileTest) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](OpBuilder& b, ValueRange iterArgs) { - auto - [q1, - measureResult] = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).measure( - iterArgs - [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfCondition( - measureResult, ValueRange{q1}); + [&](ValueRange iterArgs) { + auto [q1, measureResult] = b.measure(iterArgs[0]); + b.scfCondition(measureResult, ValueRange{q1}); return q1; }, - [&](OpBuilder& b, ValueRange iterArgs) { - auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h( - iterArgs - [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).y( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield({q2}); + [&](ValueRange iterArgs) { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.y(q1); + b.scfYield({q2}); return q2; }); b.h(scfWhileResult[0]); @@ -249,26 +213,15 @@ TEST_F(ConversionTest, ScfWhileTest2) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](OpBuilder& b, ValueRange iterArgs) { - auto - [q1, - measureResult] = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).measure( - iterArgs - [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfCondition( - measureResult, ValueRange{q1}); + [&](ValueRange iterArgs) { + auto [q1, measureResult] = b.measure(iterArgs[0]); + b.scfCondition(measureResult, ValueRange{q1}); return q1; }, - [&](OpBuilder& b, ValueRange iterArgs) { - auto - q1 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h( - iterArgs - [0]); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).y( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield({q2}); + [&](ValueRange iterArgs) { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.y(q1); + b.scfYield({q2}); return q2; }); b.h(scfWhileResult[0]); @@ -283,19 +236,14 @@ TEST_F(ConversionTest, ScfWhileTest2) { auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); b.scfWhile( - [&](OpBuilder& b) { - auto - measureResult = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).measure( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfCondition( - measureResult); + [&] { + auto measureResult = b.measure(q0); + + b.scfCondition(measureResult); }, - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y(q0); + [&] { + b.h(q0); + b.y(q0); }); b.h(q0); }); @@ -312,17 +260,13 @@ TEST_F(ConversionTest, ScfIfTest) { auto measure = b.measure(q0); b.scfIf( measure, - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y(q0); + [&] { + b.h(q0); + b.y(q0); }, - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); + [&] { + b.y(q0); + b.h(q0); }); b.h(q0); }); @@ -338,24 +282,17 @@ TEST_F(ConversionTest, ScfIfTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, ValueRange{q1}, - [&](OpBuilder& b) { - auto - q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q1); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).y( - q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q3); + [&] { + auto q2 = b.h(q1); + + auto q3 = b.y(q2); + b.scfYield(q3); return q3; }, - [&](OpBuilder& b) { - auto - q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).h( - q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q3); + [&] { + auto q2 = b.y(q1); + auto q3 = b.h(q2); + b.scfYield(q3); return q3; }); b.h(scfIfResult[0]); @@ -374,24 +311,17 @@ TEST_F(ConversionTest, ScfIfTest2) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, ValueRange{q1}, - [&](OpBuilder& b) { - auto - q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q1); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).y( - q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q3); + [&] { + auto q2 = b.h(q1); + + auto q3 = b.y(q2); + b.scfYield(q3); return q3; }, - [&](OpBuilder& b) { - auto - q2 = // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y( - q1); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).h( - q2); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).scfYield(q3); + [&] { + auto q2 = b.y(q1); + auto q3 = b.h(q2); + b.scfYield(q3); return q3; }); b.h(scfIfResult[0]); @@ -408,17 +338,13 @@ TEST_F(ConversionTest, ScfIfTest2) { auto measure = b.measure(q0); b.scfIf( measure, - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y(q0); + [&] { + b.h(q0); + b.y(q0); }, - [&](OpBuilder& - b) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y( - q0); // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(q0); + [&] { + b.y(q0); + b.h(q0); }); b.h(q0); }); @@ -435,17 +361,11 @@ TEST_F(ConversionTest, FuncFuncTest) { auto q0 = b.allocQubit(); b.funcCall("test", q0); b.h(q0); - b.funcFunc( - "test", q0.getType(), - [&](OpBuilder& b, - ValueRange - args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(args[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y(args[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).funcReturn(); - }); + b.funcFunc("test", q0.getType(), [&](ValueRange args) { + b.h(args[0]); + b.y(args[0]); + b.funcReturn(); + }); }); PassManager pm(context.get()); @@ -458,16 +378,13 @@ TEST_F(ConversionTest, FuncFuncTest) { auto q0 = b.allocQubit(); auto q1 = b.funcCall("test", q0); b.h(q1[0]); - b.funcFunc( - "test", q0.getType(), q0.getType(), [&](OpBuilder& b, ValueRange args) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).h(args[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).y(q2); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).funcReturn(q3); - }); + b.funcFunc("test", q0.getType(), q0.getType(), [&](ValueRange args) { + auto q2 = b.h(args[0]); + auto q3 = b.y(q2); + b.funcReturn(q3); + }); }); + const auto outputString = getOutputString(input); const auto checkString = getOutputString(expectedOutput); @@ -480,15 +397,11 @@ TEST_F(ConversionTest, FuncFuncTest2) { auto q0 = b.allocQubit(); auto q1 = b.funcCall("test", q0); b.h(q1[0]); - b.funcFunc( - "test", q0.getType(), q0.getType(), [&](OpBuilder& b, ValueRange args) { - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q2 = static_cast(b).h(args[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - auto q3 = static_cast(b).y(q2); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).funcReturn(q3); - }); + b.funcFunc("test", q0.getType(), q0.getType(), [&](ValueRange args) { + auto q2 = b.h(args[0]); + auto q3 = b.y(q2); + b.funcReturn(q3); + }); }); PassManager pm(context.get()); @@ -501,17 +414,11 @@ TEST_F(ConversionTest, FuncFuncTest2) { auto q0 = b.allocQubit(); b.funcCall("test", q0); b.h(q0); - b.funcFunc( - "test", q0.getType(), - [&](OpBuilder& b, - ValueRange - args) { // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).h(args[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).y(args[0]); - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast) - static_cast(b).funcReturn(); - }); + b.funcFunc("test", q0.getType(), [&](ValueRange args) { + b.h(args[0]); + b.y(args[0]); + b.funcReturn(); + }); }); const auto outputString = getOutputString(input); From 7fb419104126eb1f8c6cc8540b6f57611f7c00e6 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 15:49:48 +0100 Subject: [PATCH 037/108] fix docstrings --- .../Dialect/QC/Builder/QCProgramBuilder.h | 45 ++++++------ .../Dialect/QCO/Builder/QCOProgramBuilder.h | 73 +++++++++---------- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 8 +- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 8 +- 4 files changed, 66 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 0212c01cd2..fe116d1d0e 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -883,7 +883,7 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfFor(lb, ub, step, [&](auto& b) { b.x(q0); }); + * builder.scfFor(lb, ub, step, [&] { builder.x(q0); }); * ``` * ```mlir * scf.for %iv = %lb to %ub step %step { @@ -904,23 +904,23 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfWhile([&](auto& b) { - * b.h(q0); - * auto res = b.measure(q0); - * b.condition(res); - * }, [&](auto& b) { - * b.x(q0); - * b.yield(); + * builder.scfWhile([&] { + * builder.h(q0); + * auto res = builder.measure(q0); + * builder.condition(res); + * }, [&] { + * builder.x(q0); + * builder.yield(); * }); * ``` * ```mlir * scf.while : () -> () { - * qc.h %q0 : !qc.qubit - * %res = qc.measure %q0 : !qc.qubit -> i1 - * scf.condition(%tres) + * qc.h %q0 : !qc.qubit + * %res = qc.measure %q0 : !qc.qubit -> i1 + * scf.condition(%tres) * } do { - * qc.x %q0 : !qc.qubit - * scf.yield + * qc.x %q0 : !qc.qubit + * scf.yield * } * ``` */ @@ -938,17 +938,17 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfIf(condition, [&](auto& b) { - * b.h(q0); - * }, [&](auto& b) { - * b.x(q0); + * builder.scfIf(condition, [&] { + * builder.h(q0); + * }, [&] { + * builder.x(q0); * }); * ``` * ```mlir * scf.if %condition { - * qc.h %q0 : !qc.qubit + * qc.h %q0 : !qc.qubit * } else { - * qc.x %q0 : !qc.qubit + * qc.x %q0 : !qc.qubit * } * ``` */ @@ -1017,10 +1017,9 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.funcFunc("test", argTypes, [&](OpBuilder& b, - * ValueRange args) { - * b.h(args[0]); - * b.funcReturn(); + * builder.funcFunc("test", argTypes, [&](ValueRange args) { + * builder.h(args[0]); + * builder.funcReturn(); * }) * ``` * ```mlir diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index b450bf2668..0fb4e3a15f 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1044,16 +1044,16 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfFor(lb, ub, step, initArgs, [&](auto& b) { - * auto q1 = b.x(initArgs[0]); - * b.scfYield(q1); - }); + * builder.scfFor(lb, ub, step, initArgs, [&] { + * auto q1 = builder.x(initArgs[0]); + * builder.scfYield(q1); + * }); * ``` * ```mlir - * %q1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %q0) -> - !qco.qubit { - * %q2 = qc.x %arg0 : !qco.qubit -> !qco.qubit - * scf.yield %q2 : !qco.qubit + * %q1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %q0) + * -> !qco.qubit { + * %q2 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q2 : !qco.qubit * } * ``` */ @@ -1071,24 +1071,24 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfWhile(args, [&](auto& b, ValueRange iterArgs) { - * auto q1 = b.h(iterArgs[0]); - * auto [q2, measureRes] = b.measure(q1); - * b.condition(measureRes); - * }, [&](auto& b, ValueRange iterArgs) { - * auto q1 = b.x(iterArgs[0]); - * b.scfYield(q1); + * builder.scfWhile(args, [&](ValueRange iterArgs) { + * auto q1 = builder.h(iterArgs[0]); + * auto [q2, measureRes] = builder.measure(q1); + * builder.condition(measureRes, q2); + * }, [&](ValueRange iterArgs) { + * auto q1 = builder.x(iterArgs[0]); + * builder.scfYield(q1); * }); * ``` * ```mlir * %q1 = scf.while (%arg0 = %q0): (!qco.qubit) -> (!qco.qubit) { - * %q2 = qco.h(%arg0) - * %q3, %result = qco.measure %q2 : !qco.qubit - * scf.condition(%result) %q3 : !qco.qubit - * } do { - * ^bb0(%arg0 : !qco.qubit): - * %q4 = qco.x %arg0 : !qco.qubit -> !qco.qubit - * scf.yield %q4 : !qco.qubit + * %q2 = qco.h(%arg0) + * %q3, %result = qco.measure %q2 : !qco.qubit + * scf.condition(%result) %q3 : !qco.qubit + * } do { + * ^bb0(%arg0 : !qco.qubit): + * %q4 = qco.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q4 : !qco.qubit * } * ``` */ @@ -1108,21 +1108,21 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfIf(condition, qubits, [&](auto& b) { - * auto q1 = b.h(q0); - * b.scfYield(q1); - * }, [&](auto& b) { - * auto q1 = b.x(q0); - * b.scfYield(q1); + * builder.scfIf(condition, qubits, [&] { + * auto q1 = builder.h(q0); + * builder.scfYield(q1); + * }, [&] { + * auto q1 = builder.x(q0); + * builder.scfYield(q1); * }); * ``` * ```mlir * %q1 = scf.if %condition -> (!qco.qubit) { - * %q2 = qco.h %q0 : !qco.qubit -> !qco.qubit - * scf.yield %q2 : !qco.qubit + * %q2 = qco.h %q0 : !qco.qubit -> !qco.qubit + * scf.yield %q2 : !qco.qubit * } else { - * %q2 = qco.x %q0 : !qco.qubit -> !qco.qubit - * scf.yield %q2 : !qco.qubit + * %q2 = qco.x %q0 : !qco.qubit -> !qco.qubit + * scf.yield %q2 : !qco.qubit * } * ``` */ @@ -1175,7 +1175,7 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.funcReturn( yieldedValues); + * builder.funcReturn(yieldedValues); * ``` * ```mlir * func.return %q0 : !qco.qubit @@ -1211,10 +1211,9 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.funcFunc("test", argTypes, resultTypes, [&](OpBuilder& b, - * ValueRange args) { - * auto q1 = b.h(args[0]); - * b.funcReturn({q1}); + * builder.funcFunc("test", argTypes, resultTypes, [&](ValueRange args) { + * auto q1 = builder.h(args[0]); + * builder.funcReturn(q1); * }) * ``` * ```mlir diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 303f8bd36d..92bd0860b2 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -829,7 +829,7 @@ struct ConvertQCOYieldOp final : OpConversionPattern { * is converted to * ```mlir * scf.if %cond { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` @@ -880,10 +880,10 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { * is converted to * ```mlir * scf.while : () -> () { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.condition(%cond) * } do { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` @@ -937,7 +937,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { * is converted to * ```mlir * scf.for %iv = %lb to %ub step %step { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 005a89e7a3..1f98494c75 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1246,7 +1246,7 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { * @par Example: * ```mlir * scf.if %cond { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` @@ -1335,10 +1335,10 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { * @par Example: * ```mlir * scf.while : () -> () { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.condition(%cond) * } do { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` @@ -1427,7 +1427,7 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { * @par Example: * ```mlir * scf.for %iv = %lb to %ub step %step { - * qc.x %q0 + * qc.x %q0 : !qc.qubit * scf.yield * } * ``` From 6fed2462759240fdf1172c9e96c8c56ef4b1c4bc Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 15:50:04 +0100 Subject: [PATCH 038/108] change testnames --- mlir/unittests/conversion/test_conversion.cpp | 39 +++++++++---------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index ce506e26ca..f6f2738f07 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -74,7 +74,7 @@ static std::string getOutputString(mlir::OwningOpRef& module) { return outputString; } -TEST_F(ConversionTest, ScfForTest) { +TEST_F(ConversionTest, ScfForQCToQCOTest) { // Test conversion from qc to qco for scf.for operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); @@ -101,11 +101,11 @@ TEST_F(ConversionTest, ScfForTest) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, ValueRange{q0}, [&](Value, ValueRange iterArgs) { + b.scfFor(c0, c2, c1, {q0}, [&](Value, ValueRange iterArgs) { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); auto q3 = b.h(q2); - b.scfYield(ValueRange{q3}); + b.scfYield(q3); return q3; }); b.h(scfForRes[0]); @@ -117,7 +117,7 @@ TEST_F(ConversionTest, ScfForTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfForTest2) { +TEST_F(ConversionTest, ScfForQCOToQCTest) { // Test conversion from qco to qc for scf.for operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); @@ -125,11 +125,11 @@ TEST_F(ConversionTest, ScfForTest2) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, ValueRange{q0}, [&](Value, ValueRange iterArgs) { + b.scfFor(c0, c2, c1, {q0}, [&](Value, ValueRange iterArgs) { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); auto q3 = b.h(q2); - b.scfYield(ValueRange{q3}); + b.scfYield(q3); return q3; }); b.h(scfForRes[0]); @@ -160,14 +160,13 @@ TEST_F(ConversionTest, ScfForTest2) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfWhileTest) { +TEST_F(ConversionTest, ScfWhileQCToQCOTest) { // Test conversion from qc to qco for scf.while operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); b.scfWhile( [&] { auto measureResult = b.measure(q0); - b.scfCondition(measureResult); }, [&] { @@ -189,13 +188,13 @@ TEST_F(ConversionTest, ScfWhileTest) { ValueRange{q0}, [&](ValueRange iterArgs) { auto [q1, measureResult] = b.measure(iterArgs[0]); - b.scfCondition(measureResult, ValueRange{q1}); + b.scfCondition(measureResult, q1); return q1; }, [&](ValueRange iterArgs) { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); - b.scfYield({q2}); + b.scfYield(q2); return q2; }); b.h(scfWhileResult[0]); @@ -207,7 +206,7 @@ TEST_F(ConversionTest, ScfWhileTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfWhileTest2) { +TEST_F(ConversionTest, ScfWhileQCOToQCTest) { // Test conversion from qco to qc for scf.while operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); @@ -215,13 +214,13 @@ TEST_F(ConversionTest, ScfWhileTest2) { ValueRange{q0}, [&](ValueRange iterArgs) { auto [q1, measureResult] = b.measure(iterArgs[0]); - b.scfCondition(measureResult, ValueRange{q1}); + b.scfCondition(measureResult, q1); return q1; }, [&](ValueRange iterArgs) { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); - b.scfYield({q2}); + b.scfYield(q2); return q2; }); b.h(scfWhileResult[0]); @@ -253,7 +252,7 @@ TEST_F(ConversionTest, ScfWhileTest2) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfIfTest) { +TEST_F(ConversionTest, ScfIfQCToQCOTest) { // Test conversion from qc to qco for scf.if operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); @@ -281,10 +280,9 @@ TEST_F(ConversionTest, ScfIfTest) { auto q0 = b.allocQubit(); auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( - measureResult, ValueRange{q1}, + measureResult, {q1}, [&] { auto q2 = b.h(q1); - auto q3 = b.y(q2); b.scfYield(q3); return q3; @@ -304,16 +302,15 @@ TEST_F(ConversionTest, ScfIfTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfIfTest2) { +TEST_F(ConversionTest, ScfIfQCOToQCTest) { // Test conversion from qco to qc for scf.if operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( - measureResult, ValueRange{q1}, + measureResult, {q1}, [&] { auto q2 = b.h(q1); - auto q3 = b.y(q2); b.scfYield(q3); return q3; @@ -355,7 +352,7 @@ TEST_F(ConversionTest, ScfIfTest2) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, FuncFuncTest) { +TEST_F(ConversionTest, FuncFuncQCToQCOTest) { // Test conversion from qc to qco for func.func operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); @@ -391,7 +388,7 @@ TEST_F(ConversionTest, FuncFuncTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, FuncFuncTest2) { +TEST_F(ConversionTest, FuncFuncQCOToQCTest) { // Test conversion from qco to qc for func.func operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); From 30cea5525e8735860301d7e5a986a07ec677c7eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 10 Jan 2026 15:02:37 +0000 Subject: [PATCH 039/108] =?UTF-8?q?=F0=9F=8E=A8=20pre-commit=20fixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mlir/unittests/conversion/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/unittests/conversion/CMakeLists.txt b/mlir/unittests/conversion/CMakeLists.txt index 833dc2371b..a90d3bc2a4 100644 --- a/mlir/unittests/conversion/CMakeLists.txt +++ b/mlir/unittests/conversion/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM -# Copyright (c) 2025 Munich Quantum Software Company GmbH +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH # All rights reserved. # # SPDX-License-Identifier: MIT From 7596d58e5068066bd5c376820c49890fbec79d86 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 17:32:41 +0100 Subject: [PATCH 040/108] address coderabbit suggestions --- .../Dialect/QC/Builder/QCProgramBuilder.h | 14 ++--- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 1 - mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 45 ++++++++-------- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 9 ++-- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 14 ++--- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 3 +- mlir/unittests/conversion/test_conversion.cpp | 51 +++++++++++++++++-- 7 files changed, 92 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index fe116d1d0e..e046f1a475 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -892,7 +892,7 @@ class QCProgramBuilder final : public OpBuilder { * ``` */ QCProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, - const std::function& body); + const std::function& body); /** * @brief Constructs a scf.while operation without return values @@ -910,7 +910,7 @@ class QCProgramBuilder final : public OpBuilder { * builder.condition(res); * }, [&] { * builder.x(q0); - * builder.yield(); + * builder.scfYield(); * }); * ``` * ```mlir @@ -952,9 +952,9 @@ class QCProgramBuilder final : public OpBuilder { * } * ``` */ - QCProgramBuilder& scfIf(Value condition, - const std::function& thenBody, - const std::function& elseBody = nullptr); + QCProgramBuilder& + scfIf(Value condition, const std::function& thenBody, + std::optional> elseBody = std::nullopt); /** * @brief Constructs a scf.condition operation without any additional Values @@ -964,7 +964,7 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.condition(condition); + * builder.scfCondition(condition); * ``` * ```mlir * scf.condition(%condition) @@ -1008,7 +1008,7 @@ class QCProgramBuilder final : public OpBuilder { QCProgramBuilder& funcCall(StringRef name, ValueRange operands); /** - * @brief Constructs a func.func operation with return values + * @brief Constructs a func.func operation without return values * * @param name Name of the function that is called * @param argTypes TypeRange of the arguments diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 0fb4e3a15f..7edefab802 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1284,7 +1284,6 @@ class QCOProgramBuilder final : public OpBuilder { MLIRContext* ctx{}; Location loc; ModuleOp module; - Region* funcRegion; /// Check if the builder has been finalized void checkFinalized() const; diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 92bd0860b2..b139c7c9f0 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -815,7 +815,7 @@ struct ConvertQCOYieldOp final : OpConversionPattern { /** * @brief Converts scf.if with value semantics to scf.if with memory semantics - * for qubit values + * for qubit values. This currently assumes no mixed types as return values. * * @par Example: * ```mlir @@ -841,16 +841,15 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { // Create the new if operation - auto newIf = - rewriter.create(op.getLoc(), ValueRange{}, op.getCondition(), - op.getElseRegion().empty()); + auto newIf = rewriter.create(op.getLoc(), ValueRange{}, + op.getCondition(), false); // Inline the regions rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); - if (!op.getElseRegion().empty()) { - rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), - newIf.getElseRegion().end()); - } + + rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), + newIf.getElseRegion().end()); + // Erase the empty block that was created during the initialization rewriter.eraseBlock(&newIf.getThenRegion().front()); @@ -864,7 +863,8 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { /** * @brief Converts scf.while with value semantics to scf.while with memory - * semantics for qubit values + * semantics for qubit values. This currently assumes no mixed types as return + * values. * * @par Example: * ```mlir @@ -906,6 +906,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { beforeArgs[i].replaceAllUsesWith(inits[i]); afterArgs[i].replaceAllUsesWith(inits[i]); } + // Create the blocks of the new operation and move the operations to them auto* newBeforeBlock = rewriter.createBlock(&newWhileOp.getBefore(), {}, {}, {}); @@ -915,7 +916,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { op.getBeforeBody()->getOperations()); newAfterBlock->getOperations().splice(newAfterBlock->end(), op.getAfterBody()->getOperations()); - + llvm::outs() << newWhileOp.getBefore().getBlocks().size() << "\n"; // replace the result values with the init values rewriter.replaceOp(op, inits); return success(); @@ -924,7 +925,8 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { /** * @brief Converts scf.for with value semantics to scf.while with memory - * semantics for qubit values + * semantics for qubit values. This currently assumes no mixed types as return + * values. * * @par Example: * ```mlir @@ -974,7 +976,8 @@ struct ConvertQCOScfForOp final : OpConversionPattern { /** * @brief Converts scf.yield with value semantics to scf.yield with memory - * semantics for qubit values + * semantics for qubit values. This currently assumes no mixed types as yielded + * values. * * @par Example: * ```mlir @@ -998,7 +1001,8 @@ struct ConvertQCOScfYieldOp final : OpConversionPattern { /** * @brief Converts scf.condition with value semantics to scf.condition with - * memory semantics for qubit values + * memory semantics for qubit values. This currently assumes no mixed types as + * target values. * * @par Example: * ```mlir @@ -1025,7 +1029,8 @@ struct ConvertQCOScfConditionOp final : OpConversionPattern { /** * @brief Converts func.call with value semantics to func.call with - * memory semantics for qubit values + * memory semantics for qubit values. This currently assumes no mixed types as + * parameters/return values. * * @par Example: * ```mlir @@ -1052,7 +1057,8 @@ struct ConvertQCOFuncCallOp final : OpConversionPattern { /** * @brief Converts func.func with memory semantics to func.func with - * value semantics for qubit values + * value semantics for qubit values. This currently assumes no mixed types as + * parameters/return values. * * @par Example: * ```mlir @@ -1087,7 +1093,8 @@ struct ConvertQCOFuncFuncOp final : OpConversionPattern { /** * @brief Converts func.return with value semantics to func.return with - * memory semantics for qubit values + * memory semantics for qubit values. This currently assumes no mixed types as + * target values. * * @par Example: * ```mlir @@ -1154,8 +1161,7 @@ struct QCOToQC final : impl::QCOToQCBase { target.addDynamicallyLegalOp([&](scf::YieldOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qc::QubitType::get(context) || - type == qco::QubitType::get(context); + return type == qco::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](scf::WhileOp op) { @@ -1185,8 +1191,7 @@ struct QCOToQC final : impl::QCOToQCBase { }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qc::QubitType::get(context) || - type == qco::QubitType::get(context); + return type == qco::QubitType::get(context); }); }); diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 1f98494c75..bfe0459135 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1276,7 +1276,8 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { // create new if operation auto newIfOp = rewriter.create(op->getLoc(), TypeRange{qcoTypes}, - op.getCondition(), true); + op.getCondition(), + op.getElseRegion().empty()); auto& thenRegion = newIfOp.getThenRegion(); auto& elseRegion = newIfOp.getElseRegion(); @@ -1289,7 +1290,6 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { if (!op.getElseRegion().empty()) { rewriter.inlineRegionBefore(op.getElseRegion(), elseRegion, elseRegion.end()); - rewriter.eraseBlock(&elseRegion.front()); } else { // create the yield operation if it does not exist yet rewriter.setInsertionPointToEnd(&elseRegion.front()); @@ -1330,7 +1330,7 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { /** * @brief Converts scf.while with memory semantics to scf.while with value - * semantics for qubit values + * semantics for qubit values. * * @par Example: * ```mlir @@ -1521,6 +1521,7 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { SmallVector qcoQubits; qcoQubits.reserve(orderedQubits.size()); for (const auto& qcQubit : orderedQubits) { + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap.lookup(qcQubit)); } @@ -1558,6 +1559,7 @@ struct ConvertQCScfConditionOp final SmallVector qcoQubits; qcoQubits.reserve(orderedQubits.size()); for (const auto& qcQubit : orderedQubits) { + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap.lookup(qcQubit)); } @@ -1683,6 +1685,7 @@ struct ConvertQCFuncReturnOp final SmallVector qcoQubits; qcoQubits.reserve(orderedQubits.size()); for (const auto& qcQubit : orderedQubits) { + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap.lookup(qcQubit)); } rewriter.replaceOpWithNewOp(op, qcoQubits); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 0cbdb28c66..c1e18e4251 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -448,12 +448,12 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { // SCF operations //===----------------------------------------------------------------------===// -QCProgramBuilder& QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, - Value step, - const std::function& body) { +QCProgramBuilder& +QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, + const std::function& body) { create(loc, lowerbound, upperbound, step, ValueRange{}, - [&](OpBuilder& b, Location, Value, ValueRange) { - body(); + [&](OpBuilder& b, Location, Value iv, ValueRange) { + body(iv); b.create(loc); }); @@ -476,7 +476,7 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, QCProgramBuilder& QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, - const std::function& elseBody) { + std::optional> elseBody) { if (!elseBody) { create(loc, cond, [&](OpBuilder& b, Location loc) { thenBody(); @@ -490,7 +490,7 @@ QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, b.create(loc); }, [&](OpBuilder& b, Location loc) { - elseBody(); + (*elseBody)(); b.create(loc); }); } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index b1e1eea417..8096b9ae0c 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -37,7 +37,7 @@ namespace mlir::qco { QCOProgramBuilder::QCOProgramBuilder(MLIRContext* context) : OpBuilder(context), ctx(context), loc(getUnknownLoc()), - module(ModuleOp::create(loc)), funcRegion(nullptr) { + module(ModuleOp::create(loc)) { ctx->loadDialect(); } @@ -52,7 +52,6 @@ void QCOProgramBuilder::initialize() { // Add entry_point attribute to identify the main function auto entryPointAttr = getStringAttr("entry_point"); mainFunc->setAttr("passthrough", getArrayAttr({entryPointAttr})); - funcRegion = &mainFunc->getRegion(0); // Create entry block and set insertion point auto& entryBlock = mainFunc.getBody().emplaceBlock(); setInsertionPointToStart(&entryBlock); diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index f6f2738f07..0eb91b121d 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include #include @@ -81,7 +80,7 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&] { + b.scfFor(c0, c2, c1, [&](Value /*iv*/) { b.h(q0); b.x(q0); b.h(q0); @@ -101,7 +100,7 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, {q0}, [&](Value, ValueRange iterArgs) { + b.scfFor(c0, c2, c1, {q0}, [&](Value /*iv*/, ValueRange iterArgs) { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); auto q3 = b.h(q2); @@ -125,7 +124,7 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, {q0}, [&](Value, ValueRange iterArgs) { + b.scfFor(c0, c2, c1, {q0}, [&](Value /*iv*/, ValueRange iterArgs) { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); auto q3 = b.h(q2); @@ -146,7 +145,7 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&] { + b.scfFor(c0, c2, c1, [&](Value /*iv*/) { b.h(q0); b.x(q0); b.h(q0); @@ -352,6 +351,48 @@ TEST_F(ConversionTest, ScfIfQCOToQCTest) { ASSERT_EQ(outputString, checkString); } +TEST_F(ConversionTest, ScfIfEmptyElseTest) { + // Test conversion from qc to qco for scf.if operation without an else body + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto measure = b.measure(q0); + b.scfIf(measure, [&] { + b.h(q0); + b.y(q0); + }); + b.h(q0); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf.if"; + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto [q1, measureResult] = b.measure(q0); + auto scfIfResult = b.scfIf( + measureResult, {q1}, + [&] { + auto q2 = b.h(q1); + auto q3 = b.y(q2); + b.scfYield(q3); + return q3; + }, + [&] { + b.scfYield(q1); + return q1; + }); + b.h(scfIfResult[0]); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + TEST_F(ConversionTest, FuncFuncQCToQCOTest) { // Test conversion from qc to qco for func.func operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { From b52fa7b876845d0ad6f77ce21c578e4a8dc33666 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 17:36:05 +0100 Subject: [PATCH 041/108] add explanation for nested qubit tracking --- mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 7edefab802..5e302eb835 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1303,11 +1303,13 @@ class QCOProgramBuilder final : public OpBuilder { * @brief Update tracking when an operation consumes and produces a qubit * @param inputQubit Input qubit being consumed (must be valid) * @param outputQubit New output qubit being produced + * @param region The Region in where the qubits are defined. */ void updateQubitTracking(Value inputQubit, Value outputQubit, Region* region); /// Track valid (unconsumed) qubit SSA values for linear type enforcement. /// Only values present in this set are valid for use in operations. + /// Each Region has its own set of valid qubits. /// When an operation consumes a qubit and produces a new one, the old value /// is removed and the new output is added. llvm::DenseMap> validQubits; From 18487bc8214dbdc5c15d4d33244eeb7cf579963d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 17:44:36 +0100 Subject: [PATCH 042/108] fix linter issues --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 1 - mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index b139c7c9f0..8bbfa44657 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -916,7 +916,6 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { op.getBeforeBody()->getOperations()); newAfterBlock->getOperations().splice(newAfterBlock->end(), op.getAfterBody()->getOperations()); - llvm::outs() << newWhileOp.getBefore().getBlocks().size() << "\n"; // replace the result values with the init values rewriter.replaceOp(op, inits); return success(); diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index c1e18e4251..17cdc319d0 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include From b4da26f5d4f2e0be4db7dd242d28559de1802e20 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 18:06:54 +0100 Subject: [PATCH 043/108] add checkFinalized() calls --- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 18 +++++++++++++++++ .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 20 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 17cdc319d0..19f3c8bd27 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -452,6 +452,8 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { QCProgramBuilder& QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, const std::function& body) { + checkFinalized(); + create(loc, lowerbound, upperbound, step, ValueRange{}, [&](OpBuilder& b, Location, Value iv, ValueRange) { body(iv); @@ -464,6 +466,8 @@ QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, QCProgramBuilder& QCProgramBuilder::scfWhile(const std::function& beforeBody, const std::function& afterBody) { + checkFinalized(); + create( loc, TypeRange{}, ValueRange{}, [&](OpBuilder& /*b*/, Location, ValueRange) { beforeBody(); }, @@ -478,6 +482,8 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, QCProgramBuilder& QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, std::optional> elseBody) { + checkFinalized(); + if (!elseBody) { create(loc, cond, [&](OpBuilder& b, Location loc) { thenBody(); @@ -499,6 +505,8 @@ QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, } QCProgramBuilder& QCProgramBuilder::scfCondition(Value condition) { + checkFinalized(); + create(loc, condition, ValueRange{}); return *this; } @@ -509,11 +517,15 @@ QCProgramBuilder& QCProgramBuilder::scfCondition(Value condition) { QCProgramBuilder& QCProgramBuilder::funcCall(StringRef name, ValueRange operands) { + checkFinalized(); + create(loc, name, TypeRange{}, operands); return *this; } QCProgramBuilder& QCProgramBuilder::funcReturn() { + checkFinalized(); + create(loc); return *this; } @@ -521,6 +533,8 @@ QCProgramBuilder& QCProgramBuilder::funcReturn() { QCProgramBuilder& QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, const std::function& body) { + checkFinalized(); + // Set the insertionPoint const OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); @@ -543,12 +557,16 @@ QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, //===----------------------------------------------------------------------===// Value QCProgramBuilder::arithConstantIndex(int64_t index) { + checkFinalized(); + const auto op = create(loc, getIndexType(), getIndexAttr(index)); return op->getResult(0); } Value QCProgramBuilder::arithConstantBool(bool b) { + checkFinalized(); + const auto i1Type = getI1Type(); const auto op = create(loc, i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 8096b9ae0c..2b7eac3292 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -603,6 +603,8 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, const std::function& body) { + checkFinalized(); + // Create the empty for operation auto forOp = create(loc, lowerbound, upperbound, step, initArgs); auto* forBody = forOp.getBody(); @@ -634,6 +636,8 @@ ValueRange QCOProgramBuilder::scfWhile( ValueRange initArgs, const std::function& beforeBody, const std::function& afterBody) { + checkFinalized(); + // Create the empty while operation auto whileOp = create(loc, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), loc); @@ -685,6 +689,8 @@ ValueRange QCOProgramBuilder::scfIf(Value condition, ValueRange qubits, const std::function& thenBody, const std::function& elseBody) { + checkFinalized(); + // Create the empty while operation auto ifOp = create(loc, qubits.getTypes(), condition, /*withElseRegion=*/true); @@ -722,11 +728,15 @@ QCOProgramBuilder::scfIf(Value condition, ValueRange qubits, QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, ValueRange yieldedValues) { + checkFinalized(); + create(loc, condition, yieldedValues); return *this; } QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { + checkFinalized(); + create(loc, yieldedValues); return *this; } @@ -736,6 +746,8 @@ QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { //===----------------------------------------------------------------------===// ValueRange QCOProgramBuilder::funcCall(StringRef name, ValueRange operands) { + checkFinalized(); + const auto callOp = create(loc, name, operands.getTypes(), operands); for (auto [arg, result] : llvm::zip_equal(operands, callOp->getResults())) { @@ -745,6 +757,8 @@ ValueRange QCOProgramBuilder::funcCall(StringRef name, ValueRange operands) { } QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange returnValues) { + checkFinalized(); + create(loc, returnValues); return *this; } @@ -752,6 +766,8 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, const std::function& body) { + checkFinalized(); + // Set the insertionPoint const OpBuilder::InsertionGuard guard(*this); setInsertionPointToEnd(module.getBody()); @@ -779,12 +795,16 @@ QCOProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, //===----------------------------------------------------------------------===// Value QCOProgramBuilder::arithConstantIndex(int64_t i) { + checkFinalized(); + const auto op = create(loc, getIndexType(), getIndexAttr(i)); return op->getResult(0); } Value QCOProgramBuilder::arithConstantBool(bool b) { + checkFinalized(); + const auto i1Type = getI1Type(); const auto op = create(loc, i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); From d45f4488e5bb5b97a2035dcf4957ebb856cc92cc Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 18:24:41 +0100 Subject: [PATCH 044/108] add headers --- mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h | 1 + mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index e046f1a475..88edc5041e 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 5e302eb835..63a415a6bd 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -14,6 +14,7 @@ #include #include +#include #include #include #include @@ -1052,7 +1053,7 @@ class QCOProgramBuilder final : public OpBuilder { * ```mlir * %q1 = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %q0) * -> !qco.qubit { - * %q2 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * %q2 = qco.x %arg0 : !qco.qubit -> !qco.qubit * scf.yield %q2 : !qco.qubit * } * ``` @@ -1074,7 +1075,7 @@ class QCOProgramBuilder final : public OpBuilder { * builder.scfWhile(args, [&](ValueRange iterArgs) { * auto q1 = builder.h(iterArgs[0]); * auto [q2, measureRes] = builder.measure(q1); - * builder.condition(measureRes, q2); + * builder.scfCondition(measureRes, q2); * }, [&](ValueRange iterArgs) { * auto q1 = builder.x(iterArgs[0]); * builder.scfYield(q1); From dbe805d9f8bb6b57c685251eba2fcf89b993eaf2 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 18:44:38 +0100 Subject: [PATCH 045/108] add additional assertions and fix builder parameters --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 26 +++++++----- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 55 ++++++++++++++++++------- 2 files changed, 56 insertions(+), 25 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 8bbfa44657..64061d30d1 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -841,7 +841,7 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { // Create the new if operation - auto newIf = rewriter.create(op.getLoc(), ValueRange{}, + auto newIf = rewriter.create(op.getLoc(), TypeRange{}, op.getCondition(), false); // Inline the regions rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), @@ -896,7 +896,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { ConversionPatternRewriter& rewriter) const override { // Create the new while operation auto newWhileOp = - rewriter.create(op->getLoc(), ValueRange{}, ValueRange{}); + rewriter.create(op->getLoc(), TypeRange{}, ValueRange{}); // Replace the uses of the blockarguments with the init values const auto& inits = adaptor.getInits(); @@ -959,6 +959,7 @@ struct ConvertQCOScfForOp final : OpConversionPattern { llvm::zip_equal(op.getRegionIterArgs(), adaptor.getInitArgs())) { qcoQubit.replaceAllUsesWith(qcQubit); } + rewriter.replaceAllUsesWith(op.getInductionVar(), newFor.getInductionVar()); // Move all the operations from the old block to the new block auto* newBlock = newFor.getBody(); @@ -966,7 +967,7 @@ struct ConvertQCOScfForOp final : OpConversionPattern { rewriter.eraseOp(newBlock->getTerminator()); newBlock->getOperations().splice(newBlock->end(), op.getBody()->getOperations()); - rewriter.replaceAllUsesWith(op.getInductionVar(), newFor.getInductionVar()); + // Replace the result values with the init values rewriter.replaceOp(op, adaptor.getInitArgs()); return success(); @@ -1078,14 +1079,16 @@ struct ConvertQCOFuncFuncOp final : OpConversionPattern { LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - const SmallVector argumentTypes( - op.front().getNumArguments(), - qc::QubitType::get(rewriter.getContext())); - for (auto blockArg : op.front().getArguments()) { - blockArg.setType(qc::QubitType::get(rewriter.getContext())); - } - auto newFuncType = rewriter.getFunctionType(argumentTypes, {}); - op.setFunctionType(newFuncType); + rewriter.modifyOpInPlace(op, [&] { + const SmallVector argumentTypes( + op.front().getNumArguments(), + qc::QubitType::get(rewriter.getContext())); + for (auto blockArg : op.front().getArguments()) { + blockArg.setType(qc::QubitType::get(rewriter.getContext())); + } + auto newFuncType = rewriter.getFunctionType(argumentTypes, {}); + op.setFunctionType(newFuncType); + }); return success(); } }; @@ -1114,6 +1117,7 @@ struct ConvertQCOFuncReturnOp final : OpConversionPattern { return success(); } }; + /** * @brief Pass implementation for QCO-to-QC conversion * diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index bfe0459135..7b70699c22 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -135,6 +135,9 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (region.empty()) { continue; } + // check that the region has only one block + assert(region.hasOneBlock() && "Expected single-block region"); + // iterate over all operations inside the region // currently assumes that each region only has one block for (auto& operation : region.front().getOperations()) { @@ -1513,6 +1516,11 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + auto const qcType = qc::QubitType::get(rewriter.getContext()); + assert(llvm::all_of(op.getOperandTypes(), + [&](Type type) { return type == qcType; }) && + "Not all operands are qc qubits"); + const auto& parentRegion = op->getParentRegion(); const auto& qubitMap = getState().qubitMap[parentRegion]; const auto& orderedQubits = @@ -1551,6 +1559,7 @@ struct ConvertQCScfConditionOp final LogicalResult matchAndRewrite(scf::ConditionOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + const auto& parentRegion = op->getParentRegion(); const auto& qubitMap = getState().qubitMap[parentRegion]; const auto& orderedQubits = @@ -1589,6 +1598,11 @@ struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + auto const qcType = qc::QubitType::get(rewriter.getContext()); + assert(llvm::all_of(op.getOperandTypes(), + [&](Type type) { return type == qcType; }) && + "Not all operands are qc qubits"); + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; auto qcQubits = op->getOperands(); @@ -1637,21 +1651,28 @@ struct ConvertQCFuncFuncOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(func::FuncOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto& qubitMap = getState().qubitMap[&op->getRegion(0)]; - const SmallVector qcoTypes( - op.front().getNumArguments(), - qco::QubitType::get(rewriter.getContext())); - - // set the arguments to qco qubit type - for (auto blockArg : op.front().getArguments()) { - blockArg.setType(qco::QubitType::get(rewriter.getContext())); - qubitMap.try_emplace(blockArg, blockArg); - } + auto const qcType = qc::QubitType::get(rewriter.getContext()); + assert(llvm::all_of(op->getOperandTypes(), + [&](Type type) { return type == qcType; }) && + "Not all operands are qc qubits"); + + rewriter.modifyOpInPlace(op, [&] { + auto& qubitMap = getState().qubitMap[&op->getRegion(0)]; + const SmallVector qcoTypes( + op.front().getNumArguments(), + qco::QubitType::get(rewriter.getContext())); + + // set the arguments to qco qubit type + for (auto blockArg : op.front().getArguments()) { + blockArg.setType(qco::QubitType::get(rewriter.getContext())); + qubitMap.try_emplace(blockArg, blockArg); + } - // change the function signature to return the same number of qco Qubits as - // it gets as input - auto newFuncType = rewriter.getFunctionType(qcoTypes, qcoTypes); // - op.setFunctionType(newFuncType); + // change the function signature to return the same number of qco Qubits + // as it gets as input + auto newFuncType = rewriter.getFunctionType(qcoTypes, qcoTypes); // + op.setFunctionType(newFuncType); + }); return success(); } }; @@ -1677,6 +1698,11 @@ struct ConvertQCFuncReturnOp final LogicalResult matchAndRewrite(func::ReturnOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + auto const qcType = qc::QubitType::get(rewriter.getContext()); + assert(llvm::all_of(op.getOperandTypes(), + [&](Type type) { return type == qcType; }) && + "Not all operands are qc qubits"); + const auto& parentRegion = op->getParentRegion(); const auto& qubitMap = getState().qubitMap[parentRegion]; const auto& orderedQubits = @@ -1692,6 +1718,7 @@ struct ConvertQCFuncReturnOp final return success(); } }; + /** * @brief Pass implementation for QC-to-QCO conversion * From c4a1d3568dcac2b3f7cfa919580aaf21889fca34 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 19:22:48 +0100 Subject: [PATCH 046/108] fix more wrong docstrings --- mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h | 5 ++--- mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 88edc5041e..f413ab45f4 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -884,7 +884,7 @@ class QCProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfFor(lb, ub, step, [&] { builder.x(q0); }); + * builder.scfFor(lb, ub, step, [&](Value iv) { builder.x(q0); }); * ``` * ```mlir * scf.for %iv = %lb to %ub step %step { @@ -908,10 +908,9 @@ class QCProgramBuilder final : public OpBuilder { * builder.scfWhile([&] { * builder.h(q0); * auto res = builder.measure(q0); - * builder.condition(res); + * builder.scfCondition(res); * }, [&] { * builder.x(q0); - * builder.scfYield(); * }); * ``` * ```mlir diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 63a415a6bd..1fddf0c901 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1045,8 +1045,8 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfFor(lb, ub, step, initArgs, [&] { - * auto q1 = builder.x(initArgs[0]); + * builder.scfFor(lb, ub, step, initArgs, [&](Value iv, ValueRange iterArgs) { + * auto q1 = builder.x(iterArgs[0]); * builder.scfYield(q1); * }); * ``` @@ -1298,7 +1298,7 @@ class QCOProgramBuilder final : public OpBuilder { * @param qubit Qubit value to validate * @throws Aborts if qubit is not tracked (consumed or never created) */ - void validateQubitValue(Value qubit); + void validateQubitValue(Value qubit, Region* region) const; /** * @brief Update tracking when an operation consumes and produces a qubit From 9866d2abdcc5e73bb6ee7b024bf3e4e8f3abef1a Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 19:23:31 +0100 Subject: [PATCH 047/108] add additional operation legality checks --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 64061d30d1..522d2cbc16 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1164,7 +1164,8 @@ struct QCOToQC final : impl::QCOToQCBase { target.addDynamicallyLegalOp([&](scf::YieldOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context); + return type == qco::QubitType::get(context) || + type == qc::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](scf::WhileOp op) { @@ -1174,7 +1175,8 @@ struct QCOToQC final : impl::QCOToQCBase { }); target.addDynamicallyLegalOp([&](scf::ConditionOp op) { return !llvm::any_of(op.getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context); + return type == qco::QubitType::get(context) || + type == qc::QubitType::get(context); }); }); target.addDynamicallyLegalOp([&](scf::ForOp op) { @@ -1194,7 +1196,8 @@ struct QCOToQC final : impl::QCOToQCBase { }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context); + return type == qco::QubitType::get(context) || + type == qc::QubitType::get(context); }); }); From b36f5ab7c2602d2bbceb991e64bbd841b0a11963 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 10 Jan 2026 19:29:18 +0100 Subject: [PATCH 048/108] add insertionGuards in QCBuilders --- mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 19f3c8bd27..68495d388b 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -456,6 +456,8 @@ QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, create(loc, lowerbound, upperbound, step, ValueRange{}, [&](OpBuilder& b, Location, Value iv, ValueRange) { + const OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(b.getInsertionBlock()); body(iv); b.create(loc); }); @@ -470,8 +472,14 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, create( loc, TypeRange{}, ValueRange{}, - [&](OpBuilder& /*b*/, Location, ValueRange) { beforeBody(); }, + [&](OpBuilder& b, Location, ValueRange) { + const OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(b.getInsertionBlock()); + beforeBody(); + }, [&](OpBuilder& b, Location loc, ValueRange) { + const OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(b.getInsertionBlock()); afterBody(); b.create(loc); }); @@ -486,6 +494,8 @@ QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, if (!elseBody) { create(loc, cond, [&](OpBuilder& b, Location loc) { + const OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(b.getInsertionBlock()); thenBody(); b.create(loc); }); @@ -493,10 +503,14 @@ QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, create( loc, cond, [&](OpBuilder& b, Location loc) { + const OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(b.getInsertionBlock()); thenBody(); b.create(loc); }, [&](OpBuilder& b, Location loc) { + const OpBuilder::InsertionGuard guard(*this); + setInsertionPointToStart(b.getInsertionBlock()); (*elseBody)(); b.create(loc); }); From e2680892e86c27fc3934d47347d7b8589fa50f87 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 12:27:09 +0100 Subject: [PATCH 049/108] smaller fixes --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 +- mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 7b70699c22..f4f1406558 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1652,7 +1652,7 @@ struct ConvertQCFuncFuncOp final : StatefulOpConversionPattern { matchAndRewrite(func::FuncOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto const qcType = qc::QubitType::get(rewriter.getContext()); - assert(llvm::all_of(op->getOperandTypes(), + assert(llvm::all_of(op.getArgumentTypes(), [&](Type type) { return type == qcType; }) && "Not all operands are qc qubits"); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 2b7eac3292..a195246be4 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -128,8 +128,10 @@ QCOProgramBuilder::allocClassicalBitRegister(const int64_t size, // Linear Type Tracking Helpers //===----------------------------------------------------------------------===// -void QCOProgramBuilder::validateQubitValue(Value qubit) { - if (!validQubits[qubit.getParentRegion()].contains(qubit)) { +void QCOProgramBuilder::validateQubitValue(Value qubit, Region* region) const { + auto qubits = validQubits.lookup(region); + + if (qubits.empty() || !qubits.contains(qubit)) { llvm::errs() << "Attempting to use an invalid qubit SSA value. " << "The value may have been consumed by a previous operation " << "or was never created through this builder.\n"; @@ -141,7 +143,7 @@ void QCOProgramBuilder::validateQubitValue(Value qubit) { void QCOProgramBuilder::updateQubitTracking(Value inputQubit, Value outputQubit, Region* region) { // Validate the input qubit - validateQubitValue(inputQubit); + validateQubitValue(inputQubit, region); // Remove the input (consumed) value from tracking validQubits[region].erase(inputQubit); // Add the output (new) value to tracking @@ -588,7 +590,7 @@ QCOProgramBuilder::ctrl(ValueRange controls, ValueRange targets, QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { checkFinalized(); - validateQubitValue(qubit); + validateQubitValue(qubit, qubit.getParentRegion()); validQubits[qubit.getParentRegion()].erase(qubit); DeallocOp::create(*this, loc, qubit); From 4ea6c50a942fea90492c169d9e68fc06e25a3de3 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 13:26:39 +0100 Subject: [PATCH 050/108] update scf builder signatures --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 36 ++++++---- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 14 ++-- mlir/unittests/conversion/test_conversion.cpp | 71 ++++++++++--------- 3 files changed, 65 insertions(+), 56 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index f42c528998..d3ba1ad44e 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1046,9 +1046,10 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfFor(lb, ub, step, initArgs, [&](Value iv, ValueRange iterArgs) { - * auto q1 = builder.x(iterArgs[0]); + * builder.scfFor(lb, ub, step, initArgs, [&](Value iv, ValueRange iterArgs) + * -> llvm::SmallVector { auto q1 = builder.x(iterArgs[0]); * builder.scfYield(q1); + * return {q1}; * }); * ``` * ```mlir @@ -1059,9 +1060,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange scfFor(Value lowerbound, Value upperbound, Value step, - ValueRange initArgs, - const std::function& body); + ValueRange + scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, + llvm::function_ref(Value, ValueRange)> body); /** * @brief Constructs a scf.while operation with return values * @@ -1073,13 +1074,15 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfWhile(args, [&](ValueRange iterArgs) { - * auto q1 = builder.h(iterArgs[0]); + * builder.scfWhile(args, [&](ValueRange iterArgs) -> llvm::SmallVector + * { auto q1 = builder.h(iterArgs[0]); * auto [q2, measureRes] = builder.measure(q1); * builder.scfCondition(measureRes, q2); - * }, [&](ValueRange iterArgs) { + * return {q2}; + * }, [&](ValueRange iterArgs) -> llvm::SmallVector { * auto q1 = builder.x(iterArgs[0]); * builder.scfYield(q1); + * return {q1}; * }); * ``` * ```mlir @@ -1094,9 +1097,10 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange scfWhile(ValueRange args, - const std::function& beforeBody, - const std::function& afterBody); + ValueRange + scfWhile(ValueRange args, + llvm::function_ref(ValueRange)> beforeBody, + llvm::function_ref(ValueRange)> afterBody); /** * @brief Constructs a scf.if operation with return values @@ -1110,12 +1114,14 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.scfIf(condition, qubits, [&] { + * builder.scfIf(condition, qubits, [&]() -> llvm::SmallVector { * auto q1 = builder.h(q0); * builder.scfYield(q1); - * }, [&] { + * return {q1}; + * }, [&]() -> llvm::SmallVector { * auto q1 = builder.x(q0); * builder.scfYield(q1); + * return {q1}; * }); * ``` * ```mlir @@ -1129,8 +1135,8 @@ class QCOProgramBuilder final : public OpBuilder { * ``` */ ValueRange scfIf(Value condition, ValueRange qubits, - const std::function& thenBody, - const std::function& elseBody); + llvm::function_ref()> thenBody, + llvm::function_ref()> elseBody); /** * @brief Constructs a scf.condition operation with yielded values diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 5e32604eb5..ff1f26b60d 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -618,7 +618,7 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { ValueRange QCOProgramBuilder::scfFor( Value lowerbound, Value upperbound, Value step, ValueRange initArgs, - const std::function& body) { + llvm::function_ref(Value, ValueRange)> body) { checkFinalized(); // Create the empty for operation @@ -650,8 +650,8 @@ ValueRange QCOProgramBuilder::scfFor( ValueRange QCOProgramBuilder::scfWhile( ValueRange initArgs, - const std::function& beforeBody, - const std::function& afterBody) { + llvm::function_ref(ValueRange)> beforeBody, + llvm::function_ref(ValueRange)> afterBody) { checkFinalized(); // Create the empty while operation @@ -701,10 +701,10 @@ ValueRange QCOProgramBuilder::scfWhile( return whileOp->getResults(); } -ValueRange -QCOProgramBuilder::scfIf(Value condition, ValueRange qubits, - const std::function& thenBody, - const std::function& elseBody) { +ValueRange QCOProgramBuilder::scfIf( + Value condition, ValueRange qubits, + llvm::function_ref()> thenBody, + llvm::function_ref()> elseBody) { checkFinalized(); // Create the empty while operation diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 0eb91b121d..a75e842a5c 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -100,13 +101,14 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, {q0}, [&](Value /*iv*/, ValueRange iterArgs) { - auto q1 = b.h(iterArgs[0]); - auto q2 = b.x(q1); - auto q3 = b.h(q2); - b.scfYield(q3); - return q3; - }); + b.scfFor(c0, c2, c1, {q0}, + [&](Value /*iv*/, ValueRange iterArgs) -> SmallVector { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + b.scfYield(q3); + return {q3}; + }); b.h(scfForRes[0]); }); @@ -124,13 +126,14 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, {q0}, [&](Value /*iv*/, ValueRange iterArgs) { - auto q1 = b.h(iterArgs[0]); - auto q2 = b.x(q1); - auto q3 = b.h(q2); - b.scfYield(q3); - return q3; - }); + b.scfFor(c0, c2, c1, {q0}, + [&](Value /*iv*/, ValueRange iterArgs) -> SmallVector { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + b.scfYield(q3); + return {q3}; + }); b.h(scfForRes[0]); }); @@ -185,16 +188,16 @@ TEST_F(ConversionTest, ScfWhileQCToQCOTest) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](ValueRange iterArgs) { + [&](ValueRange iterArgs) -> SmallVector { auto [q1, measureResult] = b.measure(iterArgs[0]); b.scfCondition(measureResult, q1); - return q1; + return {q1}; }, - [&](ValueRange iterArgs) { + [&](ValueRange iterArgs) -> SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); b.scfYield(q2); - return q2; + return {q2}; }); b.h(scfWhileResult[0]); }); @@ -211,16 +214,16 @@ TEST_F(ConversionTest, ScfWhileQCOToQCTest) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](ValueRange iterArgs) { + [&](ValueRange iterArgs) -> SmallVector { auto [q1, measureResult] = b.measure(iterArgs[0]); b.scfCondition(measureResult, q1); - return q1; + return {q1}; }, - [&](ValueRange iterArgs) { + [&](ValueRange iterArgs) -> SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); b.scfYield(q2); - return q2; + return {q2}; }); b.h(scfWhileResult[0]); }); @@ -280,17 +283,17 @@ TEST_F(ConversionTest, ScfIfQCToQCOTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, {q1}, - [&] { + [&]() -> SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); b.scfYield(q3); - return q3; + return {q3}; }, - [&] { + [&]() -> SmallVector { auto q2 = b.y(q1); auto q3 = b.h(q2); b.scfYield(q3); - return q3; + return {q3}; }); b.h(scfIfResult[0]); }); @@ -308,17 +311,17 @@ TEST_F(ConversionTest, ScfIfQCOToQCTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, {q1}, - [&] { + [&]() -> SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); b.scfYield(q3); - return q3; + return {q3}; }, - [&] { + [&]() -> SmallVector { auto q2 = b.y(q1); auto q3 = b.h(q2); b.scfYield(q3); - return q3; + return {q3}; }); b.h(scfIfResult[0]); }); @@ -374,15 +377,15 @@ TEST_F(ConversionTest, ScfIfEmptyElseTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, {q1}, - [&] { + [&]() -> SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); b.scfYield(q3); - return q3; + return {q3}; }, - [&] { + [&]() -> SmallVector { b.scfYield(q1); - return q1; + return {q1}; }); b.h(scfIfResult[0]); }); From 7c03d28187e3373831c37adfe3ffff34a6c20e56 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 13:42:05 +0100 Subject: [PATCH 051/108] fix signatures --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 3 +- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 2 +- mlir/unittests/conversion/test_conversion.cpp | 56 +++++++++---------- 3 files changed, 30 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index d3ba1ad44e..0cb382ade9 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -13,7 +13,6 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include -#include #include #include #include @@ -1233,7 +1232,7 @@ class QCOProgramBuilder final : public OpBuilder { */ QCOProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body); + llvm::function_ref body); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index ff1f26b60d..27f3839712 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -781,7 +781,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange returnValues) { QCOProgramBuilder& QCOProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, - const std::function& body) { + llvm::function_ref body) { checkFinalized(); // Set the insertionPoint diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index a75e842a5c..3e13ab4bb7 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -100,15 +100,15 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - auto scfForRes = - b.scfFor(c0, c2, c1, {q0}, - [&](Value /*iv*/, ValueRange iterArgs) -> SmallVector { - auto q1 = b.h(iterArgs[0]); - auto q2 = b.x(q1); - auto q3 = b.h(q2); - b.scfYield(q3); - return {q3}; - }); + auto scfForRes = b.scfFor( + c0, c2, c1, {q0}, + [&](Value /*iv*/, ValueRange iterArgs) -> llvm::SmallVector { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + b.scfYield(q3); + return {q3}; + }); b.h(scfForRes[0]); }); @@ -125,15 +125,15 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { auto c0 = b.arithConstantIndex(0); auto c1 = b.arithConstantIndex(1); auto c2 = b.arithConstantIndex(2); - auto scfForRes = - b.scfFor(c0, c2, c1, {q0}, - [&](Value /*iv*/, ValueRange iterArgs) -> SmallVector { - auto q1 = b.h(iterArgs[0]); - auto q2 = b.x(q1); - auto q3 = b.h(q2); - b.scfYield(q3); - return {q3}; - }); + auto scfForRes = b.scfFor( + c0, c2, c1, {q0}, + [&](Value /*iv*/, ValueRange iterArgs) -> llvm::SmallVector { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + b.scfYield(q3); + return {q3}; + }); b.h(scfForRes[0]); }); @@ -188,12 +188,12 @@ TEST_F(ConversionTest, ScfWhileQCToQCOTest) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](ValueRange iterArgs) -> SmallVector { + [&](ValueRange iterArgs) -> llvm::SmallVector { auto [q1, measureResult] = b.measure(iterArgs[0]); b.scfCondition(measureResult, q1); return {q1}; }, - [&](ValueRange iterArgs) -> SmallVector { + [&](ValueRange iterArgs) -> llvm::SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); b.scfYield(q2); @@ -214,12 +214,12 @@ TEST_F(ConversionTest, ScfWhileQCOToQCTest) { auto q0 = b.allocQubit(); auto scfWhileResult = b.scfWhile( ValueRange{q0}, - [&](ValueRange iterArgs) -> SmallVector { + [&](ValueRange iterArgs) -> llvm::SmallVector { auto [q1, measureResult] = b.measure(iterArgs[0]); b.scfCondition(measureResult, q1); return {q1}; }, - [&](ValueRange iterArgs) -> SmallVector { + [&](ValueRange iterArgs) -> llvm::SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); b.scfYield(q2); @@ -283,13 +283,13 @@ TEST_F(ConversionTest, ScfIfQCToQCOTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, {q1}, - [&]() -> SmallVector { + [&]() -> llvm::SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); b.scfYield(q3); return {q3}; }, - [&]() -> SmallVector { + [&]() -> llvm::SmallVector { auto q2 = b.y(q1); auto q3 = b.h(q2); b.scfYield(q3); @@ -311,13 +311,13 @@ TEST_F(ConversionTest, ScfIfQCOToQCTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, {q1}, - [&]() -> SmallVector { + [&]() -> llvm::SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); b.scfYield(q3); return {q3}; }, - [&]() -> SmallVector { + [&]() -> llvm::SmallVector { auto q2 = b.y(q1); auto q3 = b.h(q2); b.scfYield(q3); @@ -377,13 +377,13 @@ TEST_F(ConversionTest, ScfIfEmptyElseTest) { auto [q1, measureResult] = b.measure(q0); auto scfIfResult = b.scfIf( measureResult, {q1}, - [&]() -> SmallVector { + [&]() -> llvm::SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); b.scfYield(q3); return {q3}; }, - [&]() -> SmallVector { + [&]() -> llvm::SmallVector { b.scfYield(q1); return {q1}; }); From 7933d3b9c843219bc47b3107b1d32b17c3385818 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 14:15:57 +0100 Subject: [PATCH 052/108] simplify the builders --- .../Dialect/QC/Builder/QCProgramBuilder.h | 16 ------- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 47 ++----------------- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 9 +--- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 36 +++++--------- mlir/unittests/conversion/test_conversion.cpp | 38 +++++---------- 5 files changed, 32 insertions(+), 114 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index f413ab45f4..1d794e01a1 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -976,21 +976,6 @@ class QCProgramBuilder final : public OpBuilder { // Func operations //===--------------------------------------------------------------------===// - /** - * @brief Constructs a func.return operation without return values - * - * @return Reference to this builder for method chaining - * - * @par Example: - * ```c++ - * builder.funcReturn(); - * ``` - * ```mlir - * func.return - * ``` - */ - QCProgramBuilder& funcReturn(); - /** * @brief Constructs a func.call operation without return values * @@ -1019,7 +1004,6 @@ class QCProgramBuilder final : public OpBuilder { * ```c++ * builder.funcFunc("test", argTypes, [&](ValueRange args) { * builder.h(args[0]); - * builder.funcReturn(); * }) * ``` * ```mlir diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 0cb382ade9..d9b7544d8b 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1047,7 +1047,6 @@ class QCOProgramBuilder final : public OpBuilder { * ```c++ * builder.scfFor(lb, ub, step, initArgs, [&](Value iv, ValueRange iterArgs) * -> llvm::SmallVector { auto q1 = builder.x(iterArgs[0]); - * builder.scfYield(q1); * return {q1}; * }); * ``` @@ -1080,7 +1079,6 @@ class QCOProgramBuilder final : public OpBuilder { * return {q2}; * }, [&](ValueRange iterArgs) -> llvm::SmallVector { * auto q1 = builder.x(iterArgs[0]); - * builder.scfYield(q1); * return {q1}; * }); * ``` @@ -1115,11 +1113,9 @@ class QCOProgramBuilder final : public OpBuilder { * ```c++ * builder.scfIf(condition, qubits, [&]() -> llvm::SmallVector { * auto q1 = builder.h(q0); - * builder.scfYield(q1); * return {q1}; * }, [&]() -> llvm::SmallVector { * auto q1 = builder.x(q0); - * builder.scfYield(q1); * return {q1}; * }); * ``` @@ -1154,42 +1150,10 @@ class QCOProgramBuilder final : public OpBuilder { */ QCOProgramBuilder& scfCondition(Value condition, ValueRange yieldedValues); - /** - * @brief Constructs a scf.yield operation with yielded values - * - * @param yieldedValues ValueRange of the yieldedValues - * @return Reference to this builder for method chaining - * - * @par Example: - * ```c++ - * builder.scfYield(yieldedValues); - * ``` - * ```mlir - * scf.yield %q0 : !qco.qubit - * ``` - */ - QCOProgramBuilder& scfYield(ValueRange yieldedValues); - //===--------------------------------------------------------------------===// // Func operations //===--------------------------------------------------------------------===// - /** - * @brief Constructs a func.return operation with return values - * - * @param returnValues ValueRange of the returned values - * @return Reference to this builder for method chaining - * - * @par Example: - * ```c++ - * builder.funcReturn(yieldedValues); - * ``` - * ```mlir - * func.return %q0 : !qco.qubit - * ``` - */ - QCOProgramBuilder& funcReturn(ValueRange returnValues); - /** * @brief Constructs a func.call operation with return values * @@ -1218,9 +1182,8 @@ class QCOProgramBuilder final : public OpBuilder { * * @par Example: * ```c++ - * builder.funcFunc("test", argTypes, resultTypes, [&](ValueRange args) { - * auto q1 = builder.h(args[0]); - * builder.funcReturn(q1); + * builder.funcFunc("test", argTypes, resultTypes, [&](ValueRange args) -> + * llvm::SmallVector { auto q1 = builder.h(args[0]); return {q1}; * }) * ``` * ```mlir @@ -1230,9 +1193,9 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - QCOProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, - TypeRange resultTypes, - llvm::function_ref body); + QCOProgramBuilder& + funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, + llvm::function_ref(ValueRange)> body); //===--------------------------------------------------------------------===// // Arith operations diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 68495d388b..50df363874 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -537,13 +537,6 @@ QCProgramBuilder& QCProgramBuilder::funcCall(StringRef name, return *this; } -QCProgramBuilder& QCProgramBuilder::funcReturn() { - checkFinalized(); - - create(loc); - return *this; -} - QCProgramBuilder& QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, const std::function& body) { @@ -562,7 +555,7 @@ QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, // Build function body body(entryBlock->getArguments()); - + create(loc); return *this; } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 27f3839712..3041eb76fc 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -637,7 +637,8 @@ ValueRange QCOProgramBuilder::scfFor( validQubits[bodyRegion].insert(arg); } // Build the body - body(iv, loopArgs); + const auto bodyResults = body(iv, loopArgs); + create(loc, bodyResults); // Update the qubit tracking for (const auto& [initArg, result] : @@ -690,7 +691,8 @@ ValueRange QCOProgramBuilder::scfWhile( validQubits[afterRegion].insert(arg); } - afterBody(afterArgs); + const auto afterResults = afterBody(afterArgs); + create(loc, afterResults); // Update the qubit tracking for (const auto& [arg, result] : @@ -726,13 +728,14 @@ ValueRange QCOProgramBuilder::scfIf( } // Build the then body - thenBody(); - + const auto thenResults = thenBody(); + create(loc, thenResults); // Set the insertionpoint setInsertionPointToStart(&elseBlock); // Build the else body - elseBody(); + const auto elseResults = elseBody(); + create(loc, elseResults); // Update the qubit tracking for (const auto& [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { @@ -750,13 +753,6 @@ QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, return *this; } -QCOProgramBuilder& QCOProgramBuilder::scfYield(ValueRange yieldedValues) { - checkFinalized(); - - create(loc, yieldedValues); - return *this; -} - //===----------------------------------------------------------------------===// // Func operations //===----------------------------------------------------------------------===// @@ -772,16 +768,9 @@ ValueRange QCOProgramBuilder::funcCall(StringRef name, ValueRange operands) { return callOp->getResults(); } -QCOProgramBuilder& QCOProgramBuilder::funcReturn(ValueRange returnValues) { - checkFinalized(); - - create(loc, returnValues); - return *this; -} -QCOProgramBuilder& -QCOProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, - TypeRange resultTypes, - llvm::function_ref body) { +QCOProgramBuilder& QCOProgramBuilder::funcFunc( + StringRef name, TypeRange argTypes, TypeRange resultTypes, + llvm::function_ref(ValueRange)> body) { checkFinalized(); // Set the insertionPoint @@ -801,7 +790,8 @@ QCOProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, setInsertionPointToStart(entryBlock); // Build function body - body(entryBlock->getArguments()); + const auto bodyResults = body(entryBlock->getArguments()); + create(loc, bodyResults); return *this; } diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 3e13ab4bb7..e03b521cd1 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -106,7 +106,6 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); auto q3 = b.h(q2); - b.scfYield(q3); return {q3}; }); b.h(scfForRes[0]); @@ -131,7 +130,6 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); auto q3 = b.h(q2); - b.scfYield(q3); return {q3}; }); b.h(scfForRes[0]); @@ -196,7 +194,6 @@ TEST_F(ConversionTest, ScfWhileQCToQCOTest) { [&](ValueRange iterArgs) -> llvm::SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); - b.scfYield(q2); return {q2}; }); b.h(scfWhileResult[0]); @@ -222,7 +219,6 @@ TEST_F(ConversionTest, ScfWhileQCOToQCTest) { [&](ValueRange iterArgs) -> llvm::SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.y(q1); - b.scfYield(q2); return {q2}; }); b.h(scfWhileResult[0]); @@ -286,13 +282,11 @@ TEST_F(ConversionTest, ScfIfQCToQCOTest) { [&]() -> llvm::SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); - b.scfYield(q3); return {q3}; }, [&]() -> llvm::SmallVector { auto q2 = b.y(q1); auto q3 = b.h(q2); - b.scfYield(q3); return {q3}; }); b.h(scfIfResult[0]); @@ -314,13 +308,11 @@ TEST_F(ConversionTest, ScfIfQCOToQCTest) { [&]() -> llvm::SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); - b.scfYield(q3); return {q3}; }, [&]() -> llvm::SmallVector { auto q2 = b.y(q1); auto q3 = b.h(q2); - b.scfYield(q3); return {q3}; }); b.h(scfIfResult[0]); @@ -380,13 +372,9 @@ TEST_F(ConversionTest, ScfIfEmptyElseTest) { [&]() -> llvm::SmallVector { auto q2 = b.h(q1); auto q3 = b.y(q2); - b.scfYield(q3); return {q3}; }, - [&]() -> llvm::SmallVector { - b.scfYield(q1); - return {q1}; - }); + [&]() -> llvm::SmallVector { return {q1}; }); b.h(scfIfResult[0]); }); @@ -405,7 +393,6 @@ TEST_F(ConversionTest, FuncFuncQCToQCOTest) { b.funcFunc("test", q0.getType(), [&](ValueRange args) { b.h(args[0]); b.y(args[0]); - b.funcReturn(); }); }); @@ -419,11 +406,12 @@ TEST_F(ConversionTest, FuncFuncQCToQCOTest) { auto q0 = b.allocQubit(); auto q1 = b.funcCall("test", q0); b.h(q1[0]); - b.funcFunc("test", q0.getType(), q0.getType(), [&](ValueRange args) { - auto q2 = b.h(args[0]); - auto q3 = b.y(q2); - b.funcReturn(q3); - }); + b.funcFunc("test", q0.getType(), q0.getType(), + [&](ValueRange args) -> llvm::SmallVector { + auto q2 = b.h(args[0]); + auto q3 = b.y(q2); + return {q3}; + }); }); const auto outputString = getOutputString(input); @@ -438,11 +426,12 @@ TEST_F(ConversionTest, FuncFuncQCOToQCTest) { auto q0 = b.allocQubit(); auto q1 = b.funcCall("test", q0); b.h(q1[0]); - b.funcFunc("test", q0.getType(), q0.getType(), [&](ValueRange args) { - auto q2 = b.h(args[0]); - auto q3 = b.y(q2); - b.funcReturn(q3); - }); + b.funcFunc("test", q0.getType(), q0.getType(), + [&](ValueRange args) -> llvm::SmallVector { + auto q2 = b.h(args[0]); + auto q3 = b.y(q2); + return {q3}; + }); }); PassManager pm(context.get()); @@ -458,7 +447,6 @@ TEST_F(ConversionTest, FuncFuncQCOToQCTest) { b.funcFunc("test", q0.getType(), [&](ValueRange args) { b.h(args[0]); b.y(args[0]); - b.funcReturn(); }); }); From fb8b255ae27b28279f7886e886f4eea64a03514e Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 15:09:56 +0100 Subject: [PATCH 053/108] remove const from typeConverter --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 1ca375f865..2077735268 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1179,7 +1179,7 @@ struct QCOToQC final : impl::QCOToQCBase { ConversionTarget target(*context); RewritePatternSet patterns(context); - const QCOToQCTypeConverter typeConverter(context); + QCOToQCTypeConverter typeConverter(context); target.addDynamicallyLegalOp([&](scf::IfOp op) { return !llvm::any_of(op->getResultTypes(), [&](Type type) { From f81ea76bc86f7fd9605b00b46cd822cc2508a996 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 15:13:30 +0100 Subject: [PATCH 054/108] add test for nested operation --- mlir/unittests/conversion/test_conversion.cpp | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index e03b521cd1..733a3d5f99 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -455,3 +455,105 @@ TEST_F(ConversionTest, FuncFuncQCOToQCTest) { ASSERT_EQ(outputString, checkString); } + +TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { + // Test conversion from qc to qco for scf.for operation with nested ctrl + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto control = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + b.scfFor(c0, c2, c1, [&](Value) { + b.ctrl(control, [&] { b.h(q0); }); + b.x(q0); + b.h(q0); + }); + b.h(control); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto control = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + auto scfForRes = + b.scfFor(c0, c2, c1, {q0, control}, + [&](Value, ValueRange iterArgs) -> llvm::SmallVector { + auto [controls, targets] = b.ctrl( + iterArgs[1], iterArgs[0], + [&](ValueRange targets) -> llvm::SmallVector { + auto target = b.h(targets[0]); + return {target}; + }); + auto q1 = b.x(targets[0]); + auto q2 = b.h(q1); + return {q2, controls[0]}; + }); + + b.h(scfForRes[1]); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfCtrlQCOtoQCTest) { + // Test conversion from qco to qc for scf.for operation with nested ctrl + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto control = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + auto scfForRes = + b.scfFor(c0, c2, c1, {q0, control}, + [&](Value, ValueRange iterArgs) -> llvm::SmallVector { + auto [controls, targets] = b.ctrl( + iterArgs[1], iterArgs[0], + [&](ValueRange targets) -> llvm::SmallVector { + auto target = b.h(targets[0]); + return {target}; + }); + auto q1 = b.x(targets[0]); + auto q2 = b.h(q1); + return {q2, controls[0]}; + }); + + b.h(scfForRes[1]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf nested"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto control = b.allocQubit(); + auto c0 = b.arithConstantIndex(0); + auto c1 = b.arithConstantIndex(1); + auto c2 = b.arithConstantIndex(2); + b.scfFor(c0, c2, c1, [&](Value) { + b.ctrl(control, [&] { b.h(q0); }); + b.x(q0); + b.h(q0); + }); + b.h(control); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} From beb6ed6514e0c896f7943e2ac0f2dea3f237c092 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Mon, 12 Jan 2026 16:22:35 +0100 Subject: [PATCH 055/108] simplify the building process with std::variant --- .../Dialect/QC/Builder/QCProgramBuilder.h | 7 +++- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 7 +++- mlir/include/mlir/Dialect/Utils/Utils.h | 24 +++++++++++ .../Dialect/QC/Builder/QCProgramBuilder.cpp | 20 ++++++--- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 15 +++++-- mlir/unittests/conversion/test_conversion.cpp | 41 ++++--------------- 6 files changed, 69 insertions(+), 45 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 1d794e01a1..715e8f0a03 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -892,7 +892,9 @@ class QCProgramBuilder final : public OpBuilder { * } * ``` */ - QCProgramBuilder& scfFor(Value lowerbound, Value upperbound, Value step, + QCProgramBuilder& scfFor(const std::variant& lowerbound, + const std::variant& upperbound, + const std::variant& step, const std::function& body); /** @@ -953,7 +955,8 @@ class QCProgramBuilder final : public OpBuilder { * ``` */ QCProgramBuilder& - scfIf(Value condition, const std::function& thenBody, + scfIf(const std::variant& condition, + const std::function& thenBody, std::optional> elseBody = std::nullopt); /** diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index d9b7544d8b..b6e9967f0e 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1059,7 +1059,9 @@ class QCOProgramBuilder final : public OpBuilder { * ``` */ ValueRange - scfFor(Value lowerbound, Value upperbound, Value step, ValueRange initArgs, + scfFor(const std::variant& lowerbound, + const std::variant& upperbound, + const std::variant& step, ValueRange initArgs, llvm::function_ref(Value, ValueRange)> body); /** * @brief Constructs a scf.while operation with return values @@ -1129,7 +1131,8 @@ class QCOProgramBuilder final : public OpBuilder { * } * ``` */ - ValueRange scfIf(Value condition, ValueRange qubits, + ValueRange scfIf(const std::variant& condition, + ValueRange qubits, llvm::function_ref()> thenBody, llvm::function_ref()> elseBody); diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 08f6412f7b..f991efcb6c 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -40,4 +40,28 @@ inline Value variantToValue(OpBuilder& builder, const OperationState& state, return operand; } +inline Value constantFromScalar(OpBuilder& builder, Location loc, int64_t v) { + return builder.create(loc, builder.getI64IntegerAttr(v)); +} + +inline Value constantFromScalar(OpBuilder& builder, Location loc, bool v) { + return builder.create(loc, builder.getBoolAttr(v)); +} + +/** + * @brief Convert a variant parameter (T or Value) to a Value + * + * @param builder The operation builder. + * @param state The location of the operation. + * @param parameter The parameter as a variant (T or Value). + * @return Value The parameter as a Value. + */ +template +Value variantToValue(OpBuilder& builder, const Location loc, + const std::variant& parameter) { + if (std::holds_alternative(parameter)) { + return std::get(parameter); + } + return constantFromScalar(builder, loc, std::get(parameter)); +} } // namespace mlir::utils diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 50df363874..2c3c99a461 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" #include "mlir/Dialect/QC/IR/QCDialect.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -450,11 +451,17 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { //===----------------------------------------------------------------------===// QCProgramBuilder& -QCProgramBuilder::scfFor(Value lowerbound, Value upperbound, Value step, +QCProgramBuilder::scfFor(const std::variant& lowerbound, + const std::variant& upperbound, + const std::variant& step, const std::function& body) { checkFinalized(); - create(loc, lowerbound, upperbound, step, ValueRange{}, + const auto lb = utils::variantToValue(*this, loc, lowerbound); + const auto ub = utils::variantToValue(*this, loc, upperbound); + const auto stepSize = utils::variantToValue(*this, loc, step); + + create(loc, lb, ub, stepSize, ValueRange{}, [&](OpBuilder& b, Location, Value iv, ValueRange) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); @@ -488,12 +495,15 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, } QCProgramBuilder& -QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, +QCProgramBuilder::scfIf(const std::variant& cond, + const std::function& thenBody, std::optional> elseBody) { checkFinalized(); + const auto condition = utils::variantToValue(*this, loc, cond); + if (!elseBody) { - create(loc, cond, [&](OpBuilder& b, Location loc) { + create(loc, condition, [&](OpBuilder& b, Location loc) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); thenBody(); @@ -501,7 +511,7 @@ QCProgramBuilder::scfIf(Value cond, const std::function& thenBody, }); } else { create( - loc, cond, + loc, condition, [&](OpBuilder& b, Location loc) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 3041eb76fc..4eb12f8299 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -617,12 +618,18 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { //===----------------------------------------------------------------------===// ValueRange QCOProgramBuilder::scfFor( - Value lowerbound, Value upperbound, Value step, ValueRange initArgs, + const std::variant& lowerbound, + const std::variant& upperbound, + const std::variant& step, ValueRange initArgs, llvm::function_ref(Value, ValueRange)> body) { checkFinalized(); + const auto lb = utils::variantToValue(*this, loc, lowerbound); + const auto ub = utils::variantToValue(*this, loc, upperbound); + const auto stepSize = utils::variantToValue(*this, loc, step); + // Create the empty for operation - auto forOp = create(loc, lowerbound, upperbound, step, initArgs); + auto forOp = create(loc, lb, ub, stepSize, initArgs); auto* forBody = forOp.getBody(); const auto iv = forBody->getArgument(0); const auto loopArgs = forBody->getArguments().drop_front(); @@ -704,11 +711,13 @@ ValueRange QCOProgramBuilder::scfWhile( } ValueRange QCOProgramBuilder::scfIf( - Value condition, ValueRange qubits, + const std::variant& cond, ValueRange qubits, llvm::function_ref()> thenBody, llvm::function_ref()> elseBody) { checkFinalized(); + const auto condition = utils::variantToValue(*this, loc, cond); + // Create the empty while operation auto ifOp = create(loc, qubits.getTypes(), condition, /*withElseRegion=*/true); diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index 733a3d5f99..be86f85163 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -78,10 +78,7 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { // Test conversion from qc to qco for scf.for operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&](Value /*iv*/) { + b.scfFor(0, 2, 1, [&](Value /*iv*/) { b.h(q0); b.x(q0); b.h(q0); @@ -97,11 +94,8 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); auto scfForRes = b.scfFor( - c0, c2, c1, {q0}, + 0, 2, 1, {q0}, [&](Value /*iv*/, ValueRange iterArgs) -> llvm::SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); @@ -121,11 +115,8 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { // Test conversion from qco to qc for scf.for operation auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); auto scfForRes = b.scfFor( - c0, c2, c1, {q0}, + 0, 2, 1, {q0}, [&](Value /*iv*/, ValueRange iterArgs) -> llvm::SmallVector { auto q1 = b.h(iterArgs[0]); auto q2 = b.x(q1); @@ -143,10 +134,7 @@ TEST_F(ConversionTest, ScfForQCOToQCTest) { auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&](Value /*iv*/) { + b.scfFor(0, 2, 1, [&](Value /*iv*/) { b.h(q0); b.x(q0); b.h(q0); @@ -235,7 +223,6 @@ TEST_F(ConversionTest, ScfWhileQCOToQCTest) { b.scfWhile( [&] { auto measureResult = b.measure(q0); - b.scfCondition(measureResult); }, [&] { @@ -461,10 +448,7 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); auto control = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&](Value) { + b.scfFor(0, 2, 1, [&](Value) { b.ctrl(control, [&] { b.h(q0); }); b.x(q0); b.h(q0); @@ -481,11 +465,8 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto control = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, {q0, control}, + b.scfFor(0, 2, 1, {q0, control}, [&](Value, ValueRange iterArgs) -> llvm::SmallVector { auto [controls, targets] = b.ctrl( iterArgs[1], iterArgs[0], @@ -512,11 +493,8 @@ TEST_F(ConversionTest, ScfCtrlQCOtoQCTest) { auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto q0 = b.allocQubit(); auto control = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); auto scfForRes = - b.scfFor(c0, c2, c1, {q0, control}, + b.scfFor(0, 2, 1, {q0, control}, [&](Value, ValueRange iterArgs) -> llvm::SmallVector { auto [controls, targets] = b.ctrl( iterArgs[1], iterArgs[0], @@ -541,10 +519,7 @@ TEST_F(ConversionTest, ScfCtrlQCOtoQCTest) { auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto q0 = b.allocQubit(); auto control = b.allocQubit(); - auto c0 = b.arithConstantIndex(0); - auto c1 = b.arithConstantIndex(1); - auto c2 = b.arithConstantIndex(2); - b.scfFor(c0, c2, c1, [&](Value) { + b.scfFor(0, 2, 1, [&](Value) { b.ctrl(control, [&] { b.h(q0); }); b.x(q0); b.h(q0); From 5abab32c81b1e426c7a539c6fb360e6f1c265b80 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 13 Jan 2026 14:50:48 +0100 Subject: [PATCH 056/108] remove constant builders --- .../Dialect/QC/Builder/QCProgramBuilder.h | 36 ------------------- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 36 ------------------- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 21 ----------- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 19 ---------- 4 files changed, 112 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index b72081b1fe..b095c86667 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -1019,42 +1019,6 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { QCProgramBuilder& funcFunc(StringRef name, TypeRange argTypes, const std::function& body); - //===--------------------------------------------------------------------===// - // Arith operations - //===--------------------------------------------------------------------===// - - /** - * @brief Constructs a arith.constant of type Index with a given value - * - * @param index Value of the constant operation - * @return Result of the constant operation - * - * @par Example: - * ```c++ - * builder.arithConstantIndex(4); - * ``` - * ```mlir - * arith.constant 4 : index - * ``` - */ - Value arithConstantIndex(int64_t index); - - /** - * @brief Constructs a arith.constant of type i1 with a given bool value - * - * @param b Bool value of the constant operation - * @return Result of the constant operation - * - * @par Example: - * ```c++ - * builder.arithConstantBool(true); - * ``` - * ```mlir - * arith.constant 1 : i1 - * ``` - */ - Value arithConstantBool(bool b); - //===--------------------------------------------------------------------===// // Finalization //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index ddfd207adc..fb77b9347d 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1200,42 +1200,6 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { funcFunc(StringRef name, TypeRange argTypes, TypeRange resultTypes, llvm::function_ref(ValueRange)> body); - //===--------------------------------------------------------------------===// - // Arith operations - //===--------------------------------------------------------------------===// - - /** - * @brief Constructs a arith.constant of type Index with a given value - * - * @param index Value of the constant operation - * @return Result of the constant operation - * - * @par Example: - * ```c++ - * builder.arithConstantIndex(4); - * ``` - * ```mlir - * arith.constant 4 : index - * ``` - */ - Value arithConstantIndex(int64_t i); - - /** - * @brief Constructs a arith.constant of type i1 with a given bool value - * - * @param b Bool value of the constant operation - * @return Result of the constant operation - * - * @par Example: - * ```c++ - * builder.arithConstantBool(true); - * ``` - * ```mlir - * arith.constant 1 : i1 - * ``` - */ - Value arithConstantBool(bool b); - //===--------------------------------------------------------------------===// // Finalization //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index ea589e3bd2..b2ddcedc86 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -568,27 +568,6 @@ QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, return *this; } -//===----------------------------------------------------------------------===// -// Arith operations -//===----------------------------------------------------------------------===// - -Value QCProgramBuilder::arithConstantIndex(int64_t index) { - checkFinalized(); - - const auto op = - create(getIndexType(), getIndexAttr(index)); - return op->getResult(0); -} - -Value QCProgramBuilder::arithConstantBool(bool b) { - checkFinalized(); - - const auto i1Type = getI1Type(); - const auto op = - create(i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); - return op->getResult(0); -} - //===----------------------------------------------------------------------===// // Finalization //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index b65c9a81b7..f4d8dd330c 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -807,25 +807,6 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( return *this; } -//===----------------------------------------------------------------------===// -// Arith operations -//===----------------------------------------------------------------------===// - -Value QCOProgramBuilder::arithConstantIndex(int64_t i) { - checkFinalized(); - - const auto op = create(getIndexType(), getIndexAttr(i)); - return op->getResult(0); -} - -Value QCOProgramBuilder::arithConstantBool(bool b) { - checkFinalized(); - - const auto i1Type = getI1Type(); - const auto op = - create(i1Type, getIntegerAttr(i1Type, b ? 1 : 0)); - return op->getResult(0); -} //===----------------------------------------------------------------------===// // Finalization //===----------------------------------------------------------------------===// From ce52debeda3b7cddf1c8792a24383ceee3410f80 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 13 Jan 2026 14:57:52 +0100 Subject: [PATCH 057/108] use more idiomatic way to build operations --- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index b2ddcedc86..e0d1ace504 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -460,7 +460,7 @@ QCProgramBuilder::scfFor(const std::variant& lowerbound, const auto ub = utils::variantToValue(*this, loc, upperbound); const auto stepSize = utils::variantToValue(*this, loc, step); - create(lb, ub, stepSize, ValueRange{}, + scf::ForOp::create(*this, lb, ub, stepSize, ValueRange{}, [&](OpBuilder& b, Location, Value iv, ValueRange) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); @@ -476,8 +476,8 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, const std::function& afterBody) { checkFinalized(); - create( - TypeRange{}, ValueRange{}, + scf::WhileOp::create( + *this, TypeRange{}, ValueRange{}, [&](OpBuilder& b, Location, ValueRange) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); @@ -502,15 +502,15 @@ QCProgramBuilder::scfIf(const std::variant& cond, const auto condition = utils::variantToValue(*this, getLoc(), cond); if (!elseBody) { - create(condition, [&](OpBuilder& b, Location loc) { + scf::IfOp::create(*this, condition, [&](OpBuilder& b, Location loc) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); thenBody(); b.create(loc); }); } else { - create( - condition, + scf::IfOp::create( + *this, condition, [&](OpBuilder& b, Location loc) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); @@ -530,7 +530,7 @@ QCProgramBuilder::scfIf(const std::variant& cond, QCProgramBuilder& QCProgramBuilder::scfCondition(Value condition) { checkFinalized(); - create(condition, ValueRange{}); + scf::ConditionOp::create(*this, condition, ValueRange{}); return *this; } @@ -542,7 +542,7 @@ QCProgramBuilder& QCProgramBuilder::funcCall(StringRef name, ValueRange operands) { checkFinalized(); - create(name, TypeRange{}, operands); + func::CallOp::create(*this, name, TypeRange{}, operands); return *this; } @@ -557,14 +557,14 @@ QCProgramBuilder::funcFunc(StringRef name, TypeRange argTypes, // Create the empty func operation const auto funcType = getFunctionType(argTypes, {}); - auto funcOp = create(name, funcType); + auto funcOp = func::FuncOp::create(*this, name, funcType); auto* entryBlock = funcOp.addEntryBlock(); setInsertionPointToStart(entryBlock); // Build function body body(entryBlock->getArguments()); - create(); + func::ReturnOp::create(*this); return *this; } From 551ff31c37c982cbe1153f08e4eb841b582c8383 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Tue, 13 Jan 2026 14:59:13 +0100 Subject: [PATCH 058/108] also change creation in QCO --- .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 23 ++++++++++--------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index f4d8dd330c..166bd90348 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -632,7 +632,7 @@ ValueRange QCOProgramBuilder::scfFor( const auto stepSize = utils::variantToValue(*this, loc, step); // Create the empty for operation - auto forOp = create(lb, ub, stepSize, initArgs); + auto forOp = scf::ForOp::create(*this, lb, ub, stepSize, initArgs); auto* forBody = forOp.getBody(); const auto iv = forBody->getArgument(0); const auto loopArgs = forBody->getArguments().drop_front(); @@ -648,7 +648,7 @@ ValueRange QCOProgramBuilder::scfFor( } // Build the body const auto bodyResults = body(iv, loopArgs); - create(loc, bodyResults); + scf::YieldOp::create(*this, bodyResults); // Update the qubit tracking for (const auto& [initArg, result] : @@ -666,7 +666,7 @@ ValueRange QCOProgramBuilder::scfWhile( checkFinalized(); // Create the empty while operation - auto whileOp = create(initArgs.getTypes(), initArgs); + auto whileOp = scf::WhileOp::create(*this, initArgs.getTypes(), initArgs); const SmallVector locs(initArgs.size(), getLoc()); const OpBuilder::InsertionGuard guard(*this); @@ -702,7 +702,7 @@ ValueRange QCOProgramBuilder::scfWhile( } const auto afterResults = afterBody(afterArgs); - create(afterResults); + scf::YieldOp::create(*this, afterResults); // Update the qubit tracking for (const auto& [arg, result] : @@ -722,7 +722,7 @@ ValueRange QCOProgramBuilder::scfIf( const auto condition = utils::variantToValue(*this, getLoc(), cond); // Create the empty while operation - auto ifOp = create(qubits.getTypes(), condition, + auto ifOp = scf::IfOp::create(*this, qubits.getTypes(), condition, /*withElseRegion=*/true); auto& thenBlock = ifOp.getThenRegion().front(); auto& elseBlock = ifOp.getElseRegion().front(); @@ -741,13 +741,13 @@ ValueRange QCOProgramBuilder::scfIf( // Build the then body const auto thenResults = thenBody(); - create(thenResults); + scf::YieldOp::create(*this, thenResults); // Set the insertionpoint setInsertionPointToStart(&elseBlock); // Build the else body const auto elseResults = elseBody(); - create(elseResults); + scf::YieldOp::create(*this, elseResults); // Update the qubit tracking for (const auto& [arg, result] : llvm::zip_equal(qubits, ifOp.getResults())) { @@ -761,7 +761,7 @@ QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, ValueRange yieldedValues) { checkFinalized(); - create(condition, yieldedValues); + scf::ConditionOp::create(*this, condition, yieldedValues); return *this; } @@ -772,7 +772,8 @@ QCOProgramBuilder& QCOProgramBuilder::scfCondition(Value condition, ValueRange QCOProgramBuilder::funcCall(StringRef name, ValueRange operands) { checkFinalized(); - const auto callOp = create(name, operands.getTypes(), operands); + const auto callOp = + func::CallOp::create(*this, name, operands.getTypes(), operands); for (auto [arg, result] : llvm::zip_equal(operands, callOp->getResults())) { updateQubitTracking(arg, result, callOp->getParentRegion()); } @@ -790,7 +791,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( // Create the empty func operation const auto funcType = getFunctionType(argTypes, resultTypes); - auto funcOp = create(name, funcType); + auto funcOp = func::FuncOp::create(*this, name, funcType); auto* entryBlock = funcOp.addEntryBlock(); auto* region = entryBlock->getParent(); // Add the arguments to the validQubits @@ -802,7 +803,7 @@ QCOProgramBuilder& QCOProgramBuilder::funcFunc( // Build function body const auto bodyResults = body(entryBlock->getArguments()); - create(bodyResults); + func::ReturnOp::create(*this, bodyResults); return *this; } From fd4671b975da2ac719169b3f297040e4c3564e3e Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 15 Jan 2026 14:56:23 +0100 Subject: [PATCH 059/108] fix linter issues --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 2 +- mlir/unittests/conversion/test_conversion.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 2077735268..1ca375f865 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1179,7 +1179,7 @@ struct QCOToQC final : impl::QCOToQCBase { ConversionTarget target(*context); RewritePatternSet patterns(context); - QCOToQCTypeConverter typeConverter(context); + const QCOToQCTypeConverter typeConverter(context); target.addDynamicallyLegalOp([&](scf::IfOp op) { return !llvm::any_of(op->getResultTypes(), [&](Type type) { diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/conversion/test_conversion.cpp index be86f85163..deeaa474a0 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/conversion/test_conversion.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include From 5cab58ba477cfaeabf67de592579bebb695b4d61 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 15 Jan 2026 15:46:39 +0100 Subject: [PATCH 060/108] make variantToValue generic --- mlir/include/mlir/Dialect/Utils/Utils.h | 39 ++++++++++++------- .../IR/Operations/StandardGates/GPhaseOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/POp.cpp | 2 +- .../QC/IR/Operations/StandardGates/ROp.cpp | 4 +- .../QC/IR/Operations/StandardGates/RXOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/RXXOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/RYOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/RYYOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/RZOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/RZXOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/RZZOp.cpp | 2 +- .../QC/IR/Operations/StandardGates/U2Op.cpp | 4 +- .../QC/IR/Operations/StandardGates/UOp.cpp | 6 +-- .../Operations/StandardGates/XXMinusYYOp.cpp | 4 +- .../Operations/StandardGates/XXPlusYYOp.cpp | 4 +- .../IR/Operations/StandardGates/GPhaseOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/POp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/ROp.cpp | 4 +- .../QCO/IR/Operations/StandardGates/RXOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/RXXOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/RYOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/RYYOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/RZOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/RZXOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/RZZOp.cpp | 2 +- .../QCO/IR/Operations/StandardGates/U2Op.cpp | 4 +- .../QCO/IR/Operations/StandardGates/UOp.cpp | 6 +-- .../Operations/StandardGates/XXMinusYYOp.cpp | 4 +- .../Operations/StandardGates/XXPlusYYOp.cpp | 4 +- 29 files changed, 65 insertions(+), 54 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 1674c57795..be0ebe4f2d 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -12,7 +12,7 @@ #include #include -#include +#include #include #include @@ -20,25 +20,36 @@ namespace mlir::utils { constexpr auto TOLERANCE = 1e-15; +inline Value constantFromScalar(OpBuilder& builder, const Location& loc, + double v) { + return builder.create(loc, builder.getF64FloatAttr(v)); +} + +inline Value constantFromScalar(OpBuilder& builder, const Location& loc, + int64_t v) { + return builder.create(loc, builder.getI64IntegerAttr(v)); +} + +inline Value constantFromScalar(OpBuilder& builder, const Location& loc, + bool v) { + return builder.create(loc, builder.getBoolAttr(v)); +} + /** - * @brief Convert a variant parameter (double or Value) to a Value + * @brief Convert a variant parameter (T or Value) to a Value * * @param builder The operation builder. - * @param state The operation state. - * @param parameter The parameter as a variant (double or Value). + * @param state The location of the operation. + * @param parameter The parameter as a variant (T or Value). * @return Value The parameter as a Value. */ -[[nodiscard]] inline Value -variantToValue(OpBuilder& builder, const OperationState& state, - const std::variant& parameter) { - Value operand; - if (std::holds_alternative(parameter)) { - operand = builder.create( - state.location, builder.getF64FloatAttr(std::get(parameter))); - } else { - operand = std::get(parameter); +template +[[nodiscard]] Value variantToValue(OpBuilder& builder, const Location& loc, + const std::variant& parameter) { + if (std::holds_alternative(parameter)) { + return std::get(parameter); } - return operand; + return constantFromScalar(builder, loc, std::get(parameter)); } /** diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/GPhaseOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/GPhaseOp.cpp index 02e84385ca..515c093f75 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/GPhaseOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/GPhaseOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void GPhaseOp::build(OpBuilder& builder, OperationState& state, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/POp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/POp.cpp index e221b63cdf..2a0c8cfbd2 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/POp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/POp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void POp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/ROp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/ROp.cpp index 9301c79eba..5d02702962 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/ROp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/ROp.cpp @@ -22,7 +22,7 @@ using namespace mlir::utils; void ROp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta, const std::variant& phi) { - auto thetaOperand = variantToValue(builder, state, theta); - auto phiOperand = variantToValue(builder, state, phi); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto phiOperand = variantToValue(builder, state.location, phi); build(builder, state, qubitIn, thetaOperand, phiOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXOp.cpp index 5ef1b019e7..bc1cdaee9c 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RXOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXXOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXXOp.cpp index bd16fd7226..2f8488897e 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXXOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RXXOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RXXOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYOp.cpp index 1b61b88ac0..dcce9a0ed7 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RYOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYYOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYYOp.cpp index ef8defe3a4..ba744bac60 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYYOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RYYOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RYYOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZOp.cpp index 082dd8c885..67593b975b 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RZOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZXOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZXOp.cpp index 44a2020655..db9e6124da 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZXOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZXOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RZXOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZZOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZZOp.cpp index 48e26c9eb6..27beb91d75 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZZOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/RZZOp.cpp @@ -21,6 +21,6 @@ using namespace mlir::utils; void RZZOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/U2Op.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/U2Op.cpp index b7d4ea7a73..cfb4d344bb 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/U2Op.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/U2Op.cpp @@ -22,7 +22,7 @@ using namespace mlir::utils; void U2Op::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& phi, const std::variant& lambda) { - auto phiOperand = variantToValue(builder, state, phi); - auto lambdaOperand = variantToValue(builder, state, lambda); + auto phiOperand = variantToValue(builder, state.location, phi); + auto lambdaOperand = variantToValue(builder, state.location, lambda); build(builder, state, qubitIn, phiOperand, lambdaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/UOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/UOp.cpp index 3a2f860c04..80a6a6f93d 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/UOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/UOp.cpp @@ -23,8 +23,8 @@ void UOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta, const std::variant& phi, const std::variant& lambda) { - auto thetaOperand = variantToValue(builder, state, theta); - auto phiOperand = variantToValue(builder, state, phi); - auto lambdaOperand = variantToValue(builder, state, lambda); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto phiOperand = variantToValue(builder, state.location, phi); + auto lambdaOperand = variantToValue(builder, state.location, lambda); build(builder, state, qubitIn, thetaOperand, phiOperand, lambdaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXMinusYYOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXMinusYYOp.cpp index 2a43925fbe..9570056439 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXMinusYYOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXMinusYYOp.cpp @@ -23,7 +23,7 @@ void XXMinusYYOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta, const std::variant& beta) { - auto thetaOperand = variantToValue(builder, state, theta); - auto betaOperand = variantToValue(builder, state, beta); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto betaOperand = variantToValue(builder, state.location, beta); build(builder, state, qubit0In, qubit1In, thetaOperand, betaOperand); } diff --git a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXPlusYYOp.cpp b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXPlusYYOp.cpp index 737df244e4..f45286c94e 100644 --- a/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXPlusYYOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Operations/StandardGates/XXPlusYYOp.cpp @@ -23,7 +23,7 @@ void XXPlusYYOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta, const std::variant& beta) { - auto thetaOperand = variantToValue(builder, state, theta); - auto betaOperand = variantToValue(builder, state, beta); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto betaOperand = variantToValue(builder, state.location, beta); build(builder, state, qubit0In, qubit1In, thetaOperand, betaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp index a9f9781991..b1f1e497be 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp @@ -47,7 +47,7 @@ struct RemoveTrivialGPhase final : OpRewritePattern { void GPhaseOp::build(OpBuilder& builder, OperationState& state, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp index 6ef68dfc6c..3052d692e6 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/POp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialP final : OpRewritePattern { void POp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp index 2f798d1e03..b9f5d510e5 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp @@ -73,8 +73,8 @@ struct ReplaceRWithRY final : OpRewritePattern { void ROp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta, const std::variant& phi) { - auto thetaOperand = variantToValue(builder, state, theta); - auto phiOperand = variantToValue(builder, state, phi); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto phiOperand = variantToValue(builder, state.location, phi); build(builder, state, qubitIn, thetaOperand, phiOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp index abaaa2657f..9245007c9c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRX final : OpRewritePattern { void RXOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp index 27c6564c52..5924bbcb76 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RXXOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRXX final : OpRewritePattern { void RXXOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp index ae1beb94fc..1ae0e3de03 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRY final : OpRewritePattern { void RYOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp index 92515aea51..7b181a91c8 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RYYOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRYY final : OpRewritePattern { void RYYOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp index 913112c490..2ca6092f7e 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRZ final : OpRewritePattern { void RZOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubitIn, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp index dfc48d0700..4b2b3a471c 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZXOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRZX final : OpRewritePattern { void RZXOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp index 041308030e..728c8e0aab 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/RZZOp.cpp @@ -54,7 +54,7 @@ struct RemoveTrivialRZZ final : OpRewritePattern { void RZZOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta) { - auto thetaOperand = variantToValue(builder, state, theta); + auto thetaOperand = variantToValue(builder, state.location, theta); build(builder, state, qubit0In, qubit1In, thetaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp index ced6297924..b7fc3e6fda 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp @@ -99,8 +99,8 @@ struct ReplaceU2WithRY final : OpRewritePattern { void U2Op::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& phi, const std::variant& lambda) { - auto phiOperand = variantToValue(builder, state, phi); - auto lambdaOperand = variantToValue(builder, state, lambda); + auto phiOperand = variantToValue(builder, state.location, phi); + auto lambdaOperand = variantToValue(builder, state.location, lambda); build(builder, state, qubitIn, phiOperand, lambdaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp index fcf69025ae..b8fd4dd422 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp @@ -101,9 +101,9 @@ void UOp::build(OpBuilder& builder, OperationState& state, Value qubitIn, const std::variant& theta, const std::variant& phi, const std::variant& lambda) { - auto thetaOperand = variantToValue(builder, state, theta); - auto phiOperand = variantToValue(builder, state, phi); - auto lambdaOperand = variantToValue(builder, state, lambda); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto phiOperand = variantToValue(builder, state.location, phi); + auto lambdaOperand = variantToValue(builder, state.location, lambda); build(builder, state, qubitIn, thetaOperand, phiOperand, lambdaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp index 4a45be0bc1..ebc743c4bb 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp @@ -75,8 +75,8 @@ void XXMinusYYOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta, const std::variant& beta) { - auto thetaOperand = variantToValue(builder, state, theta); - auto betaOperand = variantToValue(builder, state, beta); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto betaOperand = variantToValue(builder, state.location, beta); build(builder, state, qubit0In, qubit1In, thetaOperand, betaOperand); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp index 3574f41e9a..02a1d493a3 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp @@ -75,8 +75,8 @@ void XXPlusYYOp::build(OpBuilder& builder, OperationState& state, Value qubit0In, Value qubit1In, const std::variant& theta, const std::variant& beta) { - auto thetaOperand = variantToValue(builder, state, theta); - auto betaOperand = variantToValue(builder, state, beta); + auto thetaOperand = variantToValue(builder, state.location, theta); + auto betaOperand = variantToValue(builder, state.location, beta); build(builder, state, qubit0In, qubit1In, thetaOperand, betaOperand); } From 7a32e1ffa177351658b56edf04e6d7c5ce5c2b85 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 09:34:57 +0100 Subject: [PATCH 061/108] fix docstrings --- mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h | 2 +- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index b095c86667..85c9a019dd 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -919,7 +919,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * scf.while : () -> () { * qc.h %q0 : !qc.qubit * %res = qc.measure %q0 : !qc.qubit -> i1 - * scf.condition(%tres) + * scf.condition(%res) * } do { * qc.x %q0 : !qc.qubit * scf.yield diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 1ca375f865..feefd6b9df 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -948,7 +948,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { }; /** - * @brief Converts scf.for with value semantics to scf.while with memory + * @brief Converts scf.for with value semantics to scf.for with memory * semantics for qubit values. This currently assumes no mixed types as return * values. * @@ -956,8 +956,8 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { * ```mlir * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> * (!qco.qubit) { - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit - * scf.yield %q1 : !qco.qubit + * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } * ``` * is converted to @@ -1036,7 +1036,6 @@ struct ConvertQCOScfYieldOp final : OpConversionPattern { * is converted to * ```mlir * scf.condition(%cond) - * ``` */ struct ConvertQCOScfConditionOp final : OpConversionPattern { From f2963b2672d105560d9a082ac45e6d37a0fc35dc Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 10:16:37 +0100 Subject: [PATCH 062/108] reuse existing yield in scf for conversion --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index feefd6b9df..9e1e0f4c71 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -988,10 +988,9 @@ struct ConvertQCOScfForOp final : OpConversionPattern { // Move all the operations from the old block to the new block auto* newBlock = newFor.getBody(); - // Erase the existing yield operation - rewriter.eraseOp(newBlock->getTerminator()); - newBlock->getOperations().splice(newBlock->end(), - op.getBody()->getOperations()); + auto& srcOps = op.getBody()->getOperations(); + newBlock->getOperations().splice(newBlock->begin(), srcOps, srcOps.begin(), + std::prev(srcOps.end())); // Replace the result values with the init values rewriter.replaceOp(op, adaptor.getInitArgs()); From cafd34e365b73c8f1b479740f1732441351d812f Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 10:56:52 +0100 Subject: [PATCH 063/108] move type creation out of the loop --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 9e1e0f4c71..5096139cb2 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1106,8 +1106,10 @@ struct ConvertQCOFuncFuncOp final : OpConversionPattern { const SmallVector argumentTypes( op.front().getNumArguments(), qc::QubitType::get(rewriter.getContext())); + const auto qcType = qc::QubitType::get(rewriter.getContext()); + for (auto blockArg : op.front().getArguments()) { - blockArg.setType(qc::QubitType::get(rewriter.getContext())); + blockArg.setType(qcType); } auto newFuncType = rewriter.getFunctionType(argumentTypes, {}); op.setFunctionType(newFuncType); From 372a2318e0d6cbec98b179cc759b849dd8e04521 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 11:29:09 +0100 Subject: [PATCH 064/108] use more idiomatic way of building operations --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 14 +++++++------- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 20 ++++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 5096139cb2..030d386560 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -866,8 +866,8 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { // Create the new if operation - auto newIf = rewriter.create(op.getLoc(), TypeRange{}, - op.getCondition(), false); + auto newIf = scf::IfOp::create(rewriter, op.getLoc(), TypeRange{}, + op.getCondition(), false); // Inline the regions rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); @@ -921,7 +921,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { ConversionPatternRewriter& rewriter) const override { // Create the new while operation auto newWhileOp = - rewriter.create(op->getLoc(), TypeRange{}, ValueRange{}); + scf::WhileOp::create(rewriter, op->getLoc(), TypeRange{}, ValueRange{}); // Replace the uses of the blockarguments with the init values const auto& inits = adaptor.getInits(); @@ -975,8 +975,8 @@ struct ConvertQCOScfForOp final : OpConversionPattern { matchAndRewrite(scf::ForOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // Create a new for-loop with no iter_args - auto newFor = rewriter.create( - op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + auto newFor = scf::ForOp::create( + rewriter, op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), ValueRange{}); // Replace the uses of the previous iter_args @@ -1071,8 +1071,8 @@ struct ConvertQCOFuncCallOp final : OpConversionPattern { LogicalResult matchAndRewrite(func::CallOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - rewriter.create(op->getLoc(), adaptor.getCallee(), - TypeRange{}, adaptor.getOperands()); + func::CallOp::create(rewriter, op->getLoc(), adaptor.getCallee(), + TypeRange{}, adaptor.getOperands()); rewriter.replaceOp(op, adaptor.getOperands()); return success(); } diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index f43911abef..89304e3e56 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1293,9 +1293,9 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { qcQubits.size(), qco::QubitType::get(rewriter.getContext())); // create new if operation - auto newIfOp = rewriter.create(op->getLoc(), TypeRange{qcoTypes}, - op.getCondition(), - op.getElseRegion().empty()); + auto newIfOp = + scf::IfOp::create(rewriter, op->getLoc(), TypeRange{qcoTypes}, + op.getCondition(), op.getElseRegion().empty()); auto& thenRegion = newIfOp.getThenRegion(); auto& elseRegion = newIfOp.getElseRegion(); @@ -1312,7 +1312,7 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { // create the yield operation if it does not exist yet rewriter.setInsertionPointToEnd(&elseRegion.front()); const auto elseYield = - rewriter.create(op->getLoc(), qcValues); + scf::YieldOp::create(rewriter, op->getLoc(), qcValues); // mark the yield operation for conversion elseYield->setAttr("needChange", StringAttr::get(rewriter.getContext(), "yes")); @@ -1392,8 +1392,8 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { qcQubits.size(), qco::QubitType::get(rewriter.getContext())); // create the new while operation - auto newWhileOp = rewriter.create( - op.getLoc(), TypeRange(qcoTypes), ValueRange(qcoQubits)); + auto newWhileOp = scf::WhileOp::create( + rewriter, op.getLoc(), TypeRange(qcoTypes), ValueRange(qcoQubits)); auto& newBeforeRegion = newWhileOp.getBefore(); auto& newAfterRegion = newWhileOp.getAfter(); const SmallVector locs(qcQubits.size(), op->getLoc()); @@ -1475,8 +1475,8 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { } // Create a new for-loop with qco qubits as iter_args - auto newFor = rewriter.create( - op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), + auto newFor = scf::ForOp::create( + rewriter, op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), ValueRange(qcoQubits)); // move the operations to the new block @@ -1630,8 +1630,8 @@ struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { const SmallVector qcoTypes( qcQubits.size(), qco::QubitType::get(rewriter.getContext())); - const auto callOp = rewriter.create( - op->getLoc(), adaptor.getCallee(), qcoTypes, qcoQubits); + const auto callOp = func::CallOp::create( + rewriter, op->getLoc(), adaptor.getCallee(), qcoTypes, qcoQubits); for (const auto& [qcQubit, qcoQubit] : llvm::zip_equal(qcQubits, callOp->getResults())) { From 470dfa9335a472f798ad25e8efdc0b0361b9ad36 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 12:16:05 +0100 Subject: [PATCH 065/108] update CHANGELOG.md --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 938f4ada36..fe862f0d6d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added -- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1465]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**]) +- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1396], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1465]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**]) ### Changed @@ -335,6 +335,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1406]: https://github.com/munich-quantum-toolkit/core/pull/1406 [#1403]: https://github.com/munich-quantum-toolkit/core/pull/1403 [#1402]: https://github.com/munich-quantum-toolkit/core/pull/1402 +[#1396]: https://github.com/munich-quantum-toolkit/core/pull/1396 [#1385]: https://github.com/munich-quantum-toolkit/core/pull/1385 [#1384]: https://github.com/munich-quantum-toolkit/core/pull/1384 [#1383]: https://github.com/munich-quantum-toolkit/core/pull/1383 From fb2e3cca7c2a7a152265e58e6867f61facc2075c Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 12:34:35 +0100 Subject: [PATCH 066/108] update conversion test structure --- mlir/unittests/CMakeLists.txt | 5 +- mlir/unittests/Conversion/CMakeLists.txt | 10 + .../Conversion/QCOToQC/CMakeLists.txt | 22 ++ .../QCOToQC/test_conversion_qco_to_qc.cpp | 285 ++++++++++++++++++ .../Conversion/QCToQCO/CMakeLists.txt | 22 ++ .../QCToQCO/test_conversion_qc_to_qco.cpp} | 211 ------------- mlir/unittests/conversion/CMakeLists.txt | 30 -- 7 files changed, 342 insertions(+), 243 deletions(-) create mode 100644 mlir/unittests/Conversion/CMakeLists.txt create mode 100644 mlir/unittests/Conversion/QCOToQC/CMakeLists.txt create mode 100644 mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp create mode 100644 mlir/unittests/Conversion/QCToQCO/CMakeLists.txt rename mlir/unittests/{conversion/test_conversion.cpp => Conversion/QCToQCO/test_conversion_qc_to_qco.cpp} (60%) delete mode 100644 mlir/unittests/conversion/CMakeLists.txt diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index b91b76fa3b..46e4504b15 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -7,11 +7,12 @@ # Licensed under the MIT License add_subdirectory(pipeline) -add_subdirectory(conversion) +add_subdirectory(Conversion) add_subdirectory(Dialect) add_custom_target(mqt-core-mlir-unittests) add_dependencies( - mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test mqt-core-mlir-conversion-test + mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test + mqt-core-mlir-conversion-qc-to-qco-test mqt-core-mlir-conversion-qco-to-qc-test mqt-core-mlir-dialect-qco-ir-modifiers-test mqt-core-mlir-dialect-utils-test) diff --git a/mlir/unittests/Conversion/CMakeLists.txt b/mlir/unittests/Conversion/CMakeLists.txt new file mode 100644 index 0000000000..f1b033ed54 --- /dev/null +++ b/mlir/unittests/Conversion/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_subdirectory(QCToQCO) +add_subdirectory(QCOToQC) diff --git a/mlir/unittests/Conversion/QCOToQC/CMakeLists.txt b/mlir/unittests/Conversion/QCOToQC/CMakeLists.txt new file mode 100644 index 0000000000..1b60b2b927 --- /dev/null +++ b/mlir/unittests/Conversion/QCOToQC/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_executable(mqt-core-mlir-conversion-qco-to-qc-test test_conversion_qco_to_qc.cpp) + +target_link_libraries( + mqt-core-mlir-conversion-qco-to-qc-test + PRIVATE GTest::gtest_main + MLIRParser + MLIRQCProgramBuilder + QCOToQC + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRQCOProgramBuilder) + +gtest_discover_tests(mqt-core-mlir-conversion-qco-to-qc-test) diff --git a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp new file mode 100644 index 0000000000..6256245bd3 --- /dev/null +++ b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp @@ -0,0 +1,285 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Conversion/QCOToQC/QCOToQC.h" +#include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" +#include "mlir/Dialect/QC/IR/QCDialect.h" +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; + +class ConversionTest : public ::testing::Test { +protected: + std::unique_ptr context; + void SetUp() override { + // Register all dialects needed for the full compilation pipeline + DialectRegistry registry; + registry.insert(); + + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } + + [[nodiscard]] OwningOpRef buildQCIR( + const std::function& buildFunc) const { + mlir::qc::QCProgramBuilder builder(context.get()); + builder.initialize(); + buildFunc(builder); + auto module = builder.finalize(); + return module; + } + [[nodiscard]] OwningOpRef buildQCOIR( + const std::function& buildFunc) const { + qco::QCOProgramBuilder builder(context.get()); + builder.initialize(); + buildFunc(builder); + auto module = builder.finalize(); + return module; + } +}; + +static std::string getOutputString(mlir::OwningOpRef& module) { + std::string outputString; + llvm::raw_string_ostream os(outputString); + module->print(os); + os.flush(); + return outputString; +} + +TEST_F(ConversionTest, ScfForQCOToQCTest) { + // Test conversion from qco to qc for scf.for operation + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto scfForRes = b.scfFor( + 0, 2, 1, {q0}, + [&](Value /*iv*/, ValueRange iterArgs) -> llvm::SmallVector { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.x(q1); + auto q3 = b.h(q2); + return {q3}; + }); + b.h(scfForRes[0]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf.for"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + b.scfFor(0, 2, 1, [&](Value /*iv*/) { + b.h(q0); + b.x(q0); + b.h(q0); + }); + b.h(q0); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfWhileQCOToQCTest) { + // Test conversion from qco to qc for scf.while operation + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto scfWhileResult = b.scfWhile( + ValueRange{q0}, + [&](ValueRange iterArgs) -> llvm::SmallVector { + auto [q1, measureResult] = b.measure(iterArgs[0]); + b.scfCondition(measureResult, q1); + return {q1}; + }, + [&](ValueRange iterArgs) -> llvm::SmallVector { + auto q1 = b.h(iterArgs[0]); + auto q2 = b.y(q1); + return {q2}; + }); + b.h(scfWhileResult[0]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf.while"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + b.scfWhile( + [&] { + auto measureResult = b.measure(q0); + b.scfCondition(measureResult); + }, + [&] { + b.h(q0); + b.y(q0); + }); + b.h(q0); + }); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfIfQCOToQCTest) { + // Test conversion from qco to qc for scf.if operation + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto [q1, measureResult] = b.measure(q0); + auto scfIfResult = b.scfIf( + measureResult, {q1}, + [&]() -> llvm::SmallVector { + auto q2 = b.h(q1); + auto q3 = b.y(q2); + return {q3}; + }, + [&]() -> llvm::SmallVector { + auto q2 = b.y(q1); + auto q3 = b.h(q2); + return {q3}; + }); + b.h(scfIfResult[0]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf.if"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto measure = b.measure(q0); + b.scfIf( + measure, + [&] { + b.h(q0); + b.y(q0); + }, + [&] { + b.y(q0); + b.h(q0); + }); + b.h(q0); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, FuncFuncQCOToQCTest) { + // Test conversion from qco to qc for func.func operation + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto q1 = b.funcCall("test", q0); + b.h(q1[0]); + b.funcFunc("test", q0.getType(), q0.getType(), + [&](ValueRange args) -> llvm::SmallVector { + auto q2 = b.h(args[0]); + auto q3 = b.y(q2); + return {q3}; + }); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for func.func"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + b.funcCall("test", q0); + b.h(q0); + b.funcFunc("test", q0.getType(), [&](ValueRange args) { + b.h(args[0]); + b.y(args[0]); + }); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfCtrlQCOtoQCTest) { + // Test conversion from qco to qc for scf.for operation with nested ctrl + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto control = b.allocQubit(); + auto scfForRes = + b.scfFor(0, 2, 1, {q0, control}, + [&](Value, ValueRange iterArgs) -> llvm::SmallVector { + auto [controls, targets] = b.ctrl( + iterArgs[1], iterArgs[0], + [&](ValueRange targets) -> llvm::SmallVector { + auto target = b.h(targets[0]); + return {target}; + }); + auto q1 = b.x(targets[0]); + auto q2 = b.h(q1); + return {q2, controls[0]}; + }); + + b.h(scfForRes[1]); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QCO-QC conversion for scf nested"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto q0 = b.allocQubit(); + auto control = b.allocQubit(); + b.scfFor(0, 2, 1, [&](Value) { + b.ctrl(control, [&] { b.h(q0); }); + b.x(q0); + b.h(q0); + }); + b.h(control); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + + ASSERT_EQ(outputString, checkString); +} diff --git a/mlir/unittests/Conversion/QCToQCO/CMakeLists.txt b/mlir/unittests/Conversion/QCToQCO/CMakeLists.txt new file mode 100644 index 0000000000..8cdc2543e3 --- /dev/null +++ b/mlir/unittests/Conversion/QCToQCO/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_executable(mqt-core-mlir-conversion-qc-to-qco-test test_conversion_qc_to_qco.cpp) + +target_link_libraries( + mqt-core-mlir-conversion-qc-to-qco-test + PRIVATE GTest::gtest_main + MLIRParser + MLIRQCProgramBuilder + QCToQCO + MLIRPass + MLIRTransforms + MLIRLLVMDialect + MLIRQCOProgramBuilder) + +gtest_discover_tests(mqt-core-mlir-conversion-qc-to-qco-test) diff --git a/mlir/unittests/conversion/test_conversion.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp similarity index 60% rename from mlir/unittests/conversion/test_conversion.cpp rename to mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index deeaa474a0..df80a0711b 100644 --- a/mlir/unittests/conversion/test_conversion.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -8,7 +8,6 @@ * Licensed under the MIT License */ -#include "mlir/Conversion/QCOToQC/QCOToQC.h" #include "mlir/Conversion/QCToQCO/QCToQCO.h" #include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" #include "mlir/Dialect/QC/IR/QCDialect.h" @@ -112,43 +111,6 @@ TEST_F(ConversionTest, ScfForQCToQCOTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfForQCOToQCTest) { - // Test conversion from qco to qc for scf.for operation - auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto scfForRes = b.scfFor( - 0, 2, 1, {q0}, - [&](Value /*iv*/, ValueRange iterArgs) -> llvm::SmallVector { - auto q1 = b.h(iterArgs[0]); - auto q2 = b.x(q1); - auto q3 = b.h(q2); - return {q3}; - }); - b.h(scfForRes[0]); - }); - - PassManager pm(context.get()); - pm.addPass(createQCOToQC()); - if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QCO-QC conversion for scf.for"; - } - - auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { - auto q0 = b.allocQubit(); - b.scfFor(0, 2, 1, [&](Value /*iv*/) { - b.h(q0); - b.x(q0); - b.h(q0); - }); - b.h(q0); - }); - - const auto outputString = getOutputString(input); - const auto checkString = getOutputString(expectedOutput); - - ASSERT_EQ(outputString, checkString); -} - TEST_F(ConversionTest, ScfWhileQCToQCOTest) { // Test conversion from qc to qco for scf.while operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { @@ -194,50 +156,6 @@ TEST_F(ConversionTest, ScfWhileQCToQCOTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfWhileQCOToQCTest) { - // Test conversion from qco to qc for scf.while operation - auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto scfWhileResult = b.scfWhile( - ValueRange{q0}, - [&](ValueRange iterArgs) -> llvm::SmallVector { - auto [q1, measureResult] = b.measure(iterArgs[0]); - b.scfCondition(measureResult, q1); - return {q1}; - }, - [&](ValueRange iterArgs) -> llvm::SmallVector { - auto q1 = b.h(iterArgs[0]); - auto q2 = b.y(q1); - return {q2}; - }); - b.h(scfWhileResult[0]); - }); - - PassManager pm(context.get()); - pm.addPass(createQCOToQC()); - if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QCO-QC conversion for scf.while"; - } - - auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { - auto q0 = b.allocQubit(); - b.scfWhile( - [&] { - auto measureResult = b.measure(q0); - b.scfCondition(measureResult); - }, - [&] { - b.h(q0); - b.y(q0); - }); - b.h(q0); - }); - const auto outputString = getOutputString(input); - const auto checkString = getOutputString(expectedOutput); - - ASSERT_EQ(outputString, checkString); -} - TEST_F(ConversionTest, ScfIfQCToQCOTest) { // Test conversion from qc to qco for scf.if operation auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { @@ -286,54 +204,6 @@ TEST_F(ConversionTest, ScfIfQCToQCOTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfIfQCOToQCTest) { - // Test conversion from qco to qc for scf.if operation - auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto [q1, measureResult] = b.measure(q0); - auto scfIfResult = b.scfIf( - measureResult, {q1}, - [&]() -> llvm::SmallVector { - auto q2 = b.h(q1); - auto q3 = b.y(q2); - return {q3}; - }, - [&]() -> llvm::SmallVector { - auto q2 = b.y(q1); - auto q3 = b.h(q2); - return {q3}; - }); - b.h(scfIfResult[0]); - }); - - PassManager pm(context.get()); - pm.addPass(createQCOToQC()); - if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QCO-QC conversion for scf.if"; - } - - auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto measure = b.measure(q0); - b.scfIf( - measure, - [&] { - b.h(q0); - b.y(q0); - }, - [&] { - b.y(q0); - b.h(q0); - }); - b.h(q0); - }); - - const auto outputString = getOutputString(input); - const auto checkString = getOutputString(expectedOutput); - - ASSERT_EQ(outputString, checkString); -} - TEST_F(ConversionTest, ScfIfEmptyElseTest) { // Test conversion from qc to qco for scf.if operation without an else body auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { @@ -408,42 +278,6 @@ TEST_F(ConversionTest, FuncFuncQCToQCOTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, FuncFuncQCOToQCTest) { - // Test conversion from qco to qc for func.func operation - auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto q1 = b.funcCall("test", q0); - b.h(q1[0]); - b.funcFunc("test", q0.getType(), q0.getType(), - [&](ValueRange args) -> llvm::SmallVector { - auto q2 = b.h(args[0]); - auto q3 = b.y(q2); - return {q3}; - }); - }); - - PassManager pm(context.get()); - pm.addPass(createQCOToQC()); - if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QCO-QC conversion for func.func"; - } - - auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { - auto q0 = b.allocQubit(); - b.funcCall("test", q0); - b.h(q0); - b.funcFunc("test", q0.getType(), [&](ValueRange args) { - b.h(args[0]); - b.y(args[0]); - }); - }); - - const auto outputString = getOutputString(input); - const auto checkString = getOutputString(expectedOutput); - - ASSERT_EQ(outputString, checkString); -} - TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { // Test conversion from qc to qco for scf.for operation with nested ctrl auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { @@ -488,48 +322,3 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { ASSERT_EQ(outputString, checkString); } - -TEST_F(ConversionTest, ScfCtrlQCOtoQCTest) { - // Test conversion from qco to qc for scf.for operation with nested ctrl - auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto control = b.allocQubit(); - auto scfForRes = - b.scfFor(0, 2, 1, {q0, control}, - [&](Value, ValueRange iterArgs) -> llvm::SmallVector { - auto [controls, targets] = b.ctrl( - iterArgs[1], iterArgs[0], - [&](ValueRange targets) -> llvm::SmallVector { - auto target = b.h(targets[0]); - return {target}; - }); - auto q1 = b.x(targets[0]); - auto q2 = b.h(q1); - return {q2, controls[0]}; - }); - - b.h(scfForRes[1]); - }); - - PassManager pm(context.get()); - pm.addPass(createQCOToQC()); - if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QCO-QC conversion for scf nested"; - } - - auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { - auto q0 = b.allocQubit(); - auto control = b.allocQubit(); - b.scfFor(0, 2, 1, [&](Value) { - b.ctrl(control, [&] { b.h(q0); }); - b.x(q0); - b.h(q0); - }); - b.h(control); - }); - - const auto outputString = getOutputString(input); - const auto checkString = getOutputString(expectedOutput); - - ASSERT_EQ(outputString, checkString); -} diff --git a/mlir/unittests/conversion/CMakeLists.txt b/mlir/unittests/conversion/CMakeLists.txt deleted file mode 100644 index a90d3bc2a4..0000000000 --- a/mlir/unittests/conversion/CMakeLists.txt +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM -# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH -# All rights reserved. -# -# SPDX-License-Identifier: MIT -# -# Licensed under the MIT License - -set(testname "mqt-core-mlir-conversion-test") -file(GLOB_RECURSE CONVERSION_TEST_SOURCES *.cpp) - -if(NOT TARGET ${testname}) - # create an executable in which the tests will be stored - add_executable(${testname} ${CONVERSION_TEST_SOURCES}) - # link the Google test infrastructure and a default main function to the test executable. - target_link_libraries( - ${testname} - PRIVATE GTest::gtest_main - MLIRParser - MLIRQCProgramBuilder - QCToQCO - MLIRPass - MLIRTransforms - MLIRLLVMDialect - QCOToQC - MLIRQCOProgramBuilder) - # discover tests - gtest_discover_tests(${testname} DISCOVERY_TIMEOUT 60) - set_target_properties(${testname} PROPERTIES FOLDER unittests) -endif() From e6ff89a3f90b851f89353c7464879c1967837829 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 12:44:42 +0100 Subject: [PATCH 067/108] fix more docstrings --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 15 ++++++++------- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 18 +++++++++--------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 030d386560..4877a88a5a 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -854,7 +854,9 @@ struct ConvertQCOYieldOp final : OpConversionPattern { * is converted to * ```mlir * scf.if %cond { - * qc.x %q0 : !qc.qubit + * qc.h %q0 : !qc.qubit + * scf.yield + * } else { * scf.yield * } * ``` @@ -894,11 +896,11 @@ struct ConvertQCOScfIfOp final : OpConversionPattern { * @par Example: * ```mlir * %targets_out = scf.while (%arg0 = %q0) : (!qco.qubit) -> !qco.qubit { - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * %q1 = qco.x %arg0 : !qco.qubit -> !qco.qubit * scf.condition(%cond) %q1 : !qco.qubit * } do { * ^bb0(%arg0: !qco.qubit): - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * %q1 = qco.x %arg0 : !qco.qubit -> !qco.qubit * scf.yield %q1 : !qco.qubit * } * ``` @@ -954,9 +956,9 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { * * @par Example: * ```mlir - * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> + * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %q0) -> * (!qco.qubit) { - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * %q1 = qco.x %arg0 : !qco.qubit -> !qco.qubit * scf.yield %q1 : !qco.qubit * } * ``` @@ -1058,7 +1060,6 @@ struct ConvertQCOScfConditionOp final : OpConversionPattern { * @par Example: * ```mlir * %q1 = call @test(%q0) : (!qco.qubit) -> !qco.qubit - * } * ``` * is converted to * ```mlir @@ -1091,7 +1092,7 @@ struct ConvertQCOFuncCallOp final : OpConversionPattern { * ``` * is converted to * ```mlir - * func.func @test(%arg0: !qc.qubit){ + * func.func @test(%arg0: !qc.qubit) { * ... * } * ``` diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 89304e3e56..613064ac1a 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1363,11 +1363,11 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { * is converted to * ```mlir * %targets_out = scf.while (%arg0 = %q0) : (!qco.qubit) -> !qco.qubit { - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * %q1 = qco.x %arg0 : !qco.qubit -> !qco.qubit * scf.condition(%cond) %q1 : !qco.qubit * } do { * ^bb0(%arg0: !qco.qubit): - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit + * %q1 = qco.x %arg0 : !qco.qubit -> !qco.qubit * scf.yield %q1 : !qco.qubit * } * ``` @@ -1439,7 +1439,7 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { }; /** - * @brief Converts scf.for with memory semantics to scf.while with value + * @brief Converts scf.for with memory semantics to scf.for with value * semantics for qubit values * * @par Example: @@ -1451,10 +1451,10 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { * ``` * is converted to * ```mlir - * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = q0) -> + * %targets_out = scf.for %iv = %lb to %ub step %step iter_args(%arg0 = %q0) -> * (!qco.qubit) { - * %q1 = qc.x %arg0 : !qco.qubit -> !qco.qubit - * scf.yield %q1 : !qco.qubit + * %q1 = qco.x %arg0 : !qco.qubit -> !qco.qubit + * scf.yield %q1 : !qco.qubit * } * ``` */ @@ -1604,7 +1604,7 @@ struct ConvertQCScfConditionOp final * ``` * is converted to * ```mlir - * %q1 = call @test(%q1) : (!qco.qubit) -> !qco.qubit + * %q1 = call @test(%q0) : (!qco.qubit) -> !qco.qubit * ``` */ struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { @@ -1649,13 +1649,13 @@ struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { * * @par Example: * ```mlir - * func.func @test(%arg0: !qc.qubit){ + * func.func @test(%arg0: !qc.qubit) { * ... * } * ``` * is converted to * ```mlir - * func.func @test(%arg0: !qco.qubit) -> !qco.qubit{ + * func.func @test(%arg0: !qco.qubit) -> !qco.qubit { * ... * } * ``` From 8c5612d66c5327dc4525a02bba5fdb539bb573e1 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 13:41:12 +0100 Subject: [PATCH 068/108] use more idiomatic way to buld yieldOp in body --- mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index e0d1ace504..d6cbbed815 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -465,7 +465,7 @@ QCProgramBuilder::scfFor(const std::variant& lowerbound, const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); body(iv); - b.create(loc); + scf::YieldOp::create(b, loc); }); return *this; @@ -487,7 +487,7 @@ QCProgramBuilder::scfWhile(const std::function& beforeBody, const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); afterBody(); - b.create(loc); + scf::YieldOp::create(b, loc); }); return *this; @@ -506,7 +506,7 @@ QCProgramBuilder::scfIf(const std::variant& cond, const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); thenBody(); - b.create(loc); + scf::YieldOp::create(b, loc); }); } else { scf::IfOp::create( @@ -515,13 +515,13 @@ QCProgramBuilder::scfIf(const std::variant& cond, const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); thenBody(); - b.create(loc); + scf::YieldOp::create(b, loc); }, [&](OpBuilder& b, Location loc) { const OpBuilder::InsertionGuard guard(*this); setInsertionPointToStart(b.getInsertionBlock()); (*elseBody)(); - b.create(loc); + scf::YieldOp::create(b, loc); }); } return *this; From e1a8512e29a4339ce373ee6f2b9e15ee3ce97d49 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 14:23:27 +0100 Subject: [PATCH 069/108] add missing header --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 4877a88a5a..b78b1892e8 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include +#include #include #include #include From ab8f74d71e6f399e0fad1896bba8092b9522e757 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 14:35:53 +0100 Subject: [PATCH 070/108] add additional asserts --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 613064ac1a..81a1cbd658 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1321,15 +1321,16 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { // create the qubit map for the regions auto& thenRegionQubitMap = getState().qubitMap[&thenRegion]; auto& elseRegionQubitMap = getState().qubitMap[&elseRegion]; + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + for (const auto& qcQubit : qcQubits) { - thenRegionQubitMap.try_emplace( - qcQubit, getState().qubitMap[op->getParentRegion()][qcQubit]); - elseRegionQubitMap.try_emplace( - qcQubit, getState().qubitMap[op->getParentRegion()][qcQubit]); + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); + thenRegionQubitMap.try_emplace(qcQubit, qubitMap[qcQubit]); + elseRegionQubitMap.try_emplace(qcQubit, qubitMap[qcQubit]); } // update the qubit map in the current region - auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + for (const auto& [qcQubit, qcoQubit] : llvm::zip_equal(qcQubits, newIfOp->getResults())) { qubitMap[qcQubit] = qcoQubit; @@ -1385,6 +1386,7 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { SmallVector qcoQubits; qcoQubits.reserve(qcQubits.size()); for (const auto& qcQubit : qcQubits) { + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap[qcQubit]); } // create the result typerange @@ -1471,6 +1473,7 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { SmallVector qcoQubits; qcoQubits.reserve(qcQubits.size()); for (const auto& qcQubit : qcQubits) { + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap[qcQubit]); } @@ -1624,6 +1627,7 @@ struct ConvertQCFuncCallOp final : StatefulOpConversionPattern { SmallVector qcoQubits; qcoQubits.reserve(qcQubits.size()); for (const auto& qcQubit : qcQubits) { + assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap[qcQubit]); } // create the result typerange From 9145b4390227fc0fbd0a34a50a7f3acf168a03f9 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 14:46:06 +0100 Subject: [PATCH 071/108] merge loops together --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 47 +++++++++---------------- 1 file changed, 17 insertions(+), 30 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 81a1cbd658..3a0970edb5 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1284,6 +1284,7 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(scf::IfOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; auto& regionMap = getState().regionMap; const auto& qcQubits = regionMap[op]; const SmallVector qcValues(qcQubits.begin(), qcQubits.end()); @@ -1318,21 +1319,16 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { StringAttr::get(rewriter.getContext(), "yes")); } - // create the qubit map for the regions auto& thenRegionQubitMap = getState().qubitMap[&thenRegion]; auto& elseRegionQubitMap = getState().qubitMap[&elseRegion]; - auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - for (const auto& qcQubit : qcQubits) { + // create the qubit map for the regions and update the qubit map for the + // current region + for (const auto& [qcQubit, qcoQubit] : + llvm::zip_equal(qcQubits, newIfOp->getResults())) { assert(qubitMap.contains(qcQubit) && "QC qubit not found"); thenRegionQubitMap.try_emplace(qcQubit, qubitMap[qcQubit]); elseRegionQubitMap.try_emplace(qcQubit, qubitMap[qcQubit]); - } - - // update the qubit map in the current region - - for (const auto& [qcQubit, qcoQubit] : - llvm::zip_equal(qcQubits, newIfOp->getResults())) { qubitMap[qcQubit] = qcoQubit; } @@ -1411,21 +1407,16 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { newAfterBlock->getOperations().splice(newAfterBlock->end(), op.getAfterBody()->getOperations()); - // create the qubit map for the new regions auto& newBeforeRegionMap = getState().qubitMap[&newWhileOp.getBefore()]; auto& newAfterRegionMap = getState().qubitMap[&newWhileOp.getAfter()]; - for (const auto& [qcQubit, qcoQubit] : - llvm::zip_equal(qcQubits, newWhileOp.getBeforeArguments())) { - newBeforeRegionMap.try_emplace(qcQubit, qcoQubit); - } - for (const auto& [qcQubit, qcoQubit] : - llvm::zip_equal(qcQubits, newWhileOp.getAfterArguments())) { - newAfterRegionMap.try_emplace(qcQubit, qcoQubit); - } - // update the qubit map in the current region - for (const auto& [qcQubit, qcoQubit] : - llvm::zip_equal(qcQubits, newWhileOp->getResults())) { + // create the qubit map for the new regions and update the qubit map in the + // current region + for (const auto& [qcQubit, beforeArg, afterArg, qcoQubit] : llvm::zip_equal( + qcQubits, newWhileOp.getBeforeArguments(), + newWhileOp.getAfterArguments(), newWhileOp->getResults())) { + newBeforeRegionMap.try_emplace(qcQubit, beforeArg); + newAfterRegionMap.try_emplace(qcQubit, afterArg); qubitMap[qcQubit] = qcoQubit; } @@ -1492,15 +1483,11 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { auto& newRegion = newFor.getRegion(); auto& regionQubitMap = getState().qubitMap[&newRegion]; - // create the qubitmap for the new region - for (const auto& [qcQubit, qcoQubit] : - llvm::zip_equal(qcQubits, newFor.getRegionIterArgs())) { - regionQubitMap.try_emplace(qcQubit, qcoQubit); - } - - // update the qubitmap in the current region - for (const auto& [qcQubit, qcoQubit] : - llvm::zip_equal(qcQubits, newFor->getResults())) { + // create the qubitmap for the new region and update the qubitmap in the + // current region + for (const auto& [qcQubit, iterArg, qcoQubit] : llvm::zip_equal( + qcQubits, newFor.getRegionIterArgs(), newFor->getResults())) { + regionQubitMap.try_emplace(qcQubit, iterArg); qubitMap[qcQubit] = qcoQubit; } From 7ef1f7fdb3dfc3585c701c0939bf90ad9f42a798 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 15:03:10 +0100 Subject: [PATCH 072/108] fix conversion direction in docstrings --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index b78b1892e8..4ce1155144 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -1081,8 +1081,8 @@ struct ConvertQCOFuncCallOp final : OpConversionPattern { }; /** - * @brief Converts func.func with memory semantics to func.func with - * value semantics for qubit values. This currently assumes no mixed types as + * @brief Converts func.func with value semantics to func.func with + * memory semantics for qubit values. This currently assumes no mixed types as * parameters/return values. * * @par Example: From 3938420004ccd2cb34132bee223963def047ba73 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 17 Jan 2026 15:10:31 +0100 Subject: [PATCH 073/108] apply codeRabbit feedback --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 3a0970edb5..66b4f25964 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -130,6 +130,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // get the regions of the current operation const auto& regions = op->getRegions(); SetVector uniqueQubits; + auto const qcType = qc::QubitType::get(ctx); for (auto& region : regions) { // skip empty regions e.g. empty else region of an If operation if (region.empty()) { @@ -138,6 +139,13 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // check that the region has only one block assert(region.hasOneBlock() && "Expected single-block region"); + // collect qubits from the blockarguments + for (auto arg : region.front().getArguments()) { + if (arg.getType() == qcType) { + uniqueQubits.insert(arg); + } + } + // iterate over all operations inside the region // currently assumes that each region only has one block for (auto& operation : region.front().getOperations()) { @@ -149,13 +157,13 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { } // collect qubits form the operands for (const auto& operand : operation.getOperands()) { - if (operand.getType() == qc::QubitType::get(ctx)) { + if (operand.getType() == qcType) { uniqueQubits.insert(operand); } } // collect qubits from the results for (const auto& result : operation.getResults()) { - if (result.getType() == qc::QubitType::get(ctx)) { + if (result.getType() == qcType) { uniqueQubits.insert(result); } } @@ -171,7 +179,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (llvm::isa(operation)) { if (auto func = operation.getParentOfType()) { if (!func.getArgumentTypes().empty() && - func.getArgumentTypes().front() == qc::QubitType::get(ctx)) { + func.getArgumentTypes().front() == qcType) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); state->regionMap[func] = uniqueQubits; } From ba0c481b820b88d235c8c1e5c8a0364f8228b1e8 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Wed, 21 Jan 2026 15:18:44 +0100 Subject: [PATCH 074/108] add support for memref in dialects --- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 2 +- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 2 +- mlir/lib/Conversion/QCToQCO/CMakeLists.txt | 1 + mlir/tools/mqt-cc/mqt-cc.cpp | 2 ++ 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index f8754887e2..f246f1a6ac 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -60,7 +60,7 @@ class QCType traits = []> let mnemonic = typeMnemonic; } -def QubitType : QCType<"Qubit", "qubit"> { +def QubitType : QCType<"Qubit", "qubit", [MemRefElementTypeInterface]> { let summary = "QC qubit reference type"; let description = [{ The `!qc.qubit` type represents a reference to a quantum bit in the diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index aef50a3198..74aa84e736 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -59,7 +59,7 @@ class QCOType traits = []> let mnemonic = typeMnemonic; } -def QubitType : QCOType<"Qubit", "qubit"> { +def QubitType : QCOType<"Qubit", "qubit", [MemRefElementTypeInterface]> { let summary = "QCO qubit value type"; let description = [{ The `!qco.qubit` type represents an SSA value holding a quantum bit diff --git a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt index 25c6dc7940..3d330da09c 100644 --- a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt +++ b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt @@ -17,6 +17,7 @@ add_mlir_library( MLIRQCDialect MLIRQCODialect MLIRArithDialect + MLIRMemRefDialect MLIRFuncDialect MLIRTransforms MLIRFuncTransforms diff --git a/mlir/tools/mqt-cc/mqt-cc.cpp b/mlir/tools/mqt-cc/mqt-cc.cpp index c7de855855..993f01ecbb 100644 --- a/mlir/tools/mqt-cc/mqt-cc.cpp +++ b/mlir/tools/mqt-cc/mqt-cc.cpp @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -119,6 +120,7 @@ int main(int argc, char** argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); MLIRContext context(registry); context.loadAllAvailableDialects(); From d14aecd31cfea5a1074a063ad1269a6716e1c614 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Wed, 21 Jan 2026 16:39:57 +0100 Subject: [PATCH 075/108] add initial conversion of memref ops --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 110 +++++++++++++++++++++--- mlir/tools/mqt-cc/mqt-cc.cpp | 2 + 2 files changed, 99 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 66b4f25964..ff25673f3e 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -21,7 +21,9 @@ #include #include #include +#include #include +#include #include #include #include @@ -116,6 +118,17 @@ class StatefulOpConversionPattern : public OpConversionPattern { } // namespace +static bool isQubitType(Type type) { + if (!llvm::isa(type)) { + auto memrefType = dyn_cast(type); + if (memrefType) { + return llvm::isa(memrefType.getElementType()); + } + return false; + } + return true; +} + /** * @brief Recursively collects all the QC qubit references used by an * operation and store them in map @@ -130,7 +143,6 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // get the regions of the current operation const auto& regions = op->getRegions(); SetVector uniqueQubits; - auto const qcType = qc::QubitType::get(ctx); for (auto& region : regions) { // skip empty regions e.g. empty else region of an If operation if (region.empty()) { @@ -141,7 +153,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // collect qubits from the blockarguments for (auto arg : region.front().getArguments()) { - if (arg.getType() == qcType) { + if (isQubitType(arg.getType())) { uniqueQubits.insert(arg); } } @@ -157,13 +169,14 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { } // collect qubits form the operands for (const auto& operand : operation.getOperands()) { - if (operand.getType() == qcType) { + if (isQubitType(operand.getType())) { uniqueQubits.insert(operand); } } // collect qubits from the results for (const auto& result : operation.getResults()) { - if (result.getType() == qcType) { + if (!llvm::isa(operation) && + isQubitType(result.getType())) { uniqueQubits.insert(result); } } @@ -179,7 +192,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (llvm::isa(operation)) { if (auto func = operation.getParentOfType()) { if (!func.getArgumentTypes().empty() && - func.getArgumentTypes().front() == qcType) { + isQubitType(func.getArgumentTypes().front())) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); state->regionMap[func] = uniqueQubits; } @@ -1265,6 +1278,65 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { } }; +struct ConvertQCMemRefAllocOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + + SmallVector qcoQubits; + for (auto* user : op->getUsers()) { + if (llvm::isa(user)) { + auto storeOp = dyn_cast(user); + qcoQubits.push_back(qubitMap[storeOp.getValue()]); + } + } + auto const qcoType = qco::QubitType::get(rewriter.getContext()); + const auto tensorType = RankedTensorType::get( + {static_cast(qcoQubits.size())}, qcoType); + auto fromElements = tensor::FromElementsOp::create(rewriter, op->getLoc(), + tensorType, qcoQubits); + qubitMap.try_emplace(op->getResult(0), fromElements->getResult(0)); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct ConvertQCMemRefStoreOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::StoreOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct ConvertQCMemRefLoadOp final + : StatefulOpConversionPattern { + using StatefulOpConversionPattern::StatefulOpConversionPattern; + + LogicalResult + matchAndRewrite(memref::LoadOp op, OpAdaptor /*adaptor*/, + ConversionPatternRewriter& rewriter) const override { + auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + auto tensor = qubitMap[op.getMemRef()]; + auto const qcoType = qco::QubitType::get(rewriter.getContext()); + + auto extractOp = tensor::ExtractOp::create(rewriter, op->getLoc(), qcoType, + tensor, {op.getIndices()}); + + qubitMap.try_emplace(op.getResult(), extractOp.getResult()); + rewriter.eraseOp(op); + return success(); + } +}; /** * @brief Converts scf.if with memory semantics to scf.if with value semantics * for qubit values @@ -1771,9 +1843,10 @@ struct QCToQCO final : impl::QCToQCOBase { QCToQCOTypeConverter typeConverter(context); collectUniqueQubits(module, &state, context); - // Configure conversion target: QC illegal, QCO + // Configure conversion target: QC illegal, QCO and tensor // legal target.addIllegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addDynamicallyLegalOp([&](scf::YieldOp op) { @@ -1792,18 +1865,28 @@ struct QCToQCO final : impl::QCToQCOBase { return !(op->getAttrOfType("needChange")); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { - return !llvm::any_of(op.front().getArgumentTypes(), [&](Type type) { - return type == qc::QubitType::get(context); - }); + return !llvm::any_of(op.front().getArgumentTypes(), + [&](Type type) { return isQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::CallOp op) { - return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qc::QubitType::get(context); - }); + return !llvm::any_of(op->getOperandTypes(), + [&](Type type) { return isQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !op->getAttrOfType("needChange"); }); + target.addDynamicallyLegalOp([&](memref::AllocOp op) { + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQubitType(type); }); + }); + target.addDynamicallyLegalOp([&](memref::StoreOp op) { + return !llvm::any_of(op.getOperandTypes(), + [&](Type type) { return isQubitType(type); }); + }); + target.addDynamicallyLegalOp([&](memref::LoadOp op) { + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQubitType(type); }); + }); // Register operation conversion patterns with state // tracking patterns @@ -1817,7 +1900,8 @@ struct QCToQCO final : impl::QCToQCOBase { ConvertQCDCXOp, ConvertQCECROp, ConvertQCRXXOp, ConvertQCRYYOp, ConvertQCRZXOp, ConvertQCRZZOp, ConvertQCXXPlusYYOp, ConvertQCXXMinusYYOp, ConvertQCBarrierOp, ConvertQCCtrlOp, - ConvertQCYieldOp, ConvertQCScfIfOp, ConvertQCScfYieldOp, + ConvertQCYieldOp, ConvertQCMemRefAllocOp, ConvertQCMemRefStoreOp, + ConvertQCMemRefLoadOp, ConvertQCScfIfOp, ConvertQCScfYieldOp, ConvertQCScfWhileOp, ConvertQCScfConditionOp, ConvertQCScfForOp, ConvertQCFuncCallOp, ConvertQCFuncFuncOp, ConvertQCFuncReturnOp>( typeConverter, context, &state); diff --git a/mlir/tools/mqt-cc/mqt-cc.cpp b/mlir/tools/mqt-cc/mqt-cc.cpp index 993f01ecbb..070a589651 100644 --- a/mlir/tools/mqt-cc/mqt-cc.cpp +++ b/mlir/tools/mqt-cc/mqt-cc.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -121,6 +122,7 @@ int main(int argc, char** argv) { registry.insert(); registry.insert(); registry.insert(); + registry.insert(); MLIRContext context(registry); context.loadAllAvailableDialects(); From 86429c508e9cf6c538fdbee8ba53d29b299c9302 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Wed, 21 Jan 2026 18:17:43 +0100 Subject: [PATCH 076/108] add memref to tensor conversion --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 64 ++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index ff25673f3e..2201de9bb8 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -26,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -169,14 +171,19 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { } // collect qubits form the operands for (const auto& operand : operation.getOperands()) { + if (operand.getDefiningOp()) { + continue; + } if (isQubitType(operand.getType())) { uniqueQubits.insert(operand); } } // collect qubits from the results for (const auto& result : operation.getResults()) { - if (!llvm::isa(operation) && - isQubitType(result.getType())) { + if (llvm::isa(operation)) { + break; + } + if (isQubitType(result.getType())) { uniqueQubits.insert(result); } } @@ -204,6 +211,8 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (!uniqueQubits.empty() && (llvm::isa(op) || (llvm::isa(op)) || llvm::isa(op))) { + if (llvm::isa(op)) { + } state->regionMap[op] = uniqueQubits; op->setAttr("needChange", StringAttr::get(ctx, "yes")); } @@ -1569,6 +1578,27 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { qcQubits, newFor.getRegionIterArgs(), newFor->getResults())) { regionQubitMap.try_emplace(qcQubit, iterArg); qubitMap[qcQubit] = qcoQubit; + + // if the value of the qc qubit is a memref register, extract each value + // from the new tensor and update the qubitmap for each value + if (llvm::isa(qcQubit.getType())) { + // get all the qubits that were stored in the memref register + for (const auto* user : qcQubit.getUsers()) { + if (auto storeOp = dyn_cast(user)) { + // get the qubit + const auto qubit = storeOp.getValueToStore(); + auto const qcoType = qco::QubitType::get(rewriter.getContext()); + + // create the extract operation for each qubit from the resulting + // tensor of the scf.for operation + auto extractOp = + tensor::ExtractOp::create(rewriter, op->getLoc(), qcoType, + qcoQubit, {storeOp.getIndices()}); + // update the qubit map for each of them + qubitMap[qubit] = extractOp.getResult(); + } + } + } } // replace the old entry in the regionMap with the new operation @@ -1601,24 +1631,45 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { - auto const qcType = qc::QubitType::get(rewriter.getContext()); assert(llvm::all_of(op.getOperandTypes(), - [&](Type type) { return type == qcType; }) && + [&](Type type) { return isQubitType(type); }) && "Not all operands are qc qubits"); const auto& parentRegion = op->getParentRegion(); - const auto& qubitMap = getState().qubitMap[parentRegion]; + auto& qubitMap = getState().qubitMap[parentRegion]; const auto& orderedQubits = getState().regionMap[parentRegion->getParentOp()]; SmallVector qcoQubits; qcoQubits.reserve(orderedQubits.size()); + // get the latest qco qubit or the latest qco tensor from the qubitMap for (const auto& qcQubit : orderedQubits) { assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - qcoQubits.push_back(qubitMap.lookup(qcQubit)); + const auto qcoQubit = qubitMap[qcQubit]; + + // add an insert operation for every qubit that was extract from a + // register + if (dyn_cast(qcQubit.getType())) { + // find all extracted values of the register + for (const auto* user : qcQubit.getUsers()) { + if (auto loadOp = dyn_cast(user)) { + // get the latest qco qubit and add it back to the tensor + auto qubit = loadOp.getResult(); + assert(qubitMap.contains(qubit) && "QC qubit not found"); + + auto latestQcoQubit = qubitMap.lookup(qubit); + auto insertOp = + tensor::InsertOp::create(rewriter, op.getLoc(), latestQcoQubit, + qcoQubit, loadOp.getIndices()); + qubitMap[qcQubit] = insertOp.getResult(); + } + } + } + qcoQubits.push_back(qubitMap[qcQubit]); } rewriter.replaceOpWithNewOp(op, qcoQubits); + return success(); } }; @@ -1848,6 +1899,7 @@ struct QCToQCO final : impl::QCToQCOBase { target.addIllegalDialect(); target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addDynamicallyLegalOp([&](scf::YieldOp op) { return !(op->getAttrOfType("needChange")); From aee2bca1836f65f0502c3a61e2602af9dd1826cf Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Wed, 21 Jan 2026 18:46:42 +0100 Subject: [PATCH 077/108] fix multiple tensor inserts --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 2201de9bb8..dc08d9724a 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1645,7 +1645,6 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { // get the latest qco qubit or the latest qco tensor from the qubitMap for (const auto& qcQubit : orderedQubits) { assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - const auto qcoQubit = qubitMap[qcQubit]; // add an insert operation for every qubit that was extract from a // register @@ -1658,9 +1657,9 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { assert(qubitMap.contains(qubit) && "QC qubit not found"); auto latestQcoQubit = qubitMap.lookup(qubit); - auto insertOp = - tensor::InsertOp::create(rewriter, op.getLoc(), latestQcoQubit, - qcoQubit, loadOp.getIndices()); + auto insertOp = tensor::InsertOp::create( + rewriter, op.getLoc(), latestQcoQubit, qubitMap[qcQubit], + loadOp.getIndices()); qubitMap[qcQubit] = insertOp.getResult(); } } From 686ebe304f05ba3f7d1126b6031d57638202a025 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 22 Jan 2026 13:08:01 +0100 Subject: [PATCH 078/108] fix order of values --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index dc08d9724a..73cf2bafe3 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1297,7 +1297,8 @@ struct ConvertQCMemRefAllocOp final auto& qubitMap = getState().qubitMap[op->getParentRegion()]; SmallVector qcoQubits; - for (auto* user : op->getUsers()) { + const auto users = llvm::to_vector(op->getUsers()); + for (auto* user : llvm::reverse(users)) { if (llvm::isa(user)) { auto storeOp = dyn_cast(user); qcoQubits.push_back(qubitMap[storeOp.getValue()]); @@ -1346,6 +1347,7 @@ struct ConvertQCMemRefLoadOp final return success(); } }; + /** * @brief Converts scf.if with memory semantics to scf.if with value semantics * for qubit values From aa2a70e2bd6b1b3cd8f19eb8234669b5ea13b9fd Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 22 Jan 2026 14:34:36 +0100 Subject: [PATCH 079/108] add tensor conversion from QCO to QC --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 126 ++++++++++++++++++++- mlir/lib/Conversion/QCToQCO/CMakeLists.txt | 1 - 2 files changed, 120 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 4ce1155144..2ec448d3a9 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -16,10 +16,14 @@ #include #include #include +#include #include #include +#include #include +#include #include +#include #include #include #include @@ -271,6 +275,13 @@ class QCOToQCTypeConverter final : public TypeConverter { addConversion([ctx](qco::QubitType /*type*/) -> Type { return qc::QubitType::get(ctx); }); + + addConversion([&](RankedTensorType t) -> Type { + if (t.getElementType() == qco::QubitType::get(ctx)) { + return MemRefType::get(t.getShape(), qc::QubitType::get(ctx)); + } + return t; + }); } }; @@ -839,6 +850,62 @@ struct ConvertQCOYieldOp final : OpConversionPattern { } }; +struct ConvertQCOTensorFromElementsOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + + const auto qcType = qc::QubitType::get(rewriter.getContext()); + + auto memrefType = MemRefType::get(op.getType().getShape(), qcType); + + auto memrefAllocOp = + rewriter.create(op.getLoc(), memrefType); + + // Store each element + for (auto it : llvm::enumerate(adaptor.getElements())) { + Value idx = + rewriter.create(op.getLoc(), it.index()); + rewriter.create(op.getLoc(), it.value(), memrefAllocOp, + idx); + } + + // Replace all uses of the tensor result with the memref + rewriter.replaceOp(op, memrefAllocOp); + + return success(); + } +}; + +struct ConvertQCOTensorExtractOp final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + const auto memref = adaptor.getTensor(); + const auto loadOp = memref::LoadOp::create(rewriter, op.getLoc(), memref, + adaptor.getIndices()); + rewriter.replaceOp(op, loadOp); + return success(); + } +}; + +struct ConvertQCOTensorInsertOp final : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + + rewriter.replaceOp(op, adaptor.getDest()); + return success(); + } +}; /** * @brief Converts scf.if with value semantics to scf.if with memory semantics * for qubit values. This currently assumes no mixed types as return values. @@ -995,6 +1062,33 @@ struct ConvertQCOScfForOp final : OpConversionPattern { newBlock->getOperations().splice(newBlock->begin(), srcOps, srcOps.begin(), std::prev(srcOps.end())); + // find the init args that are tensors + for (auto initArg : llvm::enumerate(op.getInitArgs())) { + auto value = initArg.value(); + if (llvm::isa(value.getType())) { + // find the equivalent memref register from the adaptor + const auto memref = adaptor.getInitArgs()[initArg.index()]; + SmallVector qcQubits; + // get the qc qubits from them + const auto memrefUsers = llvm::to_vector(memref.getUsers()); + for (auto* user : llvm::reverse(memrefUsers)) { + if (llvm::isa(user)) { + auto storeOp = dyn_cast(user); + qcQubits.push_back(storeOp.getValueToStore()); + } + } + // get the users of the returned tensor + const auto users = + llvm::to_vector(op->getResult(initArg.index()).getUsers()); + for (auto user : llvm::enumerate(llvm::reverse(users))) { + if (llvm::isa(user.value())) { + rewriter.replaceAllUsesWith(user.value()->getResult(0), + qcQubits[user.index()]); + rewriter.eraseOp(user.value()); + } + } + } + } // Replace the result values with the init values rewriter.replaceOp(op, adaptor.getInitArgs()); return success(); @@ -1183,6 +1277,22 @@ struct QCOToQC final : impl::QCOToQCBase { RewritePatternSet patterns(context); const QCOToQCTypeConverter typeConverter(context); + target.addDynamicallyLegalOp( + [&](tensor::FromElementsOp op) { + return !llvm::any_of(op.getOperandTypes(), [&](Type type) { + return type == qco::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](tensor::ExtractOp op) { + return !llvm::any_of(op->getResultTypes(), [&](Type type) { + return type == qco::QubitType::get(context); + }); + }); + target.addDynamicallyLegalOp([&](tensor::InsertOp op) { + return !llvm::any_of(op.getOperandTypes(), [&](Type type) { + return type == qco::QubitType::get(context); + }); + }); target.addDynamicallyLegalOp([&](scf::IfOp op) { return !llvm::any_of(op->getResultTypes(), [&](Type type) { return type == qco::QubitType::get(context); @@ -1208,7 +1318,8 @@ struct QCOToQC final : impl::QCOToQCBase { }); target.addDynamicallyLegalOp([&](scf::ForOp op) { return !llvm::any_of(op->getResultTypes(), [&](Type type) { - return type == qco::QubitType::get(context); + return type == qco::QubitType::get(context) || + llvm::isa(type); }); }); target.addDynamicallyLegalOp([&](func::CallOp op) { @@ -1227,10 +1338,11 @@ struct QCOToQC final : impl::QCOToQCBase { type == qc::QubitType::get(context); }); }); - // Configure conversion target: QCO illegal, QC legal target.addIllegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns @@ -1244,10 +1356,12 @@ struct QCOToQC final : impl::QCOToQCBase { ConvertQCODCXOp, ConvertQCOECROp, ConvertQCORXXOp, ConvertQCORYYOp, ConvertQCORZXOp, ConvertQCORZZOp, ConvertQCOXXPlusYYOp, ConvertQCOXXMinusYYOp, ConvertQCOBarrierOp, ConvertQCOCtrlOp, - ConvertQCOYieldOp, ConvertQCOScfIfOp, ConvertQCOScfYieldOp, - ConvertQCOScfWhileOp, ConvertQCOScfConditionOp, ConvertQCOScfForOp, - ConvertQCOFuncCallOp, ConvertQCOFuncFuncOp, - ConvertQCOFuncReturnOp>(typeConverter, context); + ConvertQCOTensorFromElementsOp, ConvertQCOTensorExtractOp, + ConvertQCOTensorInsertOp, ConvertQCOYieldOp, ConvertQCOScfIfOp, + ConvertQCOScfYieldOp, ConvertQCOScfWhileOp, + ConvertQCOScfConditionOp, ConvertQCOScfForOp, ConvertQCOFuncCallOp, + ConvertQCOFuncFuncOp, ConvertQCOFuncReturnOp>(typeConverter, + context); // Apply the conversion if (failed(applyPartialConversion(module, target, std::move(patterns)))) { diff --git a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt index 3d330da09c..25c6dc7940 100644 --- a/mlir/lib/Conversion/QCToQCO/CMakeLists.txt +++ b/mlir/lib/Conversion/QCToQCO/CMakeLists.txt @@ -17,7 +17,6 @@ add_mlir_library( MLIRQCDialect MLIRQCODialect MLIRArithDialect - MLIRMemRefDialect MLIRFuncDialect MLIRTransforms MLIRFuncTransforms From 7950efeaa15cfa863615f7a9c01b41228caf70b2 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 22 Jan 2026 15:19:36 +0100 Subject: [PATCH 080/108] update the docstrings for the new conversions --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 102 +++++++++++++++++------- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 95 ++++++++++++++++------ 2 files changed, 140 insertions(+), 57 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 2ec448d3a9..3d5bd37943 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -850,6 +850,21 @@ struct ConvertQCOYieldOp final : OpConversionPattern { } }; +/** + * @brief Converts tensor.from_elements to memref.alloc for qubits + * + * @par Example: + * ```mlir + * %tensor = tensor.from_elements %q0, %q1, %q2 : tensore<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %alloc = memref.alloc() : memref<3x!qc.qubit> + * memref.store %q0, %alloc[%c0] : memref<3x!qc.qubit> + * memref.store %q1, %alloc[%c1] : memref<3x!qc.qubit> + * memref.store %q2, %alloc[%c2] : memref<3x!qc.qubit> + * ``` + */ struct ConvertQCOTensorFromElementsOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -857,29 +872,37 @@ struct ConvertQCOTensorFromElementsOp final LogicalResult matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - const auto qcType = qc::QubitType::get(rewriter.getContext()); - + const auto loc = op.getLoc(); auto memrefType = MemRefType::get(op.getType().getShape(), qcType); + // create the memref alloc operation + auto memrefAllocOp = rewriter.create(loc, memrefType); - auto memrefAllocOp = - rewriter.create(op.getLoc(), memrefType); - - // Store each element + // store each qubit into the memref for (auto it : llvm::enumerate(adaptor.getElements())) { - Value idx = - rewriter.create(op.getLoc(), it.index()); - rewriter.create(op.getLoc(), it.value(), memrefAllocOp, - idx); + Value idx = rewriter.create(loc, it.index()); + rewriter.create(loc, it.value(), memrefAllocOp, idx); } - // Replace all uses of the tensor result with the memref + // replace all uses of the tensor result with the memref rewriter.replaceOp(op, memrefAllocOp); return success(); } }; +/** + * @brief Converts tensor.extract to memref.load for qubits + * + * @par Example: + * ```mlir + * %q0 = tensor.extract %tensor[%c0] : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %q0 = memref.load %memref[%c0] : memref<3x!qco.qubit> + * ``` + */ struct ConvertQCOTensorExtractOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -887,25 +910,34 @@ struct ConvertQCOTensorExtractOp final LogicalResult matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - const auto memref = adaptor.getTensor(); - const auto loadOp = memref::LoadOp::create(rewriter, op.getLoc(), memref, - adaptor.getIndices()); - rewriter.replaceOp(op, loadOp); + rewriter.replaceOpWithNewOp(op, adaptor.getTensor(), + adaptor.getIndices()); return success(); } }; +/** + * @brief Removes tensor.insert for qubits + * + * @par Example: + * ```mlir + * %new_tensor = tensor.insert %q0 into %tensor[%c0] : tensor<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * ``` + */ struct ConvertQCOTensorInsertOp final : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(tensor::InsertOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOp(op, adaptor.getDest()); return success(); } }; + /** * @brief Converts scf.if with value semantics to scf.if with memory semantics * for qubit values. This currently assumes no mixed types as return values. @@ -1020,7 +1052,7 @@ struct ConvertQCOScfWhileOp final : OpConversionPattern { /** * @brief Converts scf.for with value semantics to scf.for with memory * semantics for qubit values. This currently assumes no mixed types as return - * values. + * values except for qco.qubits and tensors of qco.qubits. * * @par Example: * ```mlir @@ -1049,7 +1081,7 @@ struct ConvertQCOScfForOp final : OpConversionPattern { rewriter, op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), ValueRange{}); - // Replace the uses of the previous iter_args + // Replace the uses of the previous iter_args and the induction variable for (const auto& [qcoQubit, qcQubit] : llvm::zip_equal(op.getRegionIterArgs(), adaptor.getInitArgs())) { qcoQubit.replaceAllUsesWith(qcQubit); @@ -1062,14 +1094,15 @@ struct ConvertQCOScfForOp final : OpConversionPattern { newBlock->getOperations().splice(newBlock->begin(), srcOps, srcOps.begin(), std::prev(srcOps.end())); - // find the init args that are tensors + // Find the init args that are tensors for (auto initArg : llvm::enumerate(op.getInitArgs())) { - auto value = initArg.value(); + const auto value = initArg.value(); if (llvm::isa(value.getType())) { - // find the equivalent memref register from the adaptor + // Find the equivalent memref register from the adaptor const auto memref = adaptor.getInitArgs()[initArg.index()]; SmallVector qcQubits; - // get the qc qubits from them + + // Get the qc qubits from them const auto memrefUsers = llvm::to_vector(memref.getUsers()); for (auto* user : llvm::reverse(memrefUsers)) { if (llvm::isa(user)) { @@ -1077,18 +1110,23 @@ struct ConvertQCOScfForOp final : OpConversionPattern { qcQubits.push_back(storeOp.getValueToStore()); } } - // get the users of the returned tensor + + // Get the users of the result tensor of the current operation const auto users = llvm::to_vector(op->getResult(initArg.index()).getUsers()); for (auto user : llvm::enumerate(llvm::reverse(users))) { - if (llvm::isa(user.value())) { - rewriter.replaceAllUsesWith(user.value()->getResult(0), + auto* const extractOp = user.value(); + if (llvm::isa(extractOp)) { + // Replace the extract operations with the values of the memref + // register and delete the extract operation + rewriter.replaceAllUsesWith(extractOp->getResult(0), qcQubits[user.index()]); - rewriter.eraseOp(user.value()); + rewriter.eraseOp(extractOp); } } } } + // Replace the result values with the init values rewriter.replaceOp(op, adaptor.getInitArgs()); return success(); @@ -1277,6 +1315,12 @@ struct QCOToQC final : impl::QCOToQCBase { RewritePatternSet patterns(context); const QCOToQCTypeConverter typeConverter(context); + // Configure conversion target: QCO illegal, QC, Arith, MemRef legal + target.addIllegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addDynamicallyLegalOp( [&](tensor::FromElementsOp op) { return !llvm::any_of(op.getOperandTypes(), [&](Type type) { @@ -1338,11 +1382,7 @@ struct QCOToQC final : impl::QCOToQCBase { type == qc::QubitType::get(context); }); }); - // Configure conversion target: QCO illegal, QC legal - target.addIllegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); + // Register operation conversion patterns // Note: No state tracking needed - OpAdaptors handle type conversion patterns diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 73cf2bafe3..acb6632dc1 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1287,6 +1287,18 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { } }; +/** + * @brief Converts memref.alloc to tensor.from_elements for qubits + * + * @par Example: + * ```mlir + * %alloc = memref.alloc() : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * %tensor = tensor.from_elements %q0, %q1, %q2 : tensore<3x!qco.qubit> + * ``` + */ struct ConvertQCMemRefAllocOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1296,6 +1308,7 @@ struct ConvertQCMemRefAllocOp final ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + // Get the qco qubits from the users SmallVector qcoQubits; const auto users = llvm::to_vector(op->getUsers()); for (auto* user : llvm::reverse(users)) { @@ -1307,15 +1320,29 @@ struct ConvertQCMemRefAllocOp final auto const qcoType = qco::QubitType::get(rewriter.getContext()); const auto tensorType = RankedTensorType::get( {static_cast(qcoQubits.size())}, qcoType); + // Create the FromElements operation auto fromElements = tensor::FromElementsOp::create(rewriter, op->getLoc(), tensorType, qcoQubits); + // Add them to the qubitMap qubitMap.try_emplace(op->getResult(0), fromElements->getResult(0)); - rewriter.eraseOp(op); + // Erase the old operation + rewriter.eraseOp(op); return success(); } }; +/** + * @brief Deletes the memref.store operation for qubits + * + * @par Example: + * ```mlir + * memref.store %q0, %alloc[%c0] : memref<3x!qc.qubit> + * ``` + * is converted to + * ```mlir + * ``` + */ struct ConvertQCMemRefStoreOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1328,6 +1355,18 @@ struct ConvertQCMemRefStoreOp final } }; +/** + * @brief Converts memref.load to tensor.extract for qubits + * + * @par Example: + * ```mlir + * %q0 = memref.load %memref[%c0] : memref<3x!qco.qubit> + * ``` + * is converted to + * ```mlir + * %q0 = tensor.extract %tensor[%c0] : tensor<3x!qco.qubit> + * ``` + */ struct ConvertQCMemRefLoadOp final : StatefulOpConversionPattern { using StatefulOpConversionPattern::StatefulOpConversionPattern; @@ -1336,13 +1375,15 @@ struct ConvertQCMemRefLoadOp final matchAndRewrite(memref::LoadOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - auto tensor = qubitMap[op.getMemRef()]; + const auto tensor = qubitMap[op.getMemRef()]; auto const qcoType = qco::QubitType::get(rewriter.getContext()); - + // Create the extract operation auto extractOp = tensor::ExtractOp::create(rewriter, op->getLoc(), qcoType, tensor, {op.getIndices()}); + // Update the qubitMap + qubitMap[op.getResult()] = extractOp.getResult(); - qubitMap.try_emplace(op.getResult(), extractOp.getResult()); + // Erase the old operation rewriter.eraseOp(op); return success(); } @@ -1380,32 +1421,32 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { const auto& qcQubits = regionMap[op]; const SmallVector qcValues(qcQubits.begin(), qcQubits.end()); - // create result typerange + // Create result typerange const SmallVector qcoTypes( qcQubits.size(), qco::QubitType::get(rewriter.getContext())); - // create new if operation + // Create new if operation auto newIfOp = scf::IfOp::create(rewriter, op->getLoc(), TypeRange{qcoTypes}, op.getCondition(), op.getElseRegion().empty()); auto& thenRegion = newIfOp.getThenRegion(); auto& elseRegion = newIfOp.getElseRegion(); - // move the regions of the old operations inside the new operation + // Move the regions of the old operations inside the new operation rewriter.inlineRegionBefore(op.getThenRegion(), thenRegion, thenRegion.end()); - // eliminate the empty block that was created during the initialization + // Eliminate the empty block that was created during the initialization rewriter.eraseBlock(&thenRegion.front()); if (!op.getElseRegion().empty()) { rewriter.inlineRegionBefore(op.getElseRegion(), elseRegion, elseRegion.end()); } else { - // create the yield operation if it does not exist yet + // Create the yield operation if it does not exist yet rewriter.setInsertionPointToEnd(&elseRegion.front()); const auto elseYield = scf::YieldOp::create(rewriter, op->getLoc(), qcValues); - // mark the yield operation for conversion + // Mark the yield operation for conversion elseYield->setAttr("needChange", StringAttr::get(rewriter.getContext(), "yes")); } @@ -1413,7 +1454,7 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { auto& thenRegionQubitMap = getState().qubitMap[&thenRegion]; auto& elseRegionQubitMap = getState().qubitMap[&elseRegion]; - // create the qubit map for the regions and update the qubit map for the + // Create the qubit map for the regions and update the qubit map for the // current region for (const auto& [qcQubit, qcoQubit] : llvm::zip_equal(qcQubits, newIfOp->getResults())) { @@ -1423,7 +1464,7 @@ struct ConvertQCScfIfOp final : StatefulOpConversionPattern { qubitMap[qcQubit] = qcoQubit; } - // replace the old entry in the regionMap with the new operation + // Replace the old entry in the regionMap with the new operation const auto& it = regionMap.find(op); const auto values = std::move(it->second); regionMap.erase(op); @@ -1476,23 +1517,23 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { assert(qubitMap.contains(qcQubit) && "QC qubit not found"); qcoQubits.push_back(qubitMap[qcQubit]); } - // create the result typerange + // Create the result typerange const SmallVector qcoTypes( qcQubits.size(), qco::QubitType::get(rewriter.getContext())); - // create the new while operation + // Create the new while operation auto newWhileOp = scf::WhileOp::create( rewriter, op.getLoc(), TypeRange(qcoTypes), ValueRange(qcoQubits)); auto& newBeforeRegion = newWhileOp.getBefore(); auto& newAfterRegion = newWhileOp.getAfter(); const SmallVector locs(qcQubits.size(), op->getLoc()); - // create the new blocks + // Create the new blocks auto* newBeforeBlock = rewriter.createBlock(&newBeforeRegion, {}, qcoTypes, locs); auto* newAfterBlock = rewriter.createBlock(&newAfterRegion, {}, qcoTypes, locs); - // move the operations to the new blocks + // Move the operations to the new blocks newBeforeBlock->getOperations().splice(newBeforeBlock->end(), op.getBeforeBody()->getOperations()); newAfterBlock->getOperations().splice(newAfterBlock->end(), @@ -1501,7 +1542,7 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { auto& newBeforeRegionMap = getState().qubitMap[&newWhileOp.getBefore()]; auto& newAfterRegionMap = getState().qubitMap[&newWhileOp.getAfter()]; - // create the qubit map for the new regions and update the qubit map in the + // Create the qubit map for the new regions and update the qubit map in the // current region for (const auto& [qcQubit, beforeArg, afterArg, qcoQubit] : llvm::zip_equal( qcQubits, newWhileOp.getBeforeArguments(), @@ -1511,7 +1552,7 @@ struct ConvertQCScfWhileOp final : StatefulOpConversionPattern { qubitMap[qcQubit] = qcoQubit; } - // replace the old entry in the regionMap with the new operation + // Replace the old entry in the regionMap with the new operation const auto& it = regionMap.find(op); const auto values = std::move(it->second); regionMap.erase(op); @@ -1564,7 +1605,7 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { rewriter, op.getLoc(), adaptor.getLowerBound(), adaptor.getUpperBound(), adaptor.getStep(), ValueRange(qcoQubits)); - // move the operations to the new block + // Move the operations to the new block auto& srcBlock = op.getRegion().front(); auto& dstBlock = newFor.getRegion().front(); @@ -1574,36 +1615,36 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { auto& newRegion = newFor.getRegion(); auto& regionQubitMap = getState().qubitMap[&newRegion]; - // create the qubitmap for the new region and update the qubitmap in the + // Create the qubitmap for the new region and update the qubitmap in the // current region for (const auto& [qcQubit, iterArg, qcoQubit] : llvm::zip_equal( qcQubits, newFor.getRegionIterArgs(), newFor->getResults())) { regionQubitMap.try_emplace(qcQubit, iterArg); qubitMap[qcQubit] = qcoQubit; - // if the value of the qc qubit is a memref register, extract each value + // If the value of the qc qubit is a memref register, extract each value // from the new tensor and update the qubitmap for each value if (llvm::isa(qcQubit.getType())) { - // get all the qubits that were stored in the memref register + // Get all the qubits that were stored in the memref register for (const auto* user : qcQubit.getUsers()) { if (auto storeOp = dyn_cast(user)) { - // get the qubit + // gGet the qubit const auto qubit = storeOp.getValueToStore(); auto const qcoType = qco::QubitType::get(rewriter.getContext()); - // create the extract operation for each qubit from the resulting + // Create the extract operation for each qubit from the resulting // tensor of the scf.for operation auto extractOp = tensor::ExtractOp::create(rewriter, op->getLoc(), qcoType, qcoQubit, {storeOp.getIndices()}); - // update the qubit map for each of them + // Update the qubit map for each of them qubitMap[qubit] = extractOp.getResult(); } } } } - // replace the old entry in the regionMap with the new operation + // Replace the old entry in the regionMap with the new operation const auto& it = regionMap.find(op); const auto values = std::move(it->second); regionMap.erase(op); @@ -1894,6 +1935,7 @@ struct QCToQCO final : impl::QCToQCOBase { RewritePatternSet patterns(context); QCToQCOTypeConverter typeConverter(context); + // Collect the qubits for each region collectUniqueQubits(module, &state, context); // Configure conversion target: QC illegal, QCO and tensor // legal @@ -1940,6 +1982,7 @@ struct QCToQCO final : impl::QCToQCOBase { return !llvm::any_of(op->getResultTypes(), [&](Type type) { return isQubitType(type); }); }); + // Register operation conversion patterns with state // tracking patterns From b48080184b472d96a0633a46c47bfc71bf6d0941 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 22 Jan 2026 16:56:28 +0100 Subject: [PATCH 081/108] add QC builders for memref --- .../Dialect/QC/Builder/QCProgramBuilder.h | 7 ++++++ .../Dialect/QC/Builder/QCProgramBuilder.cpp | 24 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 85c9a019dd..edf20e3606 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -869,6 +869,13 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { */ QCProgramBuilder& dealloc(Value qubit); + //===--------------------------------------------------------------------===// + // MemRef operations + //===--------------------------------------------------------------------===// + Value memrefAlloc(ValueRange elements); + + Value memrefLoad(Value memref, const std::variant& index); + //===--------------------------------------------------------------------===// // SCF operations //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index d6cbbed815..99c5813021 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/Utils/Utils.h" @@ -444,6 +445,29 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { return *this; } +//===----------------------------------------------------------------------===// +// MemRef operations +//===----------------------------------------------------------------------===// +Value QCProgramBuilder::memrefAlloc(ValueRange elements) { + const auto qcType = qc::QubitType::get(ctx); + const auto memType = + MemRefType::get({static_cast(elements.size())}, qcType); + auto allocOp = memref::AllocOp::create(*this, memType); + for (auto it : llvm::enumerate(elements)) { + Value idx = arith::ConstantOp::create( + *this, getIndexAttr(static_cast(it.index()))); + memref::StoreOp::create(*this, it.value(), allocOp, idx); + } + return allocOp.getResult(); +} + +Value QCProgramBuilder::memrefLoad(Value memref, + const std::variant& index) { + const auto indexValue = utils::variantToValue(*this, getLoc(), index); + const auto loadOp = memref::LoadOp::create(*this, memref, indexValue); + return loadOp->getResult(0); +} + //===----------------------------------------------------------------------===// // SCF operations //===----------------------------------------------------------------------===// From b320196d7bcb3935075f81b8e91a8b9ea132d574 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 22 Jan 2026 16:57:01 +0100 Subject: [PATCH 082/108] add QCO builders for tensor --- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 11 ++++ .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 64 ++++++++++++++++++- 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index fb77b9347d..5de7fa4385 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1029,6 +1029,17 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { */ QCOProgramBuilder& dealloc(Value qubit); + //===--------------------------------------------------------------------===// + // Tensor operations + //===--------------------------------------------------------------------===// + + Value tensorFromElements(ValueRange elements); + + Value tensorExtract(Value tensor, const std::variant& index); + + Value tensorInsert(Value element, Value tensor, + const std::variant& index); + //===--------------------------------------------------------------------===// // SCF operations //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 166bd90348..94b6ae2462 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/Utils/Utils.h" +#include "mlir/IR/BuiltinTypes.h" #include #include @@ -23,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -135,6 +137,7 @@ void QCOProgramBuilder::validateQubitValue(Value qubit, Region* region) const { auto qubits = validQubits.lookup(region); if (qubits.empty() || !qubits.contains(qubit)) { + qubit.print(llvm::outs()); llvm::errs() << "Attempting to use an invalid qubit SSA value. " << "The value may have been consumed by a previous operation " << "or was never created through this builder.\n"; @@ -615,6 +618,62 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { return *this; } +//===----------------------------------------------------------------------===// +// Tensor operations +//===----------------------------------------------------------------------===// + +Value QCOProgramBuilder::tensorFromElements(ValueRange elements) { + checkFinalized(); + auto const qcoType = qco::QubitType::get(ctx); + const auto tensorType = + RankedTensorType::get({static_cast(elements.size())}, qcoType); + // Create the FromElements operation + auto fromElements = + tensor::FromElementsOp::create(*this, tensorType, elements); + return fromElements.getResult(); +} + +Value QCOProgramBuilder::tensorExtract( + Value tensor, const std::variant& index) { + checkFinalized(); + + auto const qcoType = qco::QubitType::get(ctx); + const auto indexValue = utils::variantToValue(*this, getLoc(), index); + auto extractOp = + tensor::ExtractOp::create(*this, qcoType, tensor, indexValue); + auto* const extractParentRegion = extractOp->getParentRegion(); + if (auto scfFor = tensor.getDefiningOp()) { + for (auto arg : scfFor.getInitArgs()) { + if (llvm::isa(arg.getType())) { + auto fromTensorOp = arg.getDefiningOp(); + int64_t val = 0; + if (std::holds_alternative(index)) { + val = std::get(index); + } else { + auto constantOp = + std::get(index).getDefiningOp(); + val = dyn_cast(constantOp.getValue()).getInt(); + } + updateQubitTracking(fromTensorOp.getElements()[val], + extractOp.getResult(), + extractOp->getParentRegion()); + } + } + } + if (!llvm::isa(extractParentRegion->getParentOp())) { + validQubits[extractOp->getParentRegion()].insert(extractOp); + } + + return extractOp.getResult(); +} + +Value QCOProgramBuilder::tensorInsert( + Value element, Value tensor, const std::variant& index) { + checkFinalized(); + const auto indexValue = utils::variantToValue(*this, getLoc(), index); + auto insertOp = tensor::InsertOp::create(*this, element, tensor, indexValue); + return insertOp.getResult(); +} //===----------------------------------------------------------------------===// // SCF operations //===----------------------------------------------------------------------===// @@ -653,7 +712,10 @@ ValueRange QCOProgramBuilder::scfFor( // Update the qubit tracking for (const auto& [initArg, result] : llvm::zip_equal(initArgs, forOp.getResults())) { - updateQubitTracking(initArg, result, forOp->getParentRegion()); + if (!llvm::isa(initArg.getType())) { + + updateQubitTracking(initArg, result, forOp->getParentRegion()); + } } return forOp->getResults(); From ec8936b1413be723c108c970a322dba50893b6e1 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 22 Jan 2026 16:58:17 +0100 Subject: [PATCH 083/108] fix order of users --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index acb6632dc1..5f43542a69 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1379,7 +1379,7 @@ struct ConvertQCMemRefLoadOp final auto const qcoType = qco::QubitType::get(rewriter.getContext()); // Create the extract operation auto extractOp = tensor::ExtractOp::create(rewriter, op->getLoc(), qcoType, - tensor, {op.getIndices()}); + tensor, op.getIndices()); // Update the qubitMap qubitMap[op.getResult()] = extractOp.getResult(); @@ -1592,6 +1592,7 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; auto& regionMap = getState().regionMap; const auto& qcQubits = regionMap[op]; + const auto qcoType = qco::QubitType::get(rewriter.getContext()); SmallVector qcoQubits; qcoQubits.reserve(qcQubits.size()); @@ -1626,11 +1627,11 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { // from the new tensor and update the qubitmap for each value if (llvm::isa(qcQubit.getType())) { // Get all the qubits that were stored in the memref register - for (const auto* user : qcQubit.getUsers()) { + const auto qcQubitUsers = llvm::to_vector(qcQubit.getUsers()); + for (const auto* user : llvm::reverse(qcQubitUsers)) { if (auto storeOp = dyn_cast(user)) { // gGet the qubit const auto qubit = storeOp.getValueToStore(); - auto const qcoType = qco::QubitType::get(rewriter.getContext()); // Create the extract operation for each qubit from the resulting // tensor of the scf.for operation From ca52bf3642bdab37b1ffea391fdd597b64681be4 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 12:42:44 +0100 Subject: [PATCH 084/108] add initial test for registers --- .../QCToQCO/test_conversion_qc_to_qco.cpp | 92 +++++++++++++++++-- 1 file changed, 82 insertions(+), 10 deletions(-) diff --git a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index df80a0711b..96f6850a6f 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -22,14 +22,19 @@ #include #include #include +#include #include +#include #include #include +#include #include #include #include #include #include +#include +#include #include using namespace mlir; @@ -41,19 +46,29 @@ class ConversionTest : public ::testing::Test { // Register all dialects needed for the full compilation pipeline DialectRegistry registry; registry.insert(); + func::FuncDialect, scf::SCFDialect, LLVM::LLVMDialect, + memref::MemRefDialect, tensor::TensorDialect>(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); } + static void runCanonicalizationPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createCanonicalizerPass()); + if (pm.run(module).failed()) { + llvm::errs() << "Failed to run canonicalization passes.\n"; + } + } + [[nodiscard]] OwningOpRef buildQCIR( const std::function& buildFunc) const { mlir::qc::QCProgramBuilder builder(context.get()); builder.initialize(); buildFunc(builder); auto module = builder.finalize(); + runCanonicalizationPass(module.get()); return module; } [[nodiscard]] OwningOpRef buildQCOIR( @@ -62,17 +77,36 @@ class ConversionTest : public ::testing::Test { builder.initialize(); buildFunc(builder); auto module = builder.finalize(); + runCanonicalizationPass(module.get()); return module; } -}; -static std::string getOutputString(mlir::OwningOpRef& module) { - std::string outputString; - llvm::raw_string_ostream os(outputString); - module->print(os); - os.flush(); - return outputString; -} + static std::string + getOutputString(mlir::OwningOpRef& module) { + std::string outputString; + llvm::raw_string_ostream os(outputString); + + auto* moduleOp = module->getOperation(); + const auto* qcoDialect = + moduleOp->getContext()->getLoadedDialect(); + const auto* scfDialect = + moduleOp->getContext()->getLoadedDialect(); + + moduleOp->walk([&](Operation* op) -> WalkResult { + const auto* opDialect = op->getDialect(); + // Only consider operations from the qco dialect and the scf dialect or + // func.call or func.return op + if (opDialect == qcoDialect || opDialect == scfDialect || + llvm::isa(op) || llvm::isa(op)) { + op->print(os); + } + return WalkResult::advance(); + }); + + os.flush(); + return outputString; + } +}; TEST_F(ConversionTest, ScfForQCToQCOTest) { // Test conversion from qc to qco for scf.for operation @@ -274,7 +308,6 @@ TEST_F(ConversionTest, FuncFuncQCToQCOTest) { const auto outputString = getOutputString(input); const auto checkString = getOutputString(expectedOutput); - ASSERT_EQ(outputString, checkString); } @@ -322,3 +355,42 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { ASSERT_EQ(outputString, checkString); } + +TEST_F(ConversionTest, ScfCtrlQCtoQCOTest2) { + // Test conversion from qc to qco for scf.for operation with a memref register + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto reg = b.allocQubitRegister(4); + auto memref = b.memrefAlloc(reg); + b.scfFor(0, 3, 1, [&](Value iv) { + auto extractedQubit = b.memrefLoad(memref, iv); + b.h(extractedQubit); + }); + }); + + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto reg = b.allocQubitRegister(4); + auto tensor = b.tensorFromElements(reg); + auto scfForRes = b.scfFor( + 0, 3, 1, {tensor}, + [&](Value iv, ValueRange iterArgs) -> llvm::SmallVector { + auto extractedQubit = b.tensorExtract(iterArgs[0], iv); + auto q4 = b.h(extractedQubit); + auto newTensor = b.tensorInsert(q4, iterArgs[0], iv); + return {newTensor}; + }); + auto extractedq0 = b.tensorExtract(scfForRes[0], 0); + auto extractedq1 = b.tensorExtract(scfForRes[0], 1); + auto extractedq2 = b.tensorExtract(scfForRes[0], 2); + auto extractedq3 = b.tensorExtract(scfForRes[0], 3); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + ASSERT_EQ(outputString, checkString); +} From 6473046976a26d82cbfb2dc7def3b9b1e9c5e0ba Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 15:13:19 +0100 Subject: [PATCH 085/108] trying to fix nested operations for qc conversion --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 47 +++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 5f43542a69..1d4076b9e7 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -166,12 +166,35 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // check if the operation has an region, if yes recursively collect the // qubits if (operation.getNumRegions() > 0) { - const auto& qubits = collectUniqueQubits(&operation, state, ctx); + auto qubits = collectUniqueQubits(&operation, state, ctx); + + qubits.remove_if([&](Value qubit) { + return llvm::isa(qubit.getType()) || + (llvm::isa(qubit.getDefiningOp()) && + ®ion == qubit.getParentRegion()); + }); + uniqueQubits.set_union(qubits); } + + if (llvm::isa(operation)) { + if (llvm::isa(operation.getParentOp())) { + continue; + } + } + if (llvm::isa(operation)) { + auto loadOp = dyn_cast(operation); + uniqueQubits.insert(loadOp.getMemRef()); + continue; + } // collect qubits form the operands for (const auto& operand : operation.getOperands()) { - if (operand.getDefiningOp()) { + if ((operand.getDefiningOp() || + operand.getDefiningOp())) { + continue; + } + if (operand.getDefiningOp() && + llvm::isa(op)) { continue; } if (isQubitType(operand.getType())) { @@ -207,6 +230,18 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { } } } + for (const auto& operand : op->getOperands()) { + if ((operand.getDefiningOp() || + operand.getDefiningOp())) { + continue; + } + if (operand.getDefiningOp() && llvm::isa(op)) { + continue; + } + if (isQubitType(operand.getType())) { + uniqueQubits.insert(operand); + } + } // mark scf operations that need to be changed afterwards if (!uniqueQubits.empty() && (llvm::isa(op) || (llvm::isa(op)) || @@ -1198,6 +1233,7 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { const auto& qcControls = op.getControls(); SmallVector qcoControls; qcoControls.reserve(qcControls.size()); + for (const auto& qcControl : qcControls) { assert(qubitMap.contains(qcControl) && "QC qubit not found"); qcoControls.push_back(qubitMap[qcControl]); @@ -1616,11 +1652,16 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { auto& newRegion = newFor.getRegion(); auto& regionQubitMap = getState().qubitMap[&newRegion]; + // Copy the qubit Map into the region + for (const auto& [key, value] : qubitMap) { + regionQubitMap[key] = value; + } + // Create the qubitmap for the new region and update the qubitmap in the // current region for (const auto& [qcQubit, iterArg, qcoQubit] : llvm::zip_equal( qcQubits, newFor.getRegionIterArgs(), newFor->getResults())) { - regionQubitMap.try_emplace(qcQubit, iterArg); + regionQubitMap[qcQubit] = iterArg; qubitMap[qcQubit] = qcoQubit; // If the value of the qc qubit is a memref register, extract each value From 9be2e7cff6b310d0412e20106bb19c6a18cc425e Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 17:10:42 +0100 Subject: [PATCH 086/108] fix nested tensor conversion in QCO --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 73 ++++++++++++------------- 1 file changed, 35 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 3d5bd37943..89b24e1a65 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -910,8 +910,35 @@ struct ConvertQCOTensorExtractOp final LogicalResult matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getTensor(), - adaptor.getIndices()); + // Remove the extract operations following a scf.for operation + if (!llvm::isa(op.getOperand(0).getType())) { + // Find the memref register + const auto memref = adaptor.getTensor(); + const auto memrefUsers = llvm::to_vector(memref.getUsers()); + // Get the index where the value was extracted + int64_t index = -1; + auto constantOp = + adaptor.getIndices().front().getDefiningOp(); + const auto indexToStore = + dyn_cast(constantOp.getValue()).getInt(); + // Find the appropriate store operation depending on the index to get the + // qubit + for (auto* user : llvm::reverse(memrefUsers)) { + if (llvm::isa(user)) { + index++; + if (index == indexToStore) { + auto storeOp = dyn_cast(user); + rewriter.replaceOp(op, storeOp.getValueToStore()); + } + } + } + } + // Replace other extract operations with a memref.load operation + else { + rewriter.replaceOpWithNewOp(op, adaptor.getTensor(), + adaptor.getIndices()); + } + return success(); } }; @@ -1094,39 +1121,6 @@ struct ConvertQCOScfForOp final : OpConversionPattern { newBlock->getOperations().splice(newBlock->begin(), srcOps, srcOps.begin(), std::prev(srcOps.end())); - // Find the init args that are tensors - for (auto initArg : llvm::enumerate(op.getInitArgs())) { - const auto value = initArg.value(); - if (llvm::isa(value.getType())) { - // Find the equivalent memref register from the adaptor - const auto memref = adaptor.getInitArgs()[initArg.index()]; - SmallVector qcQubits; - - // Get the qc qubits from them - const auto memrefUsers = llvm::to_vector(memref.getUsers()); - for (auto* user : llvm::reverse(memrefUsers)) { - if (llvm::isa(user)) { - auto storeOp = dyn_cast(user); - qcQubits.push_back(storeOp.getValueToStore()); - } - } - - // Get the users of the result tensor of the current operation - const auto users = - llvm::to_vector(op->getResult(initArg.index()).getUsers()); - for (auto user : llvm::enumerate(llvm::reverse(users))) { - auto* const extractOp = user.value(); - if (llvm::isa(extractOp)) { - // Replace the extract operations with the values of the memref - // register and delete the extract operation - rewriter.replaceAllUsesWith(extractOp->getResult(0), - qcQubits[user.index()]); - rewriter.eraseOp(extractOp); - } - } - } - } - // Replace the result values with the init values rewriter.replaceOp(op, adaptor.getInitArgs()); return success(); @@ -1323,9 +1317,12 @@ struct QCOToQC final : impl::QCOToQCBase { target.addDynamicallyLegalOp( [&](tensor::FromElementsOp op) { - return !llvm::any_of(op.getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op.getOperandTypes(), + [&](Type type) { + return type == qco::QubitType::get(context); + }) && + !(op.getType().getElementType() == + qco::QubitType::get(context)); }); target.addDynamicallyLegalOp([&](tensor::ExtractOp op) { return !llvm::any_of(op->getResultTypes(), [&](Type type) { From ffb23132d82c9418b45291a22c83eddf28675130 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 17:26:24 +0100 Subject: [PATCH 087/108] fix issue with func arguments --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 50 ++++++++++++------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 1d4076b9e7..420830bbd0 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -142,18 +142,18 @@ static bool isQubitType(Type type) { */ static llvm::SetVector collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { - // get the regions of the current operation + // Get the regions of the current operation const auto& regions = op->getRegions(); SetVector uniqueQubits; for (auto& region : regions) { - // skip empty regions e.g. empty else region of an If operation + // Skip empty regions e.g. empty else region of an If operation if (region.empty()) { continue; } - // check that the region has only one block + // Check that the region has only one block assert(region.hasOneBlock() && "Expected single-block region"); - // collect qubits from the blockarguments + // Collect qubits from the blockarguments for (auto arg : region.front().getArguments()) { if (isQubitType(arg.getType())) { uniqueQubits.insert(arg); @@ -167,32 +167,37 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // qubits if (operation.getNumRegions() > 0) { auto qubits = collectUniqueQubits(&operation, state, ctx); - - qubits.remove_if([&](Value qubit) { - return llvm::isa(qubit.getType()) || - (llvm::isa(qubit.getDefiningOp()) && - ®ion == qubit.getParentRegion()); - }); - + // Remove the memref registers and exclude qubits from load operations + // in the same region in the set of unique qubits + if (!llvm::isa(operation)) { + qubits.remove_if([&](Value qubit) { + return llvm::isa(qubit.getType()) || + (llvm::isa(qubit.getDefiningOp()) && + ®ion == qubit.getParentRegion()); + }); + } uniqueQubits.set_union(qubits); } - + // Ignore the alloc operations inside scf.for operations if (llvm::isa(operation)) { if (llvm::isa(operation.getParentOp())) { continue; } } + // Only add the memref to the register when the load operation is matched if (llvm::isa(operation)) { auto loadOp = dyn_cast(operation); uniqueQubits.insert(loadOp.getMemRef()); continue; } - // collect qubits form the operands + // Collect qubits form the operands for (const auto& operand : operation.getOperands()) { + // Ignore the values from memref store and alloc operations if ((operand.getDefiningOp() || operand.getDefiningOp())) { continue; } + // Ignore the qubits that stems from load operations if (operand.getDefiningOp() && llvm::isa(op)) { continue; @@ -201,23 +206,20 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { uniqueQubits.insert(operand); } } - // collect qubits from the results + // Collect qubits from the results for (const auto& result : operation.getResults()) { - if (llvm::isa(operation)) { - break; - } if (isQubitType(result.getType())) { uniqueQubits.insert(result); } } - // mark scf terminator operations if they need to return a value after the + // Mark scf terminator operations if they need to return a value after the // conversion if ((llvm::isa(operation) || llvm::isa(operation)) && !uniqueQubits.empty()) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); } - // mark func.return operation for functions that need to return a qubit + // Mark func.return operation for functions that need to return a qubit // value if (llvm::isa(operation)) { if (auto func = operation.getParentOfType()) { @@ -230,19 +232,13 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { } } } + // Add the operands from the operation itself for (const auto& operand : op->getOperands()) { - if ((operand.getDefiningOp() || - operand.getDefiningOp())) { - continue; - } - if (operand.getDefiningOp() && llvm::isa(op)) { - continue; - } if (isQubitType(operand.getType())) { uniqueQubits.insert(operand); } } - // mark scf operations that need to be changed afterwards + // Mark scf operations that need to be changed afterwards if (!uniqueQubits.empty() && (llvm::isa(op) || (llvm::isa(op)) || llvm::isa(op))) { From 8a411cf72fbcf3fcb758ebf68fe2b47acaf4e70b Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 17:28:48 +0100 Subject: [PATCH 088/108] add tests for qc to qco conversion with registers --- .../QCToQCO/test_conversion_qc_to_qco.cpp | 76 ++++++++++++++++++- 1 file changed, 74 insertions(+), 2 deletions(-) diff --git a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index 96f6850a6f..a672512840 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -27,7 +27,7 @@ #include #include #include -#include +#include #include #include #include @@ -94,6 +94,11 @@ class ConversionTest : public ::testing::Test { moduleOp->walk([&](Operation* op) -> WalkResult { const auto* opDialect = op->getDialect(); + + // Ignore dealloc operations as the order does not matter + if (llvm::isa(op)) { + return WalkResult::advance(); + } // Only consider operations from the qco dialect and the scf dialect or // func.call or func.return op if (opDialect == qcoDialect || opDialect == scfDialect || @@ -287,7 +292,6 @@ TEST_F(ConversionTest, FuncFuncQCToQCOTest) { b.y(args[0]); }); }); - PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { @@ -365,6 +369,8 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest2) { auto extractedQubit = b.memrefLoad(memref, iv); b.h(extractedQubit); }); + b.swap(reg[0], reg[1]); + b.swap(reg[2], reg[3]); }); PassManager pm(context.get()); @@ -388,9 +394,75 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest2) { auto extractedq1 = b.tensorExtract(scfForRes[0], 1); auto extractedq2 = b.tensorExtract(scfForRes[0], 2); auto extractedq3 = b.tensorExtract(scfForRes[0], 3); + b.swap(extractedq0, extractedq1); + b.swap(extractedq2, extractedq3); }); const auto outputString = getOutputString(input); const auto checkString = getOutputString(expectedOutput); ASSERT_EQ(outputString, checkString); } + +TEST_F(ConversionTest, ScfCtrlQCtoQCOTest3) { + // Test conversion from qc to qco for scf.for operation with a memref register + auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto reg0 = b.allocQubitRegister(4, "q0"); + auto reg1 = b.allocQubitRegister(4, "q1"); + auto memref0 = b.memrefAlloc(reg0); + b.scfFor(0, 3, 1, [&](Value iv) { + auto extractedQubit = b.memrefLoad(memref0, iv); + b.x(extractedQubit); + auto memref1 = b.memrefAlloc(reg1); + b.scfFor(0, 3, 1, [&](Value iv2) { + auto q1 = b.memrefLoad(memref1, iv2); + b.cx(extractedQubit, q1); + }); + }); + b.swap(reg0[0], reg0[1]); + b.swap(reg0[2], reg0[3]); + }); + PassManager pm(context.get()); + pm.addPass(createQCToQCO()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + } + + auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto reg0 = b.allocQubitRegister(4, "q0"); + auto reg1 = b.allocQubitRegister(4, "q1"); + auto tensor0 = b.tensorFromElements(reg0); + auto scfForRes = b.scfFor( + 0, 3, 1, {tensor0, reg1[0], reg1[1], reg1[2], reg1[3]}, + [&](Value iv, ValueRange iterArgs) -> llvm::SmallVector { + auto extractedQubit = b.tensorExtract(iterArgs[0], iv); + auto outerQubit = b.x(extractedQubit); + auto tensor1 = b.tensorFromElements( + {iterArgs[1], iterArgs[2], iterArgs[3], iterArgs[4]}); + auto innerResults = b.scfFor( + 0, 3, 1, {tensor1, outerQubit}, + [&](Value innerIv, + ValueRange innerIterArgs) -> llvm::SmallVector { + auto innerQubit = b.tensorExtract(innerIterArgs[0], innerIv); + auto ctrlOp = b.cx(innerIterArgs[1], innerQubit); + auto innerTensor = + b.tensorInsert(ctrlOp.second, innerIterArgs[0], innerIv); + return {innerTensor, ctrlOp.first}; + }); + auto extractedq0 = b.tensorExtract(innerResults[0], 0); + auto extractedq1 = b.tensorExtract(innerResults[0], 1); + auto extractedq2 = b.tensorExtract(innerResults[0], 2); + auto extractedq3 = b.tensorExtract(innerResults[0], 3); + auto tensor2 = b.tensorInsert(innerResults[1], iterArgs[0], iv); + return {tensor2, extractedq0, extractedq1, extractedq2, extractedq3}; + }); + auto extractedq0 = b.tensorExtract(scfForRes[0], 0); + auto extractedq1 = b.tensorExtract(scfForRes[0], 1); + auto extractedq2 = b.tensorExtract(scfForRes[0], 2); + auto extractedq3 = b.tensorExtract(scfForRes[0], 3); + b.swap(extractedq0, extractedq1); + b.swap(extractedq2, extractedq3); + }); + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + ASSERT_EQ(outputString, checkString); +} From fa1a93889386a8ee5d1f7b03cb4aa4b3d3b1513b Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 17:44:43 +0100 Subject: [PATCH 089/108] add tests for qco conversion with tensor --- .../QCOToQC/test_conversion_qco_to_qc.cpp | 155 +++++++++++++++++- 1 file changed, 153 insertions(+), 2 deletions(-) diff --git a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp index 6256245bd3..d521ea4a5d 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp @@ -22,14 +22,18 @@ #include #include #include +#include #include +#include #include #include +#include #include #include #include #include #include +#include #include using namespace mlir; @@ -41,12 +45,20 @@ class ConversionTest : public ::testing::Test { // Register all dialects needed for the full compilation pipeline DialectRegistry registry; registry.insert(); + func::FuncDialect, scf::SCFDialect, LLVM::LLVMDialect, + tensor::TensorDialect, memref::MemRefDialect>(); context = std::make_unique(); context->appendDialectRegistry(registry); context->loadAllAvailableDialects(); } + static void runCanonicalizationPass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(createCanonicalizerPass()); + if (pm.run(module).failed()) { + llvm::errs() << "Failed to run canonicalization passes.\n"; + } + } [[nodiscard]] OwningOpRef buildQCIR( const std::function& buildFunc) const { @@ -54,14 +66,17 @@ class ConversionTest : public ::testing::Test { builder.initialize(); buildFunc(builder); auto module = builder.finalize(); + runCanonicalizationPass(module.get()); return module; } + [[nodiscard]] OwningOpRef buildQCOIR( const std::function& buildFunc) const { qco::QCOProgramBuilder builder(context.get()); builder.initialize(); buildFunc(builder); auto module = builder.finalize(); + runCanonicalizationPass(module.get()); return module; } }; @@ -69,7 +84,29 @@ class ConversionTest : public ::testing::Test { static std::string getOutputString(mlir::OwningOpRef& module) { std::string outputString; llvm::raw_string_ostream os(outputString); - module->print(os); + + auto* moduleOp = module->getOperation(); + const auto* qcoDialect = + moduleOp->getContext()->getLoadedDialect(); + const auto* scfDialect = + moduleOp->getContext()->getLoadedDialect(); + + moduleOp->walk([&](Operation* op) -> WalkResult { + const auto* opDialect = op->getDialect(); + + // Ignore dealloc operations as the order does not matter + if (llvm::isa(op)) { + return WalkResult::advance(); + } + // Only consider operations from the qco dialect and the scf dialect or + // func.call or func.return op + if (opDialect == qcoDialect || opDialect == scfDialect || + llvm::isa(op) || llvm::isa(op)) { + op->print(os); + } + return WalkResult::advance(); + }); + os.flush(); return outputString; } @@ -283,3 +320,117 @@ TEST_F(ConversionTest, ScfCtrlQCOtoQCTest) { ASSERT_EQ(outputString, checkString); } + +TEST_F(ConversionTest, ScfForTensorQCOtoQCTest) { + // Test conversion from qco to qc for scf.for operation with a tensor + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto reg = b.allocQubitRegister(4); + auto tensor = b.tensorFromElements(reg); + auto scfForRes = b.scfFor( + 0, 3, 1, {tensor}, + [&](Value iv, ValueRange iterArgs) -> llvm::SmallVector { + auto extractedQubit = b.tensorExtract(iterArgs[0], iv); + auto q4 = b.h(extractedQubit); + auto newTensor = b.tensorInsert(q4, iterArgs[0], iv); + return {newTensor}; + }); + auto extractedq0 = b.tensorExtract(scfForRes[0], 0); + auto extractedq1 = b.tensorExtract(scfForRes[0], 1); + auto extractedq2 = b.tensorExtract(scfForRes[0], 2); + auto extractedq3 = b.tensorExtract(scfForRes[0], 3); + b.swap(extractedq0, extractedq1); + b.swap(extractedq2, extractedq3); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + } + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto reg = b.allocQubitRegister(4); + auto memref = b.memrefAlloc(reg); + b.scfFor(0, 3, 1, [&](Value iv) { + auto extractedQubit = b.memrefLoad(memref, iv); + b.h(extractedQubit); + }); + b.swap(reg[0], reg[1]); + b.swap(reg[2], reg[3]); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + ASSERT_EQ(outputString, checkString); +} + +TEST_F(ConversionTest, ScfForNestedTensorQCOtoQCTest) { + // Test conversion from qco to qc for scf.for operation with a nested tesnor + auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { + auto reg0 = b.allocQubitRegister(4, "q0"); + auto reg1 = b.allocQubitRegister(4, "q1"); + auto tensor0 = b.tensorFromElements(reg0); + auto scfForRes = b.scfFor( + 0, 3, 1, {tensor0, reg1[0], reg1[1], reg1[2], reg1[3]}, + [&](Value iv, ValueRange iterArgs) -> llvm::SmallVector { + auto extractedQubit = b.tensorExtract(iterArgs[0], iv); + auto outerQubit = b.x(extractedQubit); + auto tensor1 = b.tensorFromElements( + {iterArgs[1], iterArgs[2], iterArgs[3], iterArgs[4]}); + auto innerResults = b.scfFor( + 0, 3, 1, {tensor1, outerQubit}, + [&](Value innerIv, + ValueRange innerIterArgs) -> llvm::SmallVector { + auto innerQubit = b.tensorExtract(innerIterArgs[0], innerIv); + auto ctrlOp = b.cx(innerIterArgs[1], innerQubit); + auto innerTensor = + b.tensorInsert(ctrlOp.second, innerIterArgs[0], innerIv); + return {innerTensor, ctrlOp.first}; + }); + auto extractedq0 = b.tensorExtract(innerResults[0], 0); + auto extractedq1 = b.tensorExtract(innerResults[0], 1); + auto extractedq2 = b.tensorExtract(innerResults[0], 2); + auto extractedq3 = b.tensorExtract(innerResults[0], 3); + auto tensor2 = b.tensorInsert(innerResults[1], iterArgs[0], iv); + return {tensor2, extractedq0, extractedq1, extractedq2, extractedq3}; + }); + auto extractedq0 = b.tensorExtract(scfForRes[0], 0); + auto extractedq1 = b.tensorExtract(scfForRes[0], 1); + auto extractedq2 = b.tensorExtract(scfForRes[0], 2); + auto extractedq3 = b.tensorExtract(scfForRes[0], 3); + b.swap(extractedq0, extractedq1); + b.swap(extractedq2, extractedq3); + }); + + PassManager pm(context.get()); + pm.addPass(createQCOToQC()); + if (failed(pm.run(input.get()))) { + FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + } + // Run the canonicalizer again to remove the additional constants + pm.clear(); + pm.addPass(createCanonicalizerPass()); + if (failed(pm.run(input.get()))) { + FAIL() << "Error during canonicalization"; + } + + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { + auto reg0 = b.allocQubitRegister(4, "q0"); + auto reg1 = b.allocQubitRegister(4, "q1"); + auto memref0 = b.memrefAlloc(reg0); + b.scfFor(0, 3, 1, [&](Value iv) { + auto extractedQubit = b.memrefLoad(memref0, iv); + b.x(extractedQubit); + auto memref1 = b.memrefAlloc(reg1); + b.scfFor(0, 3, 1, [&](Value iv2) { + auto q1 = b.memrefLoad(memref1, iv2); + b.cx(extractedQubit, q1); + }); + }); + b.swap(reg0[0], reg0[1]); + b.swap(reg0[2], reg0[3]); + }); + + const auto outputString = getOutputString(input); + const auto checkString = getOutputString(expectedOutput); + ASSERT_EQ(outputString, checkString); +} From a82c3d4f9682db7448bb87049f75c30f5efc49e3 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 18:02:35 +0100 Subject: [PATCH 090/108] add docstrings to the builders --- .../Dialect/QC/Builder/QCProgramBuilder.h | 32 +++++++++++++ .../Dialect/QCO/Builder/QCOProgramBuilder.h | 45 +++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index edf20e3606..8060086371 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -872,8 +872,40 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// // MemRef operations //===--------------------------------------------------------------------===// + + /** + * @brief Allocates a memref register and insert the given values + * + * @param elements The stored elements + * @return The memref register + * + * @par Example: + * ```c++ + * builder.memrefAlloc(elements); + * ``` + * ```mlir + * %memref = memref.alloc() : memref<2x!qc.qubit> + * memref.store %q0, %memref[%c0] : memref<2x!qc.qubit> + * memref.store %q1, %memref[%c1] : memref<2x!qc.qubit> + * ``` + */ Value memrefAlloc(ValueRange elements); + /** + * @brief Loads a value from a memref register + * + * @param memref The memref register + * @param index The index where the value is extracted + * @return The extracted value + * + * @par Example: + * ```c++ + * builder.memrefLoad(memref, index); + * ``` + * ```mlir + * %q0 = memref.load %memref[%c0] : memref<2x!qc.qubit> + * ``` + */ Value memrefLoad(Value memref, const std::variant& index); //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 5de7fa4385..492e70d449 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1033,10 +1033,55 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { // Tensor operations //===--------------------------------------------------------------------===// + /** + * @brief Constructs a tensor.from_elements operation with the given values + * + * @param elements The elements of the tensor + * @return The resulting tensor + * + * @par Example: + * ```c++ + * builder.tensorFromElements(elements); + * ``` + * ```mlir + * %tensor = tensor.from_elements %q0, %q1, %q2 : tensor<3x!qco.qubit> + * ``` + */ Value tensorFromElements(ValueRange elements); + /** + * @brief Constructs a tensor.extract operation at the given index + * + * @param tensor The tensor where the value is extracted + * @param index The index where the value is extracted + * @return The extracted value + * + * @par Example: + * ```c++ + * q = builder.tensorExtract(tensor, iv); + * ``` + * ```mlir + * %q = tensor.extract %tensor[%iv] : tensor<3x!qco.qubit> + * ``` + */ Value tensorExtract(Value tensor, const std::variant& index); + /** + * @brief Constructs a tensor.insert operation at the given index + * + * @param element The inserted value + * @param tensor The tensor where the value is inserted + * @param index The index where the value is inserted + * @return The resulting tensor + * + * @par Example: + * ```c++ + * newTensor = builder.tensorInsert(element, tensor, iv); + * ``` + * ```mlir + * %newTensor = tensor.insert %q into %tensor[%iv] : tensor<3x!qco.qubit> + * ``` + */ Value tensorInsert(Value element, Value tensor, const std::variant& index); From ef771132f1514656468f8e50c926d601c9712105 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 18:13:17 +0100 Subject: [PATCH 091/108] smaller fixes --- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 8 ++++++++ .../Dialect/QCO/Builder/QCOProgramBuilder.cpp | 18 ++++++++++++++---- .../QCToQCO/test_conversion_qc_to_qco.cpp | 8 +++++--- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index a9881646ed..275a271770 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -460,11 +460,17 @@ QCProgramBuilder& QCProgramBuilder::dealloc(Value qubit) { //===----------------------------------------------------------------------===// // MemRef operations //===----------------------------------------------------------------------===// + Value QCProgramBuilder::memrefAlloc(ValueRange elements) { + checkFinalized(); + const auto qcType = qc::QubitType::get(ctx); const auto memType = MemRefType::get({static_cast(elements.size())}, qcType); + // Create the alloc operation auto allocOp = memref::AllocOp::create(*this, memType); + + // Iterate through all elements and create a store operation for each qubit for (auto it : llvm::enumerate(elements)) { Value idx = arith::ConstantOp::create( *this, getIndexAttr(static_cast(it.index()))); @@ -475,6 +481,8 @@ Value QCProgramBuilder::memrefAlloc(ValueRange elements) { Value QCProgramBuilder::memrefLoad(Value memref, const std::variant& index) { + checkFinalized(); + const auto indexValue = utils::variantToValue(*this, getLoc(), index); const auto loadOp = memref::LoadOp::create(*this, memref, indexValue); return loadOp->getResult(0); diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 7f7cd09351..139d0a51b8 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -646,10 +646,11 @@ QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { Value QCOProgramBuilder::tensorFromElements(ValueRange elements) { checkFinalized(); + auto const qcoType = qco::QubitType::get(ctx); const auto tensorType = RankedTensorType::get({static_cast(elements.size())}, qcoType); - // Create the FromElements operation + auto fromElements = tensor::FromElementsOp::create(*this, tensorType, elements); return fromElements.getResult(); @@ -659,15 +660,21 @@ Value QCOProgramBuilder::tensorExtract( Value tensor, const std::variant& index) { checkFinalized(); - auto const qcoType = qco::QubitType::get(ctx); + const auto qcoType = qco::QubitType::get(ctx); const auto indexValue = utils::variantToValue(*this, getLoc(), index); + // Create the extract operation auto extractOp = tensor::ExtractOp::create(*this, qcoType, tensor, indexValue); - auto* const extractParentRegion = extractOp->getParentRegion(); + + // Check if the tensor stems from a scf.for operation + // These are the extract operations directly following the scf.for operation if (auto scfFor = tensor.getDefiningOp()) { + // Find the initial qubit before it was inserted to the tensor for (auto arg : scfFor.getInitArgs()) { if (llvm::isa(arg.getType())) { auto fromTensorOp = arg.getDefiningOp(); + + // Get the index as integer int64_t val = 0; if (std::holds_alternative(index)) { val = std::get(index); @@ -676,13 +683,15 @@ Value QCOProgramBuilder::tensorExtract( std::get(index).getDefiningOp(); val = dyn_cast(constantOp.getValue()).getInt(); } + // Update the tracking of the qubit updateQubitTracking(fromTensorOp.getElements()[val], extractOp.getResult(), extractOp->getParentRegion()); } } } - if (!llvm::isa(extractParentRegion->getParentOp())) { + // Add the extracted Qubit to the qubit tracking if it is inside a loop + if (!llvm::isa(extractOp->getParentRegion()->getParentOp())) { validQubits[extractOp->getParentRegion()].insert(extractOp); } @@ -692,6 +701,7 @@ Value QCOProgramBuilder::tensorExtract( Value QCOProgramBuilder::tensorInsert( Value element, Value tensor, const std::variant& index) { checkFinalized(); + const auto indexValue = utils::variantToValue(*this, getLoc(), index); auto insertOp = tensor::InsertOp::create(*this, element, tensor, indexValue); return insertOp.getResult(); diff --git a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index a672512840..e64caf1a42 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -71,6 +71,7 @@ class ConversionTest : public ::testing::Test { runCanonicalizationPass(module.get()); return module; } + [[nodiscard]] OwningOpRef buildQCOIR( const std::function& buildFunc) const { qco::QCOProgramBuilder builder(context.get()); @@ -360,7 +361,7 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfCtrlQCtoQCOTest2) { +TEST_F(ConversionTest, ScfForRegisterQCtoQCOTest) { // Test conversion from qc to qco for scf.for operation with a memref register auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); @@ -403,8 +404,9 @@ TEST_F(ConversionTest, ScfCtrlQCtoQCOTest2) { ASSERT_EQ(outputString, checkString); } -TEST_F(ConversionTest, ScfCtrlQCtoQCOTest3) { - // Test conversion from qc to qco for scf.for operation with a memref register +TEST_F(ConversionTest, ScfForNestedRegisterQCtoQCOTest) { + // Test conversion from qc to qco for scf.for operation with a nested memref + // register auto input = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto reg0 = b.allocQubitRegister(4, "q0"); auto reg1 = b.allocQubitRegister(4, "q1"); From f24a768d92cd76cc60ee9e656171726c9a8ba71d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Fri, 23 Jan 2026 18:25:22 +0100 Subject: [PATCH 092/108] only use arguments in the regionMap of func --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 420830bbd0..4ee4d0adc5 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -226,7 +226,14 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (!func.getArgumentTypes().empty() && isQubitType(func.getArgumentTypes().front())) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); - state->regionMap[func] = uniqueQubits; + // Only add the arguments as qubits for the regionMap of func + llvm::SetVector argQubits; + for (auto arg : func.getArguments()) { + if (isQubitType(arg.getType())) { + argQubits.insert(arg); + } + } + state->regionMap[func] = argQubits; } } } @@ -242,8 +249,6 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (!uniqueQubits.empty() && (llvm::isa(op) || (llvm::isa(op)) || llvm::isa(op))) { - if (llvm::isa(op)) { - } state->regionMap[op] = uniqueQubits; op->setAttr("needChange", StringAttr::get(ctx, "yes")); } From 46851b9ab3a440dc6a128b6e5690ee088a848342 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 14:24:46 +0100 Subject: [PATCH 093/108] remove print statement --- mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 139d0a51b8..88baca5ce3 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -139,7 +139,6 @@ void QCOProgramBuilder::validateQubitValue(Value qubit, Region* region) const { auto qubits = validQubits.lookup(region); if (qubits.empty() || !qubits.contains(qubit)) { - qubit.print(llvm::outs()); llvm::errs() << "Attempting to use an invalid qubit SSA value. " << "The value may have been consumed by a previous operation " << "or was never created through this builder.\n"; From d6c6b9495e9dff2d7f1fd101776f5424e69de9cd Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 14:45:50 +0100 Subject: [PATCH 094/108] apply coderabbit suggestion and fix smaller typos --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 11 +++++++---- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 4 +++- .../Conversion/QCOToQC/test_conversion_qco_to_qc.cpp | 8 +++++--- .../Conversion/QCToQCO/test_conversion_qc_to_qco.cpp | 6 ++++-- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 89b24e1a65..300792de93 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -14,8 +14,11 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include +#include #include #include +#include +#include #include #include #include @@ -880,7 +883,7 @@ struct ConvertQCOTensorFromElementsOp final // store each qubit into the memref for (auto it : llvm::enumerate(adaptor.getElements())) { - Value idx = rewriter.create(loc, it.index()); + const auto idx = rewriter.create(loc, it.index()); rewriter.create(loc, it.value(), memrefAllocOp, idx); } @@ -911,7 +914,7 @@ struct ConvertQCOTensorExtractOp final matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // Remove the extract operations following a scf.for operation - if (!llvm::isa(op.getOperand(0).getType())) { + if (!llvm::isa(adaptor.getOperands().front().getType())) { // Find the memref register const auto memref = adaptor.getTensor(); const auto memrefUsers = llvm::to_vector(memref.getUsers()); @@ -921,8 +924,8 @@ struct ConvertQCOTensorExtractOp final adaptor.getIndices().front().getDefiningOp(); const auto indexToStore = dyn_cast(constantOp.getValue()).getInt(); - // Find the appropriate store operation depending on the index to get the - // qubit + // Find the appropriate store operation depending on the index to get + // the qubit for (auto* user : llvm::reverse(memrefUsers)) { if (llvm::isa(user)) { index++; diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 4ee4d0adc5..652fb3b654 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1351,6 +1351,7 @@ struct ConvertQCMemRefAllocOp final for (auto* user : llvm::reverse(users)) { if (llvm::isa(user)) { auto storeOp = dyn_cast(user); + assert(qubitMap.contains(storeOp.getValue()) && "QC qubit not found"); qcoQubits.push_back(qubitMap[storeOp.getValue()]); } } @@ -1412,6 +1413,7 @@ struct ConvertQCMemRefLoadOp final matchAndRewrite(memref::LoadOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; + assert(qubitMap.contains(op.getMemRef()) && "QC memref not found"); const auto tensor = qubitMap[op.getMemRef()]; auto const qcoType = qco::QubitType::get(rewriter.getContext()); // Create the extract operation @@ -1672,7 +1674,7 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { const auto qcQubitUsers = llvm::to_vector(qcQubit.getUsers()); for (const auto* user : llvm::reverse(qcQubitUsers)) { if (auto storeOp = dyn_cast(user)) { - // gGet the qubit + // Get the qubit const auto qubit = storeOp.getValueToStore(); // Create the extract operation for each qubit from the resulting diff --git a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp index d521ea4a5d..f4fc1b92c5 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp @@ -345,7 +345,8 @@ TEST_F(ConversionTest, ScfForTensorQCOtoQCTest) { PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + FAIL() + << "Conversion error during QCO-QC conversion for scf.for with tensor"; } auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); @@ -364,7 +365,7 @@ TEST_F(ConversionTest, ScfForTensorQCOtoQCTest) { } TEST_F(ConversionTest, ScfForNestedTensorQCOtoQCTest) { - // Test conversion from qco to qc for scf.for operation with a nested tesnor + // Test conversion from qco to qc for scf.for operation with a nested tensor auto input = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { auto reg0 = b.allocQubitRegister(4, "q0"); auto reg1 = b.allocQubitRegister(4, "q1"); @@ -404,7 +405,8 @@ TEST_F(ConversionTest, ScfForNestedTensorQCOtoQCTest) { PassManager pm(context.get()); pm.addPass(createQCOToQC()); if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + FAIL() << "Conversion error during QCO-QC Conversion for scf.for with " + "nested tensor"; } // Run the canonicalizer again to remove the additional constants pm.clear(); diff --git a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index e64caf1a42..40d04d5d52 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -377,7 +377,8 @@ TEST_F(ConversionTest, ScfForRegisterQCtoQCOTest) { PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + FAIL() << "Conversion error during QC-QCO conversion for scf.for with " + "memref register"; } auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { @@ -426,7 +427,8 @@ TEST_F(ConversionTest, ScfForNestedRegisterQCtoQCOTest) { PassManager pm(context.get()); pm.addPass(createQCToQCO()); if (failed(pm.run(input.get()))) { - FAIL() << "Conversion error during QC-QCO conversion for scf nested"; + FAIL() << "Conversion error during QC-QCO conversion for scf.for with " + "nested memref register"; } auto expectedOutput = buildQCOIR([](mlir::qco::QCOProgramBuilder& b) { From d093ee7aabda3292c45b6c880abea587f27789ad Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 14:49:28 +0100 Subject: [PATCH 095/108] fix more linter issues --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 1 + mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp | 2 +- mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 8 ++++++++ 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 652fb3b654..4922f6b09a 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 275a271770..7a02994073 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -472,7 +472,7 @@ Value QCProgramBuilder::memrefAlloc(ValueRange elements) { // Iterate through all elements and create a store operation for each qubit for (auto it : llvm::enumerate(elements)) { - Value idx = arith::ConstantOp::create( + const Value idx = arith::ConstantOp::create( *this, getIndexAttr(static_cast(it.index()))); memref::StoreOp::create(*this, it.value(), allocOp, idx); } diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 88baca5ce3..3ea54653e3 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #include #include #include @@ -673,6 +674,10 @@ Value QCOProgramBuilder::tensorExtract( if (llvm::isa(arg.getType())) { auto fromTensorOp = arg.getDefiningOp(); + if (!fromTensorOp) { + continue; + } + // Get the index as integer int64_t val = 0; if (std::holds_alternative(index)) { @@ -680,6 +685,9 @@ Value QCOProgramBuilder::tensorExtract( } else { auto constantOp = std::get(index).getDefiningOp(); + if (!constantOp) { + continue; + } val = dyn_cast(constantOp.getValue()).getInt(); } // Update the tracking of the qubit From aa6a8909cc6e795b12eb76de316cce0f2e71f570 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 15:43:44 +0100 Subject: [PATCH 096/108] refactor type checks --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 87 +++++++++++++++---------- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 44 +++++++------ 2 files changed, 76 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 300792de93..a4aa6138e1 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -288,6 +288,36 @@ class QCOToQCTypeConverter final : public TypeConverter { } }; +/** + * @brief Helper function to check whether the type is a qco qubit type or a + * container that holds qco qubit types + * + * @param type The type that is checked + * @return Whether it is a qco type or not + */ +static bool isQCOQubitType(Type type) { + if (llvm::isa(type)) { + return true; + } + auto tensor = dyn_cast(type); + return tensor && llvm::isa(tensor.getElementType()); +} + +/** + * @brief Helper function to check whether the type is a qc qubit type or a + * container that holds qc qubit types + * + * @param type The type that is checked + * @return Whether it is a qc type or not + */ +static bool isQCQubitType(Type type) { + if (llvm::isa(type)) { + return true; + } + auto memref = dyn_cast(type); + return memref && llvm::isa(memref.getElementType()); +} + /** * @brief Converts qco.alloc to qc.alloc * @@ -883,7 +913,8 @@ struct ConvertQCOTensorFromElementsOp final // store each qubit into the memref for (auto it : llvm::enumerate(adaptor.getElements())) { - const auto idx = rewriter.create(loc, it.index()); + const Value idx = + rewriter.create(loc, it.index()); rewriter.create(loc, it.value(), memrefAllocOp, idx); } @@ -1320,66 +1351,52 @@ struct QCOToQC final : impl::QCOToQCBase { target.addDynamicallyLegalOp( [&](tensor::FromElementsOp op) { - return !llvm::any_of(op.getOperandTypes(), - [&](Type type) { - return type == qco::QubitType::get(context); - }) && - !(op.getType().getElementType() == - qco::QubitType::get(context)); + return !llvm::any_of(op.getOperandTypes(), [&](Type type) { + return isQCOQubitType(type); + }) && !(isQCOQubitType(op.getType())); }); target.addDynamicallyLegalOp([&](tensor::ExtractOp op) { - return !llvm::any_of(op->getResultTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](tensor::InsertOp op) { - return !llvm::any_of(op.getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op.getOperandTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](scf::IfOp op) { - return !llvm::any_of(op->getResultTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](scf::YieldOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context) || - type == qc::QubitType::get(context); + return isQCOQubitType(type) || isQCQubitType(type); }); }); target.addDynamicallyLegalOp([&](scf::WhileOp op) { - return !llvm::any_of(op->getResultTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](scf::ConditionOp op) { return !llvm::any_of(op.getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context) || - type == qc::QubitType::get(context); + return isQCOQubitType(type) || isQCQubitType(type); }); }); target.addDynamicallyLegalOp([&](scf::ForOp op) { - return !llvm::any_of(op->getResultTypes(), [&](Type type) { - return type == qco::QubitType::get(context) || - llvm::isa(type); - }); + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::CallOp op) { - return !llvm::any_of(op->getResultTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op->getResultTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::FuncOp op) { - return !llvm::any_of(op.getArgumentTypes(), [&](Type type) { - return type == qco::QubitType::get(context); - }); + return !llvm::any_of(op.getArgumentTypes(), + [&](Type type) { return isQCOQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !llvm::any_of(op->getOperandTypes(), [&](Type type) { - return type == qco::QubitType::get(context) || - type == qc::QubitType::get(context); + return isQCOQubitType(type) || isQCQubitType(type); }); }); diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 4922f6b09a..cb51e608d6 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -121,15 +121,19 @@ class StatefulOpConversionPattern : public OpConversionPattern { } // namespace -static bool isQubitType(Type type) { - if (!llvm::isa(type)) { - auto memrefType = dyn_cast(type); - if (memrefType) { - return llvm::isa(memrefType.getElementType()); - } - return false; +/** + * @brief Helper function to check whether the type is a qc qubit type or a + * container that holds qc qubit types + * + * @param type The type that is checked + * @return Whether it is a qc type or not + */ +static bool isQCQubitType(Type type) { + if (llvm::isa(type)) { + return true; } - return true; + auto memref = dyn_cast(type); + return memref && llvm::isa(memref.getElementType()); } /** @@ -156,7 +160,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { // Collect qubits from the blockarguments for (auto arg : region.front().getArguments()) { - if (isQubitType(arg.getType())) { + if (isQCQubitType(arg.getType())) { uniqueQubits.insert(arg); } } @@ -203,13 +207,13 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { llvm::isa(op)) { continue; } - if (isQubitType(operand.getType())) { + if (isQCQubitType(operand.getType())) { uniqueQubits.insert(operand); } } // Collect qubits from the results for (const auto& result : operation.getResults()) { - if (isQubitType(result.getType())) { + if (isQCQubitType(result.getType())) { uniqueQubits.insert(result); } } @@ -225,12 +229,12 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { if (llvm::isa(operation)) { if (auto func = operation.getParentOfType()) { if (!func.getArgumentTypes().empty() && - isQubitType(func.getArgumentTypes().front())) { + isQCQubitType(func.getArgumentTypes().front())) { operation.setAttr("needChange", StringAttr::get(ctx, "yes")); // Only add the arguments as qubits for the regionMap of func llvm::SetVector argQubits; for (auto arg : func.getArguments()) { - if (isQubitType(arg.getType())) { + if (isQCQubitType(arg.getType())) { argQubits.insert(arg); } } @@ -242,7 +246,7 @@ collectUniqueQubits(Operation* op, LoweringState* state, MLIRContext* ctx) { } // Add the operands from the operation itself for (const auto& operand : op->getOperands()) { - if (isQubitType(operand.getType())) { + if (isQCQubitType(operand.getType())) { uniqueQubits.insert(operand); } } @@ -1721,7 +1725,7 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { matchAndRewrite(scf::YieldOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { assert(llvm::all_of(op.getOperandTypes(), - [&](Type type) { return isQubitType(type); }) && + [&](Type type) { return isQCQubitType(type); }) && "Not all operands are qc qubits"); const auto& parentRegion = op->getParentRegion(); @@ -2007,26 +2011,26 @@ struct QCToQCO final : impl::QCToQCOBase { }); target.addDynamicallyLegalOp([&](func::FuncOp op) { return !llvm::any_of(op.front().getArgumentTypes(), - [&](Type type) { return isQubitType(type); }); + [&](Type type) { return isQCQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::CallOp op) { return !llvm::any_of(op->getOperandTypes(), - [&](Type type) { return isQubitType(type); }); + [&](Type type) { return isQCQubitType(type); }); }); target.addDynamicallyLegalOp([&](func::ReturnOp op) { return !op->getAttrOfType("needChange"); }); target.addDynamicallyLegalOp([&](memref::AllocOp op) { return !llvm::any_of(op->getResultTypes(), - [&](Type type) { return isQubitType(type); }); + [&](Type type) { return isQCQubitType(type); }); }); target.addDynamicallyLegalOp([&](memref::StoreOp op) { return !llvm::any_of(op.getOperandTypes(), - [&](Type type) { return isQubitType(type); }); + [&](Type type) { return isQCQubitType(type); }); }); target.addDynamicallyLegalOp([&](memref::LoadOp op) { return !llvm::any_of(op->getResultTypes(), - [&](Type type) { return isQubitType(type); }); + [&](Type type) { return isQCQubitType(type); }); }); // Register operation conversion patterns with state From a6c34245af04bd2cd6f0593d0f86fa83eede7023 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 15:47:25 +0100 Subject: [PATCH 097/108] trying to fix extractOp --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index a4aa6138e1..c98f0f3e8e 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -944,12 +944,14 @@ struct ConvertQCOTensorExtractOp final LogicalResult matchAndRewrite(tensor::ExtractOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + const auto memref = adaptor.getTensor(); // Remove the extract operations following a scf.for operation - if (!llvm::isa(adaptor.getOperands().front().getType())) { - // Find the memref register - const auto memref = adaptor.getTensor(); + // if the region of the converted memref is the same as the extract + // operation + if (memref.getDefiningOp()->getParentRegion() == op->getParentRegion()) { + // get the users of the memref register const auto memrefUsers = llvm::to_vector(memref.getUsers()); - // Get the index where the value was extracted + // Get the index where the value was extracted int64_t index = -1; auto constantOp = adaptor.getIndices().front().getDefiningOp(); From 415b3999d1da727226d0b3390ccbeb67e2ebb8e1 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 16:04:01 +0100 Subject: [PATCH 098/108] fix typeConverter --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index c98f0f3e8e..796ccd9419 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -279,7 +279,7 @@ class QCOToQCTypeConverter final : public TypeConverter { return qc::QubitType::get(ctx); }); - addConversion([&](RankedTensorType t) -> Type { + addConversion([ctx](RankedTensorType t) -> Type { if (t.getElementType() == qco::QubitType::get(ctx)) { return MemRefType::get(t.getShape(), qc::QubitType::get(ctx)); } From 41e3c58ef6dd58bb7d9d7746371c65e3116a4641 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 16:24:10 +0100 Subject: [PATCH 099/108] fix conversion test for scf operations --- .../QCOToQC/test_conversion_qco_to_qc.cpp | 26 ++++++++++++++----- .../QCToQCO/test_conversion_qc_to_qco.cpp | 11 +++++--- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp index f4fc1b92c5..75b1c24fc7 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -33,6 +34,7 @@ #include #include #include +#include #include #include @@ -86,22 +88,25 @@ static std::string getOutputString(mlir::OwningOpRef& module) { llvm::raw_string_ostream os(outputString); auto* moduleOp = module->getOperation(); - const auto* qcoDialect = - moduleOp->getContext()->getLoadedDialect(); + const auto* qcDialect = + moduleOp->getContext()->getLoadedDialect(); const auto* scfDialect = moduleOp->getContext()->getLoadedDialect(); + const auto* memrefDialect = + moduleOp->getContext()->getLoadedDialect(); moduleOp->walk([&](Operation* op) -> WalkResult { const auto* opDialect = op->getDialect(); // Ignore dealloc operations as the order does not matter - if (llvm::isa(op)) { + if (llvm::isa(op)) { return WalkResult::advance(); } - // Only consider operations from the qco dialect and the scf dialect or - // func.call or func.return op - if (opDialect == qcoDialect || opDialect == scfDialect || - llvm::isa(op) || llvm::isa(op)) { + // Only consider operations from the qc dialect, scf dialect or memref + // dialect or func.call or func.return op + if (opDialect == qcDialect || opDialect == scfDialect || + opDialect == memrefDialect || llvm::isa(op) || + llvm::isa(op)) { op->print(os); } return WalkResult::advance(); @@ -348,6 +353,13 @@ TEST_F(ConversionTest, ScfForTensorQCOtoQCTest) { FAIL() << "Conversion error during QCO-QC conversion for scf.for with tensor"; } + // Run the canonicalizer again to remove the additional constants + pm.clear(); + pm.addPass(createCanonicalizerPass()); + if (failed(pm.run(input.get()))) { + FAIL() << "Error during canonicalization"; + } + auto expectedOutput = buildQCIR([](mlir::qc::QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); auto memref = b.memrefAlloc(reg); diff --git a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index 40d04d5d52..2f6bf5fbb1 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -92,7 +93,8 @@ class ConversionTest : public ::testing::Test { moduleOp->getContext()->getLoadedDialect(); const auto* scfDialect = moduleOp->getContext()->getLoadedDialect(); - + const auto* tensorDialect = + moduleOp->getContext()->getLoadedDialect(); moduleOp->walk([&](Operation* op) -> WalkResult { const auto* opDialect = op->getDialect(); @@ -100,10 +102,11 @@ class ConversionTest : public ::testing::Test { if (llvm::isa(op)) { return WalkResult::advance(); } - // Only consider operations from the qco dialect and the scf dialect or - // func.call or func.return op + // Only consider operations from the qco dialect, scf dialect, + // tensor dialect or func.call op or func.return op if (opDialect == qcoDialect || opDialect == scfDialect || - llvm::isa(op) || llvm::isa(op)) { + opDialect == tensorDialect || llvm::isa(op) || + llvm::isa(op)) { op->print(os); } return WalkResult::advance(); From 2f3731accaab7282374ffda02a999a299e30c467 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 16:34:47 +0100 Subject: [PATCH 100/108] fix typo --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index cb51e608d6..a296ab9ce3 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1338,7 +1338,7 @@ struct ConvertQCYieldOp final : StatefulOpConversionPattern { * ``` * is converted to * ```mlir - * %tensor = tensor.from_elements %q0, %q1, %q2 : tensore<3x!qco.qubit> + * %tensor = tensor.from_elements %q0, %q1, %q2 : tensor<3x!qco.qubit> * ``` */ struct ConvertQCMemRefAllocOp final From b9d3e035ff23f68705eafcd5bf36b1ae92ab7a67 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 16:35:08 +0100 Subject: [PATCH 101/108] apply coderabbit suggestion for dealloc --- mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp index 3ea54653e3..5eda5cec1b 100644 --- a/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp +++ b/mlir/lib/Dialect/QCO/Builder/QCOProgramBuilder.cpp @@ -632,8 +632,9 @@ std::pair QCOProgramBuilder::ctrl( QCOProgramBuilder& QCOProgramBuilder::dealloc(Value qubit) { checkFinalized(); - validateQubitValue(qubit, qubit.getParentRegion()); - validQubits[qubit.getParentRegion()].erase(qubit); + auto* region = getInsertionBlock()->getParent(); + validateQubitValue(qubit, region); + validQubits[region].erase(qubit); DeallocOp::create(*this, qubit); From 7c4fe402ca823ce32372acd2bbc2c1c70add9019 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 16:38:45 +0100 Subject: [PATCH 102/108] add failure when canonicalization fails --- mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp | 2 +- mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp index 75b1c24fc7..eae4dbbcd1 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_conversion_qco_to_qc.cpp @@ -58,7 +58,7 @@ class ConversionTest : public ::testing::Test { PassManager pm(module.getContext()); pm.addPass(createCanonicalizerPass()); if (pm.run(module).failed()) { - llvm::errs() << "Failed to run canonicalization passes.\n"; + FAIL() << "Error during canonicalization"; } } diff --git a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp index 2f6bf5fbb1..d4238336a4 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_conversion_qc_to_qco.cpp @@ -59,7 +59,7 @@ class ConversionTest : public ::testing::Test { PassManager pm(module.getContext()); pm.addPass(createCanonicalizerPass()); if (pm.run(module).failed()) { - llvm::errs() << "Failed to run canonicalization passes.\n"; + FAIL() << "Error during canonicalization"; } } From e712c98b696ec09a9bfaba03a14d82a7fe736de4 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 18:15:43 +0100 Subject: [PATCH 103/108] refactor memref qubit tracking --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 80 +++++++++++++++---------- 1 file changed, 48 insertions(+), 32 deletions(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index a296ab9ce3..16140c90f7 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -82,6 +82,9 @@ struct LoweringState { llvm::DenseMap> qubitMap; /// Map each operation to its Set of QC qubit references llvm::DenseMap> regionMap; + /// Map each memref operation to its stored qubits and their index + llvm::DenseMap>> + memrefMap; /// Modifier information int64_t inCtrlOp = 0; @@ -1349,20 +1352,37 @@ struct ConvertQCMemRefAllocOp final matchAndRewrite(memref::AllocOp op, OpAdaptor /*adaptor*/, ConversionPatternRewriter& rewriter) const override { auto& qubitMap = getState().qubitMap[op->getParentRegion()]; - - // Get the qco qubits from the users + auto& memrefMap = getState().memrefMap[op]; SmallVector qcoQubits; - const auto users = llvm::to_vector(op->getUsers()); - for (auto* user : llvm::reverse(users)) { - if (llvm::isa(user)) { - auto storeOp = dyn_cast(user); + SmallVector> indexedQubits; + // Get all qubits that are stored in the memref register and find the qco + // qubits + for (const auto* user : op->getUsers()) { + if (auto storeOp = dyn_cast(user)) { + auto storeIndex = storeOp.getIndices() + .front() + .getDefiningOp(); + assert(storeIndex && "Expected constant index for memref index"); assert(qubitMap.contains(storeOp.getValue()) && "QC qubit not found"); - qcoQubits.push_back(qubitMap[storeOp.getValue()]); + + indexedQubits.emplace_back(storeOp.getValue(), storeIndex); } } + // Sort the list of users + llvm::sort(indexedQubits, [](auto& a, auto& b) { + return a.second.value() < b.second.value(); + }); + + // Get the qco qubits and add values to the memref map + qcoQubits.reserve(indexedQubits.size()); + for (auto& [qubit, index] : indexedQubits) { + memrefMap.emplace_back(qubit, index.getResult()); + qcoQubits.push_back(qubitMap[qubit]); + } + auto const qcoType = qco::QubitType::get(rewriter.getContext()); - const auto tensorType = RankedTensorType::get( - {static_cast(qcoQubits.size())}, qcoType); + const auto tensorType = + RankedTensorType::get(op.getType().getShape(), qcoType); // Create the FromElements operation auto fromElements = tensor::FromElementsOp::create(rewriter, op->getLoc(), tensorType, qcoQubits); @@ -1672,24 +1692,20 @@ struct ConvertQCScfForOp final : StatefulOpConversionPattern { regionQubitMap[qcQubit] = iterArg; qubitMap[qcQubit] = qcoQubit; - // If the value of the qc qubit is a memref register, extract each value - // from the new tensor and update the qubitmap for each value + // If the value of the qc qubit is a memref register, create an extract + // operation for each qubit afterwards and update the qubitmap if (llvm::isa(qcQubit.getType())) { - // Get all the qubits that were stored in the memref register - const auto qcQubitUsers = llvm::to_vector(qcQubit.getUsers()); - for (const auto* user : llvm::reverse(qcQubitUsers)) { - if (auto storeOp = dyn_cast(user)) { - // Get the qubit - const auto qubit = storeOp.getValueToStore(); - - // Create the extract operation for each qubit from the resulting - // tensor of the scf.for operation - auto extractOp = - tensor::ExtractOp::create(rewriter, op->getLoc(), qcoType, - qcoQubit, {storeOp.getIndices()}); - // Update the qubit map for each of them - qubitMap[qubit] = extractOp.getResult(); - } + // Get the memref map + auto& memrefMap = getState().memrefMap[qcQubit.getDefiningOp()]; + + // Iterate over all entries + for (const auto& [memrefQubit, index] : memrefMap) { + // Create the extract operation for each qubit from the resulting + // tensor of the scf.for operation + auto extractOp = tensor::ExtractOp::create(rewriter, op->getLoc(), + qcoType, qcoQubit, index); + // Update the qubit map for each of them + qubitMap[memrefQubit] = extractOp.getResult(); } } } @@ -1735,21 +1751,21 @@ struct ConvertQCScfYieldOp final : StatefulOpConversionPattern { SmallVector qcoQubits; qcoQubits.reserve(orderedQubits.size()); - // get the latest qco qubit or the latest qco tensor from the qubitMap + // Get the latest qco qubit or the latest qco tensor from the qubitMap for (const auto& qcQubit : orderedQubits) { assert(qubitMap.contains(qcQubit) && "QC qubit not found"); - // add an insert operation for every qubit that was extract from a + // Add an insert operation for every qubit that was extract from a // register - if (dyn_cast(qcQubit.getType())) { - // find all extracted values of the register + if (llvm::isa(qcQubit.getType())) { + // Find all extracted values of the register for (const auto* user : qcQubit.getUsers()) { if (auto loadOp = dyn_cast(user)) { - // get the latest qco qubit and add it back to the tensor + // Get the latest qco qubit and add it back to the tensor auto qubit = loadOp.getResult(); assert(qubitMap.contains(qubit) && "QC qubit not found"); - auto latestQcoQubit = qubitMap.lookup(qubit); + auto latestQcoQubit = qubitMap[qubit]; auto insertOp = tensor::InsertOp::create( rewriter, op.getLoc(), latestQcoQubit, qubitMap[qcQubit], loadOp.getIndices()); From 6690227d8a6e7415c42397d7a7deeb7ce2dadd0d Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 18:16:12 +0100 Subject: [PATCH 104/108] improve extract op conversion --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 36 ++++++++++++------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 796ccd9419..1d7b0115f9 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -14,10 +14,8 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include -#include #include #include -#include #include #include #include @@ -888,7 +886,7 @@ struct ConvertQCOYieldOp final : OpConversionPattern { * * @par Example: * ```mlir - * %tensor = tensor.from_elements %q0, %q1, %q2 : tensore<3x!qco.qubit> + * %tensor = tensor.from_elements %q0, %q1, %q2 : tensor<3x!qco.qubit> * ``` * is converted to * ```mlir @@ -934,7 +932,7 @@ struct ConvertQCOTensorFromElementsOp final * ``` * is converted to * ```mlir - * %q0 = memref.load %memref[%c0] : memref<3x!qco.qubit> + * %q0 = memref.load %memref[%c0] : memref<3x!qc.qubit> * ``` */ struct ConvertQCOTensorExtractOp final @@ -949,22 +947,24 @@ struct ConvertQCOTensorExtractOp final // if the region of the converted memref is the same as the extract // operation if (memref.getDefiningOp()->getParentRegion() == op->getParentRegion()) { - // get the users of the memref register - const auto memrefUsers = llvm::to_vector(memref.getUsers()); // Get the index where the value was extracted - int64_t index = -1; - auto constantOp = - adaptor.getIndices().front().getDefiningOp(); - const auto indexToStore = - dyn_cast(constantOp.getValue()).getInt(); - // Find the appropriate store operation depending on the index to get - // the qubit - for (auto* user : llvm::reverse(memrefUsers)) { - if (llvm::isa(user)) { - index++; - if (index == indexToStore) { - auto storeOp = dyn_cast(user); + auto extractIndex = + adaptor.getIndices().front().getDefiningOp(); + assert(extractIndex && "Expected constant index for tensor index"); + const auto indexToStore = extractIndex.value(); + + // Find the store operation with the same index + for (const auto* user : memref.getUsers()) { + if (auto storeOp = dyn_cast(user)) { + auto memrefIndex = storeOp.getIndices() + .front() + .getDefiningOp(); + assert(memrefIndex && "Expected constant index for memref index"); + + // Replace the extract op with the qubit value + if (indexToStore == memrefIndex.value()) { rewriter.replaceOp(op, storeOp.getValueToStore()); + break; } } } From 35a41ec6cc8be997f6305753c743eea4ea681336 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 18:24:52 +0100 Subject: [PATCH 105/108] update comment --- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 16140c90f7..26d1c1b925 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1368,7 +1368,7 @@ struct ConvertQCMemRefAllocOp final indexedQubits.emplace_back(storeOp.getValue(), storeIndex); } } - // Sort the list of users + // Sort the list of qubits depending on their index llvm::sort(indexedQubits, [](auto& a, auto& b) { return a.second.value() < b.second.value(); }); From a62b91c65dc5cf0021f71aa26ff02058563a7464 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Sat, 24 Jan 2026 18:44:22 +0100 Subject: [PATCH 106/108] add missing header --- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 1d7b0115f9..bf4a7c622a 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" +#include #include #include #include From 1ad3e77100516f3941ac159d3f3b3701b5743817 Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Wed, 28 Jan 2026 18:51:55 +0100 Subject: [PATCH 107/108] apply feedback from codereview --- .../Dialect/QC/Builder/QCProgramBuilder.h | 16 +++++++-------- .../Dialect/QCO/Builder/QCOProgramBuilder.h | 20 +++++++++---------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index 8060086371..871dd59703 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -874,7 +874,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Allocates a memref register and insert the given values + * @brief Allocate a memref register and insert the given values * * @param elements The stored elements * @return The memref register @@ -892,7 +892,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { Value memrefAlloc(ValueRange elements); /** - * @brief Loads a value from a memref register + * @brief Load a value from a memref register * * @param memref The memref register * @param index The index where the value is extracted @@ -913,7 +913,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Constructs a scf.for operation without iter args + * @brief Construct a scf.for operation without iter args * * @param lowerbound Lowerbound of the loop * @param upperbound Upperbound of the loop @@ -937,7 +937,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { const std::function& body); /** - * @brief Constructs a scf.while operation without return values + * @brief Construct a scf.while operation without return values * * @param beforeBody Function that builds the before body of the while * operation @@ -969,7 +969,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { const std::function& afterBody); /** - * @brief Constructs a scf.if operation without return values + * @brief Construct a scf.if operation without return values * * @param condition Condition for the if operation * @param thenBody Function that builds the then body of the if @@ -999,7 +999,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { std::optional> elseBody = std::nullopt); /** - * @brief Constructs a scf.condition operation without any additional Values + * @brief Construct a scf.condition operation without yielded values * * @param condition Condition for condition operation * @return Reference to this builder for method chaining @@ -1019,7 +1019,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Constructs a func.call operation without return values + * @brief Construct a func.call operation without return values * * @param name Name of the function that is called * @param operands ValueRange of the used operands @@ -1035,7 +1035,7 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { QCProgramBuilder& funcCall(StringRef name, ValueRange operands); /** - * @brief Constructs a func.func operation without return values + * @brief Construct a func.func operation without return values * * @param name Name of the function that is called * @param argTypes TypeRange of the arguments diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 492e70d449..049d7b908f 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1034,7 +1034,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Constructs a tensor.from_elements operation with the given values + * @brief Construct a tensor.from_elements operation with the given values * * @param elements The elements of the tensor * @return The resulting tensor @@ -1050,7 +1050,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { Value tensorFromElements(ValueRange elements); /** - * @brief Constructs a tensor.extract operation at the given index + * @brief Construct a tensor.extract operation at the given index * * @param tensor The tensor where the value is extracted * @param index The index where the value is extracted @@ -1067,7 +1067,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { Value tensorExtract(Value tensor, const std::variant& index); /** - * @brief Constructs a tensor.insert operation at the given index + * @brief Construct a tensor.insert operation at the given index * * @param element The inserted value * @param tensor The tensor where the value is inserted @@ -1090,12 +1090,12 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Constructs a scf.for operation with iterArgs + * @brief Construct a scf.for operation with iter args * * @param lowerbound Lowerbound of the loop * @param upperbound Upperbound of the loop * @param step Stepsize of the loop - * @param initArgs Initial arguments for the iterArgs + * @param initArgs Initial arguments for the iter args * @param body Function that builds the body of the for operation * @return ValueRange of the results * @@ -1120,7 +1120,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { const std::variant& step, ValueRange initArgs, llvm::function_ref(Value, ValueRange)> body); /** - * @brief Constructs a scf.while operation with return values + * @brief Construct a scf.while operation with return values * * @param args Arguments for the while loop * @param beforeBody Function that builds the before body of the while @@ -1158,7 +1158,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { llvm::function_ref(ValueRange)> afterBody); /** - * @brief Constructs a scf.if operation with return values + * @brief Construct a scf.if operation with return values * * @param condition Condition for the if operation * @param qubits Qubits used in the if/else body @@ -1193,7 +1193,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { llvm::function_ref()> elseBody); /** - * @brief Constructs a scf.condition operation with yielded values + * @brief Construct a scf.condition operation with yielded values * * @param condition Condition for condition operation * @param yieldedValues ValueRange of the yieldedValues @@ -1214,7 +1214,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { //===--------------------------------------------------------------------===// /** - * @brief Constructs a func.call operation with return values + * @brief Construct a func.call operation with return values * * @param name Name of the function that is called * @param operands ValueRange of the used operands @@ -1231,7 +1231,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { ValueRange funcCall(StringRef name, ValueRange operands); /** - * @brief Constructs a func.func operation with return values + * @brief Construct a func.func operation with return values * * @param name Name of the function that is called * @param argTypes TypeRange of the arguments From 254edb44430d8e2315a6fc33a07aeea3b8d7e7dc Mon Sep 17 00:00:00 2001 From: Li-ming Bao Date: Thu, 29 Jan 2026 11:15:27 +0100 Subject: [PATCH 108/108] add mssing parameter in docstring --- mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h index 049d7b908f..765ab2bd4d 100644 --- a/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QCO/Builder/QCOProgramBuilder.h @@ -1287,6 +1287,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { /** * @brief Validate that a qubit value is valid and unconsumed * @param qubit Qubit value to validate + * @param region Region that owns the qubit SSA value * @throws Aborts if qubit is not tracked (consumed or never created) */ void validateQubitValue(Value qubit, Region* region) const; @@ -1295,7 +1296,7 @@ class QCOProgramBuilder final : public ImplicitLocOpBuilder { * @brief Update tracking when an operation consumes and produces a qubit * @param inputQubit Input qubit being consumed (must be valid) * @param outputQubit New output qubit being produced - * @param region The Region in where the qubits are defined. + * @param region Region where the qubits are defined */ void updateQubitTracking(Value inputQubit, Value outputQubit, Region* region);