From 9682f414529d379b2dc13a45be8522dcfc39096e Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 26 May 2026 18:37:39 +0200 Subject: [PATCH 01/17] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20Relax=20condition=20?= =?UTF-8?q?on=20modifiers?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../Dialect/QC/Builder/QCProgramBuilder.h | 6 +- mlir/include/mlir/Dialect/QC/IR/QCOps.td | 77 ++-- mlir/include/mlir/Dialect/QCO/IR/QCOOps.td | 22 +- mlir/include/mlir/Dialect/Utils/Utils.h | 107 +++++ mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 55 ++- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 66 ++-- mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 61 ++- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 6 +- .../Dialect/QC/Builder/QCProgramBuilder.cpp | 42 +- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 164 +++++--- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 331 ++++++++++------ mlir/lib/Dialect/QC/IR/QCOps.cpp | 12 + .../TranslateQuantumComputationToQC.cpp | 10 +- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 172 ++++---- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 369 +++++++++++------- mlir/lib/Dialect/QCO/IR/QCOOps.cpp | 80 +--- .../Optimizations/HadamardLifting.cpp | 3 +- mlir/unittests/programs/qc_programs.cpp | 347 +++++++++++----- 18 files changed, 1210 insertions(+), 720 deletions(-) diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index ab9532cfb4..c2483d9f06 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -917,7 +917,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * } : !qc.qubit * ``` */ - QCProgramBuilder& ctrl(ValueRange controls, const function_ref& body); + QCProgramBuilder& ctrl(ValueRange controls, ValueRange targets, + const function_ref& body); /** * @brief Apply an inverse (i.e., adjoint) operation. @@ -936,7 +937,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * } * ``` */ - QCProgramBuilder& inv(const function_ref& body); + QCProgramBuilder& inv(ValueRange qubits, + const function_ref& body); //===--------------------------------------------------------------------===// // Deallocation diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index 8e76c9c3ba..cce7265131 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -916,7 +916,7 @@ def YieldOp : QCOp<"yield", traits = [Terminator]> { def CtrlOp : QCOp<"ctrl", - traits = [UnitaryOpInterface, + traits = [UnitaryOpInterface, AttrSizedOperandSegments, SingleBlockImplicitTerminator<"::mlir::qc::YieldOp">, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; @@ -937,30 +937,36 @@ def CtrlOp ``` }]; - let arguments = - (ins Arg, - "the control qubits", [MemRead, MemWrite]>:$controls); + let arguments = (ins Arg, + "the control qubits", [MemRead, MemWrite]>:$controls, + Arg, + "the target qubits", [MemRead, MemWrite]>:$targets); let regions = (region SizedRegion<1>:$region); - let assemblyFormat = - "`(` $controls `)` $region attr-dict `:` type($controls)"; + let assemblyFormat = [{ + `(` $controls `)` + `targets` + custom($region, $targets) + attr-dict `:` + `{` type($controls) `}` ( `,` `{` type($targets)^ `}` )? + }]; let extraClassDeclaration = [{ - [[nodiscard]] UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumTargets() + getNumControls(); } - size_t getNumTargets() { return getBodyUnitary().getNumTargets(); } + size_t getNumTargets() { return getTargets().size(); } size_t getNumControls() { return getControls().size(); } Value getQubit(size_t i); - Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); } - ValueRange getTargets() { return getBodyUnitary().getTargets(); } + Value getTarget(size_t i) { return getTargets()[i]; } Value getControl(size_t i); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + ValueRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } static StringRef getBaseSymbol() { return "ctrl"; } }]; - let builders = [OpBuilder<(ins "ValueRange":$controls, - "const function_ref&":$bodyBuilder)>]; + let builders = [OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets, + "const function_ref&":$body)>]; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -983,26 +989,35 @@ def InvOp : QCOp<"inv", ``` }]; + let arguments = (ins Arg< + Variadic, + "the qubits involved in the operation", [MemRead, MemWrite]>:$qubits); let regions = (region SizedRegion<1>:$region); - let assemblyFormat = "$region attr-dict"; - - let extraClassDeclaration = [{ - [[nodiscard]] UnitaryOpInterface getBodyUnitary(); - size_t getNumQubits() { return getBodyUnitary().getNumQubits(); } - size_t getNumTargets() { return getBodyUnitary().getNumTargets(); } - size_t getNumControls() { return getBodyUnitary().getNumControls(); } - Value getQubit(size_t i) { return getBodyUnitary().getQubit(i); } - Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); } - ValueRange getTargets() { return getBodyUnitary().getTargets(); } - Value getControl(size_t i) { return getBodyUnitary().getControl(i); } - ValueRange getControls() { return getBodyUnitary().getControls(); } - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + let assemblyFormat = [{ + custom($region, $qubits) + attr-dict `:` + type($qubits) + }]; + + let extraClassDeclaration = [{ + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); + size_t getNumQubits() { return getNumTargets(); } + size_t getNumTargets() { return getQubits().size(); } + size_t getNumControls() { return 0; } + Value getQubit(size_t i) { return getTarget(i); } + Value getTarget(size_t i) { return getQubits()[i]; } + ValueRange getTargets() { return getQubits(); } + Value getControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } + ValueRange getControls() { return {nullptr, 0}; } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } static StringRef getBaseSymbol() { return "inv"; } }]; - let builders = [OpBuilder<(ins "const function_ref&":$bodyBuilder)>]; + let builders = [OpBuilder<(ins "ValueRange":$qubits, + "const function_ref&":$body)>]; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a5bbfb7f51..78e15ecc86 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1102,7 +1102,8 @@ def CtrlOp }]; let extraClassDeclaration = [{ - UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumControls() + getNumTargets(); } size_t getNumTargets() { return getTargetsIn().size(); } size_t getNumControls() { return getControlsIn().size(); } @@ -1120,9 +1121,9 @@ def CtrlOp ResultRange getOutputControls() { return getControlsOut(); } Value getInputForOutput(Value output); Value getOutputForInput(Value input); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "ctrl"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; @@ -1173,7 +1174,8 @@ def InvOp }]; let extraClassDeclaration = [{ - UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumTargets(); } size_t getNumTargets() { return getQubitsIn().size(); } static size_t getNumControls() { return 0; } @@ -1184,16 +1186,16 @@ def InvOp ResultRange getOutputQubits() { return getQubitsOut(); } Value getInputTarget(size_t i) { return getInputQubit(i); } Value getOutputTarget(size_t i) { return getOutputQubit(i); } - static Value getInputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); } + static Value getInputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } static OperandRange getInputControls() { return {nullptr, 0}; } - static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); } + static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } static ResultRange getOutputControls() { return {nullptr, 0}; } ResultRange getOutputTargets() { return getOutputQubits(); } Value getInputForOutput(Value output); Value getOutputForInput(Value input); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "inv"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 3d976a5a63..546ecc479c 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -78,4 +79,110 @@ template return std::nullopt; } +template +[[nodiscard]] +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + // 1. Parse the opening parenthesis + if (parser.parseLParen()) { + return failure(); + } + + // Temporary storage for block arguments we are about to create + SmallVector blockArgs; + + // 2. Prepare to parse the list + if (failed(parser.parseOptionalRParen())) { + do { + OpAsmParser::Argument newArg; // The "new" variable name + OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable + + // Parse "%new" + if (parser.parseArgument(newArg)) { + return failure(); + } + + // Parse "=" + if (parser.parseEqual()) { + return failure(); + } + + // Parse "%old" + if (parser.parseOperand(oldOperand)) { + return failure(); + } + operands.push_back(oldOperand); + + // Hard-code QubitType since targets in qco.ctrl are always qubits. + // This avoids double-binding type($targets_in) in the assembly format + // while keeping the parser simple and the assembly format clean. + newArg.type = QubitType::get(parser.getBuilder().getContext()); + blockArgs.push_back(newArg); + + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) { + return failure(); + } + } + + // 4. Parse the Region + // We explicitly pass the blockArgs we just parsed so they become the entry + // block! + if (parser.parseRegion(region, blockArgs)) { + return failure(); + } + + return success(); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Region& region, + OperandRange targetsIn) { + printer << "("; + if (region.empty()) { + printer << ") "; + printer.printRegion(region, false); + return; + } + Block& entryBlock = region.front(); + + const auto numTargets = targetsIn.size(); + for (unsigned i = 0; i < numTargets; ++i) { + if (i > 0) { + printer << ", "; + } + printer.printOperand(entryBlock.getArgument(i)); + printer << " = "; + printer.printOperand(targetsIn[i]); + } + printer << ") "; + + printer.printRegion(region, false); +} + +// TODO: Document +static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { + if (auto blockArg = dyn_cast(qubit)) { + return qubits[blockArg.getArgNumber()]; + } + return qubit; +} + +// TODO: Rename and document +static void prova(Block& block, IRMapping& mapping, ValueRange innerQubits, + ValueRange outerQubits, ValueRange newQubits, + ValueRange qubitArgs) { + for (auto arg : block.getArguments()) { + auto innerQubit = innerQubits[arg.getArgNumber()]; + auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); + if (auto it = llvm::find(newQubits, outerQubit); it != newQubits.end()) { + auto index = std::distance(newQubits.begin(), it); + mapping.map(arg, qubitArgs[index]); + } else { + llvm::reportFatalInternalError("TODO"); + } + } +} + } // namespace mlir::utils diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index a486e82c5d..001844e1e6 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -155,19 +155,30 @@ static void handleResult(Operation* op, ConversionPatternRewriter& rewriter, * @brief Target operands: `adaptor.getOperands()` at the matched op, or * `state.targetsIn` while lowering inside `qco.ctrl` / `qco.inv`. * - * @param state Lowering state. - * @param adaptor Operand adaptor for the matched op. + * @param op The operation being converted. + * @param adaptor The operation adaptor of the operation. + * @param state The lowering state. * @tparam NumParams Number of parameters to drop from the end of the operand * list. - * @tparam OpAdaptor Adaptor with `getOperands()`. - * @return ValueRange The target operands. + * @tparam OpType The type of the operation. + * @tparam OpAdaptorType The type of the operation adaptor. + * @return The target operands. */ -template -[[nodiscard]] static ValueRange getEffectiveTargetOperands(LoweringState& state, - OpAdaptor adaptor) { - return state.inModifier() - ? ValueRange(state.targetsIn) - : ValueRange(adaptor.getOperands().drop_back(NumParams)); +template +[[nodiscard]] static SmallVector +getEffectiveTargetOperands(OpType op, OpAdaptorType adaptor, + LoweringState& state) { + if (!state.inModifier()) { + return adaptor.getOperands().drop_back(NumParams); + } + + SmallVector targets; + for (auto targetArg : op->getOperands().drop_back(NumParams)) { + auto target = + state.targetsIn[cast(targetArg).getArgNumber()]; + targets.push_back(target); + } + return targets; } /** @@ -190,10 +201,10 @@ convertJeffGate(QCOOpType op, typename QCOOpType::Adaptor adaptor, std::index_sequence /*targetIndices*/, std::index_sequence /*paramIndices*/) { constexpr std::size_t numParams = sizeof...(ParamIndices); - ValueRange targets = getEffectiveTargetOperands(state, adaptor); + auto targets = getEffectiveTargetOperands(op, adaptor, state); assert(targets.size() >= sizeof...(TargetIndices) && "Not enough operands available for conversion"); - ValueRange params = op.getParameters(); + auto params = op.getParameters(); auto jeffOp = JeffOpType::create( rewriter, op.getLoc(), targets[TargetIndices]..., params[ParamIndices]..., @@ -336,7 +347,7 @@ static LogicalResult moveRegion(Region& source, Region& dest, ConversionPatternRewriter& rewriter, const TypeConverter* typeConverter) { rewriter.inlineRegionBefore(source, dest, dest.end()); - Block* block = &dest.front(); + auto* block = &dest.front(); TypeConverter::SignatureConversion sc(block->getNumArguments()); if (failed( typeConverter->convertSignatureArgs(block->getArgumentTypes(), sc))) { @@ -728,7 +739,7 @@ struct ConvertQCOCustomGateToJeff final } } - ValueRange targets = getEffectiveTargetOperands(state, adaptor); + auto targets = getEffectiveTargetOperands(op, adaptor, state); assert(targets.size() >= NumTargets && "Not enough operands available for conversion"); @@ -764,7 +775,7 @@ struct ConvertQCOPPRGateToJeff final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - ValueRange targets = getEffectiveTargetOperands<1>(state, adaptor); + auto targets = getEffectiveTargetOperands<1>(op, adaptor, state); assert(targets.size() >= 2 && "Not enough operands available for conversion"); createPPROp(op, rewriter, state, targets, {p0_, p1_}); @@ -798,7 +809,7 @@ struct ConvertQCOU2OpToJeff final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - ValueRange targets = getEffectiveTargetOperands<2>(state, adaptor); + auto targets = getEffectiveTargetOperands<2>(op, adaptor, state); assert(!targets.empty() && "Not enough operands available for conversion"); auto target = targets.front(); @@ -840,11 +851,8 @@ struct ConvertQCOBarrierOpToJeff final matchAndRewrite(BarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - - ValueRange targets = getEffectiveTargetOperands<0>(state, adaptor); - + auto targets = getEffectiveTargetOperands<0>(op, adaptor, state); createCustomOp(op, rewriter, state, targets, {}, false, "barrier"); - return success(); } }; @@ -934,6 +942,13 @@ struct ConvertQCOInvOpToJeff final : StatefulOpConversionPattern { state.invOp = op; if (state.targetsIn.empty()) { state.targetsIn = llvm::to_vector(adaptor.getQubitsIn()); + } else { + auto outerQubits = state.targetsIn; + SmallVector innerQubits; + for (auto arg : op.getBody()->getArguments()) { + innerQubits.push_back(outerQubits[arg.getArgNumber()]); + } + state.targetsIn = std::move(innerQubits); } // Inline region diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 97d7071f69..77758bd3ba 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -120,17 +120,12 @@ class StatefulOpConversionPattern : public OpConversionPattern { * @param sourceRegion Source region where the operations are moved from * @param targetRegion Target region where the operations are moved to * @param offset Offset to the arguments that are dropped - * @param numArgs Number of arguments that are dropped * @param replacementValues Values to replace the uses of the arguments * @param rewriter PatternRewriter of the current conversion pass */ static void inlineRegion(Region& sourceRegion, Region& targetRegion, - unsigned int offset, unsigned int numArgs, - ValueRange replacementValues, + unsigned int offset, ValueRange replacementValues, ConversionPatternRewriter& rewriter) { - assert(replacementValues.size() == numArgs && - "replacementValues size must match numArgs"); - rewriter.inlineRegionBefore(sourceRegion, targetRegion, targetRegion.end()); auto& block = targetRegion.front(); @@ -138,7 +133,7 @@ static void inlineRegion(Region& sourceRegion, Region& targetRegion, block.getArguments().drop_front(offset), replacementValues)) { arg.replaceAllUsesWith(replacementVal); } - block.eraseArguments(offset, numArgs); + block.eraseArguments(offset, replacementValues.size()); } #define GEN_PASS_DEF_QCOTOQC @@ -645,16 +640,19 @@ struct ConvertQCOCtrlOp final : OpConversionPattern { LogicalResult matchAndRewrite(qco::CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // Get QC controls - auto qcControls = adaptor.getControlsIn(); - // Create qc.ctrl operation - auto qcOp = qc::CtrlOp::create(rewriter, op.getLoc(), qcControls); - - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), qcOp.getRegion(), 0, - adaptor.getTargetsIn().size(), adaptor.getTargetsIn(), - rewriter); + auto qcOp = qc::CtrlOp::create( + rewriter, op.getLoc(), adaptor.getControlsIn(), adaptor.getTargetsIn()); + + auto& dstRegion = qcOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -687,11 +685,17 @@ struct ConvertQCOInvOp final : OpConversionPattern { matchAndRewrite(qco::InvOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // Create qc.inv operation - auto qcOp = qc::InvOp::create(rewriter, op.getLoc()); - - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), qcOp.getRegion(), 0, - adaptor.getOperands().size(), adaptor.getQubitsIn(), rewriter); + auto qcOp = qc::InvOp::create(rewriter, op.getLoc(), adaptor.getQubitsIn()); + + auto& dstRegion = qcOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -764,9 +768,9 @@ struct ConvertQCOSCFForOp final : OpConversionPattern { // Erase default block rewriter.eraseBlock(&newFor.getRegion().front()); - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), newFor.getRegion(), 1, - adaptor.getInitArgs().size(), adaptor.getInitArgs(), rewriter); + // Inline the region and replace the block arguments + inlineRegion(op.getRegion(), newFor.getRegion(), 1, adaptor.getInitArgs(), + rewriter); rewriter.replaceOp(op, adaptor.getInitArgs()); @@ -810,11 +814,11 @@ struct ConvertQCOSCFWhileOp final : OpConversionPattern { auto newWhileOp = scf::WhileOp::create(rewriter, op->getLoc(), TypeRange{}, ValueRange{}); - // Inline the regions and replace the blockarguments - inlineRegion(op.getBefore(), newWhileOp.getBefore(), 0, - adaptor.getInits().size(), adaptor.getInits(), rewriter); - inlineRegion(op.getAfter(), newWhileOp.getAfter(), 0, - adaptor.getInits().size(), adaptor.getInits(), rewriter); + // Inline the regions and replace the block arguments + inlineRegion(op.getBefore(), newWhileOp.getBefore(), 0, adaptor.getInits(), + rewriter); + inlineRegion(op.getAfter(), newWhileOp.getAfter(), 0, adaptor.getInits(), + rewriter); rewriter.replaceOp(op, adaptor.getInits()); @@ -855,15 +859,13 @@ struct ConvertQCOIfOp final : OpConversionPattern { // Erase the default empty then block rewriter.eraseBlock(&newThenRegion.front()); - // Inline the region and replace the blockarguments + // Inline the region and replace the block arguments inlineRegion(op.getThenRegion(), newThenRegion, 0, - adaptor.getOperands().size() - 1, adaptor.getOperands().drop_front(1), rewriter); // Inline the else block if it has more than just the yield operation if (oldElseRegion.front().getOperations().size() > 1) { inlineRegion(oldElseRegion, newIf.getElseRegion(), 0, - adaptor.getOperands().size() - 1, adaptor.getOperands().drop_front(1), rewriter); } diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index b4a4b0a151..9c32fc302d 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -391,22 +391,6 @@ static void popModifierFrame(LoweringState& state) { state.modifierFrames.pop_back(); } -/** @brief Adds entry block aliases for modifier target values. */ -template -[[nodiscard]] static ValueRange addModifierAliases(OpType op, - const size_t numTargets, - PatternRewriter& rewriter) { - auto& entryBlock = op.getRegion().front(); - const auto opLoc = op.getLoc(); - const auto qubitType = qco::QubitType::get(op.getContext()); - rewriter.modifyOpInPlace(op, [&] { - for (size_t i = 0; i < numTargets; ++i) { - entryBlock.addArgument(qubitType, opLoc); - } - }); - return entryBlock.getArguments().take_back(numTargets); -} - /** * @brief Inserts extracted qubits that are not required by @p target back into * their tensors. @@ -525,7 +509,8 @@ collectQubitValuesInsideSCFOps(Operation* op, LoweringState* state) { // Iterate through all operations of the current region for (auto& operation : region.front().getOperations()) { // Recursively walk through nested regions - if (operation.getNumRegions() > 0) { + if (operation.getNumRegions() > 0 && + !isa(operation)) { auto [qubits, registers] = collectQubitValuesInsideSCFOps(&operation, state); auto& regionQubitMap = state->regionQubitMap[op]; @@ -1124,16 +1109,20 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { assignMappedQubits(state, operation, qcControls, qcoOp.getControlsOut()); assignMappedQubits(state, operation, qcTargets, qcoOp.getTargetsOut()); - // Clone body region from QC to QCO + auto qcArgs = op.getRegion().front().getArguments(); + + // Inline region auto& dstRegion = qcoOp.getRegion(); - rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); - // Create block arguments for QCO targets - auto& entryBlock = dstRegion.front(); - assert(entryBlock.getNumArguments() == 0 && - "QC ctrl region unexpectedly has entry block arguments"); - pushModifierFrame(state, qcTargets, - addModifierAliases(qcoOp, numTargets, rewriter)); + pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); rewriter.eraseOp(op); return success(); @@ -1174,16 +1163,20 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { assignMappedQubits(state, operation, qcTargets, qcoOp.getOutputTargets()); - // Clone body region from QC to QCO + auto qcArgs = op.getRegion().front().getArguments(); + + // Inline region auto& dstRegion = qcoOp.getRegion(); - rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - // Create block arguments for target qubits and seed the nested frame. - auto& entryBlock = dstRegion.front(); - assert(entryBlock.getNumArguments() == 0 && - "QC inv region unexpectedly has entry block arguments"); - pushModifierFrame(state, qcTargets, - addModifierAliases(qcoOp, numTargets, rewriter)); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); + + pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 98942d998c..6c432f57d2 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -870,9 +870,9 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { adaptor.getControls().end()); state.controls[state.inCtrlOp] = controls; - // Inline region and remove operation - rewriter.inlineBlockBefore(&op.getRegion().front(), op->getBlock(), - op->getIterator()); + // Inline block and remove operation + rewriter.inlineBlockBefore(&op.getRegion().front(), op, + adaptor.getTargets()); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 5eb022c03a..4bde38d301 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -223,7 +223,8 @@ QCProgramBuilder& QCProgramBuilder::reset(Value qubit) { const std::variant&(PARAM), ValueRange controls) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, param); }); \ + ctrl(controls, ValueRange{}, \ + [&](ValueRange /*targets*/) { OP_NAME(param); }); \ return *this; \ } @@ -247,7 +248,7 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) QCProgramBuilder& QCProgramBuilder::mc##OP_NAME(ValueRange controls, \ Value target) { \ checkFinalized(); \ - CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, target); }); \ + ctrl(controls, target, [&](ValueRange targets) { OP_NAME(targets[0]); }); \ return *this; \ } @@ -285,8 +286,8 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value target) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, target, param); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_NAME(param, targets[0]); }); \ return *this; \ } @@ -321,8 +322,8 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, theta) checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, target, param1, param2); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_NAME(param1, param2, targets[0]); }); \ return *this; \ } @@ -360,8 +361,8 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ auto param3 = variantToValue(*this, getLoc(), PARAM3); \ - CtrlOp::create(*this, controls, [&] { \ - OP_CLASS::create(*this, target, param1, param2, param3); \ + ctrl(controls, target, [&](ValueRange targets) { \ + OP_NAME(param1, param2, param3, targets[0]); \ }); \ return *this; \ } @@ -386,8 +387,8 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ ValueRange controls, Value qubit0, Value qubit1) { \ checkFinalized(); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, qubit0, qubit1); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, \ + [&](ValueRange targets) { OP_NAME(targets[0], targets[1]); }); \ return *this; \ } @@ -418,8 +419,8 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) Value qubit0, Value qubit1) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, qubit0, qubit1, param); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, \ + [&](ValueRange targets) { OP_NAME(param, targets[0], targets[1]); }); \ return *this; \ } @@ -455,8 +456,8 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - CtrlOp::create(*this, controls, [&] { \ - OP_CLASS::create(*this, qubit0, qubit1, param1, param2); \ + ctrl(controls, ValueRange{qubit0, qubit1}, [&](ValueRange targets) { \ + OP_NAME(param1, param2, targets[0], targets[1]); \ }); \ return *this; \ } @@ -478,16 +479,19 @@ QCProgramBuilder& QCProgramBuilder::barrier(ValueRange qubits) { // Modifiers //===----------------------------------------------------------------------===// -QCProgramBuilder& QCProgramBuilder::ctrl(ValueRange controls, - const function_ref& body) { +QCProgramBuilder& +QCProgramBuilder::ctrl(ValueRange controls, ValueRange targets, + const function_ref& body) { checkFinalized(); - CtrlOp::create(*this, controls, body); + CtrlOp::create(*this, controls, targets, body); return *this; } -QCProgramBuilder& QCProgramBuilder::inv(const function_ref& body) { +QCProgramBuilder& +QCProgramBuilder::inv(ValueRange qubits, + const function_ref& body) { checkFinalized(); - InvOp::create(*this, body); + InvOp::create(*this, qubits, body); return *this; } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index d457fa9a35..cf943b1f99 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -9,10 +9,12 @@ */ #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include #include +#include #include #include #include @@ -33,22 +35,47 @@ struct MergeNestedCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); - auto bodyCtrlOp = dyn_cast(bodyUnitary); - if (!bodyCtrlOp) { + // Require at least one control + // Trivial case is handled by ReduceCtrl + if (op.getNumControls() == 0) { return failure(); } - // add the inner controls as operands to the outer one - op->insertOperands(op.getNumOperands(), bodyCtrlOp.getControls()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + if (!innerCtrlOp) { + return failure(); + } - // Move the inner unitary op into the outer one's body region and replace - // the outer one with the inner one's results - const OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(bodyUnitary); - auto* innerUnitaryOp = bodyCtrlOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerUnitaryOp, bodyUnitary); - rewriter.replaceOp(bodyUnitary, innerUnitaryOp->getResults()); + auto outerControls = op.getControls(); + auto outerTargets = op.getTargets(); + auto innerTargets = innerCtrlOp.getTargets(); + + SmallVector controls; + SmallVector targets; + llvm::append_range(controls, outerControls); + for (auto [arg, qubit] : + llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { + if (llvm::is_contained(innerTargets, arg)) { + targets.push_back(qubit); + } else { + controls.push_back(qubit); + } + } + + rewriter.replaceOpWithNewOp( + op, controls, targets, [&](ValueRange targetArgs) { + auto* innerCtrlBody = innerCtrlOp.getBody(); + IRMapping mapping; + utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, + targets, targetArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + }); return success(); } @@ -63,16 +90,30 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + // Inline ops from empty control modifiers, IdOp and BarrierOp - if (op.getNumControls() == 0 || isa(bodyUnitary)) { - rewriter.moveOpBefore(bodyUnitary, op); - rewriter.replaceOp(op, bodyUnitary->getResults()); + if (op.getNumControls() == 0 || isa(innerOp)) { + const auto numTargets = op.getNumTargets(); + auto outerTargets = op.getTargets(); + SmallVector targets; + for (auto target : innerOp->getOperands().take_front(numTargets)) { + targets.push_back( + utils::getValueFromBlockArgument(target, outerTargets)); + } + + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numTargets, targets); + rewriter.eraseOp(op); return success(); } // The remaining code explicitly handles GPhaseOp and nothing else - auto gPhaseOp = dyn_cast(bodyUnitary); + auto gPhaseOp = dyn_cast(innerOp); if (!gPhaseOp) { return failure(); } @@ -84,16 +125,23 @@ struct ReduceCtrl final : OpRewritePattern { return success(); } - // Remove the last control and replace with a single POp with the removed - // control as target - auto controls = op.getControls(); - auto target = controls.back(); - controls = controls.drop_back(); - op->setOperands(controls); + // Adjust the segment sizes of the control and target operands + const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); + auto segmentsAttr = + op->getAttrOfType(opSegmentsAttrName); + auto newSegments = DenseI32ArrayAttr::get( + rewriter.getContext(), {segmentsAttr[0] - 1, segmentsAttr[1] + 1}); + op->setAttr(opSegmentsAttrName, newSegments); + + // Add a block argument for the target qubit + auto arg = op.getBody()->addArgument(QubitType::get(rewriter.getContext()), + op.getLoc()); + // Replace the current GPhaseOp with a PhaseOp const OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(gPhaseOp); - rewriter.replaceOpWithNewOp(gPhaseOp, target, gPhaseOp.getTheta()); + POp::create(rewriter, gPhaseOp.getLoc(), arg, gPhaseOp.getTheta()); + rewriter.eraseOp(gPhaseOp); return success(); } @@ -101,13 +149,27 @@ struct ReduceCtrl final : OpRewritePattern { } // namespace -UnitaryOpInterface CtrlOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t CtrlOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Unitary index out of bounds"); } Value CtrlOp::getQubit(const size_t i) { @@ -116,9 +178,9 @@ Value CtrlOp::getQubit(const size_t i) { return getControls()[i]; } if (numControls <= i && i < getNumQubits()) { - return getBodyUnitary().getQubit(i - numControls); + return getTarget(i - numControls); } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getControl(const size_t i) { @@ -129,15 +191,19 @@ Value CtrlOp::getControl(const size_t i) { } void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, - ValueRange controls, - const function_ref& bodyBuilder) { - const OpBuilder::InsertionGuard guard(odsBuilder); - odsState.addOperands(controls); - auto* region = odsState.addRegion(); - auto& block = region->emplaceBlock(); + ValueRange controls, ValueRange targets, + const function_ref& body) { + build(odsBuilder, odsState, controls, targets); + auto& block = odsState.regions.front()->emplaceBlock(); + + auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < targets.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - bodyBuilder(); + body(block.getArguments()); YieldOp::create(odsBuilder, odsState.location); } @@ -150,16 +216,6 @@ LogicalResult CtrlOp::verify() { return emitOpError( "last operation in body region must be a yield operation"); } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } SmallPtrSet uniqueQubits; for (const auto& control : getControls()) { @@ -167,11 +223,9 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } - auto bodyUnitary = getBodyUnitary(); - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubits.insert(bodyUnitary.getQubit(i)).second) { - return emitOpError("duplicate qubit found"); + for (const auto& target : getTargets()) { + if (!uniqueQubits.insert(target).second) { + return emitOpError("duplicate target qubit found"); } } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 065fe431be..b935e5c823 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -9,11 +9,13 @@ */ #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include #include #include +#include #include #include #include @@ -33,20 +35,36 @@ namespace { struct MoveCtrlOutside final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto bodyUnitary = invOp.getBodyUnitary(); - auto innerCtrlOp = dyn_cast(bodyUnitary.getOperation()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerCtrlOp) { return failure(); } - auto controls = innerCtrlOp.getControls(); - rewriter.replaceOpWithNewOp(invOp, controls, [&] { - InvOp::create(rewriter, invOp.getLoc(), [&] { - rewriter.clone(*innerCtrlOp.getBodyUnitary().getOperation()); - }); - }); + const auto numControls = innerCtrlOp.getNumControls(); + const auto numTargets = innerCtrlOp.getNumTargets(); + auto outerQubits = op.getQubits(); + auto controls = outerQubits.take_front(numControls); + auto targets = outerQubits.take_back(numTargets); + + rewriter.replaceOpWithNewOp( + op, controls, targets, [&](ValueRange targetArgs) { + InvOp::create( + rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { + auto* innerCtrlBody = innerCtrlOp.getBody(); + IRMapping mapping; + utils::prova(*innerCtrlBody, mapping, innerCtrlOp.getTargets(), + outerQubits, targets, qubitArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + }); + }); return success(); } @@ -62,13 +80,24 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); if (!isa(innerOp)) { return failure(); } + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getQubits(); + SmallVector qubits; + for (auto qubit : innerOp->getOperands().take_front(numQubits)) { + qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); + } + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numQubits, qubits); rewriter.replaceOp(op, innerOp->getResults()); return success(); } @@ -85,143 +114,181 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + + auto loc = op.getLoc(); + auto outerQubits = op.getQubits(); return TypeSwitch(innerOp) .Case([&](auto g) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), g.getTheta()); + Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); rewriter.replaceOpWithNewOp(op, negTheta); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto t) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(t.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto tdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(tdg.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto s) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(s.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sdg.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sx) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sx.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sxdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sxdg.getTarget(0), outerQubits)); return success(); }) .Case([&](auto p) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), p.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(p.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto r) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), r.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta, - r.getPhi()); + auto negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(r.getTarget(0), outerQubits), + negTheta, r.getPhi()); return success(); }) .Case([&](auto rx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rx.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto u) { - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - Value newTheta = - arith::NegFOp::create(rewriter, op.getLoc(), u.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), newTheta, - newPhi, newLambda); + Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); + Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); + Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(u.getTarget(0), outerQubits), + newTheta, newPhi, newLambda); return success(); }) - .Case([&](auto u) { - auto pi = arith::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - newPhi = arith::SubFOp::create(rewriter, op.getLoc(), newPhi, pi); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - newLambda = - arith::AddFOp::create(rewriter, op.getLoc(), newLambda, pi); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), newPhi, - newLambda); + .Case([&](auto u2) { + Value pi = arith::ConstantOp::create( + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u2.getTarget(0), outerQubits), + newPhi, newLambda); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(1), - op.getTarget(0)); + .Case([&](auto dcx) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(dcx.getTarget(1), outerQubits), + utils::getValueFromBlockArgument(dcx.getTarget(0), outerQubits)); return success(); }) .Case([&](auto rxx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rxx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rxx.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rxx.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto ry) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ry.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ry.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto ryy) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ryy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ryy.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(ryy.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto rz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rz.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto rzx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzx.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rzx.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto rzz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzz.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rzz.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto xxminusyy) { - Value negTheta = arith::NegFOp::create(rewriter, op.getLoc(), - xxminusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta, - xxminusyy.getBeta()); + Value negTheta = + arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxminusyy.getTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxminusyy.getTarget(1), + outerQubits), + negTheta, xxminusyy.getBeta()); return success(); }) .Case([&](auto xxplusyy) { Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta, - xxplusyy.getBeta()); + arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxplusyy.getTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxplusyy.getTarget(1), + outerQubits), + negTheta, xxplusyy.getBeta()); return success(); }) .Default([&](auto) { return failure(); }); @@ -233,41 +300,79 @@ struct ReplaceWithKnownGates final : OpRewritePattern { */ struct CancelNestedInv final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto innerUnitary = invOp.getBodyUnitary(); - auto innerInvOp = dyn_cast(innerUnitary.getOperation()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerInvOp) { return failure(); } - auto* innerInnerUnitary = innerInvOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerInnerUnitary, invOp); - rewriter.replaceOp(invOp, innerInnerUnitary->getResults()); + // TODO: Relax this condition? + if (innerInvOp.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getQubits(); + auto innerQubits = innerInvOp.getQubits(); + SmallVector qubits; + for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { + auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); + qubits.push_back( + utils::getValueFromBlockArgument(innerQubit, outerQubits)); + } + + rewriter.moveOpBefore(innerInnerOp, op); + innerInnerOp->setOperands(0, numQubits, qubits); + rewriter.replaceOp(op, innerInnerOp->getResults()); return success(); } }; } // namespace -UnitaryOpInterface InvOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t InvOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Invalid unitary index"); } void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, - const function_ref& bodyBuilder) { - const OpBuilder::InsertionGuard guard(odsBuilder); - auto* region = odsState.addRegion(); - auto& block = region->emplaceBlock(); + ValueRange qubits, + const function_ref& body) { + build(odsBuilder, odsState, qubits); + auto& block = odsState.regions.front()->emplaceBlock(); + + auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < qubits.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - bodyBuilder(); + body(block.getArguments()); YieldOp::create(odsBuilder, odsState.location); } @@ -280,16 +385,6 @@ LogicalResult InvOp::verify() { return emitOpError( "last operation in body region must be a yield operation"); } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } return success(); } diff --git a/mlir/lib/Dialect/QC/IR/QCOps.cpp b/mlir/lib/Dialect/QC/IR/QCOps.cpp index 5b93c2ebaa..bf6551f924 100644 --- a/mlir/lib/Dialect/QC/IR/QCOps.cpp +++ b/mlir/lib/Dialect/QC/IR/QCOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QC/IR/QCDialect.h" // IWYU pragma: associated +#include "mlir/Dialect/Utils/Utils.h" // The following headers are needed for some template instantiations. // IWYU pragma: begin_keep @@ -21,6 +22,17 @@ using namespace mlir; using namespace mlir::qc; +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + return utils::parseTargetAliasing(parser, region, operands); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + utils::printTargetAliasing(printer, region, targetsIn); +} + //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index ac70cd4269..66d562ae82 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -452,10 +452,14 @@ static void addISWAPdgOp(QCProgramBuilder& builder, auto target0 = qubits[operation.getTargets()[0]]; auto target1 = qubits[operation.getTargets()[1]]; if (const auto& controls = getControls(operation, qubits); controls.empty()) { - builder.inv([&] { builder.iswap(target0, target1); }); + builder.inv({target0, target1}, [&](ValueRange qubits) { + builder.iswap(qubits[0], qubits[1]); + }); } else { - builder.ctrl(controls, [&] { - builder.inv([&] { builder.iswap(target0, target1); }); + builder.ctrl(controls, {target0, target1}, [&](ValueRange targets) { + builder.inv(targets, [&](ValueRange qubits) { + builder.iswap(qubits[0], qubits[1]); + }); }); } } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 25fc88d084..e86f3f7dc3 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -42,38 +43,52 @@ struct MergeNestedCtrl final : OpRewritePattern { LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // Require at least one positive control + // Require at least one control // Trivial case is handled by ReduceCtrl - const auto numOuterControls = op.getNumControls(); - if (numOuterControls == 0) { + if (op.getNumControls() == 0) { return failure(); } - auto bodyCtrlOp = dyn_cast(op.getBodyUnitary().getOperation()); - if (!bodyCtrlOp) { + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { return failure(); } - const auto numInnerControls = bodyCtrlOp.getNumControls(); - auto outerControls = op.getControlsIn(); + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + if (!innerCtrlOp) { + return failure(); + } + auto outerTargets = op.getTargetsIn(); - auto newAdditionalControls = outerTargets.take_front(numInnerControls); - auto newTargets = outerTargets.drop_front(numInnerControls); - auto newControls = llvm::to_vector( - llvm::concat(outerControls, newAdditionalControls)); + auto outerControls = op.getControlsIn(); + auto innerTargets = innerCtrlOp.getTargetsIn(); + + SmallVector controls; + SmallVector targets; + llvm::append_range(controls, outerControls); + for (auto [arg, qubit] : + llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { + if (llvm::is_contained(innerTargets, arg)) { + targets.push_back(qubit); + } else { + controls.push_back(qubit); + } + } rewriter.replaceOpWithNewOp( - op, newControls, newTargets, - [&](ValueRange newTargetArgs) -> SmallVector { + op, controls, targets, + [&](ValueRange targetArgs) -> SmallVector { + auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - auto* innerBody = bodyCtrlOp.getBody(); - for (size_t i = 0; i < bodyCtrlOp.getNumTargets(); ++i) { - mapping.map(innerBody->getArgument(i), newTargetArgs[i]); + utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, + targets, targetArgs); + SmallVector yields; + for (auto& op : innerCtrlBody->without_terminator()) { + auto results = rewriter.clone(op, mapping)->getResults(); + llvm::append_range(yields, results); } - - return rewriter - .clone(*bodyCtrlOp.getBodyUnitary().getOperation(), mapping) - ->getResults(); + return yields; }); + return success(); } }; @@ -87,20 +102,32 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + // Inline ops from empty control modifiers, IdOp and BarrierOp - if (op.getNumControls() == 0 || isa(bodyUnitary)) { - rewriter.moveOpBefore(bodyUnitary, op); - bodyUnitary->setOperands(0, op.getNumTargets(), op.getTargetsIn()); + if (op.getNumControls() == 0 || isa(innerOp)) { + const auto numTargets = op.getNumTargets(); + auto outerTargets = op.getTargetsIn(); + SmallVector targets; + for (auto target : innerOp->getOperands().take_front(numTargets)) { + targets.push_back( + utils::getValueFromBlockArgument(target, outerTargets)); + } + + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numTargets, targets); rewriter.replaceAllUsesWith(op.getControlsOut(), op.getControlsIn()); - rewriter.replaceAllUsesWith(op.getTargetsOut(), - bodyUnitary->getResults()); + rewriter.replaceAllUsesWith(op.getTargetsOut(), innerOp->getResults()); rewriter.eraseOp(op); return success(); } // The remaining code explicitly handles GPhaseOp and nothing else - auto gPhaseOp = dyn_cast(bodyUnitary); + auto gPhaseOp = dyn_cast(innerOp); if (!gPhaseOp) { return failure(); } @@ -136,7 +163,7 @@ struct ReduceCtrl final : OpRewritePattern { auto yieldOp = cast(op.getBody()->back()); yieldOp->setOperands(pOp->getResults()); - // erase the GPhaseOp + // Erase the GPhaseOp rewriter.eraseOp(gPhaseOp); return success(); @@ -145,13 +172,27 @@ struct ReduceCtrl final : OpRewritePattern { } // namespace -UnitaryOpInterface CtrlOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t CtrlOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Unitary index out of bounds"); } Value CtrlOp::getInputQubit(const size_t i) { @@ -162,7 +203,7 @@ Value CtrlOp::getInputQubit(const size_t i) { if (numControls <= i && i < getNumQubits()) { return getTargetsIn()[i - numControls]; } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getOutputQubit(const size_t i) { @@ -173,7 +214,7 @@ Value CtrlOp::getOutputQubit(const size_t i) { if (numControls <= i && i < getNumQubits()) { return getTargetsOut()[i - numControls]; } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getInputTarget(const size_t i) { @@ -238,7 +279,7 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, controls, targets); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); + auto qubitType = QubitType::get(odsBuilder.getContext()); for (size_t i = 0; i < targets.size(); ++i) { block.addArgument(qubitType, odsState.location); } @@ -275,18 +316,9 @@ LogicalResult CtrlOp::verify() { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } SmallPtrSet uniqueQubitsIn; + SmallPtrSet uniqueTargetsIn; for (const auto& control : getControlsIn()) { if (!uniqueQubitsIn.insert(control).second) { return emitOpError("duplicate control qubit found"); @@ -296,29 +328,20 @@ LogicalResult CtrlOp::verify() { if (!uniqueQubitsIn.insert(target).second) { return emitOpError("duplicate target qubit found"); } - } - - auto bodyUnitary = getBodyUnitary(); - if (bodyUnitary.getNumQubits() != numTargets) { - return emitOpError("body unitary must operate on exactly ") - << numTargets << " target qubits, but found " - << bodyUnitary.getNumQubits(); - } - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { - return emitOpError("body unitary must use target alias block argument ") - << i << " (and not the original target operand)"; + if (!uniqueTargetsIn.insert(target).second) { + return emitOpError("duplicate target qubit found"); } } - // Also require yield to forward the unitary's outputs in-order. - for (size_t i = 0; i < numTargets; ++i) { - if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { - return emitOpError("yield operand ") - << i << " must be the body unitary output qubit " << i; - } - } + // TODO: Re-enable + // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { + // auto bodyUnitary = getBodyUnitary(i); + // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { + // if (!uniqueTargetsIn.contains(bodyUnitary.getInputQubit(j))) { + // return emitOpError("unitary is using an unknown input qubit"); + // } + // } + // } SmallPtrSet uniqueQubitsOut; for (const auto& control : getControlsOut()) { @@ -326,8 +349,8 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } - for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubitsOut.insert(bodyUnitary.getOutputQubit(i)).second) { + for (size_t i = 0; i < numTargets; i++) { + if (!uniqueQubitsOut.insert(block.back().getOperand(i)).second) { return emitOpError("duplicate qubit found"); } } @@ -341,11 +364,16 @@ void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional CtrlOp::getUnitaryMatrix() { - auto&& bodyUnitary = getBodyUnitary(); + // TODO: Relax this condition + if (getNumBodyUnitaries() != 1) { + return std::nullopt; + } + + auto bodyUnitary = getBodyUnitary(0); if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto targetMatrix = bodyUnitary.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index d82a64f819..1b6a98c07e 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -40,36 +41,41 @@ namespace { struct MoveCtrlOutside final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto bodyUnitary = invOp.getBodyUnitary(); - auto innerCtrlOp = dyn_cast(bodyUnitary.getOperation()); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerCtrlOp) { return failure(); } const auto numControls = innerCtrlOp.getNumControls(); const auto numTargets = innerCtrlOp.getNumTargets(); - auto invTargets = invOp.getInputQubits(); - auto controls = invTargets.take_front(numControls); - auto targets = invTargets.take_back(numTargets); + auto outerQubits = op.getQubitsIn(); + auto controls = outerQubits.take_front(numControls); + auto targets = outerQubits.take_back(numTargets); rewriter.replaceOpWithNewOp( - invOp, controls, targets, - [&](ValueRange newTargetArgs) -> SmallVector { + op, controls, targets, + [&](ValueRange targetArgs) -> SmallVector { return InvOp::create( - rewriter, invOp.getLoc(), newTargetArgs, - [&](ValueRange invArgs) -> SmallVector { + rewriter, op.getLoc(), targetArgs, + [&](ValueRange qubitArgs) -> SmallVector { + auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - auto* innerBody = innerCtrlOp.getBody(); - for (size_t i = 0; i < innerCtrlOp.getNumTargets(); - ++i) { - mapping.map(innerBody->getArgument(i), invArgs[i]); + utils::prova(*innerCtrlBody, mapping, + innerCtrlOp.getTargetsIn(), outerQubits, + targets, qubitArgs); + SmallVector yields; + for (auto& op : innerCtrlBody->without_terminator()) { + auto results = + rewriter.clone(op, mapping)->getResults(); + llvm::append_range(yields, results); } - auto* cloned = rewriter.clone( - *innerCtrlOp.getBodyUnitary().getOperation(), - mapping); - return cloned->getResults(); + return yields; }) .getResults(); }); @@ -88,14 +94,24 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); if (!isa(innerOp)) { return failure(); } + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getInputQubits(); + SmallVector qubits; + for (auto qubit : innerOp->getOperands().take_front(numQubits)) { + qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); + } + rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, op.getNumQubits(), op.getInputQubits()); + innerOp->setOperands(0, numQubits, qubits); rewriter.replaceOp(op, innerOp->getResults()); return success(); } @@ -112,138 +128,192 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + + auto loc = op.getLoc(); + auto outerQubits = op.getInputQubits(); return TypeSwitch(innerOp) .Case([&](auto g) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), g.getTheta()); + Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); rewriter.replaceOpWithNewOp(op, negTheta); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto t) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(t.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto tdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(tdg.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto s) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(s.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sdg.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sx) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sx.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sxdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sxdg.getInputTarget(0), + outerQubits)); return success(); }) .Case([&](auto p) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), p.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(p.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto r) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), r.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta, - r.getPhi()); + Value negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(r.getInputTarget(0), + outerQubits), + negTheta, r.getPhi()); return success(); }) .Case([&](auto rx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rx.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto u) { - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - Value newTheta = - arith::NegFOp::create(rewriter, op.getLoc(), u.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), newTheta, - newPhi, newLambda); + Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); + Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); + Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u.getInputTarget(0), + outerQubits), + newTheta, newPhi, newLambda); return success(); }) - .Case([&](auto u) { + .Case([&](auto u2) { auto pi = arith::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - newPhi = arith::SubFOp::create(rewriter, op.getLoc(), newPhi, pi); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - newLambda = - arith::AddFOp::create(rewriter, op.getLoc(), newLambda, pi); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), newPhi, - newLambda); + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u2.getInputTarget(0), + outerQubits), + newPhi, newLambda); return success(); }) .Case([&](auto rxx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rxx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rxx.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rxx.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto ry) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ry.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ry.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto ryy) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ryy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ryy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(ryy.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto rz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rz.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto rzx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzx.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rzx.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto rzz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzz.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rzz.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto xxminusyy) { - Value negTheta = arith::NegFOp::create(rewriter, op.getLoc(), - xxminusyy.getTheta()); + Value negTheta = + arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); rewriter.replaceOpWithNewOp( - op, op.getInputTarget(0), op.getInputTarget(1), negTheta, - xxminusyy.getBeta()); + op, + utils::getValueFromBlockArgument(xxminusyy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxminusyy.getInputTarget(1), + outerQubits), + negTheta, xxminusyy.getBeta()); return success(); }) .Case([&](auto xxplusyy) { Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), - negTheta, xxplusyy.getBeta()); + arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxplusyy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxplusyy.getInputTarget(1), + outerQubits), + negTheta, xxplusyy.getBeta()); return success(); }) .Default([&](auto) { return failure(); }); @@ -258,30 +328,61 @@ struct CancelNestedInv final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerUnitary = op.getBodyUnitary().getOperation(); - auto innerInvOp = dyn_cast(innerUnitary); + // TODO: Relax this condition? + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerInvOp) { return failure(); } - auto* innerInnerUnitary = innerInvOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerInnerUnitary, op); - innerInnerUnitary->setOperands(0, op.getNumQubits(), op.getInputQubits()); - rewriter.replaceOp(op, innerInnerUnitary->getResults()); + // TODO: Relax this condition? + if (innerInvOp.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getInputQubits(); + auto innerQubits = innerInvOp.getInputQubits(); + SmallVector qubits; + for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { + auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); + qubits.push_back( + utils::getValueFromBlockArgument(innerQubit, outerQubits)); + } + rewriter.moveOpBefore(innerInnerOp, op); + innerInnerOp->setOperands(0, numQubits, qubits); + rewriter.replaceOp(op, innerInnerOp->getResults()); return success(); } }; } // namespace -UnitaryOpInterface InvOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t InvOp::getNumBodyUnitaries() { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + count++; + } + } + return count; +} + +UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { + size_t count = 0; + for (auto& op : *getBody()) { + if (isa(op)) { + if (count == i) { + return cast(op); + } + count++; + } + } + llvm::reportFatalUsageError("Unitary index out of bounds"); } Value InvOp::getInputQubit(const size_t i) { @@ -322,7 +423,7 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubits); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); + auto qubitType = QubitType::get(odsBuilder.getContext()); for (size_t i = 0; i < qubits.size(); ++i) { block.addArgument(qubitType, odsState.location); } @@ -359,38 +460,23 @@ LogicalResult InvOp::verify() { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } - auto bodyUnitary = getBodyUnitary(); - if (bodyUnitary.getNumQubits() != numTargets) { - return emitOpError("body unitary must operate on exactly ") - << numTargets << " target qubits, but found " - << bodyUnitary.getNumQubits(); - } - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { - return emitOpError("body unitary must use target alias block argument ") - << i << " (and not the original target operand)"; + SmallPtrSet uniqueQubitsIn; + for (const auto& target : getQubitsIn()) { + if (!uniqueQubitsIn.insert(target).second) { + return emitOpError("duplicate qubit found"); } } - // Also require yield to forward the unitary's outputs in-order. - for (size_t i = 0; i < numTargets; ++i) { - if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { - return emitOpError("yield operand ") - << i << " must be the body unitary output qubit " << i; - } - } + // TODO: Re-enable + // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { + // auto bodyUnitary = getBodyUnitary(i); + // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { + // if (!uniqueQubitsIn.contains(bodyUnitary.getInputQubit(j))) { + // return emitOpError("unitary is using an unknown qubit"); + // } + // } + // } return success(); } @@ -402,11 +488,16 @@ void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional InvOp::getUnitaryMatrix() { - auto&& bodyUnitary = getBodyUnitary(); + // TODO: Relax this condition + if (getNumBodyUnitaries() != 1) { + return std::nullopt; + } + + auto bodyUnitary = getBodyUnitary(0); if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto targetMatrix = bodyUnitary.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index a3ce816081..f1cb23a849 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -37,57 +38,12 @@ using namespace mlir::qco; static ParseResult parseTargetAliasing(OpAsmParser& parser, Region& region, SmallVectorImpl& operands) { - // 1. Parse the opening parenthesis - if (parser.parseLParen()) { - return failure(); - } - - // Temporary storage for block arguments we are about to create - SmallVector blockArgs; - - // 2. Prepare to parse the list - if (failed(parser.parseOptionalRParen())) { - do { - OpAsmParser::Argument newArg; // The "new" variable name - OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable - - // Parse "%new" - if (parser.parseArgument(newArg)) { - return failure(); - } - - // Parse "=" - if (parser.parseEqual()) { - return failure(); - } - - // Parse "%old" - if (parser.parseOperand(oldOperand)) { - return failure(); - } - operands.push_back(oldOperand); - - // Hard-code QubitType since targets in qco.ctrl are always qubits. - // This avoids double-binding type($targets_in) in the assembly format - // while keeping the parser simple and the assembly format clean. - newArg.type = QubitType::get(parser.getBuilder().getContext()); - blockArgs.push_back(newArg); - - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseRParen()) { - return failure(); - } - } - - // 4. Parse the Region - // We explicitly pass the blockArgs we just parsed so they become the entry - // block! - if (parser.parseRegion(region, blockArgs)) { - return failure(); - } + return utils::parseTargetAliasing(parser, region, operands); +} - return success(); +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + utils::printTargetAliasing(printer, region, targetsIn); } ParseResult IfOp::parse(::mlir::OpAsmParser& parser, @@ -213,30 +169,6 @@ void IfOp::print(OpAsmPrinter& p) { p.printOptionalAttrDict((*this)->getAttrs()); } -static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, - Region& region, OperandRange targetsIn) { - printer << "("; - if (region.empty()) { - printer << ") "; - printer.printRegion(region, false); - return; - } - Block& entryBlock = region.front(); - - const auto numTargets = targetsIn.size(); - for (unsigned i = 0; i < numTargets; ++i) { - if (i > 0) { - printer << ", "; - } - printer.printOperand(entryBlock.getArgument(i)); - printer << " = "; - printer.printOperand(targetsIn[i]); - } - printer << ") "; - - printer.printRegion(region, false); -} - //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp index 0ca22726a1..3d874533b6 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp @@ -162,7 +162,8 @@ struct LiftHadamardAboveCNOTPattern final : OpRewritePattern { if (!cnotGate) { return failure(); } - if (!isa(cnotGate.getBodyUnitary()) || + if (cnotGate.getNumBodyUnitaries() != 1 || + !isa(cnotGate.getBodyUnitary(0)) || cnotGate.getOutputTarget(0) != inQubitHadamard) { return failure(); } diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 373452252a..234c70cc9a 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -62,7 +62,7 @@ void staticQubitsWithCtrl(QCProgramBuilder& b) { void staticQubitsWithInv(QCProgramBuilder& b) { auto q0 = b.staticQubit(0); - b.inv([&]() { b.t(q0); }); + b.inv({q0}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void staticQubitsWithDuplicates(QCProgramBuilder& b) { @@ -75,7 +75,7 @@ void staticQubitsWithDuplicates(QCProgramBuilder& b) { b.p(std::numbers::pi / 2., q1a); b.rzz(0.123, q0b, q1b); b.cx(q0b, q1b); - b.inv([&]() { b.t(q0a); }); + b.inv({q0a}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void staticQubitsCanonical(QCProgramBuilder& b) { @@ -86,7 +86,7 @@ void staticQubitsCanonical(QCProgramBuilder& b) { b.p(std::numbers::pi / 2., q1); b.rzz(0.123, q0, q1); b.cx(q0, q1); - b.inv([&]() { b.t(q0); }); + b.inv({q0}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void allocDeallocPair(QCProgramBuilder& b) { @@ -194,7 +194,8 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b) { void nestedControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.ctrl(q[0], [&] { b.cgphase(0.123, q[1]); }); + b.ctrl(q[0], {q[1]}, + [&](ValueRange targets) { b.cgphase(0.123, targets[0]); }); } void trivialControlledGlobalPhase(QCProgramBuilder& b) { @@ -203,12 +204,13 @@ void trivialControlledGlobalPhase(QCProgramBuilder& b) { } void inverseGlobalPhase(QCProgramBuilder& b) { - b.inv([&]() { b.gphase(-0.123); }); + b.inv({}, [&](ValueRange qubits) { b.gphase(-0.123); }); } void inverseMultipleControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcgphase(-0.123, {q[0], q[1], q[2]}); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcgphase(-0.123, qubits); }); } void identity(QCProgramBuilder& b) { @@ -228,7 +230,8 @@ void multipleControlledIdentity(QCProgramBuilder& b) { void nestedControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.ctrl(q[2], [&] { b.cid(q[1], q[0]); }); + b.ctrl(q[2], {q[0], q[1]}, + [&](ValueRange targets) { b.cid(targets[1], targets[0]); }); } void trivialControlledIdentity(QCProgramBuilder& b) { @@ -238,12 +241,13 @@ void trivialControlledIdentity(QCProgramBuilder& b) { void inverseIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.id(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.id(qubits[0]); }); } void inverseMultipleControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcid({q[2], q[1]}, q[0]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcid({qubits[0], qubits[1]}, qubits[2]); }); } void x(QCProgramBuilder& b) { @@ -263,7 +267,8 @@ void multipleControlledX(QCProgramBuilder& b) { void nestedControlledX(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cx(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cx(targets[0], targets[1]); }); } void trivialControlledX(QCProgramBuilder& b) { @@ -282,12 +287,13 @@ void repeatedControlledX(QCProgramBuilder& b) { void inverseX(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.x(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.x(qubits[0]); }); } void inverseMultipleControlledX(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcx({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcx({qubits[0], qubits[1]}, qubits[2]); }); } void y(QCProgramBuilder& b) { @@ -307,7 +313,8 @@ void multipleControlledY(QCProgramBuilder& b) { void nestedControlledY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cy(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cy(targets[0], targets[1]); }); } void trivialControlledY(QCProgramBuilder& b) { @@ -317,12 +324,13 @@ void trivialControlledY(QCProgramBuilder& b) { void inverseY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.y(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.y(qubits[0]); }); } void inverseMultipleControlledY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcy({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcy({qubits[0], qubits[1]}, qubits[2]); }); } void z(QCProgramBuilder& b) { @@ -342,7 +350,8 @@ void multipleControlledZ(QCProgramBuilder& b) { void nestedControlledZ(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cz(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cz(targets[0], targets[1]); }); } void trivialControlledZ(QCProgramBuilder& b) { @@ -352,12 +361,13 @@ void trivialControlledZ(QCProgramBuilder& b) { void inverseZ(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.z(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.z(qubits[0]); }); } void inverseMultipleControlledZ(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcz({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcz({qubits[0], qubits[1]}, qubits[2]); }); } void h(QCProgramBuilder& b) { @@ -377,7 +387,8 @@ void multipleControlledH(QCProgramBuilder& b) { void nestedControlledH(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ch(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ch(targets[0], targets[1]); }); } void trivialControlledH(QCProgramBuilder& b) { @@ -387,12 +398,13 @@ void trivialControlledH(QCProgramBuilder& b) { void inverseH(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.h(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.h(qubits[0]); }); } void inverseMultipleControlledH(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mch({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mch({qubits[0], qubits[1]}, qubits[2]); }); } void hWithoutRegister(QCProgramBuilder& b) { @@ -417,7 +429,8 @@ void multipleControlledS(QCProgramBuilder& b) { void nestedControlledS(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cs(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cs(targets[0], targets[1]); }); } void trivialControlledS(QCProgramBuilder& b) { @@ -427,12 +440,13 @@ void trivialControlledS(QCProgramBuilder& b) { void inverseS(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.s(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.s(qubits[0]); }); } void inverseMultipleControlledS(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcs({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcs({qubits[0], qubits[1]}, qubits[2]); }); } void sdg(QCProgramBuilder& b) { @@ -452,7 +466,8 @@ void multipleControlledSdg(QCProgramBuilder& b) { void nestedControlledSdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csdg(targets[0], targets[1]); }); } void trivialControlledSdg(QCProgramBuilder& b) { @@ -462,12 +477,13 @@ void trivialControlledSdg(QCProgramBuilder& b) { void inverseSdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sdg(qubits[0]); }); } void inverseMultipleControlledSdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcsdg({qubits[0], qubits[1]}, qubits[2]); }); } void t_(QCProgramBuilder& b) { @@ -487,7 +503,8 @@ void multipleControlledT(QCProgramBuilder& b) { void nestedControlledT(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ct(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ct(targets[0], targets[1]); }); } void trivialControlledT(QCProgramBuilder& b) { @@ -497,12 +514,13 @@ void trivialControlledT(QCProgramBuilder& b) { void inverseT(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.t(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.t(qubits[0]); }); } void inverseMultipleControlledT(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mct({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mct({qubits[0], qubits[1]}, qubits[2]); }); } void tdg(QCProgramBuilder& b) { @@ -522,7 +540,8 @@ void multipleControlledTdg(QCProgramBuilder& b) { void nestedControlledTdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ctdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ctdg(targets[0], targets[1]); }); } void trivialControlledTdg(QCProgramBuilder& b) { @@ -532,12 +551,13 @@ void trivialControlledTdg(QCProgramBuilder& b) { void inverseTdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.tdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.tdg(qubits[0]); }); } void inverseMultipleControlledTdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mctdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mctdg({qubits[0], qubits[1]}, qubits[2]); }); } void sx(QCProgramBuilder& b) { @@ -557,7 +577,8 @@ void multipleControlledSx(QCProgramBuilder& b) { void nestedControlledSx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csx(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csx(targets[0], targets[1]); }); } void trivialControlledSx(QCProgramBuilder& b) { @@ -567,12 +588,13 @@ void trivialControlledSx(QCProgramBuilder& b) { void inverseSx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sx(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sx(qubits[0]); }); } void inverseMultipleControlledSx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsx({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcsx({qubits[0], qubits[1]}, qubits[2]); }); } void sxdg(QCProgramBuilder& b) { @@ -592,7 +614,8 @@ void multipleControlledSxdg(QCProgramBuilder& b) { void nestedControlledSxdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csxdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csxdg(targets[0], targets[1]); }); } void trivialControlledSxdg(QCProgramBuilder& b) { @@ -602,12 +625,14 @@ void trivialControlledSxdg(QCProgramBuilder& b) { void inverseSxdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sxdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sxdg(qubits[0]); }); } void inverseMultipleControlledSxdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsxdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcsxdg({qubits[0], qubits[1]}, qubits[2]); + }); } void rx(QCProgramBuilder& b) { @@ -627,7 +652,8 @@ void multipleControlledRx(QCProgramBuilder& b) { void nestedControlledRx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.crx(0.123, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.crx(0.123, targets[0], targets[1]); }); } void trivialControlledRx(QCProgramBuilder& b) { @@ -637,12 +663,14 @@ void trivialControlledRx(QCProgramBuilder& b) { void inverseRx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.rx(-0.123, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.rx(-0.123, qubits[0]); }); } void inverseMultipleControlledRx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcrx(-0.123, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcrx(-0.123, {qubits[0], qubits[1]}, qubits[2]); + }); } void ry(QCProgramBuilder& b) { @@ -662,7 +690,8 @@ void multipleControlledRy(QCProgramBuilder& b) { void nestedControlledRy(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cry(0.456, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cry(0.456, targets[0], targets[1]); }); } void trivialControlledRy(QCProgramBuilder& b) { @@ -672,12 +701,14 @@ void trivialControlledRy(QCProgramBuilder& b) { void inverseRy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.ry(-0.456, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.ry(-0.456, qubits[0]); }); } void inverseMultipleControlledRy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcry(-0.456, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcry(-0.456, {qubits[0], qubits[1]}, qubits[2]); + }); } void rz(QCProgramBuilder& b) { @@ -697,7 +728,8 @@ void multipleControlledRz(QCProgramBuilder& b) { void nestedControlledRz(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.crz(0.789, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.crz(0.789, targets[0], targets[1]); }); } void trivialControlledRz(QCProgramBuilder& b) { @@ -707,12 +739,14 @@ void trivialControlledRz(QCProgramBuilder& b) { void inverseRz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.rz(-0.789, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.rz(-0.789, qubits[0]); }); } void inverseMultipleControlledRz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcrz(-0.789, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcrz(-0.789, {qubits[0], qubits[1]}, qubits[2]); + }); } void p(QCProgramBuilder& b) { @@ -732,7 +766,8 @@ void multipleControlledP(QCProgramBuilder& b) { void nestedControlledP(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cp(0.123, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cp(0.123, targets[0], targets[1]); }); } void trivialControlledP(QCProgramBuilder& b) { @@ -742,12 +777,14 @@ void trivialControlledP(QCProgramBuilder& b) { void inverseP(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.p(-0.123, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.p(-0.123, qubits[0]); }); } void inverseMultipleControlledP(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcp(-0.123, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcp(-0.123, {qubits[0], qubits[1]}, qubits[2]); + }); } void r(QCProgramBuilder& b) { @@ -767,7 +804,9 @@ void multipleControlledR(QCProgramBuilder& b) { void nestedControlledR(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cr(0.123, 0.456, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cr(0.123, 0.456, targets[0], targets[1]); + }); } void trivialControlledR(QCProgramBuilder& b) { @@ -777,12 +816,14 @@ void trivialControlledR(QCProgramBuilder& b) { void inverseR(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.r(-0.123, 0.456, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.r(-0.123, 0.456, qubits[0]); }); } void inverseMultipleControlledR(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcr(-0.123, 0.456, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcr(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2]); + }); } void u2(QCProgramBuilder& b) { @@ -802,7 +843,9 @@ void multipleControlledU2(QCProgramBuilder& b) { void nestedControlledU2(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cu2(0.234, 0.567, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cu2(0.234, 0.567, targets[0], targets[1]); + }); } void trivialControlledU2(QCProgramBuilder& b) { @@ -813,13 +856,16 @@ void trivialControlledU2(QCProgramBuilder& b) { void inverseU2(QCProgramBuilder& b) { constexpr double pi = std::numbers::pi; auto q = b.allocQubitRegister(1); - b.inv([&]() { b.u2(-0.567 + pi, -0.234 - pi, q[0]); }); + b.inv(q[0], + [&](ValueRange qubits) { b.u2(-0.567 + pi, -0.234 - pi, qubits[0]); }); } void inverseMultipleControlledU2(QCProgramBuilder& b) { constexpr double pi = std::numbers::pi; auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcu2(-0.567 + pi, -0.234 - pi, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcu2(-0.567 + pi, -0.234 - pi, {qubits[0], qubits[1]}, qubits[2]); + }); } void u(QCProgramBuilder& b) { @@ -839,7 +885,9 @@ void multipleControlledU(QCProgramBuilder& b) { void nestedControlledU(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cu(0.1, 0.2, 0.3, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cu(0.1, 0.2, 0.3, targets[0], targets[1]); + }); } void trivialControlledU(QCProgramBuilder& b) { @@ -849,12 +897,14 @@ void trivialControlledU(QCProgramBuilder& b) { void inverseU(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.u(-0.1, -0.3, -0.2, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.u(-0.1, -0.3, -0.2, qubits[0]); }); } void inverseMultipleControlledU(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcu(-0.1, -0.3, -0.2, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcu(-0.1, -0.3, -0.2, {qubits[0], qubits[1]}, qubits[2]); + }); } void swap(QCProgramBuilder& b) { @@ -874,7 +924,9 @@ void multipleControlledSwap(QCProgramBuilder& b) { void nestedControlledSwap(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cswap(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cswap(targets[0], targets[1], targets[2]); + }); } void trivialControlledSwap(QCProgramBuilder& b) { @@ -884,12 +936,14 @@ void trivialControlledSwap(QCProgramBuilder& b) { void inverseSwap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.swap(q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.swap(qubits[0], qubits[1]); }); } void inverseMultipleControlledSwap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcswap({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcswap({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void iswap(QCProgramBuilder& b) { @@ -909,7 +963,9 @@ void multipleControlledIswap(QCProgramBuilder& b) { void nestedControlledIswap(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.ciswap(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.ciswap(targets[0], targets[1], targets[2]); + }); } void trivialControlledIswap(QCProgramBuilder& b) { @@ -919,12 +975,15 @@ void trivialControlledIswap(QCProgramBuilder& b) { void inverseIswap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.iswap(q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.iswap(qubits[0], qubits[1]); }); } void inverseMultipleControlledIswap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mciswap({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mciswap({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void dcx(QCProgramBuilder& b) { @@ -944,7 +1003,9 @@ void multipleControlledDcx(QCProgramBuilder& b) { void nestedControlledDcx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cdcx(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cdcx(targets[0], targets[1], targets[2]); + }); } void trivialControlledDcx(QCProgramBuilder& b) { @@ -954,12 +1015,14 @@ void trivialControlledDcx(QCProgramBuilder& b) { void inverseDcx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.dcx(q[1], q[0]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.dcx(qubits[1], qubits[0]); }); } void inverseMultipleControlledDcx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcdcx({q[0], q[1]}, q[3], q[2]); }); + b.inv({q[0], q[1], q[3], q[2]}, [&](ValueRange qubits) { + b.mcdcx({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void ecr(QCProgramBuilder& b) { @@ -979,7 +1042,9 @@ void multipleControlledEcr(QCProgramBuilder& b) { void nestedControlledEcr(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cecr(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cecr(targets[0], targets[1], targets[2]); + }); } void trivialControlledEcr(QCProgramBuilder& b) { @@ -989,12 +1054,14 @@ void trivialControlledEcr(QCProgramBuilder& b) { void inverseEcr(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.ecr(q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.ecr(qubits[0], qubits[1]); }); } void inverseMultipleControlledEcr(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcecr({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcecr({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rxx(QCProgramBuilder& b) { @@ -1014,7 +1081,9 @@ void multipleControlledRxx(QCProgramBuilder& b) { void nestedControlledRxx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crxx(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crxx(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRxx(QCProgramBuilder& b) { @@ -1024,18 +1093,22 @@ void trivialControlledRxx(QCProgramBuilder& b) { void inverseRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rxx(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rxx(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrxx(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrxx(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void tripleControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(5); b.mcrxx(0.123, {q[0], q[1], q[2]}, q[3], q[4]); } + void fourControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(6); b.mcrxx(0.123, {q[0], q[1], q[2], q[3]}, q[4], q[5]); @@ -1058,7 +1131,9 @@ void multipleControlledRyy(QCProgramBuilder& b) { void nestedControlledRyy(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cryy(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cryy(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRyy(QCProgramBuilder& b) { @@ -1068,12 +1143,15 @@ void trivialControlledRyy(QCProgramBuilder& b) { void inverseRyy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.ryy(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.ryy(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRyy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcryy(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcryy(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rzx(QCProgramBuilder& b) { @@ -1093,7 +1171,9 @@ void multipleControlledRzx(QCProgramBuilder& b) { void nestedControlledRzx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crzx(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crzx(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRzx(QCProgramBuilder& b) { @@ -1103,12 +1183,15 @@ void trivialControlledRzx(QCProgramBuilder& b) { void inverseRzx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rzx(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rzx(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRzx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrzx(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrzx(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rzz(QCProgramBuilder& b) { @@ -1128,7 +1211,9 @@ void multipleControlledRzz(QCProgramBuilder& b) { void nestedControlledRzz(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crzz(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crzz(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRzz(QCProgramBuilder& b) { @@ -1138,12 +1223,15 @@ void trivialControlledRzz(QCProgramBuilder& b) { void inverseRzz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rzz(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rzz(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRzz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrzz(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrzz(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void xxPlusYY(QCProgramBuilder& b) { @@ -1163,7 +1251,9 @@ void multipleControlledXxPlusYY(QCProgramBuilder& b) { void nestedControlledXxPlusYY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cxx_plus_yy(0.123, 0.456, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cxx_plus_yy(0.123, 0.456, targets[0], targets[1], targets[2]); + }); } void trivialControlledXxPlusYY(QCProgramBuilder& b) { @@ -1173,12 +1263,16 @@ void trivialControlledXxPlusYY(QCProgramBuilder& b) { void inverseXxPlusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.xx_plus_yy(-0.123, 0.456, q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.xx_plus_yy(-0.123, 0.456, qubits[0], qubits[1]); + }); } void inverseMultipleControlledXxPlusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcxx_plus_yy(-0.123, 0.456, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcxx_plus_yy(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void xxMinusYY(QCProgramBuilder& b) { @@ -1198,7 +1292,9 @@ void multipleControlledXxMinusYY(QCProgramBuilder& b) { void nestedControlledXxMinusYY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cxx_minus_yy(0.123, 0.456, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cxx_minus_yy(0.123, 0.456, targets[0], targets[1], targets[2]); + }); } void trivialControlledXxMinusYY(QCProgramBuilder& b) { @@ -1208,12 +1304,17 @@ void trivialControlledXxMinusYY(QCProgramBuilder& b) { void inverseXxMinusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.xx_minus_yy(-0.123, 0.456, q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.xx_minus_yy(-0.123, 0.456, qubits[0], qubits[1]); + }); } void inverseMultipleControlledXxMinusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcxx_minus_yy(-0.123, 0.456, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcxx_minus_yy(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2], + qubits[3]); + }); } void barrier(QCProgramBuilder& b) { @@ -1233,59 +1334,91 @@ void barrierMultipleQubits(QCProgramBuilder& b) { void singleControlledBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.ctrl(q[1], [&] { b.barrier(q[0]); }); + b.ctrl(q[1], q[0], [&](ValueRange targets) { b.barrier(targets[0]); }); } void inverseBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.barrier(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.barrier(qubits[0]); }); } void trivialCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.ctrl({}, [&]() { b.rxx(0.123, q[0], q[1]); }); + b.ctrl({}, {q[0], q[1]}, + [&](ValueRange targets) { b.rxx(0.123, targets[0], targets[1]); }); } void nestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.ctrl(q[0], [&]() { b.ctrl(q[1], [&]() { b.rxx(0.123, q[2], q[3]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); } void tripleNestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(5); - b.ctrl(q[0], [&]() { - b.ctrl(q[1], [&]() { b.ctrl(q[2], [&]() { b.rxx(0.123, q[3], q[4]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3], q[4]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2], targets[3]}, + [&](ValueRange innerTargets) { + b.ctrl(innerTargets[0], {innerTargets[1], innerTargets[2]}, + [&](ValueRange innerInnerTargets) { + b.rxx(0.123, innerInnerTargets[0], innerInnerTargets[1]); + }); + }); }); } void doubleNestedCtrlTwoQubits(QCProgramBuilder& b) { auto q = b.allocQubitRegister(6); - b.ctrl({q[0], q[1]}, - [&]() { b.ctrl({q[2], q[3]}, [&]() { b.rxx(0.123, q[4], q[5]); }); }); + b.ctrl({q[0], q[1]}, {q[2], q[3], q[4], q[5]}, [&](ValueRange targets) { + b.ctrl({targets[0], targets[1]}, {targets[2], targets[3]}, + [&](ValueRange innerTargets) { + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); } void ctrlInvSandwich(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.ctrl(q[0], [&]() { - b.inv([&]() { b.ctrl(q[1], [&]() { b.rxx(-0.123, q[2], q[3]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.inv(targets, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange innerTargets) { + b.rxx(-0.123, innerTargets[0], innerTargets[1]); + }); + }); }); } void nestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.inv([&]() { b.rxx(0.123, q[0], q[1]); }); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.inv(qubits, [&](ValueRange innerQubits) { + b.rxx(0.123, innerQubits[0], innerQubits[1]); + }); + }); } void tripleNestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv( - [&]() { b.inv([&]() { b.inv([&]() { b.rxx(-0.123, q[0], q[1]); }); }); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.inv(qubits, [&](ValueRange innerQubits) { + b.inv(innerQubits, [&](ValueRange innerInnerQubits) { + b.rxx(-0.123, innerInnerQubits[0], innerInnerQubits[1]); + }); + }); + }); } void invCtrlSandwich(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { - b.ctrl(q[0], [&]() { b.inv([&]() { b.rxx(0.123, q[1], q[2]); }); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange targets) { + b.inv({targets[0], targets[1]}, [&](ValueRange innerQubits) { + b.rxx(0.123, innerQubits[0], innerQubits[1]); + }); + }); }); } @@ -1395,7 +1528,7 @@ void nestedForLoopCtrlOpWithSeparateQubit(QCProgramBuilder& b) { b.scfFor(0, 3, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(control, [&] { b.x(q0); }); + b.ctrl(control, q0, [&](ValueRange targets) { b.x(targets[0]); }); }); } @@ -1405,7 +1538,7 @@ void nestedForLoopCtrlOpWithExtractedQubit(QCProgramBuilder& b) { b.scfFor(1, 4, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(reg[0], [&] { b.x(q0); }); + b.ctrl(reg[0], q0, [&](ValueRange targets) { b.x(targets[0]); }); }); } From fcdfc1af0b03d587a41a3bddc7f4377a510b2c01 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 00:50:38 +0200 Subject: [PATCH 02/17] Fix equivalence checking --- mlir/lib/Support/IRVerification.cpp | 45 ++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index eaac426f0a..8c3b413083 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -10,6 +10,7 @@ #include "mlir/Support/IRVerification.h" +#include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include @@ -469,7 +470,6 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, if (!rhsConst) { return false; } - if (!areConstantAttributesEquivalent(lhsConst.getValue(), rhsConst.getValue())) { return false; @@ -513,17 +513,38 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, return false; } - // Check operands according to value mapping - for (auto [lhsOperand, rhsOperand] : - llvm::zip(lhs->getOperands(), rhs->getOperands())) { - if (auto it = valueMap.find(lhsOperand); it != valueMap.end()) { - // Value already mapped, must match - if (it->second != rhsOperand) { + ValueRange lhsOperands; + ValueRange rhsOperands; + // TODO: Extend this + if (auto lhsCtrl = dyn_cast(lhs)) { + auto rhsCtrl = dyn_cast(rhs); + if (!rhsCtrl) { + return false; + } + if (lhsCtrl.getTargets().size() != rhsCtrl.getTargets().size()) { + return false; + } + for (auto [lhsTarget, lhsArg] : + llvm::zip(lhsCtrl.getTargets(), lhsCtrl.getBody()->getArguments())) { + auto rhsTarget = valueMap[lhsTarget]; + if (!llvm::is_contained(rhsCtrl.getTargets(), rhsTarget)) { return false; } - } else { - // Establish new mapping - valueMap[lhsOperand] = rhsOperand; + auto it = llvm::find(rhsCtrl.getTargets(), rhsTarget); + auto index = std::distance(rhsCtrl.getTargets().begin(), it); + valueMap[lhsArg] = rhsCtrl.getBody()->getArgument(index); + } + lhsOperands = lhsCtrl.getControls(); + rhsOperands = rhsCtrl.getControls(); + } else { + lhsOperands = lhs->getOperands(); + rhsOperands = rhs->getOperands(); + } + + // Check operands according to value mapping + for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOperands, rhsOperands)) { + if (!areValuesEquivalent(lhsOperand, rhsOperand, valueMap)) { + return false; } } @@ -725,7 +746,9 @@ static bool areBlocksEquivalent(Block& lhs, Block& rhs, if (lhsArg.getType() != rhsArg.getType()) { return false; } - valueMap[lhsArg] = rhsArg; + if (!valueMap.contains(lhsArg)) { + valueMap[lhsArg] = rhsArg; + } } // Collect all operations From a69cc622013a78f6305494a329d3c7e79b1062b2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 15:47:00 +0200 Subject: [PATCH 03/17] Fix linter errors --- mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 1 + mlir/lib/Conversion/QCToQCO/QCToQCO.cpp | 2 -- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 2 ++ mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 3 +++ mlir/lib/Dialect/QC/IR/QCOps.cpp | 7 +++++++ mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 1 + mlir/lib/Support/IRVerification.cpp | 2 ++ mlir/unittests/programs/qc_programs.cpp | 2 +- 8 files changed, 17 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 001844e1e6..09bb1b0949 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index 9c32fc302d..33a2df9217 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -1096,7 +1096,6 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto numTargets = op.getNumTargets(); const auto qcControls = op.getControls(); const auto qcTargets = op.getTargets(); auto qcoControls = resolveMappedQubits(state, operation, qcControls); @@ -1154,7 +1153,6 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto numTargets = op.getNumTargets(); const auto qcTargets = op.getTargets(); auto qcoTargets = resolveMappedQubits(state, operation, qcTargets); diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index cf943b1f99..a894891192 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -8,12 +8,14 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" #include #include #include +#include #include #include #include diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index b935e5c823..3a60bde33f 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -8,10 +8,12 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" #include +#include #include #include #include @@ -21,6 +23,7 @@ #include #include +#include #include using namespace mlir; diff --git a/mlir/lib/Dialect/QC/IR/QCOps.cpp b/mlir/lib/Dialect/QC/IR/QCOps.cpp index bf6551f924..6a72833861 100644 --- a/mlir/lib/Dialect/QC/IR/QCOps.cpp +++ b/mlir/lib/Dialect/QC/IR/QCOps.cpp @@ -13,6 +13,13 @@ #include "mlir/Dialect/QC/IR/QCDialect.h" // IWYU pragma: associated #include "mlir/Dialect/Utils/Utils.h" +#include +#include +#include +#include +#include +#include + // The following headers are needed for some template instantiations. // IWYU pragma: begin_keep #include diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 1b6a98c07e..bb7a25925f 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -13,6 +13,7 @@ #include "mlir/Dialect/Utils/Utils.h" #include +#include #include #include #include diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 8c3b413083..07d723ec6b 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -34,11 +34,13 @@ #include #include #include +#include #include #include #include #include +#include #include using namespace mlir; diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 234c70cc9a..22646d0af6 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -204,7 +204,7 @@ void trivialControlledGlobalPhase(QCProgramBuilder& b) { } void inverseGlobalPhase(QCProgramBuilder& b) { - b.inv({}, [&](ValueRange qubits) { b.gphase(-0.123); }); + b.inv({}, [&](ValueRange /*qubits*/) { b.gphase(-0.123); }); } void inverseMultipleControlledGlobalPhase(QCProgramBuilder& b) { From 342234381b939de6fa5385e888404ebecc20ae85 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 17:42:29 +0200 Subject: [PATCH 04/17] Add patterns for removing empty modifiers --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 21 +++++++++++++++---- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 21 +++++++++++++++---- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 21 +++++++++++++++---- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 21 +++++++++++++++---- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 6 +++++- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 6 +++++- mlir/unittests/programs/qc_programs.cpp | 12 +++++++++++ mlir/unittests/programs/qc_programs.h | 6 ++++++ mlir/unittests/programs/qco_programs.cpp | 12 +++++++++++ mlir/unittests/programs/qco_programs.h | 6 ++++++ 10 files changed, 114 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index a894891192..b1147a6d81 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -149,6 +149,22 @@ struct ReduceCtrl final : OpRewritePattern { } }; +/** + * @brief Erase control modifiers that do not have any body unitaries. + */ +struct EraseEmptyCtrl final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CtrlOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace size_t CtrlOp::getNumBodyUnitaries() { @@ -211,9 +227,6 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); @@ -236,5 +249,5 @@ LogicalResult CtrlOp::verify() { void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 3a60bde33f..fa9fdbfca1 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -337,6 +337,22 @@ struct CancelNestedInv final : OpRewritePattern { } }; +/** + * @brief Erase inverse modifiers that do not have any body unitaries. + */ +struct EraseEmptyInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace size_t InvOp::getNumBodyUnitaries() { @@ -381,9 +397,6 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); @@ -394,5 +407,5 @@ LogicalResult InvOp::verify() { void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + ReplaceWithKnownGates, EraseEmptyInv>(context); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index e86f3f7dc3..01a9281a45 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -170,6 +170,22 @@ struct ReduceCtrl final : OpRewritePattern { } }; +/** + * @brief Erase control modifiers that do not have any body unitaries. + */ +struct EraseEmptyCtrl final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CtrlOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + } // namespace size_t CtrlOp::getNumBodyUnitaries() { @@ -292,9 +308,6 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( @@ -360,7 +373,7 @@ LogicalResult CtrlOp::verify() { void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional CtrlOp::getUnitaryMatrix() { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index bb7a25925f..9e7c4f5490 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -361,6 +361,22 @@ struct CancelNestedInv final : OpRewritePattern { } }; +/** + * @brief Erase inverse modifiers that do not have any body unitaries. + */ +struct EraseEmptyInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + } // namespace size_t InvOp::getNumBodyUnitaries() { @@ -437,9 +453,6 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); - } const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( @@ -485,7 +498,7 @@ LogicalResult InvOp::verify() { void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + CancelNestedInv, EraseEmptyInv>(context); } std::optional InvOp::getUnitaryMatrix() { diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 97e4627363..f0c133086a 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -119,6 +119,8 @@ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCTest, testing::Values(QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), MQT_NAMED_BUILDER(multipleControlledRxx)}, QCTestCase{"TripleNestedCtrl", @@ -136,7 +138,9 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCInvOpTest, QCTest, - testing::Values(QCTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), + testing::Values(QCTestCase{"EmptyInv", MQT_NAMED_BUILDER(emptyInv), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), MQT_NAMED_BUILDER(rxx)}, QCTestCase{"TripleNestedInv", MQT_NAMED_BUILDER(tripleNestedInv), diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 413f29336d..0be0914f5e 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -223,6 +223,8 @@ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOTest, testing::Values(QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), MQT_NAMED_BUILDER(multipleControlledRxx)}, QCOTestCase{"TripleNestedCtrl", @@ -240,7 +242,9 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOInvOpTest, QCOTest, - testing::Values(QCOTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), + testing::Values(QCOTestCase{"EmptyInv", MQT_NAMED_BUILDER(emptyInv), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"TripleNestedInv", MQT_NAMED_BUILDER(tripleNestedInv), diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 22646d0af6..5287be6150 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1348,6 +1348,12 @@ void trivialCtrl(QCProgramBuilder& b) { [&](ValueRange targets) { b.rxx(0.123, targets[0], targets[1]); }); } +void emptyCtrl(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.rxx(0.123, q[0], q[1]); + b.ctrl({q[0]}, {q[1]}, [&](ValueRange /*targets*/) {}); +} + void nestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { @@ -1391,6 +1397,12 @@ void ctrlInvSandwich(QCProgramBuilder& b) { }); } +void emptyInv(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.rxx(0.123, q[0], q[1]); + b.inv({q[0], q[1]}, [&](ValueRange /*targets*/) {}); +} + void nestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.inv({q[0], q[1]}, [&](ValueRange qubits) { diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index e6569f7648..eeafac7cac 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -814,6 +814,9 @@ void inverseBarrier(QCProgramBuilder& b); /// Creates a circuit with a trivial ctrl modifier. void trivialCtrl(QCProgramBuilder& b); +/// Creates a circuit with an empty ctrl modifier. +void emptyCtrl(QCProgramBuilder& b); + /// Creates a circuit with nested ctrl modifiers. void nestedCtrl(QCProgramBuilder& b); @@ -828,6 +831,9 @@ void ctrlInvSandwich(QCProgramBuilder& b); // --- InvOp ---------------------------------------------------------------- // +/// Creates a circuit with an empty inverse modifier. +void emptyInv(QCProgramBuilder& b); + /// Creates a circuit with nested inverse modifiers. void nestedInv(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 0ad96fbb10..a985ecb11a 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -1936,6 +1936,12 @@ void trivialCtrl(QCOProgramBuilder& b) { }); } +void emptyCtrl(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); + b.ctrl(q[0], q[1], [&](ValueRange targets) { return targets; }); +} + void nestedCtrl(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl({q[0]}, {q[1], q[2], q[3]}, [&](ValueRange targets) { @@ -2003,6 +2009,12 @@ void ctrlInvSandwich(QCOProgramBuilder& b) { }); } +void emptyInv(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { return qubits; }); +} + void nestedInv(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.inv({q[0], q[1]}, [&](ValueRange qubits) { diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index b4197c5a7f..a8659701a8 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -960,6 +960,9 @@ void twoBarrier(QCOProgramBuilder& b); /// Creates a circuit with a trivial ctrl modifier. void trivialCtrl(QCOProgramBuilder& b); +/// Creates a circuit with an empty ctrl modifier. +void emptyCtrl(QCOProgramBuilder& b); + /// Creates a circuit with nested ctrl modifiers. void nestedCtrl(QCOProgramBuilder& b); @@ -974,6 +977,9 @@ void ctrlInvSandwich(QCOProgramBuilder& b); // --- InvOp ---------------------------------------------------------------- // +/// Creates a circuit with an empty inverse modifier. +void emptyInv(QCOProgramBuilder& b); + /// Creates a circuit with nested inverse modifiers. void nestedInv(QCOProgramBuilder& b); From c4e28828f90f6848a73e351a60e64ded1eab9e84 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Fri, 29 May 2026 18:11:08 +0200 Subject: [PATCH 05/17] Add test cases --- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 8 +- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 10 ++- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 17 +++- .../Conversion/QCToQCO/test_qc_to_qco.cpp | 22 ++++- mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp | 31 +++---- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 39 +++++---- mlir/unittests/programs/qc_programs.cpp | 46 ++++++++++ mlir/unittests/programs/qc_programs.h | 17 ++++ mlir/unittests/programs/qco_programs.cpp | 84 +++++++++++++++++++ mlir/unittests/programs/qco_programs.h | 26 +++++- 10 files changed, 256 insertions(+), 44 deletions(-) diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 01a9281a45..2dba84815d 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -81,10 +81,12 @@ struct MergeNestedCtrl final : OpRewritePattern { IRMapping mapping; utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, targets, targetArgs); - SmallVector yields; for (auto& op : innerCtrlBody->without_terminator()) { - auto results = rewriter.clone(op, mapping)->getResults(); - llvm::append_range(yields, results); + rewriter.clone(op, mapping); + } + SmallVector yields; + for (auto value : innerCtrlBody->getTerminator()->getOperands()) { + yields.push_back(mapping.lookup(value)); } return yields; }); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 9e7c4f5490..872a8a8902 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -70,11 +70,13 @@ struct MoveCtrlOutside final : OpRewritePattern { utils::prova(*innerCtrlBody, mapping, innerCtrlOp.getTargetsIn(), outerQubits, targets, qubitArgs); - SmallVector yields; for (auto& op : innerCtrlBody->without_terminator()) { - auto results = - rewriter.clone(op, mapping)->getResults(); - llvm::append_range(yields, results); + rewriter.clone(op, mapping); + } + SmallVector yields; + for (auto value : + innerCtrlBody->getTerminator()->getOperands()) { + yields.push_back(mapping.lookup(value)); } return yields; }) diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 4bd2b24615..aa3a428809 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -144,6 +144,17 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::allocDeallocPair)})); /// @} +/// \name QCOToQC/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCOCtrlOpTest, QCOToQCTest, + testing::Values(QCOToQCTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qco::ctrlTwo), + MQT_NAMED_BUILDER(qc::ctrlTwo)}, + QCOToQCTestCase{"CtrlInvTwo", + MQT_NAMED_BUILDER(qco::ctrlInvTwo), + MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); +/// @} + /// \name QCOToQC/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( @@ -160,7 +171,11 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::dcx)}, QCOToQCTestCase{"InverseMultipleControlledDCX", MQT_NAMED_BUILDER(qco::inverseMultipleControlledDcx), - MQT_NAMED_BUILDER(qc::multipleControlledDcx)})); + MQT_NAMED_BUILDER(qc::multipleControlledDcx)}, + QCOToQCTestCase{"InvTwo", MQT_NAMED_BUILDER(qco::invTwo), + MQT_NAMED_BUILDER(qc::invTwo)}, + QCOToQCTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qco::invCtrlTwo), + MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); /// @} /// \name QCOToQC/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 3f0df25542..71f47b0841 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -143,6 +143,17 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qco::allocSinkPair)})); /// @} +/// \name QCToQCO/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCCtrlOpTest, QCToQCOTest, + testing::Values(QCToQCOTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(qco::ctrlTwo)}, + QCToQCOTestCase{"CtrlInvTwo", + MQT_NAMED_BUILDER(qc::ctrlInvTwo), + MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); +/// @} + /// \name QCToQCO/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( @@ -151,10 +162,13 @@ INSTANTIATE_TEST_SUITE_P( // iSWAP cannot be inverted with current canonicalization QCToQCOTestCase{"InverseiSWAP", MQT_NAMED_BUILDER(qc::inverseIswap), MQT_NAMED_BUILDER(qco::inverseIswap)}, - QCToQCOTestCase{ - "InverseMultipleControllediSWAP", - MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), - MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)})); + QCToQCOTestCase{"InverseMultipleControllediSWAP", + MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), + MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)}, + QCToQCOTestCase{"InvTwo", MQT_NAMED_BUILDER(qc::invTwo), + MQT_NAMED_BUILDER(qco::invTwo)}, + QCToQCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qc::invCtrlTwo), + MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); /// @} /// \name QCToQCO/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index f0c133086a..4d0f56912b 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -117,21 +117,22 @@ TEST_F(QCTest, BuilderRejectsMixedStaticAndDynamicQubitAllocationModes) { /// @{ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCTest, - testing::Values(QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCTestCase{"TripleNestedCtrl", - MQT_NAMED_BUILDER(tripleNestedCtrl), - MQT_NAMED_BUILDER(tripleControlledRxx)}, - QCTestCase{"CtrlInvSandwich", - MQT_NAMED_BUILDER(ctrlInvSandwich), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCTestCase{"DoubleNestedCtrlTwoQubits", - MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), - MQT_NAMED_BUILDER(fourControlledRxx)})); + testing::Values( + QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCTestCase{"TripleNestedCtrl", MQT_NAMED_BUILDER(tripleNestedCtrl), + MQT_NAMED_BUILDER(tripleControlledRxx)}, + QCTestCase{"CtrlInvSandwich", MQT_NAMED_BUILDER(ctrlInvSandwich), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCTestCase{"DoubleNestedCtrlTwoQubits", + MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), + MQT_NAMED_BUILDER(fourControlledRxx)}, + QCTestCase{"NestedCtrlTwo", MQT_NAMED_BUILDER(nestedCtrlTwo), + MQT_NAMED_BUILDER(ctrlTwo)})); /// @} /// \name QC/Modifiers/InvOp.cpp diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 0be0914f5e..4a8fb691ad 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -221,21 +221,22 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOTest, - testing::Values(QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCOTestCase{"TripleNestedCtrl", - MQT_NAMED_BUILDER(tripleNestedCtrl), - MQT_NAMED_BUILDER(tripleControlledRxx)}, - QCOTestCase{"CtrlInvSandwich", - MQT_NAMED_BUILDER(ctrlInvSandwich), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCOTestCase{"DoubleNestedCtrlTwoQubits", - MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), - MQT_NAMED_BUILDER(fourControlledRxx)})); + testing::Values( + QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCOTestCase{"TripleNestedCtrl", MQT_NAMED_BUILDER(tripleNestedCtrl), + MQT_NAMED_BUILDER(tripleControlledRxx)}, + QCOTestCase{"CtrlInvSandwich", MQT_NAMED_BUILDER(ctrlInvSandwich), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCOTestCase{"DoubleNestedCtrlTwoQubits", + MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), + MQT_NAMED_BUILDER(fourControlledRxx)}, + QCOTestCase{"NestedCtrlTwo", MQT_NAMED_BUILDER(nestedCtrlTwo), + MQT_NAMED_BUILDER(ctrlTwo)})); /// @} /// \name QCO/Modifiers/InvOp.cpp @@ -251,7 +252,9 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"InvControlSandwich", MQT_NAMED_BUILDER(invCtrlSandwich), - MQT_NAMED_BUILDER(singleControlledRxx)})); + MQT_NAMED_BUILDER(singleControlledRxx)}, + QCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(invCtrlTwo), + MQT_NAMED_BUILDER(ctrlInvTwo)})); /// @} /// \name QCO/Operations/StandardGates/BarrierOp.cpp @@ -963,6 +966,10 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(inverseMultipleControlledX), MQT_NAMED_BUILDER(multipleControlledX)}, QCOTestCase{"TwoX", MQT_NAMED_BUILDER(twoX), + MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"ControlledTwoX", MQT_NAMED_BUILDER(controlledTwoX), + MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"inverseTwoX", MQT_NAMED_BUILDER(twoX), MQT_NAMED_BUILDER(emptyQCO)})); /// @} diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 5287be6150..5b10a98af4 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1397,6 +1397,34 @@ void ctrlInvSandwich(QCProgramBuilder& b) { }); } +void ctrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + b.x(targets[0]); + b.rxx(0.123, targets[0], targets[1]); + }); +} + +void nestedCtrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + b.x(innerTargets[0]); + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); +} + +void ctrlInvTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], {q[1], q[2]}, [&](ValueRange targets) { + b.inv(targets, [&](ValueRange qubits) { + b.x(qubits[0]); + b.rxx(0.123, qubits[0], qubits[1]); + }); + }); +} + void emptyInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.rxx(0.123, q[0], q[1]); @@ -1434,6 +1462,24 @@ void invCtrlSandwich(QCProgramBuilder& b) { }); } +void invTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.x(qubits[0]); + b.rxx(0.123, qubits[0], qubits[1]); + }); +} + +void invCtrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange targets) { + b.x(targets[0]); + b.rxx(0.123, targets[0], targets[1]); + }); + }); +} + void simpleIf(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.h(q[0]); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index eeafac7cac..14114b1ee2 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -829,6 +829,16 @@ void doubleNestedCtrlTwoQubits(QCProgramBuilder& b); /// Creates a circuit with control modifiers interleaved by an inverse modifier. void ctrlInvSandwich(QCProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QCProgramBuilder& b); + +/// Creates a circuit with nested control modifiers applied to two gates. +void nestedCtrlTwo(QCProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to a inverse modifier +/// applied to two gates. +void ctrlInvTwo(QCProgramBuilder& b); + // --- InvOp ---------------------------------------------------------------- // /// Creates a circuit with an empty inverse modifier. @@ -843,6 +853,13 @@ void tripleNestedInv(QCProgramBuilder& b); /// Creates a circuit with inverse modifiers interleaved by a control modifier. void invCtrlSandwich(QCProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to two gates. +void invTwo(QCProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a control modifier +/// applied to two gates. +void invCtrlTwo(QCProgramBuilder& b); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index a985ecb11a..868e3a2a3b 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -301,6 +301,24 @@ void twoX(QCOProgramBuilder& b) { q[0] = b.x(q[0]); } +void controlledTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.ctrl(q[0], q[1], [&](ValueRange targets) { + auto q = b.x(targets[0]); + q = b.x(q); + return SmallVector{q}; + }); +} + +void inverseTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.inv(q[0], [&](ValueRange qubits) { + auto q = b.x(qubits[0]); + q = b.x(q); + return SmallVector{q}; + }); +} + void y(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.y(q[0]); @@ -2009,6 +2027,46 @@ void ctrlInvSandwich(QCOProgramBuilder& b) { }); } +void ctrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void nestedCtrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + const auto& [controlsOut, targetsOut] = b.ctrl( + targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + auto i0 = innerTargets[0]; + auto i1 = innerTargets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(llvm::concat(controlsOut, targetsOut)); + }); +} + +void ctrlInvTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], {q[1], q[2]}, [&](ValueRange targets) { + auto inner = b.inv(targets, [&](ValueRange qubits) { + auto i0 = qubits[0]; + auto i1 = qubits[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(inner); + }); +} + void emptyInv(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(2); std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); @@ -2058,6 +2116,32 @@ void invCtrlSandwich(QCOProgramBuilder& b) { }); } +void invTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + auto i0 = qubits[0]; + auto i1 = qubits[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void invCtrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + const auto& [controlsOut, targetsOut] = + b.ctrl({qubits[0]}, {qubits[1], qubits[2]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(llvm::concat(controlsOut, targetsOut)); + }); +} + void simpleIf(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); auto q0 = b.h(q[0]); diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index a8659701a8..1ec606d103 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -167,9 +167,16 @@ void inverseX(QCOProgramBuilder& b); /// Creates a circuit with an inverse modifier applied to a controlled X gate. void inverseMultipleControlledX(QCOProgramBuilder& b); -/// Creates a circuit with two X gates in a row. +/// Creates a circuit with two subsequent X gates. void twoX(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two subsequent X gates. +void controlledTwoX(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to two subsequent X +/// gates. +void inverseTwoX(QCOProgramBuilder& b); + // --- YOp ------------------------------------------------------------------ // /// Creates a circuit with just a Y gate. @@ -975,6 +982,16 @@ void doubleNestedCtrlTwoQubits(QCOProgramBuilder& b); /// Creates a circuit with control modifiers interleaved by an inverse modifier. void ctrlInvSandwich(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QCOProgramBuilder& b); + +/// Creates a circuit with nested control modifiers applied to two gates. +void nestedCtrlTwo(QCOProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to an inverse modifier +/// applied to two gates. +void ctrlInvTwo(QCOProgramBuilder& b); + // --- InvOp ---------------------------------------------------------------- // /// Creates a circuit with an empty inverse modifier. @@ -989,6 +1006,13 @@ void tripleNestedInv(QCOProgramBuilder& b); /// Creates a circuit with inverse modifiers interleaved by a control modifier. void invCtrlSandwich(QCOProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to two gates. +void invTwo(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a control modifier +/// applied to two gates. +void invCtrlTwo(QCOProgramBuilder& b); + // --- IfOp ---------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. From 60249614384b009d57ce84da637369db1d88ab0f Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 15:22:32 +0200 Subject: [PATCH 06/17] Add support for translating CompoundOperations --- .../TranslateQuantumComputationToQC.cpp | 176 +++++++++++++----- .../test_quantum_computation_translation.cpp | 6 + mlir/unittests/programs/qc_programs.cpp | 14 +- mlir/unittests/programs/qc_programs.h | 6 +- .../programs/quantum_computation_programs.cpp | 25 +++ .../programs/quantum_computation_programs.h | 8 + 6 files changed, 178 insertions(+), 57 deletions(-) diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index 66d562ae82..727679d223 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -73,6 +73,27 @@ struct TranslationState { /// Whether the translation is currently processing an IfElseOperation bool inIfElse = false; + + /// Whether the translation is currently within a control modifier + bool inCtrlOp = false; + + /// Mapping from physical qubit index to block argument + DenseMap ctrlTargets{}; + + Value getQubit(size_t index) const { + if (!inCtrlOp) { + if (index >= qubits.size()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return qubits[index]; + } else { + auto it = ctrlTargets.find(index); + if (it == ctrlTargets.end()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return it->second; + } + }; }; } // namespace @@ -222,7 +243,7 @@ static void addMeasureOp(QCProgramBuilder& builder, const auto& classics = measureOp.getClassics(); for (size_t i = 0; i < targets.size(); ++i) { - const auto& qubit = state.qubits[targets[i]]; + const auto& qubit = state.getQubit(targets[i]); const auto bitIdx = static_cast(classics[i]); const auto& [mem, localIdx] = state.bitMap[bitIdx]; const auto& bit = mem[static_cast(localIdx)]; @@ -239,13 +260,13 @@ static void addMeasureOp(QCProgramBuilder& builder, * * @param builder The QCProgramBuilder used to create operations * @param operation The reset operation to translate - * @param qubits Flat vector of qubit values indexed by physical qubit index + * @param state The translation state */ static void addResetOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { for (const auto& target : operation.getTargets()) { - auto qubit = qubits[target]; + auto qubit = state.getQubit(target); builder.reset(qubit); } } @@ -258,18 +279,21 @@ static void addResetOp(QCProgramBuilder& builder, * the qubit values corresponding to positive controls. * * @param operation The operation containing controls - * @param qubits Flat vector of qubit values indexed by physical qubit index + * @param state The translation state * @return Vector of qubit values corresponding to positive controls */ static SmallVector getControls(const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { + if (state.inCtrlOp) { + return {}; + } SmallVector controls; for (const auto& [control, type] : operation.getControls()) { if (type == ::qc::Control::Type::Neg) { llvm::reportFatalInternalError( "Negative controls cannot be translated to QC at the moment"); } - controls.push_back(qubits[control]); + controls.push_back(state.getQubit(control)); } return controls; } @@ -286,13 +310,13 @@ static SmallVector getControls(const ::qc::Operation& operation, * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + TranslationState& state) { \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(target); \ } else { \ @@ -326,14 +350,14 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdg, sxdg) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param = operation.getParameter()[0]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param, target); \ } else { \ @@ -358,15 +382,15 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(P, p) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, target); \ } else { \ @@ -391,16 +415,16 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2, u2) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ const auto& param3 = operation.getParameter()[2]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, param3, target); \ } else { \ @@ -424,14 +448,14 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(U, u) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + TranslationState& state) { \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(target0, target1); \ } else { \ @@ -448,10 +472,10 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECR, ecr) static void addISWAPdgOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { - auto target0 = qubits[operation.getTargets()[0]]; - auto target1 = qubits[operation.getTargets()[1]]; - if (const auto& controls = getControls(operation, qubits); controls.empty()) { + TranslationState& state) { + auto target0 = state.getQubit(operation.getTargets()[0]); + auto target1 = state.getQubit(operation.getTargets()[1]); + if (const auto& controls = getControls(operation, state); controls.empty()) { builder.inv({target0, target1}, [&](ValueRange qubits) { builder.iswap(qubits[0], qubits[1]); }); @@ -476,15 +500,15 @@ static void addISWAPdgOp(QCProgramBuilder& builder, * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param = operation.getParameter()[0]; \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param, target0, target1); \ } else { \ @@ -511,16 +535,16 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZ, rzz) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, target0, target1); \ } else { \ @@ -537,10 +561,10 @@ DEFINE_TWO_TARGET_TWO_PARAMETER(XXminusYY, xx_minus_yy) static void addBarrierOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { SmallVector targets; for (const auto& targetIdx : operation.getTargets()) { - targets.push_back(qubits[targetIdx]); + targets.push_back(state.getQubit(targetIdx)); } builder.barrier(targets); } @@ -550,6 +574,60 @@ static LogicalResult translateOperation(QCProgramBuilder& builder, const ::qc::Operation& operation, TranslationState& state); +// CompoundOp + +static LogicalResult addCompoundOp(QCProgramBuilder& builder, + const ::qc::Operation& operation, + TranslationState& state) { + const auto& compoundOp = + dynamic_cast(operation); + if (const auto& controls = getControls(operation, state); controls.empty()) { + for (const auto& op : compoundOp) { + if (failed(translateOperation(builder, *op, state))) { + return failure(); + } + } + } else { + // Collect targets + DenseMap targetMap; + for (const auto& op : compoundOp) { + if (dynamic_cast(op.get()) != nullptr) { + llvm::reportFatalInternalError("Nested CompoundOperations cannot be " + "translated to QC at the moment"); + } + for (const auto& target : op->getTargets()) { + if (!targetMap.contains(target)) { + targetMap[target] = state.getQubit(target); + } + } + } + SmallVector> sortedPairs(targetMap.begin(), + targetMap.end()); + std::sort(sortedPairs.begin(), sortedPairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + SmallVector targets; + for (const auto& pair : sortedPairs) { + targets.push_back(pair.second); + } + // Build control modifier + builder.ctrl(controls, targets, [&](ValueRange targetArgs) { + state.inCtrlOp = true; + for (size_t i = 0; i < sortedPairs.size(); ++i) { + state.ctrlTargets[sortedPairs[i].first] = targetArgs[i]; + } + for (const auto& op : compoundOp) { + if (failed(translateOperation(builder, *op, state))) { + llvm::reportFatalInternalError("Failed to translate operation inside " + "controlled CompoundOperation"); + } + } + state.ctrlTargets.clear(); + state.inCtrlOp = false; + }); + } + return success(); +} + // IfElseOp static LogicalResult addIfElseOp(QCProgramBuilder& builder, @@ -626,7 +704,7 @@ static LogicalResult addIfElseOp(QCProgramBuilder& builder, #define ADD_OP_CASE(OP_CORE) \ case ::qc::OpType::OP_CORE: \ - add##OP_CORE##Op(builder, operation, qubits); \ + add##OP_CORE##Op(builder, operation, state); \ return success(); /** @@ -640,7 +718,6 @@ static LogicalResult addIfElseOp(QCProgramBuilder& builder, static LogicalResult translateOperation(QCProgramBuilder& builder, const ::qc::Operation& operation, TranslationState& state) { - const auto& qubits = state.qubits; switch (operation.getType()) { case ::qc::OpType::Measure: addMeasureOp(builder, operation, state); @@ -676,7 +753,12 @@ static LogicalResult translateOperation(QCProgramBuilder& builder, ADD_OP_CASE(XXminusYY) ADD_OP_CASE(Barrier) case ::qc::OpType::iSWAPdg: - addISWAPdgOp(builder, operation, qubits); + addISWAPdgOp(builder, operation, state); + return success(); + case ::qc::OpType::Compound: + if (failed(addCompoundOp(builder, operation, state))) { + return failure(); + } return success(); case ::qc::OpType::IfElse: if (failed(addIfElseOp(builder, operation, state))) { diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index 0e5d53783b..893902448b 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -418,9 +418,15 @@ INSTANTIATE_TEST_SUITE_P( "BarrierMultipleQubits", MQT_NAMED_BUILDER(qc::barrierMultipleQubits), MQT_NAMED_BUILDER(mlir::qc::barrierMultipleQubits)}, + QuantumComputationTranslationTestCase{ + "CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(mlir::qc::ctrlTwo)}, QuantumComputationTranslationTestCase{ "SimpleIf", MQT_NAMED_BUILDER(qc::simpleIf), MQT_NAMED_BUILDER(mlir::qc::simpleIf)}, + QuantumComputationTranslationTestCase{ + "IfTwoQubits", MQT_NAMED_BUILDER(qc::ifTwoQubits), + MQT_NAMED_BUILDER(mlir::qc::ifTwoQubits)}, QuantumComputationTranslationTestCase{ "IfElse", MQT_NAMED_BUILDER(qc::ifElse), MQT_NAMED_BUILDER(mlir::qc::ifElse)})); diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 5b10a98af4..7c7608963d 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1487,13 +1487,6 @@ void simpleIf(QCProgramBuilder& b) { b.scfIf(cond, [&] { b.x(q[0]); }); } -void ifElse(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(1); - b.h(q[0]); - auto cond = b.measure(q[0]); - b.scfIf(cond, [&] { b.x(q[0]); }, [&] { b.z(q[0]); }); -} - void ifTwoQubits(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.h(q[0]); @@ -1504,6 +1497,13 @@ void ifTwoQubits(QCProgramBuilder& b) { }); } +void ifElse(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.h(q[0]); + auto cond = b.measure(q[0]); + b.scfIf(cond, [&] { b.x(q[0]); }, [&] { b.z(q[0]); }); +} + void nestedIfOpForLoop(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); auto q0 = b.allocQubit(); diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 14114b1ee2..2f08c5236e 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -865,12 +865,12 @@ void invCtrlTwo(QCProgramBuilder& b); /// Creates a circuit with a simple if operation with one qubit. void simpleIf(QCProgramBuilder& b); -/// Creates a circuit with an if operation with an else branch. -void ifElse(QCProgramBuilder& b); - /// Creates a circuit with an if operation with two qubits. void ifTwoQubits(QCProgramBuilder& b); +/// Creates a circuit with an if operation with an else branch. +void ifElse(QCProgramBuilder& b); + /// Creates a circuit with an if operation with a nested for operation with /// a register. void nestedIfOpForLoop(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index 719fd50b17..db12525029 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -15,6 +15,7 @@ #include "ir/operations/StandardOperation.hpp" #include +#include namespace qc { @@ -538,6 +539,17 @@ void barrierMultipleQubits(QuantumComputation& comp) { comp.barrier({0, 1, 2}); } +void ctrlTwo(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(4, "q"); + CompoundOperation compound; + compound.emplace_back(2, X); + compound.emplace_back(Targets{2, 3}, RXX, + std::vector{0.123}); + compound.addControl(0); + compound.addControl(1); + comp.emplace_back(std::move(compound)); +} + void simpleIf(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); @@ -546,6 +558,19 @@ void simpleIf(QuantumComputation& comp) { comp.if_(X, q[0], c[0]); } +void ifTwoQubits(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(2, "q"); + const auto& c = comp.addClassicalRegister(1, "c"); + comp.h(q[0]); + comp.measure(q[0], c[0]); + CompoundOperation compound; + compound.emplace_back(0, X); + compound.emplace_back(1, X); + IfElseOperation ifElse( + std::make_unique(std::move(compound)), nullptr, c[0]); + comp.emplace_back(std::move(ifElse)); +} + void ifElse(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); diff --git a/mlir/unittests/programs/quantum_computation_programs.h b/mlir/unittests/programs/quantum_computation_programs.h index b21bba30d5..f0e1856d8f 100644 --- a/mlir/unittests/programs/quantum_computation_programs.h +++ b/mlir/unittests/programs/quantum_computation_programs.h @@ -385,11 +385,19 @@ void barrierTwoQubits(QuantumComputation& comp); /// Creates a circuit with a barrier on multiple qubits. void barrierMultipleQubits(QuantumComputation& comp); +// --- CtrlOp --------------------------------------------------------------- // + +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QuantumComputation& comp); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. void simpleIf(QuantumComputation& comp); +/// Creates a circuit with an if operation with two qubits. +void ifTwoQubits(QuantumComputation& comp); + /// Creates a circuit with an if operation with an else branch. void ifElse(QuantumComputation& comp); From ccec22f71905f9612b58fae2da973c1f5a92ffe0 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 16:57:11 +0200 Subject: [PATCH 07/17] Fix QC-to-QIR conversion --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 24 ++++++++++++------- .../Compiler/test_compiler_pipeline.cpp | 5 +++- .../Conversion/QCToQIR/test_qc_to_qir.cpp | 8 +++++++ mlir/unittests/programs/qir_programs.cpp | 6 +++++ mlir/unittests/programs/qir_programs.h | 5 ++++ 5 files changed, 38 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 6c432f57d2..832820ca65 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -93,7 +93,7 @@ struct LoweringState : QIRMetadata { /// Modifier information int64_t inCtrlOp = 0; - DenseMap> controls; + SmallVector controls; /// Allocator and StringSaver for stable StringRefs llvm::BumpPtrAllocator allocator; @@ -174,7 +174,7 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, // Query state for modifier information const auto inCtrlOp = state.inCtrlOp; const SmallVector controls = - inCtrlOp != 0 ? state.controls[inCtrlOp] : SmallVector{}; + inCtrlOp != 0 ? state.controls : SmallVector{}; const size_t numCtrls = controls.size(); // Define argument types @@ -209,9 +209,9 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); // Clean up modifier information - if (inCtrlOp != 0) { - state.controls.erase(inCtrlOp); - state.inCtrlOp--; + state.inCtrlOp--; + if (inCtrlOp == 0) { + state.controls.clear(); } // Replace operation with CallOp @@ -315,7 +315,7 @@ struct ConvertQCUnitaryOpQIR : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); const auto inCtrlOp = state.inCtrlOp; - const size_t numCtrls = inCtrlOp != 0 ? state.controls[inCtrlOp].size() : 0; + const size_t numCtrls = inCtrlOp != 0 ? state.controls.size() : 0; const auto fnName = GetFnName(numCtrls); return convertUnitaryToCallOp(op, adaptor, rewriter, this->getContext(), state, fnName, NumTargets, NumParams); @@ -863,12 +863,18 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // Update modifier information auto& state = getState(); - state.inCtrlOp++; + + if (state.inCtrlOp != 0) { + return rewriter.notifyMatchFailure(op, + "Nested CtrlOps are not supported"); + } + + // Update modifier information + state.inCtrlOp = op.getNumBodyUnitaries(); const SmallVector controls(adaptor.getControls().begin(), adaptor.getControls().end()); - state.controls[state.inCtrlOp] = controls; + state.controls = controls; // Inline block and remove operation rewriter.inlineBlockBefore(&op.getRegion().front(), op, diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 44618eedae..d9072d6630 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -686,6 +686,9 @@ INSTANTIATE_TEST_SUITE_P( "MultipleControlledXXMinusYY", MQT_NAMED_BUILDER(qc::multipleControlledXxMinusYY), nullptr, MQT_NAMED_BUILDER(mlir::qc::multipleControlledXxMinusYY), - MQT_NAMED_BUILDER(mlir::qir::multipleControlledXxMinusYY)})); + MQT_NAMED_BUILDER(mlir::qir::multipleControlledXxMinusYY)}, + CompilerPipelineTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + nullptr, MQT_NAMED_BUILDER(mlir::qc::ctrlTwo), + MQT_NAMED_BUILDER(mlir::qir::ctrlTwo)})); } // namespace mqt::test::compiler diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index cd8bcf6073..6de8bf8483 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -649,3 +649,11 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::allocDeallocPair), MQT_NAMED_BUILDER(qir::emptyQIR)})); /// @} + +/// \name QCToQIR/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P(QCToQIRCtrlOpTest, QCToQIRTest, + testing::Values(QCToQIRTestCase{ + "NestedCtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(qir::ctrlTwo)})); +/// @} diff --git a/mlir/unittests/programs/qir_programs.cpp b/mlir/unittests/programs/qir_programs.cpp index 6ae5023a09..68f209943f 100644 --- a/mlir/unittests/programs/qir_programs.cpp +++ b/mlir/unittests/programs/qir_programs.cpp @@ -605,4 +605,10 @@ void multipleControlledXxMinusYY(QIRProgramBuilder& b) { b.mcxx_minus_yy(0.123, 0.456, {q[0], q[1]}, q[2], q[3]); } +void ctrlTwo(QIRProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.mcx({q[0], q[1]}, q[2]); + b.mcrxx(0.123, {q[0], q[1]}, q[2], q[3]); +} + } // namespace mlir::qir diff --git a/mlir/unittests/programs/qir_programs.h b/mlir/unittests/programs/qir_programs.h index 92f6c54078..86a7f7c807 100644 --- a/mlir/unittests/programs/qir_programs.h +++ b/mlir/unittests/programs/qir_programs.h @@ -422,4 +422,9 @@ void singleControlledXxMinusYY(QIRProgramBuilder& b); /// Creates a circuit with a multi-controlled XXMinusYY gate. void multipleControlledXxMinusYY(QIRProgramBuilder& b); +// --- CtrlOp --------------------------------------------------------------- // + +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QIRProgramBuilder& b); + } // namespace mlir::qir From e5eb892d884a3ea0c556a718700fa275d538fde8 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:10:01 +0200 Subject: [PATCH 08/17] Resolve TODO comments --- mlir/include/mlir/Dialect/Utils/Utils.h | 20 ++++++++++++++------ mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 4 ++-- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 5 +++-- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 18 ++---------------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 16 +++------------- mlir/lib/Support/IRVerification.cpp | 1 - 6 files changed, 24 insertions(+), 40 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 546ecc479c..072c2c2368 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -161,7 +161,10 @@ static void printTargetAliasing(OpAsmPrinter& printer, Region& region, printer.printRegion(region, false); } -// TODO: Document +/** + * @brief Get the value corresponding to @p qubit from the block arguments @p + * qubits if @p qubit is a block argument, otherwise return @p qubit itself. + */ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { if (auto blockArg = dyn_cast(qubit)) { return qubits[blockArg.getArgNumber()]; @@ -169,10 +172,15 @@ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { return qubit; } -// TODO: Rename and document -static void prova(Block& block, IRMapping& mapping, ValueRange innerQubits, - ValueRange outerQubits, ValueRange newQubits, - ValueRange qubitArgs) { +/** + * @brief Create a mapping between block arguments and qubit values. + * + * @details This helper function is used to resolve block arguments for nested + * modifiers. + */ +static void populateMapping(Block& block, IRMapping& mapping, + ValueRange innerQubits, ValueRange outerQubits, + ValueRange newQubits, ValueRange qubitArgs) { for (auto arg : block.getArguments()) { auto innerQubit = innerQubits[arg.getArgNumber()]; auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); @@ -180,7 +188,7 @@ static void prova(Block& block, IRMapping& mapping, ValueRange innerQubits, auto index = std::distance(newQubits.begin(), it); mapping.map(arg, qubitArgs[index]); } else { - llvm::reportFatalInternalError("TODO"); + llvm::reportFatalInternalError("Outer qubit not found in new qubits"); } } } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index b1147a6d81..d76aef7ade 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -72,8 +72,8 @@ struct MergeNestedCtrl final : OpRewritePattern { op, controls, targets, [&](ValueRange targetArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, - targets, targetArgs); + utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index fa9fdbfca1..fa1273d503 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -61,8 +61,9 @@ struct MoveCtrlOutside final : OpRewritePattern { rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, innerCtrlOp.getTargets(), - outerQubits, targets, qubitArgs); + utils::populateMapping(*innerCtrlBody, mapping, + innerCtrlOp.getTargets(), outerQubits, + targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 2dba84815d..69febe54e4 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -79,8 +79,8 @@ struct MergeNestedCtrl final : OpRewritePattern { [&](ValueRange targetArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, innerTargets, outerTargets, - targets, targetArgs); + utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } @@ -333,7 +333,6 @@ LogicalResult CtrlOp::verify() { } SmallPtrSet uniqueQubitsIn; - SmallPtrSet uniqueTargetsIn; for (const auto& control : getControlsIn()) { if (!uniqueQubitsIn.insert(control).second) { return emitOpError("duplicate control qubit found"); @@ -343,21 +342,8 @@ LogicalResult CtrlOp::verify() { if (!uniqueQubitsIn.insert(target).second) { return emitOpError("duplicate target qubit found"); } - if (!uniqueTargetsIn.insert(target).second) { - return emitOpError("duplicate target qubit found"); - } } - // TODO: Re-enable - // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { - // auto bodyUnitary = getBodyUnitary(i); - // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { - // if (!uniqueTargetsIn.contains(bodyUnitary.getInputQubit(j))) { - // return emitOpError("unitary is using an unknown input qubit"); - // } - // } - // } - SmallPtrSet uniqueQubitsOut; for (const auto& control : getControlsOut()) { if (!uniqueQubitsOut.insert(control).second) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 872a8a8902..0c8b4bac98 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -67,9 +67,9 @@ struct MoveCtrlOutside final : OpRewritePattern { [&](ValueRange qubitArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::prova(*innerCtrlBody, mapping, - innerCtrlOp.getTargetsIn(), outerQubits, - targets, qubitArgs); + utils::populateMapping(*innerCtrlBody, mapping, + innerCtrlOp.getTargetsIn(), + outerQubits, targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); } @@ -484,16 +484,6 @@ LogicalResult InvOp::verify() { } } - // TODO: Re-enable - // for (size_t i = 0; i < getNumBodyUnitaries(); ++i) { - // auto bodyUnitary = getBodyUnitary(i); - // for (size_t j = 0; j < bodyUnitary.getNumQubits(); ++j) { - // if (!uniqueQubitsIn.contains(bodyUnitary.getInputQubit(j))) { - // return emitOpError("unitary is using an unknown qubit"); - // } - // } - // } - return success(); } diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index 07d723ec6b..0221464606 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -517,7 +517,6 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, ValueRange lhsOperands; ValueRange rhsOperands; - // TODO: Extend this if (auto lhsCtrl = dyn_cast(lhs)) { auto rhsCtrl = dyn_cast(rhs); if (!rhsCtrl) { From fed9ed86725efc786395feed55936bbb1fcea524 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:26:22 +0200 Subject: [PATCH 09/17] Remove remaining TODO comments in preparation for the Rabbit --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 2 -- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 3 --- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 3 --- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 4 ---- 4 files changed, 12 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index d76aef7ade..fad9a4b2e6 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -43,7 +43,6 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -92,7 +91,6 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index fa1273d503..b34424cceb 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -40,7 +40,6 @@ struct MoveCtrlOutside final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -306,7 +305,6 @@ struct CancelNestedInv final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -315,7 +313,6 @@ struct CancelNestedInv final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (innerInvOp.getNumBodyUnitaries() != 1) { return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 69febe54e4..8994661d86 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -49,7 +49,6 @@ struct MergeNestedCtrl final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -104,7 +103,6 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -365,7 +363,6 @@ void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional CtrlOp::getUnitaryMatrix() { - // TODO: Relax this condition if (getNumBodyUnitaries() != 1) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 0c8b4bac98..bf9da8144e 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -44,7 +44,6 @@ struct MoveCtrlOutside final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -331,7 +330,6 @@ struct CancelNestedInv final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - // TODO: Relax this condition? if (op.getNumBodyUnitaries() != 1) { return failure(); } @@ -340,7 +338,6 @@ struct CancelNestedInv final : OpRewritePattern { return failure(); } - // TODO: Relax this condition? if (innerInvOp.getNumBodyUnitaries() != 1) { return failure(); } @@ -494,7 +491,6 @@ void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, } std::optional InvOp::getUnitaryMatrix() { - // TODO: Relax this condition if (getNumBodyUnitaries() != 1) { return std::nullopt; } From d4c45e950ee74a9cd0dc79a6130ee21bf7e33477 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 17:31:46 +0200 Subject: [PATCH 10/17] Fix linter errors --- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 2 +- .../TranslateQuantumComputationToQC.cpp | 28 +++++++++++-------- .../programs/quantum_computation_programs.cpp | 3 ++ 3 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 832820ca65..3078d15b76 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -92,7 +92,7 @@ struct LoweringState : QIRMetadata { DenseMap resultPtrs; /// Modifier information - int64_t inCtrlOp = 0; + size_t inCtrlOp = 0; SmallVector controls; /// Allocator and StringSaver for stable StringRefs diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index 727679d223..930817894d 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -12,6 +12,7 @@ #include "ir/QuantumComputation.hpp" #include "ir/Register.hpp" +#include "ir/operations/CompoundOperation.hpp" #include "ir/operations/Control.hpp" #include "ir/operations/IfElseOperation.hpp" #include "ir/operations/NonUnitaryOperation.hpp" @@ -19,6 +20,7 @@ #include "ir/operations/Operation.hpp" #include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" +#include #include #include #include @@ -78,21 +80,21 @@ struct TranslationState { bool inCtrlOp = false; /// Mapping from physical qubit index to block argument - DenseMap ctrlTargets{}; + DenseMap ctrlTargets; - Value getQubit(size_t index) const { - if (!inCtrlOp) { - if (index >= qubits.size()) { - llvm::reportFatalInternalError("Qubit index out of bounds"); - } - return qubits[index]; - } else { + [[nodiscard]] Value getQubit(size_t index) const { + if (inCtrlOp) { auto it = ctrlTargets.find(index); if (it == ctrlTargets.end()) { llvm::reportFatalInternalError("Qubit index out of bounds"); } return it->second; } + + if (index >= qubits.size()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return qubits[index]; }; }; @@ -603,8 +605,8 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, } SmallVector> sortedPairs(targetMap.begin(), targetMap.end()); - std::sort(sortedPairs.begin(), sortedPairs.end(), - [](const auto& a, const auto& b) { return a.first < b.first; }); + llvm::sort(sortedPairs.begin(), sortedPairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); SmallVector targets; for (const auto& pair : sortedPairs) { targets.push_back(pair.second); @@ -845,8 +847,10 @@ OwningOpRef translateQuantumComputationToQC( // Allocate result map SmallVector results(quantumComputation.getNcbits(), nullptr); - TranslationState state{ - .qubits = qubits, .bitMap = bitMap, .results = std::move(results)}; + TranslationState state{.qubits = qubits, + .bitMap = bitMap, + .results = std::move(results), + .ctrlTargets = DenseMap{}}; // Translate operations if (translateOperations(builder, quantumComputation, state).failed()) { diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index db12525029..418798c68d 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -11,10 +11,13 @@ #include "quantum_computation_programs.h" #include "ir/QuantumComputation.hpp" +#include "ir/operations/CompoundOperation.hpp" +#include "ir/operations/IfElseOperation.hpp" #include "ir/operations/OpType.hpp" #include "ir/operations/StandardOperation.hpp" #include +#include #include namespace qc { From 925aaa57e98b0f6946f15bf80f8bab8a92601932 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Mon, 1 Jun 2026 23:58:41 +0200 Subject: [PATCH 11/17] Address the Rabbit's comments --- mlir/include/mlir/Dialect/Utils/Utils.h | 5 ++++- mlir/lib/Conversion/QCOToQC/QCOToQC.cpp | 3 ++- mlir/lib/Conversion/QCToQIR/QCToQIR.cpp | 8 +++++--- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 2 +- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 2 +- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 2 +- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 2 +- mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp | 2 +- 8 files changed, 16 insertions(+), 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 072c2c2368..a885c2da2b 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -16,6 +16,7 @@ #include #include +#include #include namespace mlir::utils { @@ -178,9 +179,11 @@ static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { * @details This helper function is used to resolve block arguments for nested * modifiers. */ -static void populateMapping(Block& block, IRMapping& mapping, +static void populateMapping(IRMapping& mapping, Block& block, ValueRange innerQubits, ValueRange outerQubits, ValueRange newQubits, ValueRange qubitArgs) { + assert(innerQubits.size() == block.getNumArguments() && + "Size of innerQubits must match number of block arguments"); for (auto arg : block.getArguments()) { auto innerQubit = innerQubits[arg.getArgNumber()]; auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 77758bd3ba..8cc099610d 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -128,7 +128,8 @@ static void inlineRegion(Region& sourceRegion, Region& targetRegion, ConversionPatternRewriter& rewriter) { rewriter.inlineRegionBefore(sourceRegion, targetRegion, targetRegion.end()); auto& block = targetRegion.front(); - + assert(block.getNumArguments() == offset + replacementValues.size() && + "Number of replacement values must match number of block arguments"); for (auto [arg, replacementVal] : llvm::zip_equal( block.getArguments().drop_front(offset), replacementValues)) { arg.replaceAllUsesWith(replacementVal); diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 3078d15b76..e7a6d2910f 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -209,9 +209,11 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end()); // Clean up modifier information - state.inCtrlOp--; - if (inCtrlOp == 0) { - state.controls.clear(); + if (inCtrlOp != 0) { + state.inCtrlOp--; + if (state.inCtrlOp == 0) { + state.controls.clear(); + } } // Replace operation with CallOp diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index fad9a4b2e6..82b5dce077 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -71,7 +71,7 @@ struct MergeNestedCtrl final : OpRewritePattern { op, controls, targets, [&](ValueRange targetArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + utils::populateMapping(mapping, *innerCtrlBody, innerTargets, outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index b34424cceb..1417d6de75 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -60,7 +60,7 @@ struct MoveCtrlOutside final : OpRewritePattern { rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, + utils::populateMapping(mapping, *innerCtrlBody, innerCtrlOp.getTargets(), outerQubits, targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 8994661d86..444f554393 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -78,7 +78,7 @@ struct MergeNestedCtrl final : OpRewritePattern { [&](ValueRange targetArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, innerTargets, + utils::populateMapping(mapping, *innerCtrlBody, innerTargets, outerTargets, targets, targetArgs); for (auto& op : innerCtrlBody->without_terminator()) { rewriter.clone(op, mapping); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index bf9da8144e..3456e0b11c 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -66,7 +66,7 @@ struct MoveCtrlOutside final : OpRewritePattern { [&](ValueRange qubitArgs) -> SmallVector { auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - utils::populateMapping(*innerCtrlBody, mapping, + utils::populateMapping(mapping, *innerCtrlBody, innerCtrlOp.getTargetsIn(), outerQubits, targets, qubitArgs); for (auto& op : innerCtrlBody->without_terminator()) { diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 4a8fb691ad..883c7d32d2 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -969,7 +969,7 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(emptyQCO)}, QCOTestCase{"ControlledTwoX", MQT_NAMED_BUILDER(controlledTwoX), MQT_NAMED_BUILDER(emptyQCO)}, - QCOTestCase{"inverseTwoX", MQT_NAMED_BUILDER(twoX), + QCOTestCase{"InverseTwoX", MQT_NAMED_BUILDER(inverseTwoX), MQT_NAMED_BUILDER(emptyQCO)})); /// @} From db79dac6cf29a5d584f61e89274f2cd58e2bb788 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:28:33 +0200 Subject: [PATCH 12/17] Fix inverse cancellation --- mlir/include/mlir/Dialect/QCO/QCOUtils.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 489fceb00e..e73bdd2650 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -35,7 +35,8 @@ removeInversePairOneTargetZeroParameter(OpType op, PatternRewriter& rewriter) { } // Unlink both operations - rewriter.replaceAllUsesWith(nextOp->getResult(0), op.getInputQubit(0)); + rewriter.replaceOp(op, op.getInputQubits()); + rewriter.replaceOp(nextOp, nextOp.getInputQubits()); return success(); } @@ -64,7 +65,8 @@ removeInversePairTwoTargetZeroParameter(OpType op, PatternRewriter& rewriter) { } // Unlink both operations - rewriter.replaceAllUsesWith(nextOp->getResults(), op.getOperands()); + rewriter.replaceOp(op, op.getInputQubits()); + rewriter.replaceOp(nextOp, nextOp.getInputQubits()); return success(); } @@ -95,8 +97,8 @@ removeTwoTargetZeroParameterPairWithSwappedTargets(OpType op, } // Unlink both operations - rewriter.replaceAllUsesWith(nextOp->getResults(), - {op.getInputQubit(1), op.getInputQubit(0)}); + rewriter.replaceOp(op, op.getInputQubits()); + rewriter.replaceOp(nextOp, nextOp.getInputQubits()); return success(); } From e14656327c9e50d1aefe4d3af497d6190ebfc2d7 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:48:54 +0200 Subject: [PATCH 13/17] Improve translation of nested control modifiers --- mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp | 20 ++++++++--- .../TranslateQuantumComputationToQC.cpp | 34 ++++++++++++++----- .../Conversion/QCOToQC/test_qco_to_qc.cpp | 3 ++ .../Conversion/QCToQCO/test_qc_to_qco.cpp | 3 ++ .../test_quantum_computation_translation.cpp | 3 ++ mlir/unittests/programs/qc_programs.cpp | 8 +++++ mlir/unittests/programs/qc_programs.h | 4 +++ mlir/unittests/programs/qco_programs.cpp | 11 ++++++ mlir/unittests/programs/qco_programs.h | 4 +++ .../programs/quantum_computation_programs.cpp | 11 ++++++ .../programs/quantum_computation_programs.h | 4 +++ 11 files changed, 92 insertions(+), 13 deletions(-) diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index 09bb1b0949..96598b2c82 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -880,18 +880,24 @@ struct ConvertQCOCtrlOpToJeff final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 1) { + return rewriter.notifyMatchFailure( + op, + "Control modifiers with multiple body unitaries are not supported."); + } + auto& state = getState(); if (state.inCtrlOp) { return rewriter.notifyMatchFailure( - op, "Nested control operations are not supported. Run the " + op, "Nested control modifiers are not supported. Run the " "canonicalization pass before the conversion"); } if (state.inInvOp) { return rewriter.notifyMatchFailure( - op, "Control operations inside inversion operations are not " - "supported. Run the canonicalization pass before the conversion"); + op, "Control modifiers inside inversion modifiers are not supported. " + "Run the canonicalization pass before the conversion"); } // Set modifier information @@ -930,11 +936,17 @@ struct ConvertQCOInvOpToJeff final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(InvOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 1) { + return rewriter.notifyMatchFailure(op, + "Inversion modifiers with multiple " + "body unitaries are not supported."); + } + auto& state = getState(); if (state.inInvOp) { return rewriter.notifyMatchFailure( - op, "Nested inversion operations are not supported. Run the " + op, "Nested inversion modifiers are not supported. Run the " "canonicalization pass before the conversion"); } diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index 930817894d..7a2cf23651 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QC/Translation/TranslateQuantumComputationToQC.h" +#include "ir/Definitions.hpp" #include "ir/QuantumComputation.hpp" #include "ir/Register.hpp" #include "ir/operations/CompoundOperation.hpp" @@ -80,12 +81,15 @@ struct TranslationState { bool inCtrlOp = false; /// Mapping from physical qubit index to block argument - DenseMap ctrlTargets; + DenseMap targetArgs; + + /// Control qubits of the current CompoundOperation + DenseSet<::qc::Qubit> compoundControls; [[nodiscard]] Value getQubit(size_t index) const { if (inCtrlOp) { - auto it = ctrlTargets.find(index); - if (it == ctrlTargets.end()) { + auto it = targetArgs.find(index); + if (it == targetArgs.end()) { llvm::reportFatalInternalError("Qubit index out of bounds"); } return it->second; @@ -286,11 +290,11 @@ static void addResetOp(QCProgramBuilder& builder, */ static SmallVector getControls(const ::qc::Operation& operation, TranslationState& state) { - if (state.inCtrlOp) { - return {}; - } SmallVector controls; for (const auto& [control, type] : operation.getControls()) { + if (state.compoundControls.contains(control)) { + continue; + } if (type == ::qc::Control::Type::Neg) { llvm::reportFatalInternalError( "Negative controls cannot be translated to QC at the moment"); @@ -602,6 +606,18 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, targetMap[target] = state.getQubit(target); } } + for (const auto& control : op->getControls()) { + if (compoundOp.getControls().contains(control)) { + continue; + } + const auto& qubit = control.qubit; + if (!targetMap.contains(qubit)) { + targetMap[qubit] = state.getQubit(qubit); + } + } + } + for (const auto& [control, _] : compoundOp.getControls()) { + state.compoundControls.insert(control); } SmallVector> sortedPairs(targetMap.begin(), targetMap.end()); @@ -615,7 +631,7 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, builder.ctrl(controls, targets, [&](ValueRange targetArgs) { state.inCtrlOp = true; for (size_t i = 0; i < sortedPairs.size(); ++i) { - state.ctrlTargets[sortedPairs[i].first] = targetArgs[i]; + state.targetArgs[sortedPairs[i].first] = targetArgs[i]; } for (const auto& op : compoundOp) { if (failed(translateOperation(builder, *op, state))) { @@ -623,7 +639,7 @@ static LogicalResult addCompoundOp(QCProgramBuilder& builder, "controlled CompoundOperation"); } } - state.ctrlTargets.clear(); + state.targetArgs.clear(); state.inCtrlOp = false; }); } @@ -850,7 +866,7 @@ OwningOpRef translateQuantumComputationToQC( TranslationState state{.qubits = qubits, .bitMap = bitMap, .results = std::move(results), - .ctrlTargets = DenseMap{}}; + .targetArgs = DenseMap{}}; // Translate operations if (translateOperations(builder, quantumComputation, state).failed()) { diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index aa3a428809..7dda9ccfda 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -150,6 +150,9 @@ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOToQCTest, testing::Values(QCOToQCTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qco::ctrlTwo), MQT_NAMED_BUILDER(qc::ctrlTwo)}, + QCOToQCTestCase{"CtrlTwoMixed", + MQT_NAMED_BUILDER(qco::ctrlTwoMixed), + MQT_NAMED_BUILDER(qc::ctrlTwoMixed)}, QCOToQCTestCase{"CtrlInvTwo", MQT_NAMED_BUILDER(qco::ctrlInvTwo), MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 71f47b0841..00b2c7fe7b 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -149,6 +149,9 @@ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCToQCOTest, testing::Values(QCToQCOTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), MQT_NAMED_BUILDER(qco::ctrlTwo)}, + QCToQCOTestCase{"CtrlTwoMixed", + MQT_NAMED_BUILDER(qc::ctrlTwoMixed), + MQT_NAMED_BUILDER(qco::ctrlTwoMixed)}, QCToQCOTestCase{"CtrlInvTwo", MQT_NAMED_BUILDER(qc::ctrlInvTwo), MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index 893902448b..b47c9f97a7 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -421,6 +421,9 @@ INSTANTIATE_TEST_SUITE_P( QuantumComputationTranslationTestCase{ "CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), MQT_NAMED_BUILDER(mlir::qc::ctrlTwo)}, + QuantumComputationTranslationTestCase{ + "CtrlTwoMixed", MQT_NAMED_BUILDER(qc::ctrlTwoMixed), + MQT_NAMED_BUILDER(mlir::qc::ctrlTwoMixed)}, QuantumComputationTranslationTestCase{ "SimpleIf", MQT_NAMED_BUILDER(qc::simpleIf), MQT_NAMED_BUILDER(mlir::qc::simpleIf)}, diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 7c7608963d..232b8cbdee 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -1405,6 +1405,14 @@ void ctrlTwo(QCProgramBuilder& b) { }); } +void ctrlTwoMixed(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + b.cx(targets[0], targets[1]); + b.rxx(0.123, targets[0], targets[1]); + }); +} + void nestedCtrlTwo(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index 2f08c5236e..dbf855b982 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -832,6 +832,10 @@ void ctrlInvSandwich(QCProgramBuilder& b); /// Creates a circuit with a control modifier applied to two gates. void ctrlTwo(QCProgramBuilder& b); +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QCProgramBuilder& b); + /// Creates a circuit with nested control modifiers applied to two gates. void nestedCtrlTwo(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 868e3a2a3b..523f071f8a 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -2038,6 +2038,17 @@ void ctrlTwo(QCOProgramBuilder& b) { }); } +void ctrlTwoMixed(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + std::tie(i0, i1) = b.cx(i0, i1); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + void nestedCtrlTwo(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index 1ec606d103..f562cfff8a 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -985,6 +985,10 @@ void ctrlInvSandwich(QCOProgramBuilder& b); /// Creates a circuit with a control modifier applied to two gates. void ctrlTwo(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QCOProgramBuilder& b); + /// Creates a circuit with nested control modifiers applied to two gates. void nestedCtrlTwo(QCOProgramBuilder& b); diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index 418798c68d..f0b9b305cd 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -553,6 +553,17 @@ void ctrlTwo(QuantumComputation& comp) { comp.emplace_back(std::move(compound)); } +void ctrlTwoMixed(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(4, "q"); + CompoundOperation compound; + compound.emplace_back(2, 3, X); + compound.emplace_back(Targets{2, 3}, RXX, + std::vector{0.123}); + compound.addControl(0); + compound.addControl(1); + comp.emplace_back(std::move(compound)); +} + void simpleIf(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); diff --git a/mlir/unittests/programs/quantum_computation_programs.h b/mlir/unittests/programs/quantum_computation_programs.h index f0e1856d8f..f6dab6e1c2 100644 --- a/mlir/unittests/programs/quantum_computation_programs.h +++ b/mlir/unittests/programs/quantum_computation_programs.h @@ -390,6 +390,10 @@ void barrierMultipleQubits(QuantumComputation& comp); /// Creates a circuit with a control modifier applied to two gates. void ctrlTwo(QuantumComputation& comp); +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QuantumComputation& comp); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. From d84b58b6c309d9e5370fdf966e25f07c1fa6bde2 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:10:13 +0200 Subject: [PATCH 14/17] Improve verifiers --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 5 +++++ mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 5 +++++ mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 16 +++++++++++----- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 16 +++++++++++----- 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 82b5dce077..1fd3d13fc1 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -225,6 +225,11 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 1417d6de75..6215a81c1c 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -395,6 +395,11 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 444f554393..7f084b9c06 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -308,22 +308,28 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } + if (!isa(block.back())) { + return emitOpError( + "last operation in body region must be a yield operation"); + } + const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { if (block.getArgument(i).getType() != qubitType) { return emitOpError("block argument type at index ") << i << " does not match target type"; } } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } if (const auto numYieldOperands = block.back().getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index 3456e0b11c..cf2043b582 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -452,22 +452,28 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } + if (!isa(block.back())) { + return emitOpError( + "last operation in body region must be a yield operation"); + } + const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { if (block.getArgument(i).getType() != qubitType) { return emitOpError("block argument type at index ") << i << " does not match target type"; } } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } if (const auto numYieldOperands = block.back().getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") From 570878fd0f5f17049dbaed0bc712c40c39496850 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:20:16 +0200 Subject: [PATCH 15/17] Improve implementation of getNumBodyUnitaries() --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 9 ++------- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 9 ++------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 9 ++------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 9 ++------- 4 files changed, 8 insertions(+), 28 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index 1fd3d13fc1..bc7bd64f9d 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -166,13 +166,8 @@ struct EraseEmptyCtrl final : OpRewritePattern { } // namespace size_t CtrlOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 6215a81c1c..5794f08a35 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -354,13 +354,8 @@ struct EraseEmptyInv final : OpRewritePattern { } // namespace size_t InvOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 7f084b9c06..bca38f47a1 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -189,13 +189,8 @@ struct EraseEmptyCtrl final : OpRewritePattern { } // namespace size_t CtrlOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index cf2043b582..e892b3e900 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -379,13 +379,8 @@ struct EraseEmptyInv final : OpRewritePattern { } // namespace size_t InvOp::getNumBodyUnitaries() { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - count++; - } - } - return count; + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { From c199cc9babcae95bc86e69e11023ac5094b83313 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:26:34 +0200 Subject: [PATCH 16/17] Improve implementation of getBodyUnitary() --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 15 ++++++--------- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 15 ++++++--------- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 15 ++++++--------- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 15 ++++++--------- 4 files changed, 24 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index bc7bd64f9d..abff17e8d6 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -171,16 +171,13 @@ size_t CtrlOp::getNumBodyUnitaries() { } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Unitary index out of bounds"); + return cast(*it); } Value CtrlOp::getQubit(const size_t i) { diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 5794f08a35..398342f7ad 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -359,16 +359,13 @@ size_t InvOp::getNumBodyUnitaries() { } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Invalid unitary index"); + return cast(*it); } void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index bca38f47a1..81ea40f784 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -194,16 +194,13 @@ size_t CtrlOp::getNumBodyUnitaries() { } UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Unitary index out of bounds"); + return cast(*it); } Value CtrlOp::getInputQubit(const size_t i) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index e892b3e900..cb7ad50acd 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -384,16 +384,13 @@ size_t InvOp::getNumBodyUnitaries() { } UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { - size_t count = 0; - for (auto& op : *getBody()) { - if (isa(op)) { - if (count == i) { - return cast(op); - } - count++; - } + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), i); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); } - llvm::reportFatalUsageError("Unitary index out of bounds"); + return cast(*it); } Value InvOp::getInputQubit(const size_t i) { From 4c4b31ca15ea9e3fd239dd80fb3a8656384719c7 Mon Sep 17 00:00:00 2001 From: Daniel Haag <121057143+denialhaag@users.noreply.github.com> Date: Wed, 3 Jun 2026 00:50:24 +0200 Subject: [PATCH 17/17] Fix linter errors --- mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp | 3 ++- mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp | 4 +++- mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp | 3 ++- mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp | 3 ++- 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index abff17e8d6..059e7dc736 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -24,6 +24,7 @@ #include #include +#include using namespace mlir; using namespace mlir::qc; @@ -173,7 +174,7 @@ size_t CtrlOp::getNumBodyUnitaries() { UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 398342f7ad..b1cf6e46ae 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/Utils/Utils.h" +#include #include #include #include @@ -24,6 +25,7 @@ #include #include +#include #include using namespace mlir; @@ -361,7 +363,7 @@ size_t InvOp::getNumBodyUnitaries() { UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 81ea40f784..d9a67fafca 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -28,6 +28,7 @@ #include #include #include +#include #include using namespace mlir; @@ -196,7 +197,7 @@ size_t CtrlOp::getNumBodyUnitaries() { UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index cb7ad50acd..1f0e6d556c 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -386,7 +387,7 @@ size_t InvOp::getNumBodyUnitaries() { UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { auto unitaries = llvm::make_filter_range( *getBody(), [](Operation& op) { return isa(op); }); - auto it = std::next(unitaries.begin(), i); + auto it = std::next(unitaries.begin(), static_cast(i)); if (it == unitaries.end()) { llvm::reportFatalUsageError("Unitary index out of bounds"); }