diff --git a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h index ab9532cfb4..c2483d9f06 100644 --- a/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h +++ b/mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h @@ -917,7 +917,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * } : !qc.qubit * ``` */ - QCProgramBuilder& ctrl(ValueRange controls, const function_ref& body); + QCProgramBuilder& ctrl(ValueRange controls, ValueRange targets, + const function_ref& body); /** * @brief Apply an inverse (i.e., adjoint) operation. @@ -936,7 +937,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder { * } * ``` */ - QCProgramBuilder& inv(const function_ref& body); + QCProgramBuilder& inv(ValueRange qubits, + const function_ref& body); //===--------------------------------------------------------------------===// // Deallocation diff --git a/mlir/include/mlir/Dialect/QC/IR/QCOps.td b/mlir/include/mlir/Dialect/QC/IR/QCOps.td index 8e76c9c3ba..cce7265131 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCOps.td +++ b/mlir/include/mlir/Dialect/QC/IR/QCOps.td @@ -916,7 +916,7 @@ def YieldOp : QCOp<"yield", traits = [Terminator]> { def CtrlOp : QCOp<"ctrl", - traits = [UnitaryOpInterface, + traits = [UnitaryOpInterface, AttrSizedOperandSegments, SingleBlockImplicitTerminator<"::mlir::qc::YieldOp">, RecursiveMemoryEffects]> { let summary = "Add control qubits to a unitary operation"; @@ -937,30 +937,36 @@ def CtrlOp ``` }]; - let arguments = - (ins Arg, - "the control qubits", [MemRead, MemWrite]>:$controls); + let arguments = (ins Arg, + "the control qubits", [MemRead, MemWrite]>:$controls, + Arg, + "the target qubits", [MemRead, MemWrite]>:$targets); let regions = (region SizedRegion<1>:$region); - let assemblyFormat = - "`(` $controls `)` $region attr-dict `:` type($controls)"; + let assemblyFormat = [{ + `(` $controls `)` + `targets` + custom($region, $targets) + attr-dict `:` + `{` type($controls) `}` ( `,` `{` type($targets)^ `}` )? + }]; let extraClassDeclaration = [{ - [[nodiscard]] UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumTargets() + getNumControls(); } - size_t getNumTargets() { return getBodyUnitary().getNumTargets(); } + size_t getNumTargets() { return getTargets().size(); } size_t getNumControls() { return getControls().size(); } Value getQubit(size_t i); - Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); } - ValueRange getTargets() { return getBodyUnitary().getTargets(); } + Value getTarget(size_t i) { return getTargets()[i]; } Value getControl(size_t i); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + ValueRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } static StringRef getBaseSymbol() { return "ctrl"; } }]; - let builders = [OpBuilder<(ins "ValueRange":$controls, - "const function_ref&":$bodyBuilder)>]; + let builders = [OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets, + "const function_ref&":$body)>]; let hasCanonicalizer = 1; let hasVerifier = 1; @@ -983,26 +989,35 @@ def InvOp : QCOp<"inv", ``` }]; + let arguments = (ins Arg< + Variadic, + "the qubits involved in the operation", [MemRead, MemWrite]>:$qubits); let regions = (region SizedRegion<1>:$region); - let assemblyFormat = "$region attr-dict"; - - let extraClassDeclaration = [{ - [[nodiscard]] UnitaryOpInterface getBodyUnitary(); - size_t getNumQubits() { return getBodyUnitary().getNumQubits(); } - size_t getNumTargets() { return getBodyUnitary().getNumTargets(); } - size_t getNumControls() { return getBodyUnitary().getNumControls(); } - Value getQubit(size_t i) { return getBodyUnitary().getQubit(i); } - Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); } - ValueRange getTargets() { return getBodyUnitary().getTargets(); } - Value getControl(size_t i) { return getBodyUnitary().getControl(i); } - ValueRange getControls() { return getBodyUnitary().getControls(); } - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + let assemblyFormat = [{ + custom($region, $qubits) + attr-dict `:` + type($qubits) + }]; + + let extraClassDeclaration = [{ + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); + size_t getNumQubits() { return getNumTargets(); } + size_t getNumTargets() { return getQubits().size(); } + size_t getNumControls() { return 0; } + Value getQubit(size_t i) { return getTarget(i); } + Value getTarget(size_t i) { return getQubits()[i]; } + ValueRange getTargets() { return getQubits(); } + Value getControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } + ValueRange getControls() { return {nullptr, 0}; } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } static StringRef getBaseSymbol() { return "inv"; } }]; - let builders = [OpBuilder<(ins "const function_ref&":$bodyBuilder)>]; + let builders = [OpBuilder<(ins "ValueRange":$qubits, + "const function_ref&":$body)>]; let hasCanonicalizer = 1; let hasVerifier = 1; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td index a5bbfb7f51..78e15ecc86 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td +++ b/mlir/include/mlir/Dialect/QCO/IR/QCOOps.td @@ -1102,7 +1102,8 @@ def CtrlOp }]; let extraClassDeclaration = [{ - UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumControls() + getNumTargets(); } size_t getNumTargets() { return getTargetsIn().size(); } size_t getNumControls() { return getControlsIn().size(); } @@ -1120,9 +1121,9 @@ def CtrlOp ResultRange getOutputControls() { return getControlsOut(); } Value getInputForOutput(Value output); Value getOutputForInput(Value input); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "ctrl"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; @@ -1173,7 +1174,8 @@ def InvOp }]; let extraClassDeclaration = [{ - UnitaryOpInterface getBodyUnitary(); + size_t getNumBodyUnitaries(); + [[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i); size_t getNumQubits() { return getNumTargets(); } size_t getNumTargets() { return getQubitsIn().size(); } static size_t getNumControls() { return 0; } @@ -1184,16 +1186,16 @@ def InvOp ResultRange getOutputQubits() { return getQubitsOut(); } Value getInputTarget(size_t i) { return getInputQubit(i); } Value getOutputTarget(size_t i) { return getOutputQubit(i); } - static Value getInputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); } + static Value getInputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } static OperandRange getInputControls() { return {nullptr, 0}; } - static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); } + static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); } static ResultRange getOutputControls() { return {nullptr, 0}; } ResultRange getOutputTargets() { return getOutputQubits(); } Value getInputForOutput(Value output); Value getOutputForInput(Value input); - size_t getNumParams() { return getBodyUnitary().getNumParams(); } - Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); } - ValueRange getParameters() { return getBodyUnitary().getParameters(); } + size_t getNumParams() { return 0; } + Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); } + ValueRange getParameters() { return {nullptr, 0}; } [[nodiscard]] static StringRef getBaseSymbol() { return "inv"; } [[nodiscard]] std::optional getUnitaryMatrix(); }]; diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 3d976a5a63..a885c2da2b 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -12,9 +12,11 @@ #include #include +#include #include #include +#include #include namespace mlir::utils { @@ -78,4 +80,120 @@ template return std::nullopt; } +template +[[nodiscard]] +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + // 1. Parse the opening parenthesis + if (parser.parseLParen()) { + return failure(); + } + + // Temporary storage for block arguments we are about to create + SmallVector blockArgs; + + // 2. Prepare to parse the list + if (failed(parser.parseOptionalRParen())) { + do { + OpAsmParser::Argument newArg; // The "new" variable name + OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable + + // Parse "%new" + if (parser.parseArgument(newArg)) { + return failure(); + } + + // Parse "=" + if (parser.parseEqual()) { + return failure(); + } + + // Parse "%old" + if (parser.parseOperand(oldOperand)) { + return failure(); + } + operands.push_back(oldOperand); + + // Hard-code QubitType since targets in qco.ctrl are always qubits. + // This avoids double-binding type($targets_in) in the assembly format + // while keeping the parser simple and the assembly format clean. + newArg.type = QubitType::get(parser.getBuilder().getContext()); + blockArgs.push_back(newArg); + + } while (succeeded(parser.parseOptionalComma())); + + if (parser.parseRParen()) { + return failure(); + } + } + + // 4. Parse the Region + // We explicitly pass the blockArgs we just parsed so they become the entry + // block! + if (parser.parseRegion(region, blockArgs)) { + return failure(); + } + + return success(); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Region& region, + OperandRange targetsIn) { + printer << "("; + if (region.empty()) { + printer << ") "; + printer.printRegion(region, false); + return; + } + Block& entryBlock = region.front(); + + const auto numTargets = targetsIn.size(); + for (unsigned i = 0; i < numTargets; ++i) { + if (i > 0) { + printer << ", "; + } + printer.printOperand(entryBlock.getArgument(i)); + printer << " = "; + printer.printOperand(targetsIn[i]); + } + printer << ") "; + + printer.printRegion(region, false); +} + +/** + * @brief Get the value corresponding to @p qubit from the block arguments @p + * qubits if @p qubit is a block argument, otherwise return @p qubit itself. + */ +static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) { + if (auto blockArg = dyn_cast(qubit)) { + return qubits[blockArg.getArgNumber()]; + } + return qubit; +} + +/** + * @brief Create a mapping between block arguments and qubit values. + * + * @details This helper function is used to resolve block arguments for nested + * modifiers. + */ +static void populateMapping(IRMapping& mapping, Block& block, + ValueRange innerQubits, ValueRange outerQubits, + ValueRange newQubits, ValueRange qubitArgs) { + assert(innerQubits.size() == block.getNumArguments() && + "Size of innerQubits must match number of block arguments"); + for (auto arg : block.getArguments()) { + auto innerQubit = innerQubits[arg.getArgNumber()]; + auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits); + if (auto it = llvm::find(newQubits, outerQubit); it != newQubits.end()) { + auto index = std::distance(newQubits.begin(), it); + mapping.map(arg, qubitArgs[index]); + } else { + llvm::reportFatalInternalError("Outer qubit not found in new qubits"); + } + } +} + } // namespace mlir::utils diff --git a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp index a486e82c5d..96598b2c82 100644 --- a/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp +++ b/mlir/lib/Conversion/QCOToJeff/QCOToJeff.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -155,19 +156,30 @@ static void handleResult(Operation* op, ConversionPatternRewriter& rewriter, * @brief Target operands: `adaptor.getOperands()` at the matched op, or * `state.targetsIn` while lowering inside `qco.ctrl` / `qco.inv`. * - * @param state Lowering state. - * @param adaptor Operand adaptor for the matched op. + * @param op The operation being converted. + * @param adaptor The operation adaptor of the operation. + * @param state The lowering state. * @tparam NumParams Number of parameters to drop from the end of the operand * list. - * @tparam OpAdaptor Adaptor with `getOperands()`. - * @return ValueRange The target operands. + * @tparam OpType The type of the operation. + * @tparam OpAdaptorType The type of the operation adaptor. + * @return The target operands. */ -template -[[nodiscard]] static ValueRange getEffectiveTargetOperands(LoweringState& state, - OpAdaptor adaptor) { - return state.inModifier() - ? ValueRange(state.targetsIn) - : ValueRange(adaptor.getOperands().drop_back(NumParams)); +template +[[nodiscard]] static SmallVector +getEffectiveTargetOperands(OpType op, OpAdaptorType adaptor, + LoweringState& state) { + if (!state.inModifier()) { + return adaptor.getOperands().drop_back(NumParams); + } + + SmallVector targets; + for (auto targetArg : op->getOperands().drop_back(NumParams)) { + auto target = + state.targetsIn[cast(targetArg).getArgNumber()]; + targets.push_back(target); + } + return targets; } /** @@ -190,10 +202,10 @@ convertJeffGate(QCOOpType op, typename QCOOpType::Adaptor adaptor, std::index_sequence /*targetIndices*/, std::index_sequence /*paramIndices*/) { constexpr std::size_t numParams = sizeof...(ParamIndices); - ValueRange targets = getEffectiveTargetOperands(state, adaptor); + auto targets = getEffectiveTargetOperands(op, adaptor, state); assert(targets.size() >= sizeof...(TargetIndices) && "Not enough operands available for conversion"); - ValueRange params = op.getParameters(); + auto params = op.getParameters(); auto jeffOp = JeffOpType::create( rewriter, op.getLoc(), targets[TargetIndices]..., params[ParamIndices]..., @@ -336,7 +348,7 @@ static LogicalResult moveRegion(Region& source, Region& dest, ConversionPatternRewriter& rewriter, const TypeConverter* typeConverter) { rewriter.inlineRegionBefore(source, dest, dest.end()); - Block* block = &dest.front(); + auto* block = &dest.front(); TypeConverter::SignatureConversion sc(block->getNumArguments()); if (failed( typeConverter->convertSignatureArgs(block->getArgumentTypes(), sc))) { @@ -728,7 +740,7 @@ struct ConvertQCOCustomGateToJeff final } } - ValueRange targets = getEffectiveTargetOperands(state, adaptor); + auto targets = getEffectiveTargetOperands(op, adaptor, state); assert(targets.size() >= NumTargets && "Not enough operands available for conversion"); @@ -764,7 +776,7 @@ struct ConvertQCOPPRGateToJeff final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); - ValueRange targets = getEffectiveTargetOperands<1>(state, adaptor); + auto targets = getEffectiveTargetOperands<1>(op, adaptor, state); assert(targets.size() >= 2 && "Not enough operands available for conversion"); createPPROp(op, rewriter, state, targets, {p0_, p1_}); @@ -798,7 +810,7 @@ struct ConvertQCOU2OpToJeff final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - ValueRange targets = getEffectiveTargetOperands<2>(state, adaptor); + auto targets = getEffectiveTargetOperands<2>(op, adaptor, state); assert(!targets.empty() && "Not enough operands available for conversion"); auto target = targets.front(); @@ -840,11 +852,8 @@ struct ConvertQCOBarrierOpToJeff final matchAndRewrite(BarrierOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { auto& state = getState(); - - ValueRange targets = getEffectiveTargetOperands<0>(state, adaptor); - + auto targets = getEffectiveTargetOperands<0>(op, adaptor, state); createCustomOp(op, rewriter, state, targets, {}, false, "barrier"); - return success(); } }; @@ -871,18 +880,24 @@ struct ConvertQCOCtrlOpToJeff final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 1) { + return rewriter.notifyMatchFailure( + op, + "Control modifiers with multiple body unitaries are not supported."); + } + auto& state = getState(); if (state.inCtrlOp) { return rewriter.notifyMatchFailure( - op, "Nested control operations are not supported. Run the " + op, "Nested control modifiers are not supported. Run the " "canonicalization pass before the conversion"); } if (state.inInvOp) { return rewriter.notifyMatchFailure( - op, "Control operations inside inversion operations are not " - "supported. Run the canonicalization pass before the conversion"); + op, "Control modifiers inside inversion modifiers are not supported. " + "Run the canonicalization pass before the conversion"); } // Set modifier information @@ -921,11 +936,17 @@ struct ConvertQCOInvOpToJeff final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(InvOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 1) { + return rewriter.notifyMatchFailure(op, + "Inversion modifiers with multiple " + "body unitaries are not supported."); + } + auto& state = getState(); if (state.inInvOp) { return rewriter.notifyMatchFailure( - op, "Nested inversion operations are not supported. Run the " + op, "Nested inversion modifiers are not supported. Run the " "canonicalization pass before the conversion"); } @@ -934,6 +955,13 @@ struct ConvertQCOInvOpToJeff final : StatefulOpConversionPattern { state.invOp = op; if (state.targetsIn.empty()) { state.targetsIn = llvm::to_vector(adaptor.getQubitsIn()); + } else { + auto outerQubits = state.targetsIn; + SmallVector innerQubits; + for (auto arg : op.getBody()->getArguments()) { + innerQubits.push_back(outerQubits[arg.getArgNumber()]); + } + state.targetsIn = std::move(innerQubits); } // Inline region diff --git a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp index 97d7071f69..8cc099610d 100644 --- a/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp +++ b/mlir/lib/Conversion/QCOToQC/QCOToQC.cpp @@ -120,25 +120,21 @@ class StatefulOpConversionPattern : public OpConversionPattern { * @param sourceRegion Source region where the operations are moved from * @param targetRegion Target region where the operations are moved to * @param offset Offset to the arguments that are dropped - * @param numArgs Number of arguments that are dropped * @param replacementValues Values to replace the uses of the arguments * @param rewriter PatternRewriter of the current conversion pass */ static void inlineRegion(Region& sourceRegion, Region& targetRegion, - unsigned int offset, unsigned int numArgs, - ValueRange replacementValues, + unsigned int offset, ValueRange replacementValues, ConversionPatternRewriter& rewriter) { - assert(replacementValues.size() == numArgs && - "replacementValues size must match numArgs"); - rewriter.inlineRegionBefore(sourceRegion, targetRegion, targetRegion.end()); auto& block = targetRegion.front(); - + assert(block.getNumArguments() == offset + replacementValues.size() && + "Number of replacement values must match number of block arguments"); for (auto [arg, replacementVal] : llvm::zip_equal( block.getArguments().drop_front(offset), replacementValues)) { arg.replaceAllUsesWith(replacementVal); } - block.eraseArguments(offset, numArgs); + block.eraseArguments(offset, replacementValues.size()); } #define GEN_PASS_DEF_QCOTOQC @@ -645,16 +641,19 @@ struct ConvertQCOCtrlOp final : OpConversionPattern { LogicalResult matchAndRewrite(qco::CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // Get QC controls - auto qcControls = adaptor.getControlsIn(); - // Create qc.ctrl operation - auto qcOp = qc::CtrlOp::create(rewriter, op.getLoc(), qcControls); - - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), qcOp.getRegion(), 0, - adaptor.getTargetsIn().size(), adaptor.getTargetsIn(), - rewriter); + auto qcOp = qc::CtrlOp::create( + rewriter, op.getLoc(), adaptor.getControlsIn(), adaptor.getTargetsIn()); + + auto& dstRegion = qcOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -687,11 +686,17 @@ struct ConvertQCOInvOp final : OpConversionPattern { matchAndRewrite(qco::InvOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { // Create qc.inv operation - auto qcOp = qc::InvOp::create(rewriter, op.getLoc()); - - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), qcOp.getRegion(), 0, - adaptor.getOperands().size(), adaptor.getQubitsIn(), rewriter); + auto qcOp = qc::InvOp::create(rewriter, op.getLoc(), adaptor.getQubitsIn()); + + auto& dstRegion = qcOp.getRegion(); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); // Replace the output qubits with the same QC references rewriter.replaceOp(op, adaptor.getOperands()); @@ -764,9 +769,9 @@ struct ConvertQCOSCFForOp final : OpConversionPattern { // Erase default block rewriter.eraseBlock(&newFor.getRegion().front()); - // Inline the region and replace the blockarguments - inlineRegion(op.getRegion(), newFor.getRegion(), 1, - adaptor.getInitArgs().size(), adaptor.getInitArgs(), rewriter); + // Inline the region and replace the block arguments + inlineRegion(op.getRegion(), newFor.getRegion(), 1, adaptor.getInitArgs(), + rewriter); rewriter.replaceOp(op, adaptor.getInitArgs()); @@ -810,11 +815,11 @@ struct ConvertQCOSCFWhileOp final : OpConversionPattern { auto newWhileOp = scf::WhileOp::create(rewriter, op->getLoc(), TypeRange{}, ValueRange{}); - // Inline the regions and replace the blockarguments - inlineRegion(op.getBefore(), newWhileOp.getBefore(), 0, - adaptor.getInits().size(), adaptor.getInits(), rewriter); - inlineRegion(op.getAfter(), newWhileOp.getAfter(), 0, - adaptor.getInits().size(), adaptor.getInits(), rewriter); + // Inline the regions and replace the block arguments + inlineRegion(op.getBefore(), newWhileOp.getBefore(), 0, adaptor.getInits(), + rewriter); + inlineRegion(op.getAfter(), newWhileOp.getAfter(), 0, adaptor.getInits(), + rewriter); rewriter.replaceOp(op, adaptor.getInits()); @@ -855,15 +860,13 @@ struct ConvertQCOIfOp final : OpConversionPattern { // Erase the default empty then block rewriter.eraseBlock(&newThenRegion.front()); - // Inline the region and replace the blockarguments + // Inline the region and replace the block arguments inlineRegion(op.getThenRegion(), newThenRegion, 0, - adaptor.getOperands().size() - 1, adaptor.getOperands().drop_front(1), rewriter); // Inline the else block if it has more than just the yield operation if (oldElseRegion.front().getOperations().size() > 1) { inlineRegion(oldElseRegion, newIf.getElseRegion(), 0, - adaptor.getOperands().size() - 1, adaptor.getOperands().drop_front(1), rewriter); } diff --git a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp index b4a4b0a151..33a2df9217 100644 --- a/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp +++ b/mlir/lib/Conversion/QCToQCO/QCToQCO.cpp @@ -391,22 +391,6 @@ static void popModifierFrame(LoweringState& state) { state.modifierFrames.pop_back(); } -/** @brief Adds entry block aliases for modifier target values. */ -template -[[nodiscard]] static ValueRange addModifierAliases(OpType op, - const size_t numTargets, - PatternRewriter& rewriter) { - auto& entryBlock = op.getRegion().front(); - const auto opLoc = op.getLoc(); - const auto qubitType = qco::QubitType::get(op.getContext()); - rewriter.modifyOpInPlace(op, [&] { - for (size_t i = 0; i < numTargets; ++i) { - entryBlock.addArgument(qubitType, opLoc); - } - }); - return entryBlock.getArguments().take_back(numTargets); -} - /** * @brief Inserts extracted qubits that are not required by @p target back into * their tensors. @@ -525,7 +509,8 @@ collectQubitValuesInsideSCFOps(Operation* op, LoweringState* state) { // Iterate through all operations of the current region for (auto& operation : region.front().getOperations()) { // Recursively walk through nested regions - if (operation.getNumRegions() > 0) { + if (operation.getNumRegions() > 0 && + !isa(operation)) { auto [qubits, registers] = collectQubitValuesInsideSCFOps(&operation, state); auto& regionQubitMap = state->regionQubitMap[op]; @@ -1111,7 +1096,6 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto numTargets = op.getNumTargets(); const auto qcControls = op.getControls(); const auto qcTargets = op.getTargets(); auto qcoControls = resolveMappedQubits(state, operation, qcControls); @@ -1124,16 +1108,20 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { assignMappedQubits(state, operation, qcControls, qcoOp.getControlsOut()); assignMappedQubits(state, operation, qcTargets, qcoOp.getTargetsOut()); - // Clone body region from QC to QCO + auto qcArgs = op.getRegion().front().getArguments(); + + // Inline region auto& dstRegion = qcoOp.getRegion(); - rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); - // Create block arguments for QCO targets - auto& entryBlock = dstRegion.front(); - assert(entryBlock.getNumArguments() == 0 && - "QC ctrl region unexpectedly has entry block arguments"); - pushModifierFrame(state, qcTargets, - addModifierAliases(qcoOp, numTargets, rewriter)); + pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); rewriter.eraseOp(op); return success(); @@ -1165,7 +1153,6 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = getState(); auto* operation = op.getOperation(); - const auto numTargets = op.getNumTargets(); const auto qcTargets = op.getTargets(); auto qcoTargets = resolveMappedQubits(state, operation, qcTargets); @@ -1174,16 +1161,20 @@ struct ConvertQCInvOp final : StatefulOpConversionPattern { assignMappedQubits(state, operation, qcTargets, qcoOp.getOutputTargets()); - // Clone body region from QC to QCO + auto qcArgs = op.getRegion().front().getArguments(); + + // Inline region auto& dstRegion = qcoOp.getRegion(); - rewriter.cloneRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - // Create block arguments for target qubits and seed the nested frame. - auto& entryBlock = dstRegion.front(); - assert(entryBlock.getNumArguments() == 0 && - "QC inv region unexpectedly has entry block arguments"); - pushModifierFrame(state, qcTargets, - addModifierAliases(qcoOp, numTargets, rewriter)); + rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.begin()); + auto* block = &dstRegion.front(); + TypeConverter::SignatureConversion sc(block->getNumArguments()); + if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), + sc))) { + return failure(); + } + rewriter.applySignatureConversion(block, sc); + + pushModifierFrame(state, qcArgs, qcoOp.getRegion().front().getArguments()); rewriter.eraseOp(op); return success(); diff --git a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp index 98942d998c..e7a6d2910f 100644 --- a/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp +++ b/mlir/lib/Conversion/QCToQIR/QCToQIR.cpp @@ -92,8 +92,8 @@ struct LoweringState : QIRMetadata { DenseMap resultPtrs; /// Modifier information - int64_t inCtrlOp = 0; - DenseMap> controls; + size_t inCtrlOp = 0; + SmallVector controls; /// Allocator and StringSaver for stable StringRefs llvm::BumpPtrAllocator allocator; @@ -174,7 +174,7 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, // Query state for modifier information const auto inCtrlOp = state.inCtrlOp; const SmallVector controls = - inCtrlOp != 0 ? state.controls[inCtrlOp] : SmallVector{}; + inCtrlOp != 0 ? state.controls : SmallVector{}; const size_t numCtrls = controls.size(); // Define argument types @@ -210,8 +210,10 @@ convertUnitaryToCallOp(QCOpType& op, QCOpAdaptorType& adaptor, // Clean up modifier information if (inCtrlOp != 0) { - state.controls.erase(inCtrlOp); state.inCtrlOp--; + if (state.inCtrlOp == 0) { + state.controls.clear(); + } } // Replace operation with CallOp @@ -315,7 +317,7 @@ struct ConvertQCUnitaryOpQIR : StatefulOpConversionPattern { ConversionPatternRewriter& rewriter) const override { auto& state = this->getState(); const auto inCtrlOp = state.inCtrlOp; - const size_t numCtrls = inCtrlOp != 0 ? state.controls[inCtrlOp].size() : 0; + const size_t numCtrls = inCtrlOp != 0 ? state.controls.size() : 0; const auto fnName = GetFnName(numCtrls); return convertUnitaryToCallOp(op, adaptor, rewriter, this->getContext(), state, fnName, NumTargets, NumParams); @@ -863,16 +865,22 @@ struct ConvertQCCtrlOp final : StatefulOpConversionPattern { LogicalResult matchAndRewrite(CtrlOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { - // Update modifier information auto& state = getState(); - state.inCtrlOp++; + + if (state.inCtrlOp != 0) { + return rewriter.notifyMatchFailure(op, + "Nested CtrlOps are not supported"); + } + + // Update modifier information + state.inCtrlOp = op.getNumBodyUnitaries(); const SmallVector controls(adaptor.getControls().begin(), adaptor.getControls().end()); - state.controls[state.inCtrlOp] = controls; + state.controls = controls; - // Inline region and remove operation - rewriter.inlineBlockBefore(&op.getRegion().front(), op->getBlock(), - op->getIterator()); + // Inline block and remove operation + rewriter.inlineBlockBefore(&op.getRegion().front(), op, + adaptor.getTargets()); rewriter.eraseOp(op); return success(); } diff --git a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp index 5eb022c03a..4bde38d301 100644 --- a/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp +++ b/mlir/lib/Dialect/QC/Builder/QCProgramBuilder.cpp @@ -223,7 +223,8 @@ QCProgramBuilder& QCProgramBuilder::reset(Value qubit) { const std::variant&(PARAM), ValueRange controls) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, param); }); \ + ctrl(controls, ValueRange{}, \ + [&](ValueRange /*targets*/) { OP_NAME(param); }); \ return *this; \ } @@ -247,7 +248,7 @@ DEFINE_ZERO_TARGET_ONE_PARAMETER(GPhaseOp, gphase, theta) QCProgramBuilder& QCProgramBuilder::mc##OP_NAME(ValueRange controls, \ Value target) { \ checkFinalized(); \ - CtrlOp::create(*this, controls, [&] { OP_CLASS::create(*this, target); }); \ + ctrl(controls, target, [&](ValueRange targets) { OP_NAME(targets[0]); }); \ return *this; \ } @@ -285,8 +286,8 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdgOp, sxdg) Value target) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, target, param); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_NAME(param, targets[0]); }); \ return *this; \ } @@ -321,8 +322,8 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(POp, p, theta) checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, target, param1, param2); }); \ + ctrl(controls, target, \ + [&](ValueRange targets) { OP_NAME(param1, param2, targets[0]); }); \ return *this; \ } @@ -360,8 +361,8 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2Op, u2, phi, lambda) auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ auto param3 = variantToValue(*this, getLoc(), PARAM3); \ - CtrlOp::create(*this, controls, [&] { \ - OP_CLASS::create(*this, target, param1, param2, param3); \ + ctrl(controls, target, [&](ValueRange targets) { \ + OP_NAME(param1, param2, param3, targets[0]); \ }); \ return *this; \ } @@ -386,8 +387,8 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(UOp, u, theta, phi, lambda) QCProgramBuilder& QCProgramBuilder::mc##OP_NAME( \ ValueRange controls, Value qubit0, Value qubit1) { \ checkFinalized(); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, qubit0, qubit1); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, \ + [&](ValueRange targets) { OP_NAME(targets[0], targets[1]); }); \ return *this; \ } @@ -418,8 +419,8 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECROp, ecr) Value qubit0, Value qubit1) { \ checkFinalized(); \ auto param = variantToValue(*this, getLoc(), PARAM); \ - CtrlOp::create(*this, controls, \ - [&] { OP_CLASS::create(*this, qubit0, qubit1, param); }); \ + ctrl(controls, ValueRange{qubit0, qubit1}, \ + [&](ValueRange targets) { OP_NAME(param, targets[0], targets[1]); }); \ return *this; \ } @@ -455,8 +456,8 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZOp, rzz, theta) checkFinalized(); \ auto param1 = variantToValue(*this, getLoc(), PARAM1); \ auto param2 = variantToValue(*this, getLoc(), PARAM2); \ - CtrlOp::create(*this, controls, [&] { \ - OP_CLASS::create(*this, qubit0, qubit1, param1, param2); \ + ctrl(controls, ValueRange{qubit0, qubit1}, [&](ValueRange targets) { \ + OP_NAME(param1, param2, targets[0], targets[1]); \ }); \ return *this; \ } @@ -478,16 +479,19 @@ QCProgramBuilder& QCProgramBuilder::barrier(ValueRange qubits) { // Modifiers //===----------------------------------------------------------------------===// -QCProgramBuilder& QCProgramBuilder::ctrl(ValueRange controls, - const function_ref& body) { +QCProgramBuilder& +QCProgramBuilder::ctrl(ValueRange controls, ValueRange targets, + const function_ref& body) { checkFinalized(); - CtrlOp::create(*this, controls, body); + CtrlOp::create(*this, controls, targets, body); return *this; } -QCProgramBuilder& QCProgramBuilder::inv(const function_ref& body) { +QCProgramBuilder& +QCProgramBuilder::inv(ValueRange qubits, + const function_ref& body) { checkFinalized(); - InvOp::create(*this, body); + InvOp::create(*this, qubits, body); return *this; } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp index d457fa9a35..059e7dc736 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/CtrlOp.cpp @@ -8,11 +8,15 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include #include +#include +#include #include #include #include @@ -20,6 +24,7 @@ #include #include +#include using namespace mlir; using namespace mlir::qc; @@ -33,22 +38,46 @@ struct MergeNestedCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); - auto bodyCtrlOp = dyn_cast(bodyUnitary); - if (!bodyCtrlOp) { + // Require at least one control + // Trivial case is handled by ReduceCtrl + if (op.getNumControls() == 0) { return failure(); } - // add the inner controls as operands to the outer one - op->insertOperands(op.getNumOperands(), bodyCtrlOp.getControls()); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + if (!innerCtrlOp) { + return failure(); + } - // Move the inner unitary op into the outer one's body region and replace - // the outer one with the inner one's results - const OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(bodyUnitary); - auto* innerUnitaryOp = bodyCtrlOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerUnitaryOp, bodyUnitary); - rewriter.replaceOp(bodyUnitary, innerUnitaryOp->getResults()); + auto outerControls = op.getControls(); + auto outerTargets = op.getTargets(); + auto innerTargets = innerCtrlOp.getTargets(); + + SmallVector controls; + SmallVector targets; + llvm::append_range(controls, outerControls); + for (auto [arg, qubit] : + llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { + if (llvm::is_contained(innerTargets, arg)) { + targets.push_back(qubit); + } else { + controls.push_back(qubit); + } + } + + rewriter.replaceOpWithNewOp( + op, controls, targets, [&](ValueRange targetArgs) { + auto* innerCtrlBody = innerCtrlOp.getBody(); + IRMapping mapping; + utils::populateMapping(mapping, *innerCtrlBody, innerTargets, + outerTargets, targets, targetArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + }); return success(); } @@ -63,16 +92,29 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + // Inline ops from empty control modifiers, IdOp and BarrierOp - if (op.getNumControls() == 0 || isa(bodyUnitary)) { - rewriter.moveOpBefore(bodyUnitary, op); - rewriter.replaceOp(op, bodyUnitary->getResults()); + if (op.getNumControls() == 0 || isa(innerOp)) { + const auto numTargets = op.getNumTargets(); + auto outerTargets = op.getTargets(); + SmallVector targets; + for (auto target : innerOp->getOperands().take_front(numTargets)) { + targets.push_back( + utils::getValueFromBlockArgument(target, outerTargets)); + } + + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numTargets, targets); + rewriter.eraseOp(op); return success(); } // The remaining code explicitly handles GPhaseOp and nothing else - auto gPhaseOp = dyn_cast(bodyUnitary); + auto gPhaseOp = dyn_cast(innerOp); if (!gPhaseOp) { return failure(); } @@ -84,30 +126,59 @@ struct ReduceCtrl final : OpRewritePattern { return success(); } - // Remove the last control and replace with a single POp with the removed - // control as target - auto controls = op.getControls(); - auto target = controls.back(); - controls = controls.drop_back(); - op->setOperands(controls); + // Adjust the segment sizes of the control and target operands + const auto opSegmentsAttrName = CtrlOp::getOperandSegmentSizeAttr(); + auto segmentsAttr = + op->getAttrOfType(opSegmentsAttrName); + auto newSegments = DenseI32ArrayAttr::get( + rewriter.getContext(), {segmentsAttr[0] - 1, segmentsAttr[1] + 1}); + op->setAttr(opSegmentsAttrName, newSegments); + + // Add a block argument for the target qubit + auto arg = op.getBody()->addArgument(QubitType::get(rewriter.getContext()), + op.getLoc()); + // Replace the current GPhaseOp with a PhaseOp const OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(gPhaseOp); - rewriter.replaceOpWithNewOp(gPhaseOp, target, gPhaseOp.getTheta()); + POp::create(rewriter, gPhaseOp.getLoc(), arg, gPhaseOp.getTheta()); + rewriter.eraseOp(gPhaseOp); return success(); } }; +/** + * @brief Erase control modifiers that do not have any body unitaries. + */ +struct EraseEmptyCtrl final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CtrlOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.eraseOp(op); + return success(); + } +}; + } // namespace -UnitaryOpInterface CtrlOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t CtrlOp::getNumBodyUnitaries() { + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); +} + +UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), static_cast(i)); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); + } + return cast(*it); } Value CtrlOp::getQubit(const size_t i) { @@ -116,9 +187,9 @@ Value CtrlOp::getQubit(const size_t i) { return getControls()[i]; } if (numControls <= i && i < getNumQubits()) { - return getBodyUnitary().getQubit(i - numControls); + return getTarget(i - numControls); } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getControl(const size_t i) { @@ -129,37 +200,33 @@ Value CtrlOp::getControl(const size_t i) { } void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, - ValueRange controls, - const function_ref& bodyBuilder) { - const OpBuilder::InsertionGuard guard(odsBuilder); - odsState.addOperands(controls); - auto* region = odsState.addRegion(); - auto& block = region->emplaceBlock(); + ValueRange controls, ValueRange targets, + const function_ref& body) { + build(odsBuilder, odsState, controls, targets); + auto& block = odsState.regions.front()->emplaceBlock(); + + auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < targets.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - bodyBuilder(); + body(block.getArguments()); YieldOp::create(odsBuilder, odsState.location); } LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } SmallPtrSet uniqueQubits; for (const auto& control : getControls()) { @@ -167,11 +234,9 @@ LogicalResult CtrlOp::verify() { return emitOpError("duplicate control qubit found"); } } - auto bodyUnitary = getBodyUnitary(); - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubits.insert(bodyUnitary.getQubit(i)).second) { - return emitOpError("duplicate qubit found"); + for (const auto& target : getTargets()) { + if (!uniqueQubits.insert(target).second) { + return emitOpError("duplicate target qubit found"); } } @@ -180,5 +245,5 @@ LogicalResult CtrlOp::verify() { void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp index 065fe431be..b1cf6e46ae 100644 --- a/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QC/IR/Modifiers/InvOp.cpp @@ -8,17 +8,24 @@ * Licensed under the MIT License */ +#include "mlir/Dialect/QC/IR/QCDialect.h" #include "mlir/Dialect/QC/IR/QCOps.h" +#include "mlir/Dialect/Utils/Utils.h" +#include #include +#include #include #include #include +#include #include #include #include #include +#include +#include #include using namespace mlir; @@ -33,20 +40,36 @@ namespace { struct MoveCtrlOutside final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto bodyUnitary = invOp.getBodyUnitary(); - auto innerCtrlOp = dyn_cast(bodyUnitary.getOperation()); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerCtrlOp) { return failure(); } - auto controls = innerCtrlOp.getControls(); - rewriter.replaceOpWithNewOp(invOp, controls, [&] { - InvOp::create(rewriter, invOp.getLoc(), [&] { - rewriter.clone(*innerCtrlOp.getBodyUnitary().getOperation()); - }); - }); + const auto numControls = innerCtrlOp.getNumControls(); + const auto numTargets = innerCtrlOp.getNumTargets(); + auto outerQubits = op.getQubits(); + auto controls = outerQubits.take_front(numControls); + auto targets = outerQubits.take_back(numTargets); + + rewriter.replaceOpWithNewOp( + op, controls, targets, [&](ValueRange targetArgs) { + InvOp::create( + rewriter, op.getLoc(), targetArgs, [&](ValueRange qubitArgs) { + auto* innerCtrlBody = innerCtrlOp.getBody(); + IRMapping mapping; + utils::populateMapping(mapping, *innerCtrlBody, + innerCtrlOp.getTargets(), outerQubits, + targets, qubitArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + }); + }); return success(); } @@ -62,13 +85,24 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); if (!isa(innerOp)) { return failure(); } + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getQubits(); + SmallVector qubits; + for (auto qubit : innerOp->getOperands().take_front(numQubits)) { + qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); + } + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numQubits, qubits); rewriter.replaceOp(op, innerOp->getResults()); return success(); } @@ -85,143 +119,181 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + + auto loc = op.getLoc(); + auto outerQubits = op.getQubits(); return TypeSwitch(innerOp) .Case([&](auto g) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), g.getTheta()); + Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); rewriter.replaceOpWithNewOp(op, negTheta); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto t) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(t.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto tdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(tdg.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto s) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(s.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sdg.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sx) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sx.getTarget(0), outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(0)); + .Case([&](auto sxdg) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(sxdg.getTarget(0), outerQubits)); return success(); }) .Case([&](auto p) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), p.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(p.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto r) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), r.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta, - r.getPhi()); + auto negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(r.getTarget(0), outerQubits), + negTheta, r.getPhi()); return success(); }) .Case([&](auto rx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rx.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto u) { - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - Value newTheta = - arith::NegFOp::create(rewriter, op.getLoc(), u.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), newTheta, - newPhi, newLambda); + Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); + Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); + Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(u.getTarget(0), outerQubits), + newTheta, newPhi, newLambda); return success(); }) - .Case([&](auto u) { - auto pi = arith::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - newPhi = arith::SubFOp::create(rewriter, op.getLoc(), newPhi, pi); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - newLambda = - arith::AddFOp::create(rewriter, op.getLoc(), newLambda, pi); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), newPhi, - newLambda); + .Case([&](auto u2) { + Value pi = arith::ConstantOp::create( + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u2.getTarget(0), outerQubits), + newPhi, newLambda); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getTarget(1), - op.getTarget(0)); + .Case([&](auto dcx) { + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(dcx.getTarget(1), outerQubits), + utils::getValueFromBlockArgument(dcx.getTarget(0), outerQubits)); return success(); }) .Case([&](auto rxx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rxx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rxx.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rxx.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto ry) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ry.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ry.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto ryy) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ryy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ryy.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(ryy.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto rz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rz.getTarget(0), outerQubits), + negTheta); return success(); }) .Case([&](auto rzx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzx.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rzx.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto rzz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzz.getTarget(0), outerQubits), + utils::getValueFromBlockArgument(rzz.getTarget(1), outerQubits), + negTheta); return success(); }) .Case([&](auto xxminusyy) { - Value negTheta = arith::NegFOp::create(rewriter, op.getLoc(), - xxminusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta, - xxminusyy.getBeta()); + Value negTheta = + arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxminusyy.getTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxminusyy.getTarget(1), + outerQubits), + negTheta, xxminusyy.getBeta()); return success(); }) .Case([&](auto xxplusyy) { Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getTarget(0), - op.getTarget(1), negTheta, - xxplusyy.getBeta()); + arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxplusyy.getTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxplusyy.getTarget(1), + outerQubits), + negTheta, xxplusyy.getBeta()); return success(); }) .Default([&](auto) { return failure(); }); @@ -233,68 +305,104 @@ struct ReplaceWithKnownGates final : OpRewritePattern { */ struct CancelNestedInv final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto innerUnitary = invOp.getBodyUnitary(); - auto innerInvOp = dyn_cast(innerUnitary.getOperation()); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerInvOp) { return failure(); } - auto* innerInnerUnitary = innerInvOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerInnerUnitary, invOp); - rewriter.replaceOp(invOp, innerInnerUnitary->getResults()); + if (innerInvOp.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getQubits(); + auto innerQubits = innerInvOp.getQubits(); + SmallVector qubits; + for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { + auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); + qubits.push_back( + utils::getValueFromBlockArgument(innerQubit, outerQubits)); + } + + rewriter.moveOpBefore(innerInnerOp, op); + innerInnerOp->setOperands(0, numQubits, qubits); + rewriter.replaceOp(op, innerInnerOp->getResults()); + return success(); + } +}; + +/** + * @brief Erase inverse modifiers that do not have any body unitaries. + */ +struct EraseEmptyInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + rewriter.eraseOp(op); return success(); } }; } // namespace -UnitaryOpInterface InvOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t InvOp::getNumBodyUnitaries() { + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); +} + +UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), static_cast(i)); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); + } + return cast(*it); } void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, - const function_ref& bodyBuilder) { - const OpBuilder::InsertionGuard guard(odsBuilder); - auto* region = odsState.addRegion(); - auto& block = region->emplaceBlock(); + ValueRange qubits, + const function_ref& body) { + build(odsBuilder, odsState, qubits); + auto& block = odsState.regions.front()->emplaceBlock(); + auto qubitType = QubitType::get(odsBuilder.getContext()); + for (size_t i = 0; i < qubits.size(); ++i) { + block.addArgument(qubitType, odsState.location); + } + + const OpBuilder::InsertionGuard guard(odsBuilder); odsBuilder.setInsertionPointToStart(&block); - bodyBuilder(); + body(block.getArguments()); YieldOp::create(odsBuilder, odsState.location); } LogicalResult InvOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); } if (!isa(block.back())) { return emitOpError( "last operation in body region must be a yield operation"); } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } return success(); } void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + ReplaceWithKnownGates, EraseEmptyInv>(context); } diff --git a/mlir/lib/Dialect/QC/IR/QCOps.cpp b/mlir/lib/Dialect/QC/IR/QCOps.cpp index 5b93c2ebaa..6a72833861 100644 --- a/mlir/lib/Dialect/QC/IR/QCOps.cpp +++ b/mlir/lib/Dialect/QC/IR/QCOps.cpp @@ -11,6 +11,14 @@ #include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QC/IR/QCDialect.h" // IWYU pragma: associated +#include "mlir/Dialect/Utils/Utils.h" + +#include +#include +#include +#include +#include +#include // The following headers are needed for some template instantiations. // IWYU pragma: begin_keep @@ -21,6 +29,17 @@ using namespace mlir; using namespace mlir::qc; +static ParseResult +parseTargetAliasing(OpAsmParser& parser, Region& region, + SmallVectorImpl& operands) { + return utils::parseTargetAliasing(parser, region, operands); +} + +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + utils::printTargetAliasing(printer, region, targetsIn); +} + //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp index ac70cd4269..7a2cf23651 100644 --- a/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp +++ b/mlir/lib/Dialect/QC/Translation/TranslateQuantumComputationToQC.cpp @@ -10,8 +10,10 @@ #include "mlir/Dialect/QC/Translation/TranslateQuantumComputationToQC.h" +#include "ir/Definitions.hpp" #include "ir/QuantumComputation.hpp" #include "ir/Register.hpp" +#include "ir/operations/CompoundOperation.hpp" #include "ir/operations/Control.hpp" #include "ir/operations/IfElseOperation.hpp" #include "ir/operations/NonUnitaryOperation.hpp" @@ -19,6 +21,7 @@ #include "ir/operations/Operation.hpp" #include "mlir/Dialect/QC/Builder/QCProgramBuilder.h" +#include #include #include #include @@ -73,6 +76,30 @@ struct TranslationState { /// Whether the translation is currently processing an IfElseOperation bool inIfElse = false; + + /// Whether the translation is currently within a control modifier + bool inCtrlOp = false; + + /// Mapping from physical qubit index to block argument + DenseMap targetArgs; + + /// Control qubits of the current CompoundOperation + DenseSet<::qc::Qubit> compoundControls; + + [[nodiscard]] Value getQubit(size_t index) const { + if (inCtrlOp) { + auto it = targetArgs.find(index); + if (it == targetArgs.end()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return it->second; + } + + if (index >= qubits.size()) { + llvm::reportFatalInternalError("Qubit index out of bounds"); + } + return qubits[index]; + }; }; } // namespace @@ -222,7 +249,7 @@ static void addMeasureOp(QCProgramBuilder& builder, const auto& classics = measureOp.getClassics(); for (size_t i = 0; i < targets.size(); ++i) { - const auto& qubit = state.qubits[targets[i]]; + const auto& qubit = state.getQubit(targets[i]); const auto bitIdx = static_cast(classics[i]); const auto& [mem, localIdx] = state.bitMap[bitIdx]; const auto& bit = mem[static_cast(localIdx)]; @@ -239,13 +266,13 @@ static void addMeasureOp(QCProgramBuilder& builder, * * @param builder The QCProgramBuilder used to create operations * @param operation The reset operation to translate - * @param qubits Flat vector of qubit values indexed by physical qubit index + * @param state The translation state */ static void addResetOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { for (const auto& target : operation.getTargets()) { - auto qubit = qubits[target]; + auto qubit = state.getQubit(target); builder.reset(qubit); } } @@ -258,18 +285,21 @@ static void addResetOp(QCProgramBuilder& builder, * the qubit values corresponding to positive controls. * * @param operation The operation containing controls - * @param qubits Flat vector of qubit values indexed by physical qubit index + * @param state The translation state * @return Vector of qubit values corresponding to positive controls */ static SmallVector getControls(const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { SmallVector controls; for (const auto& [control, type] : operation.getControls()) { + if (state.compoundControls.contains(control)) { + continue; + } if (type == ::qc::Control::Type::Neg) { llvm::reportFatalInternalError( "Negative controls cannot be translated to QC at the moment"); } - controls.push_back(qubits[control]); + controls.push_back(state.getQubit(control)); } return controls; } @@ -286,13 +316,13 @@ static SmallVector getControls(const ::qc::Operation& operation, * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + TranslationState& state) { \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(target); \ } else { \ @@ -326,14 +356,14 @@ DEFINE_ONE_TARGET_ZERO_PARAMETER(SXdg, sxdg) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param = operation.getParameter()[0]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param, target); \ } else { \ @@ -358,15 +388,15 @@ DEFINE_ONE_TARGET_ONE_PARAMETER(P, p) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, target); \ } else { \ @@ -391,16 +421,16 @@ DEFINE_ONE_TARGET_TWO_PARAMETER(U2, u2) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ const auto& param3 = operation.getParameter()[2]; \ - const auto& target = qubits[operation.getTargets()[0]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target = state.getQubit(operation.getTargets()[0]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, param3, target); \ } else { \ @@ -424,14 +454,14 @@ DEFINE_ONE_TARGET_THREE_PARAMETER(U, u) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + TranslationState& state) { \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(target0, target1); \ } else { \ @@ -448,14 +478,18 @@ DEFINE_TWO_TARGET_ZERO_PARAMETER(ECR, ecr) static void addISWAPdgOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { - auto target0 = qubits[operation.getTargets()[0]]; - auto target1 = qubits[operation.getTargets()[1]]; - if (const auto& controls = getControls(operation, qubits); controls.empty()) { - builder.inv([&] { builder.iswap(target0, target1); }); + TranslationState& state) { + auto target0 = state.getQubit(operation.getTargets()[0]); + auto target1 = state.getQubit(operation.getTargets()[1]); + if (const auto& controls = getControls(operation, state); controls.empty()) { + builder.inv({target0, target1}, [&](ValueRange qubits) { + builder.iswap(qubits[0], qubits[1]); + }); } else { - builder.ctrl(controls, [&] { - builder.inv([&] { builder.iswap(target0, target1); }); + builder.ctrl(controls, {target0, target1}, [&](ValueRange targets) { + builder.inv(targets, [&](ValueRange qubits) { + builder.iswap(qubits[0], qubits[1]); + }); }); } } @@ -472,15 +506,15 @@ static void addISWAPdgOp(QCProgramBuilder& builder, * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param = operation.getParameter()[0]; \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param, target0, target1); \ } else { \ @@ -507,16 +541,16 @@ DEFINE_TWO_TARGET_ONE_PARAMETER(RZZ, rzz) * \ * @param builder The QCProgramBuilder used to create operations \ * @param operation The OP_CORE operation to translate \ - * @param qubits Flat vector of qubit values indexed by physical qubit index \ + * @param state The translation state \ */ \ static void add##OP_CORE##Op(QCProgramBuilder& builder, \ const ::qc::Operation& operation, \ - const SmallVector& qubits) { \ + TranslationState& state) { \ const auto& param1 = operation.getParameter()[0]; \ const auto& param2 = operation.getParameter()[1]; \ - const auto& target0 = qubits[operation.getTargets()[0]]; \ - const auto& target1 = qubits[operation.getTargets()[1]]; \ - if (const auto& controls = getControls(operation, qubits); \ + const auto& target0 = state.getQubit(operation.getTargets()[0]); \ + const auto& target1 = state.getQubit(operation.getTargets()[1]); \ + if (const auto& controls = getControls(operation, state); \ controls.empty()) { \ builder.OP_QC(param1, param2, target0, target1); \ } else { \ @@ -533,10 +567,10 @@ DEFINE_TWO_TARGET_TWO_PARAMETER(XXminusYY, xx_minus_yy) static void addBarrierOp(QCProgramBuilder& builder, const ::qc::Operation& operation, - const SmallVector& qubits) { + TranslationState& state) { SmallVector targets; for (const auto& targetIdx : operation.getTargets()) { - targets.push_back(qubits[targetIdx]); + targets.push_back(state.getQubit(targetIdx)); } builder.barrier(targets); } @@ -546,6 +580,72 @@ static LogicalResult translateOperation(QCProgramBuilder& builder, const ::qc::Operation& operation, TranslationState& state); +// CompoundOp + +static LogicalResult addCompoundOp(QCProgramBuilder& builder, + const ::qc::Operation& operation, + TranslationState& state) { + const auto& compoundOp = + dynamic_cast(operation); + if (const auto& controls = getControls(operation, state); controls.empty()) { + for (const auto& op : compoundOp) { + if (failed(translateOperation(builder, *op, state))) { + return failure(); + } + } + } else { + // Collect targets + DenseMap targetMap; + for (const auto& op : compoundOp) { + if (dynamic_cast(op.get()) != nullptr) { + llvm::reportFatalInternalError("Nested CompoundOperations cannot be " + "translated to QC at the moment"); + } + for (const auto& target : op->getTargets()) { + if (!targetMap.contains(target)) { + targetMap[target] = state.getQubit(target); + } + } + for (const auto& control : op->getControls()) { + if (compoundOp.getControls().contains(control)) { + continue; + } + const auto& qubit = control.qubit; + if (!targetMap.contains(qubit)) { + targetMap[qubit] = state.getQubit(qubit); + } + } + } + for (const auto& [control, _] : compoundOp.getControls()) { + state.compoundControls.insert(control); + } + SmallVector> sortedPairs(targetMap.begin(), + targetMap.end()); + llvm::sort(sortedPairs.begin(), sortedPairs.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); + SmallVector targets; + for (const auto& pair : sortedPairs) { + targets.push_back(pair.second); + } + // Build control modifier + builder.ctrl(controls, targets, [&](ValueRange targetArgs) { + state.inCtrlOp = true; + for (size_t i = 0; i < sortedPairs.size(); ++i) { + state.targetArgs[sortedPairs[i].first] = targetArgs[i]; + } + for (const auto& op : compoundOp) { + if (failed(translateOperation(builder, *op, state))) { + llvm::reportFatalInternalError("Failed to translate operation inside " + "controlled CompoundOperation"); + } + } + state.targetArgs.clear(); + state.inCtrlOp = false; + }); + } + return success(); +} + // IfElseOp static LogicalResult addIfElseOp(QCProgramBuilder& builder, @@ -622,7 +722,7 @@ static LogicalResult addIfElseOp(QCProgramBuilder& builder, #define ADD_OP_CASE(OP_CORE) \ case ::qc::OpType::OP_CORE: \ - add##OP_CORE##Op(builder, operation, qubits); \ + add##OP_CORE##Op(builder, operation, state); \ return success(); /** @@ -636,7 +736,6 @@ static LogicalResult addIfElseOp(QCProgramBuilder& builder, static LogicalResult translateOperation(QCProgramBuilder& builder, const ::qc::Operation& operation, TranslationState& state) { - const auto& qubits = state.qubits; switch (operation.getType()) { case ::qc::OpType::Measure: addMeasureOp(builder, operation, state); @@ -672,7 +771,12 @@ static LogicalResult translateOperation(QCProgramBuilder& builder, ADD_OP_CASE(XXminusYY) ADD_OP_CASE(Barrier) case ::qc::OpType::iSWAPdg: - addISWAPdgOp(builder, operation, qubits); + addISWAPdgOp(builder, operation, state); + return success(); + case ::qc::OpType::Compound: + if (failed(addCompoundOp(builder, operation, state))) { + return failure(); + } return success(); case ::qc::OpType::IfElse: if (failed(addIfElseOp(builder, operation, state))) { @@ -759,8 +863,10 @@ OwningOpRef translateQuantumComputationToQC( // Allocate result map SmallVector results(quantumComputation.getNcbits(), nullptr); - TranslationState state{ - .qubits = qubits, .bitMap = bitMap, .results = std::move(results)}; + TranslationState state{.qubits = qubits, + .bitMap = bitMap, + .results = std::move(results), + .targetArgs = DenseMap{}}; // Translate operations if (translateOperations(builder, quantumComputation, state).failed()) { diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp index 25fc88d084..d9a67fafca 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/CtrlOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -27,6 +28,7 @@ #include #include #include +#include #include using namespace mlir; @@ -42,38 +44,53 @@ struct MergeNestedCtrl final : OpRewritePattern { LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - // Require at least one positive control + // Require at least one control // Trivial case is handled by ReduceCtrl - const auto numOuterControls = op.getNumControls(); - if (numOuterControls == 0) { + if (op.getNumControls() == 0) { return failure(); } - auto bodyCtrlOp = dyn_cast(op.getBodyUnitary().getOperation()); - if (!bodyCtrlOp) { + if (op.getNumBodyUnitaries() != 1) { return failure(); } - const auto numInnerControls = bodyCtrlOp.getNumControls(); - auto outerControls = op.getControlsIn(); + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); + if (!innerCtrlOp) { + return failure(); + } + auto outerTargets = op.getTargetsIn(); - auto newAdditionalControls = outerTargets.take_front(numInnerControls); - auto newTargets = outerTargets.drop_front(numInnerControls); - auto newControls = llvm::to_vector( - llvm::concat(outerControls, newAdditionalControls)); + auto outerControls = op.getControlsIn(); + auto innerTargets = innerCtrlOp.getTargetsIn(); + + SmallVector controls; + SmallVector targets; + llvm::append_range(controls, outerControls); + for (auto [arg, qubit] : + llvm::zip_equal(op.getBody()->getArguments(), outerTargets)) { + if (llvm::is_contained(innerTargets, arg)) { + targets.push_back(qubit); + } else { + controls.push_back(qubit); + } + } rewriter.replaceOpWithNewOp( - op, newControls, newTargets, - [&](ValueRange newTargetArgs) -> SmallVector { + op, controls, targets, + [&](ValueRange targetArgs) -> SmallVector { + auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - auto* innerBody = bodyCtrlOp.getBody(); - for (size_t i = 0; i < bodyCtrlOp.getNumTargets(); ++i) { - mapping.map(innerBody->getArgument(i), newTargetArgs[i]); + utils::populateMapping(mapping, *innerCtrlBody, innerTargets, + outerTargets, targets, targetArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); } - - return rewriter - .clone(*bodyCtrlOp.getBodyUnitary().getOperation(), mapping) - ->getResults(); + SmallVector yields; + for (auto value : innerCtrlBody->getTerminator()->getOperands()) { + yields.push_back(mapping.lookup(value)); + } + return yields; }); + return success(); } }; @@ -87,20 +104,31 @@ struct ReduceCtrl final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CtrlOp op, PatternRewriter& rewriter) const override { - auto* bodyUnitary = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + // Inline ops from empty control modifiers, IdOp and BarrierOp - if (op.getNumControls() == 0 || isa(bodyUnitary)) { - rewriter.moveOpBefore(bodyUnitary, op); - bodyUnitary->setOperands(0, op.getNumTargets(), op.getTargetsIn()); + if (op.getNumControls() == 0 || isa(innerOp)) { + const auto numTargets = op.getNumTargets(); + auto outerTargets = op.getTargetsIn(); + SmallVector targets; + for (auto target : innerOp->getOperands().take_front(numTargets)) { + targets.push_back( + utils::getValueFromBlockArgument(target, outerTargets)); + } + + rewriter.moveOpBefore(innerOp, op); + innerOp->setOperands(0, numTargets, targets); rewriter.replaceAllUsesWith(op.getControlsOut(), op.getControlsIn()); - rewriter.replaceAllUsesWith(op.getTargetsOut(), - bodyUnitary->getResults()); + rewriter.replaceAllUsesWith(op.getTargetsOut(), innerOp->getResults()); rewriter.eraseOp(op); return success(); } // The remaining code explicitly handles GPhaseOp and nothing else - auto gPhaseOp = dyn_cast(bodyUnitary); + auto gPhaseOp = dyn_cast(innerOp); if (!gPhaseOp) { return failure(); } @@ -136,22 +164,44 @@ struct ReduceCtrl final : OpRewritePattern { auto yieldOp = cast(op.getBody()->back()); yieldOp->setOperands(pOp->getResults()); - // erase the GPhaseOp + // Erase the GPhaseOp rewriter.eraseOp(gPhaseOp); return success(); } }; +/** + * @brief Erase control modifiers that do not have any body unitaries. + */ +struct EraseEmptyCtrl final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CtrlOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; + } // namespace -UnitaryOpInterface CtrlOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t CtrlOp::getNumBodyUnitaries() { + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); +} + +UnitaryOpInterface CtrlOp::getBodyUnitary(const size_t i) { + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), static_cast(i)); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); + } + return cast(*it); } Value CtrlOp::getInputQubit(const size_t i) { @@ -162,7 +212,7 @@ Value CtrlOp::getInputQubit(const size_t i) { if (numControls <= i && i < getNumQubits()) { return getTargetsIn()[i - numControls]; } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getOutputQubit(const size_t i) { @@ -173,7 +223,7 @@ Value CtrlOp::getOutputQubit(const size_t i) { if (numControls <= i && i < getNumQubits()) { return getTargetsOut()[i - numControls]; } - llvm::reportFatalUsageError("Invalid qubit index"); + llvm::reportFatalUsageError("Qubit index out of bounds"); } Value CtrlOp::getInputTarget(const size_t i) { @@ -238,7 +288,7 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, controls, targets); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); + auto qubitType = QubitType::get(odsBuilder.getContext()); for (size_t i = 0; i < targets.size(); ++i) { block.addArgument(qubitType, odsState.location); } @@ -251,40 +301,33 @@ void CtrlOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult CtrlOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); } + if (!isa(block.back())) { + return emitOpError( + "last operation in body region must be a yield operation"); + } + const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { if (block.getArgument(i).getType() != qubitType) { return emitOpError("block argument type at index ") << i << " does not match target type"; } } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } if (const auto numYieldOperands = block.back().getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } SmallPtrSet uniqueQubitsIn; for (const auto& control : getControlsIn()) { @@ -298,36 +341,14 @@ LogicalResult CtrlOp::verify() { } } - auto bodyUnitary = getBodyUnitary(); - if (bodyUnitary.getNumQubits() != numTargets) { - return emitOpError("body unitary must operate on exactly ") - << numTargets << " target qubits, but found " - << bodyUnitary.getNumQubits(); - } - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { - return emitOpError("body unitary must use target alias block argument ") - << i << " (and not the original target operand)"; - } - } - - // Also require yield to forward the unitary's outputs in-order. - for (size_t i = 0; i < numTargets; ++i) { - if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { - return emitOpError("yield operand ") - << i << " must be the body unitary output qubit " << i; - } - } - SmallPtrSet uniqueQubitsOut; for (const auto& control : getControlsOut()) { if (!uniqueQubitsOut.insert(control).second) { return emitOpError("duplicate control qubit found"); } } - for (size_t i = 0; i < numQubits; i++) { - if (!uniqueQubitsOut.insert(bodyUnitary.getOutputQubit(i)).second) { + for (size_t i = 0; i < numTargets; i++) { + if (!uniqueQubitsOut.insert(block.back().getOperand(i)).second) { return emitOpError("duplicate qubit found"); } } @@ -337,15 +358,19 @@ LogicalResult CtrlOp::verify() { void CtrlOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } std::optional CtrlOp::getUnitaryMatrix() { - auto&& bodyUnitary = getBodyUnitary(); + if (getNumBodyUnitaries() != 1) { + return std::nullopt; + } + + auto bodyUnitary = getBodyUnitary(0); if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto targetMatrix = bodyUnitary.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp index d82a64f819..1f0e6d556c 100644 --- a/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Modifiers/InvOp.cpp @@ -10,8 +10,10 @@ #include "mlir/Dialect/QCO/IR/QCODialect.h" #include "mlir/Dialect/QCO/IR/QCOOps.h" +#include "mlir/Dialect/Utils/Utils.h" #include +#include #include #include #include @@ -25,6 +27,7 @@ #include #include +#include #include #include @@ -40,36 +43,42 @@ namespace { struct MoveCtrlOutside final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InvOp invOp, + LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto bodyUnitary = invOp.getBodyUnitary(); - auto innerCtrlOp = dyn_cast(bodyUnitary.getOperation()); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerCtrlOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerCtrlOp) { return failure(); } const auto numControls = innerCtrlOp.getNumControls(); const auto numTargets = innerCtrlOp.getNumTargets(); - auto invTargets = invOp.getInputQubits(); - auto controls = invTargets.take_front(numControls); - auto targets = invTargets.take_back(numTargets); + auto outerQubits = op.getQubitsIn(); + auto controls = outerQubits.take_front(numControls); + auto targets = outerQubits.take_back(numTargets); rewriter.replaceOpWithNewOp( - invOp, controls, targets, - [&](ValueRange newTargetArgs) -> SmallVector { + op, controls, targets, + [&](ValueRange targetArgs) -> SmallVector { return InvOp::create( - rewriter, invOp.getLoc(), newTargetArgs, - [&](ValueRange invArgs) -> SmallVector { + rewriter, op.getLoc(), targetArgs, + [&](ValueRange qubitArgs) -> SmallVector { + auto* innerCtrlBody = innerCtrlOp.getBody(); IRMapping mapping; - auto* innerBody = innerCtrlOp.getBody(); - for (size_t i = 0; i < innerCtrlOp.getNumTargets(); - ++i) { - mapping.map(innerBody->getArgument(i), invArgs[i]); + utils::populateMapping(mapping, *innerCtrlBody, + innerCtrlOp.getTargetsIn(), + outerQubits, targets, qubitArgs); + for (auto& op : innerCtrlBody->without_terminator()) { + rewriter.clone(op, mapping); + } + SmallVector yields; + for (auto value : + innerCtrlBody->getTerminator()->getOperands()) { + yields.push_back(mapping.lookup(value)); } - auto* cloned = rewriter.clone( - *innerCtrlOp.getBodyUnitary().getOperation(), - mapping); - return cloned->getResults(); + return yields; }) .getResults(); }); @@ -88,14 +97,24 @@ struct InlineSelfAdjoint final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); if (!isa(innerOp)) { return failure(); } + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getInputQubits(); + SmallVector qubits; + for (auto qubit : innerOp->getOperands().take_front(numQubits)) { + qubits.push_back(utils::getValueFromBlockArgument(qubit, outerQubits)); + } + rewriter.moveOpBefore(innerOp, op); - innerOp->setOperands(0, op.getNumQubits(), op.getInputQubits()); + innerOp->setOperands(0, numQubits, qubits); rewriter.replaceOp(op, innerOp->getResults()); return success(); } @@ -112,138 +131,192 @@ struct ReplaceWithKnownGates final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerOp = op.getBodyUnitary().getOperation(); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerOp = op.getBodyUnitary(0).getOperation(); + + auto loc = op.getLoc(); + auto outerQubits = op.getInputQubits(); return TypeSwitch(innerOp) .Case([&](auto g) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), g.getTheta()); + Value negTheta = arith::NegFOp::create(rewriter, loc, g.getTheta()); rewriter.replaceOpWithNewOp(op, negTheta); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto t) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(t.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto tdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(tdg.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto s) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(s.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sdg.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sx) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sx.getInputTarget(0), + outerQubits)); return success(); }) - .Case([&](auto) { - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0)); + .Case([&](auto sxdg) { + rewriter.replaceOpWithNewOp( + op, utils::getValueFromBlockArgument(sxdg.getInputTarget(0), + outerQubits)); return success(); }) .Case([&](auto p) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), p.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, p.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(p.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto r) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), r.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta, - r.getPhi()); + Value negTheta = arith::NegFOp::create(rewriter, loc, r.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(r.getInputTarget(0), + outerQubits), + negTheta, r.getPhi()); return success(); }) .Case([&](auto rx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rx.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto u) { - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - Value newTheta = - arith::NegFOp::create(rewriter, op.getLoc(), u.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), newTheta, - newPhi, newLambda); + Value newPhi = arith::NegFOp::create(rewriter, loc, u.getLambda()); + Value newLambda = arith::NegFOp::create(rewriter, loc, u.getPhi()); + Value newTheta = arith::NegFOp::create(rewriter, loc, u.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u.getInputTarget(0), + outerQubits), + newTheta, newPhi, newLambda); return success(); }) - .Case([&](auto u) { + .Case([&](auto u2) { auto pi = arith::ConstantOp::create( - rewriter, op.getLoc(), - rewriter.getF64FloatAttr(std::numbers::pi)); - Value newPhi = - arith::NegFOp::create(rewriter, op.getLoc(), u.getLambda()); - newPhi = arith::SubFOp::create(rewriter, op.getLoc(), newPhi, pi); - Value newLambda = - arith::NegFOp::create(rewriter, op.getLoc(), u.getPhi()); - newLambda = - arith::AddFOp::create(rewriter, op.getLoc(), newLambda, pi); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), newPhi, - newLambda); + rewriter, loc, rewriter.getF64FloatAttr(std::numbers::pi)); + Value newPhi = arith::NegFOp::create(rewriter, loc, u2.getLambda()); + newPhi = arith::SubFOp::create(rewriter, loc, newPhi, pi); + Value newLambda = arith::NegFOp::create(rewriter, loc, u2.getPhi()); + newLambda = arith::AddFOp::create(rewriter, loc, newLambda, pi); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(u2.getInputTarget(0), + outerQubits), + newPhi, newLambda); return success(); }) .Case([&](auto rxx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rxx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rxx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rxx.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rxx.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto ry) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ry.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ry.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ry.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto ryy) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), ryy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, ryy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(ryy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(ryy.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto rz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rz.getInputTarget(0), + outerQubits), + negTheta); return success(); }) .Case([&](auto rzx) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzx.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzx.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzx.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rzx.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto rzz) { - Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), rzz.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), negTheta); + Value negTheta = arith::NegFOp::create(rewriter, loc, rzz.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(rzz.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(rzz.getInputTarget(1), + outerQubits), + negTheta); return success(); }) .Case([&](auto xxminusyy) { - Value negTheta = arith::NegFOp::create(rewriter, op.getLoc(), - xxminusyy.getTheta()); + Value negTheta = + arith::NegFOp::create(rewriter, loc, xxminusyy.getTheta()); rewriter.replaceOpWithNewOp( - op, op.getInputTarget(0), op.getInputTarget(1), negTheta, - xxminusyy.getBeta()); + op, + utils::getValueFromBlockArgument(xxminusyy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxminusyy.getInputTarget(1), + outerQubits), + negTheta, xxminusyy.getBeta()); return success(); }) .Case([&](auto xxplusyy) { Value negTheta = - arith::NegFOp::create(rewriter, op.getLoc(), xxplusyy.getTheta()); - rewriter.replaceOpWithNewOp(op, op.getInputTarget(0), - op.getInputTarget(1), - negTheta, xxplusyy.getBeta()); + arith::NegFOp::create(rewriter, loc, xxplusyy.getTheta()); + rewriter.replaceOpWithNewOp( + op, + utils::getValueFromBlockArgument(xxplusyy.getInputTarget(0), + outerQubits), + utils::getValueFromBlockArgument(xxplusyy.getInputTarget(1), + outerQubits), + negTheta, xxplusyy.getBeta()); return success(); }) .Default([&](auto) { return failure(); }); @@ -258,30 +331,67 @@ struct CancelNestedInv final : OpRewritePattern { LogicalResult matchAndRewrite(InvOp op, PatternRewriter& rewriter) const override { - auto* innerUnitary = op.getBodyUnitary().getOperation(); - auto innerInvOp = dyn_cast(innerUnitary); + if (op.getNumBodyUnitaries() != 1) { + return failure(); + } + auto innerInvOp = dyn_cast(op.getBodyUnitary(0).getOperation()); if (!innerInvOp) { return failure(); } - auto* innerInnerUnitary = innerInvOp.getBodyUnitary().getOperation(); - rewriter.moveOpBefore(innerInnerUnitary, op); - innerInnerUnitary->setOperands(0, op.getNumQubits(), op.getInputQubits()); - rewriter.replaceOp(op, innerInnerUnitary->getResults()); + if (innerInvOp.getNumBodyUnitaries() != 1) { + return failure(); + } + auto* innerInnerOp = innerInvOp.getBodyUnitary(0).getOperation(); + + const auto numQubits = op.getNumQubits(); + auto outerQubits = op.getInputQubits(); + auto innerQubits = innerInvOp.getInputQubits(); + SmallVector qubits; + for (auto qubit : innerInnerOp->getOperands().take_front(numQubits)) { + auto innerQubit = utils::getValueFromBlockArgument(qubit, innerQubits); + qubits.push_back( + utils::getValueFromBlockArgument(innerQubit, outerQubits)); + } + rewriter.moveOpBefore(innerInnerOp, op); + innerInnerOp->setOperands(0, numQubits, qubits); + rewriter.replaceOp(op, innerInnerOp->getResults()); + return success(); + } +}; + +/** + * @brief Erase inverse modifiers that do not have any body unitaries. + */ +struct EraseEmptyInv final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(InvOp op, + PatternRewriter& rewriter) const override { + if (op.getNumBodyUnitaries() != 0) { + return failure(); + } + + rewriter.replaceOp(op, op.getOperands()); return success(); } }; } // namespace -UnitaryOpInterface InvOp::getBodyUnitary() { - // In principle, the body region should only contain exactly two operations, - // the actual unitary operation and a yield operation. However, the region may - // also contain constants and arithmetic operations, e.g., created as part of - // canonicalization. Thus, the only safe way to access the unitary operation - // is to get the second operation from the back of the region. - return cast(*(++getBody()->rbegin())); +size_t InvOp::getNumBodyUnitaries() { + return llvm::count_if( + *getBody(), [](Operation& op) { return isa(op); }); +} + +UnitaryOpInterface InvOp::getBodyUnitary(const size_t i) { + auto unitaries = llvm::make_filter_range( + *getBody(), [](Operation& op) { return isa(op); }); + auto it = std::next(unitaries.begin(), static_cast(i)); + if (it == unitaries.end()) { + llvm::reportFatalUsageError("Unitary index out of bounds"); + } + return cast(*it); } Value InvOp::getInputQubit(const size_t i) { @@ -322,7 +432,7 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, build(odsBuilder, odsState, qubits); auto& block = odsState.regions.front()->emplaceBlock(); - const auto qubitType = QubitType::get(odsBuilder.getContext()); + auto qubitType = QubitType::get(odsBuilder.getContext()); for (size_t i = 0; i < qubits.size(); ++i) { block.addArgument(qubitType, odsState.location); } @@ -335,60 +445,38 @@ void InvOp::build(OpBuilder& odsBuilder, OperationState& odsState, LogicalResult InvOp::verify() { auto& block = *getBody(); - if (block.getOperations().size() < 2) { - return emitOpError("body region must have at least two operations"); + if (llvm::any_of(*getBody(), [](Operation& op) { + return isa(op); + })) { + return emitOpError("body must not contain non-unitary quantum operations"); + } + if (!isa(block.back())) { + return emitOpError( + "last operation in body region must be a yield operation"); } + const auto numTargets = getNumTargets(); if (block.getArguments().size() != numTargets) { return emitOpError( "number of block arguments must match the number of targets"); } - const auto qubitType = QubitType::get(getContext()); + auto qubitType = QubitType::get(getContext()); for (size_t i = 0; i < numTargets; ++i) { if (block.getArgument(i).getType() != qubitType) { return emitOpError("block argument type at index ") << i << " does not match target type"; } } - if (!isa(block.back())) { - return emitOpError( - "last operation in body region must be a yield operation"); - } if (const auto numYieldOperands = block.back().getNumOperands(); numYieldOperands != numTargets) { return emitOpError("yield operation must yield ") << numTargets << " values, but found " << numYieldOperands; } - auto iter = ++block.rbegin(); - if (!isa(*iter)) { - return emitOpError( - "second to last operation in body region must be a unitary operation"); - } - for (auto it = ++iter; it != block.rend(); ++it) { - if (isa(*it)) { - return emitOpError("body region may only contain a single unitary op"); - } - } - auto bodyUnitary = getBodyUnitary(); - if (bodyUnitary.getNumQubits() != numTargets) { - return emitOpError("body unitary must operate on exactly ") - << numTargets << " target qubits, but found " - << bodyUnitary.getNumQubits(); - } - const auto numQubits = bodyUnitary.getNumQubits(); - for (size_t i = 0; i < numQubits; i++) { - if (bodyUnitary.getInputQubit(i) != block.getArgument(i)) { - return emitOpError("body unitary must use target alias block argument ") - << i << " (and not the original target operand)"; - } - } - - // Also require yield to forward the unitary's outputs in-order. - for (size_t i = 0; i < numTargets; ++i) { - if (block.back().getOperand(i) != bodyUnitary.getOutputQubit(i)) { - return emitOpError("yield operand ") - << i << " must be the body unitary output qubit " << i; + SmallPtrSet uniqueQubitsIn; + for (const auto& target : getQubitsIn()) { + if (!uniqueQubitsIn.insert(target).second) { + return emitOpError("duplicate qubit found"); } } @@ -398,15 +486,19 @@ LogicalResult InvOp::verify() { void InvOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { results.add(context); + CancelNestedInv, EraseEmptyInv>(context); } std::optional InvOp::getUnitaryMatrix() { - auto&& bodyUnitary = getBodyUnitary(); + if (getNumBodyUnitaries() != 1) { + return std::nullopt; + } + + auto bodyUnitary = getBodyUnitary(0); if (!bodyUnitary) { return std::nullopt; } - auto&& targetMatrix = bodyUnitary.getUnitaryMatrix(); + auto targetMatrix = bodyUnitary.getUnitaryMatrix(); if (!targetMatrix) { return std::nullopt; } diff --git a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp index a3ce816081..f1cb23a849 100644 --- a/mlir/lib/Dialect/QCO/IR/QCOOps.cpp +++ b/mlir/lib/Dialect/QCO/IR/QCOOps.cpp @@ -11,6 +11,7 @@ #include "mlir/Dialect/QCO/IR/QCOOps.h" #include "mlir/Dialect/QCO/IR/QCODialect.h" // IWYU pragma: associated +#include "mlir/Dialect/Utils/Utils.h" #include #include @@ -37,57 +38,12 @@ using namespace mlir::qco; static ParseResult parseTargetAliasing(OpAsmParser& parser, Region& region, SmallVectorImpl& operands) { - // 1. Parse the opening parenthesis - if (parser.parseLParen()) { - return failure(); - } - - // Temporary storage for block arguments we are about to create - SmallVector blockArgs; - - // 2. Prepare to parse the list - if (failed(parser.parseOptionalRParen())) { - do { - OpAsmParser::Argument newArg; // The "new" variable name - OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable - - // Parse "%new" - if (parser.parseArgument(newArg)) { - return failure(); - } - - // Parse "=" - if (parser.parseEqual()) { - return failure(); - } - - // Parse "%old" - if (parser.parseOperand(oldOperand)) { - return failure(); - } - operands.push_back(oldOperand); - - // Hard-code QubitType since targets in qco.ctrl are always qubits. - // This avoids double-binding type($targets_in) in the assembly format - // while keeping the parser simple and the assembly format clean. - newArg.type = QubitType::get(parser.getBuilder().getContext()); - blockArgs.push_back(newArg); - - } while (succeeded(parser.parseOptionalComma())); - - if (parser.parseRParen()) { - return failure(); - } - } - - // 4. Parse the Region - // We explicitly pass the blockArgs we just parsed so they become the entry - // block! - if (parser.parseRegion(region, blockArgs)) { - return failure(); - } + return utils::parseTargetAliasing(parser, region, operands); +} - return success(); +static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, + Region& region, OperandRange targetsIn) { + utils::printTargetAliasing(printer, region, targetsIn); } ParseResult IfOp::parse(::mlir::OpAsmParser& parser, @@ -213,30 +169,6 @@ void IfOp::print(OpAsmPrinter& p) { p.printOptionalAttrDict((*this)->getAttrs()); } -static void printTargetAliasing(OpAsmPrinter& printer, Operation* /*op*/, - Region& region, OperandRange targetsIn) { - printer << "("; - if (region.empty()) { - printer << ") "; - printer.printRegion(region, false); - return; - } - Block& entryBlock = region.front(); - - const auto numTargets = targetsIn.size(); - for (unsigned i = 0; i < numTargets; ++i) { - if (i > 0) { - printer << ", "; - } - printer.printOperand(entryBlock.getArgument(i)); - printer << " = "; - printer.printOperand(targetsIn[i]); - } - printer << ") "; - - printer.printRegion(region, false); -} - //===----------------------------------------------------------------------===// // Dialect //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp index 0ca22726a1..3d874533b6 100644 --- a/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp +++ b/mlir/lib/Dialect/QCO/Transforms/Optimizations/HadamardLifting.cpp @@ -162,7 +162,8 @@ struct LiftHadamardAboveCNOTPattern final : OpRewritePattern { if (!cnotGate) { return failure(); } - if (!isa(cnotGate.getBodyUnitary()) || + if (cnotGate.getNumBodyUnitaries() != 1 || + !isa(cnotGate.getBodyUnitary(0)) || cnotGate.getOutputTarget(0) != inQubitHadamard) { return failure(); } diff --git a/mlir/lib/Support/IRVerification.cpp b/mlir/lib/Support/IRVerification.cpp index eaac426f0a..0221464606 100644 --- a/mlir/lib/Support/IRVerification.cpp +++ b/mlir/lib/Support/IRVerification.cpp @@ -10,6 +10,7 @@ #include "mlir/Support/IRVerification.h" +#include "mlir/Dialect/QC/IR/QCOps.h" #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" #include @@ -33,11 +34,13 @@ #include #include #include +#include #include #include #include #include +#include #include using namespace mlir; @@ -469,7 +472,6 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, if (!rhsConst) { return false; } - if (!areConstantAttributesEquivalent(lhsConst.getValue(), rhsConst.getValue())) { return false; @@ -513,17 +515,37 @@ static bool areOperationsEquivalent(Operation* lhs, Operation* rhs, return false; } - // Check operands according to value mapping - for (auto [lhsOperand, rhsOperand] : - llvm::zip(lhs->getOperands(), rhs->getOperands())) { - if (auto it = valueMap.find(lhsOperand); it != valueMap.end()) { - // Value already mapped, must match - if (it->second != rhsOperand) { + ValueRange lhsOperands; + ValueRange rhsOperands; + if (auto lhsCtrl = dyn_cast(lhs)) { + auto rhsCtrl = dyn_cast(rhs); + if (!rhsCtrl) { + return false; + } + if (lhsCtrl.getTargets().size() != rhsCtrl.getTargets().size()) { + return false; + } + for (auto [lhsTarget, lhsArg] : + llvm::zip(lhsCtrl.getTargets(), lhsCtrl.getBody()->getArguments())) { + auto rhsTarget = valueMap[lhsTarget]; + if (!llvm::is_contained(rhsCtrl.getTargets(), rhsTarget)) { return false; } - } else { - // Establish new mapping - valueMap[lhsOperand] = rhsOperand; + auto it = llvm::find(rhsCtrl.getTargets(), rhsTarget); + auto index = std::distance(rhsCtrl.getTargets().begin(), it); + valueMap[lhsArg] = rhsCtrl.getBody()->getArgument(index); + } + lhsOperands = lhsCtrl.getControls(); + rhsOperands = rhsCtrl.getControls(); + } else { + lhsOperands = lhs->getOperands(); + rhsOperands = rhs->getOperands(); + } + + // Check operands according to value mapping + for (auto [lhsOperand, rhsOperand] : llvm::zip(lhsOperands, rhsOperands)) { + if (!areValuesEquivalent(lhsOperand, rhsOperand, valueMap)) { + return false; } } @@ -725,7 +747,9 @@ static bool areBlocksEquivalent(Block& lhs, Block& rhs, if (lhsArg.getType() != rhsArg.getType()) { return false; } - valueMap[lhsArg] = rhsArg; + if (!valueMap.contains(lhsArg)) { + valueMap[lhsArg] = rhsArg; + } } // Collect all operations diff --git a/mlir/unittests/Compiler/test_compiler_pipeline.cpp b/mlir/unittests/Compiler/test_compiler_pipeline.cpp index 44618eedae..d9072d6630 100644 --- a/mlir/unittests/Compiler/test_compiler_pipeline.cpp +++ b/mlir/unittests/Compiler/test_compiler_pipeline.cpp @@ -686,6 +686,9 @@ INSTANTIATE_TEST_SUITE_P( "MultipleControlledXXMinusYY", MQT_NAMED_BUILDER(qc::multipleControlledXxMinusYY), nullptr, MQT_NAMED_BUILDER(mlir::qc::multipleControlledXxMinusYY), - MQT_NAMED_BUILDER(mlir::qir::multipleControlledXxMinusYY)})); + MQT_NAMED_BUILDER(mlir::qir::multipleControlledXxMinusYY)}, + CompilerPipelineTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + nullptr, MQT_NAMED_BUILDER(mlir::qc::ctrlTwo), + MQT_NAMED_BUILDER(mlir::qir::ctrlTwo)})); } // namespace mqt::test::compiler diff --git a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp index 4bd2b24615..7dda9ccfda 100644 --- a/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp +++ b/mlir/unittests/Conversion/QCOToQC/test_qco_to_qc.cpp @@ -144,6 +144,20 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::allocDeallocPair)})); /// @} +/// \name QCOToQC/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCOCtrlOpTest, QCOToQCTest, + testing::Values(QCOToQCTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qco::ctrlTwo), + MQT_NAMED_BUILDER(qc::ctrlTwo)}, + QCOToQCTestCase{"CtrlTwoMixed", + MQT_NAMED_BUILDER(qco::ctrlTwoMixed), + MQT_NAMED_BUILDER(qc::ctrlTwoMixed)}, + QCOToQCTestCase{"CtrlInvTwo", + MQT_NAMED_BUILDER(qco::ctrlInvTwo), + MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); +/// @} + /// \name QCOToQC/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( @@ -160,7 +174,11 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::dcx)}, QCOToQCTestCase{"InverseMultipleControlledDCX", MQT_NAMED_BUILDER(qco::inverseMultipleControlledDcx), - MQT_NAMED_BUILDER(qc::multipleControlledDcx)})); + MQT_NAMED_BUILDER(qc::multipleControlledDcx)}, + QCOToQCTestCase{"InvTwo", MQT_NAMED_BUILDER(qco::invTwo), + MQT_NAMED_BUILDER(qc::invTwo)}, + QCOToQCTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qco::invCtrlTwo), + MQT_NAMED_BUILDER(qc::ctrlInvTwo)})); /// @} /// \name QCOToQC/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp index 3f0df25542..00b2c7fe7b 100644 --- a/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp +++ b/mlir/unittests/Conversion/QCToQCO/test_qc_to_qco.cpp @@ -143,6 +143,20 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qco::allocSinkPair)})); /// @} +/// \name QCToQCO/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P( + QCCtrlOpTest, QCToQCOTest, + testing::Values(QCToQCOTestCase{"CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(qco::ctrlTwo)}, + QCToQCOTestCase{"CtrlTwoMixed", + MQT_NAMED_BUILDER(qc::ctrlTwoMixed), + MQT_NAMED_BUILDER(qco::ctrlTwoMixed)}, + QCToQCOTestCase{"CtrlInvTwo", + MQT_NAMED_BUILDER(qc::ctrlInvTwo), + MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); +/// @} + /// \name QCToQCO/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( @@ -151,10 +165,13 @@ INSTANTIATE_TEST_SUITE_P( // iSWAP cannot be inverted with current canonicalization QCToQCOTestCase{"InverseiSWAP", MQT_NAMED_BUILDER(qc::inverseIswap), MQT_NAMED_BUILDER(qco::inverseIswap)}, - QCToQCOTestCase{ - "InverseMultipleControllediSWAP", - MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), - MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)})); + QCToQCOTestCase{"InverseMultipleControllediSWAP", + MQT_NAMED_BUILDER(qc::inverseMultipleControlledIswap), + MQT_NAMED_BUILDER(qco::inverseMultipleControlledIswap)}, + QCToQCOTestCase{"InvTwo", MQT_NAMED_BUILDER(qc::invTwo), + MQT_NAMED_BUILDER(qco::invTwo)}, + QCToQCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(qc::invCtrlTwo), + MQT_NAMED_BUILDER(qco::ctrlInvTwo)})); /// @} /// \name QCToQCO/Operations/StandardGates/BarrierOp.cpp diff --git a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp index cd8bcf6073..6de8bf8483 100644 --- a/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp +++ b/mlir/unittests/Conversion/QCToQIR/test_qc_to_qir.cpp @@ -649,3 +649,11 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(qc::allocDeallocPair), MQT_NAMED_BUILDER(qir::emptyQIR)})); /// @} + +/// \name QCToQIR/Modifiers/CtrlOp.cpp +/// @{ +INSTANTIATE_TEST_SUITE_P(QCToQIRCtrlOpTest, QCToQIRTest, + testing::Values(QCToQIRTestCase{ + "NestedCtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(qir::ctrlTwo)})); +/// @} diff --git a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp index 97e4627363..4d0f56912b 100644 --- a/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp +++ b/mlir/unittests/Dialect/QC/IR/test_qc_ir.cpp @@ -117,26 +117,31 @@ TEST_F(QCTest, BuilderRejectsMixedStaticAndDynamicQubitAllocationModes) { /// @{ INSTANTIATE_TEST_SUITE_P( QCCtrlOpTest, QCTest, - testing::Values(QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCTestCase{"TripleNestedCtrl", - MQT_NAMED_BUILDER(tripleNestedCtrl), - MQT_NAMED_BUILDER(tripleControlledRxx)}, - QCTestCase{"CtrlInvSandwich", - MQT_NAMED_BUILDER(ctrlInvSandwich), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCTestCase{"DoubleNestedCtrlTwoQubits", - MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), - MQT_NAMED_BUILDER(fourControlledRxx)})); + testing::Values( + QCTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCTestCase{"TripleNestedCtrl", MQT_NAMED_BUILDER(tripleNestedCtrl), + MQT_NAMED_BUILDER(tripleControlledRxx)}, + QCTestCase{"CtrlInvSandwich", MQT_NAMED_BUILDER(ctrlInvSandwich), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCTestCase{"DoubleNestedCtrlTwoQubits", + MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), + MQT_NAMED_BUILDER(fourControlledRxx)}, + QCTestCase{"NestedCtrlTwo", MQT_NAMED_BUILDER(nestedCtrlTwo), + MQT_NAMED_BUILDER(ctrlTwo)})); /// @} /// \name QC/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( QCInvOpTest, QCTest, - testing::Values(QCTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), + testing::Values(QCTestCase{"EmptyInv", MQT_NAMED_BUILDER(emptyInv), + MQT_NAMED_BUILDER(rxx)}, + QCTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), MQT_NAMED_BUILDER(rxx)}, QCTestCase{"TripleNestedInv", MQT_NAMED_BUILDER(tripleNestedInv), diff --git a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp index 0e5d53783b..b47c9f97a7 100644 --- a/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp +++ b/mlir/unittests/Dialect/QC/Translation/test_quantum_computation_translation.cpp @@ -418,9 +418,18 @@ INSTANTIATE_TEST_SUITE_P( "BarrierMultipleQubits", MQT_NAMED_BUILDER(qc::barrierMultipleQubits), MQT_NAMED_BUILDER(mlir::qc::barrierMultipleQubits)}, + QuantumComputationTranslationTestCase{ + "CtrlTwo", MQT_NAMED_BUILDER(qc::ctrlTwo), + MQT_NAMED_BUILDER(mlir::qc::ctrlTwo)}, + QuantumComputationTranslationTestCase{ + "CtrlTwoMixed", MQT_NAMED_BUILDER(qc::ctrlTwoMixed), + MQT_NAMED_BUILDER(mlir::qc::ctrlTwoMixed)}, QuantumComputationTranslationTestCase{ "SimpleIf", MQT_NAMED_BUILDER(qc::simpleIf), MQT_NAMED_BUILDER(mlir::qc::simpleIf)}, + QuantumComputationTranslationTestCase{ + "IfTwoQubits", MQT_NAMED_BUILDER(qc::ifTwoQubits), + MQT_NAMED_BUILDER(mlir::qc::ifTwoQubits)}, QuantumComputationTranslationTestCase{ "IfElse", MQT_NAMED_BUILDER(qc::ifElse), MQT_NAMED_BUILDER(mlir::qc::ifElse)})); diff --git a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp index 413f29336d..883c7d32d2 100644 --- a/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp +++ b/mlir/unittests/Dialect/QCO/IR/test_qco_ir.cpp @@ -221,33 +221,40 @@ INSTANTIATE_TEST_SUITE_P( /// @{ INSTANTIATE_TEST_SUITE_P( QCOCtrlOpTest, QCOTest, - testing::Values(QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), - MQT_NAMED_BUILDER(rxx)}, - QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCOTestCase{"TripleNestedCtrl", - MQT_NAMED_BUILDER(tripleNestedCtrl), - MQT_NAMED_BUILDER(tripleControlledRxx)}, - QCOTestCase{"CtrlInvSandwich", - MQT_NAMED_BUILDER(ctrlInvSandwich), - MQT_NAMED_BUILDER(multipleControlledRxx)}, - QCOTestCase{"DoubleNestedCtrlTwoQubits", - MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), - MQT_NAMED_BUILDER(fourControlledRxx)})); + testing::Values( + QCOTestCase{"TrivialCtrl", MQT_NAMED_BUILDER(trivialCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"EmptyCtrl", MQT_NAMED_BUILDER(emptyCtrl), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"NestedCtrl", MQT_NAMED_BUILDER(nestedCtrl), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCOTestCase{"TripleNestedCtrl", MQT_NAMED_BUILDER(tripleNestedCtrl), + MQT_NAMED_BUILDER(tripleControlledRxx)}, + QCOTestCase{"CtrlInvSandwich", MQT_NAMED_BUILDER(ctrlInvSandwich), + MQT_NAMED_BUILDER(multipleControlledRxx)}, + QCOTestCase{"DoubleNestedCtrlTwoQubits", + MQT_NAMED_BUILDER(doubleNestedCtrlTwoQubits), + MQT_NAMED_BUILDER(fourControlledRxx)}, + QCOTestCase{"NestedCtrlTwo", MQT_NAMED_BUILDER(nestedCtrlTwo), + MQT_NAMED_BUILDER(ctrlTwo)})); /// @} /// \name QCO/Modifiers/InvOp.cpp /// @{ INSTANTIATE_TEST_SUITE_P( QCOInvOpTest, QCOTest, - testing::Values(QCOTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), + testing::Values(QCOTestCase{"EmptyInv", MQT_NAMED_BUILDER(emptyInv), + MQT_NAMED_BUILDER(rxx)}, + QCOTestCase{"NestedInv", MQT_NAMED_BUILDER(nestedInv), MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"TripleNestedInv", MQT_NAMED_BUILDER(tripleNestedInv), MQT_NAMED_BUILDER(rxx)}, QCOTestCase{"InvControlSandwich", MQT_NAMED_BUILDER(invCtrlSandwich), - MQT_NAMED_BUILDER(singleControlledRxx)})); + MQT_NAMED_BUILDER(singleControlledRxx)}, + QCOTestCase{"InvCtrlTwo", MQT_NAMED_BUILDER(invCtrlTwo), + MQT_NAMED_BUILDER(ctrlInvTwo)})); /// @} /// \name QCO/Operations/StandardGates/BarrierOp.cpp @@ -959,6 +966,10 @@ INSTANTIATE_TEST_SUITE_P( MQT_NAMED_BUILDER(inverseMultipleControlledX), MQT_NAMED_BUILDER(multipleControlledX)}, QCOTestCase{"TwoX", MQT_NAMED_BUILDER(twoX), + MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"ControlledTwoX", MQT_NAMED_BUILDER(controlledTwoX), + MQT_NAMED_BUILDER(emptyQCO)}, + QCOTestCase{"InverseTwoX", MQT_NAMED_BUILDER(inverseTwoX), MQT_NAMED_BUILDER(emptyQCO)})); /// @} diff --git a/mlir/unittests/programs/qc_programs.cpp b/mlir/unittests/programs/qc_programs.cpp index 373452252a..232b8cbdee 100644 --- a/mlir/unittests/programs/qc_programs.cpp +++ b/mlir/unittests/programs/qc_programs.cpp @@ -62,7 +62,7 @@ void staticQubitsWithCtrl(QCProgramBuilder& b) { void staticQubitsWithInv(QCProgramBuilder& b) { auto q0 = b.staticQubit(0); - b.inv([&]() { b.t(q0); }); + b.inv({q0}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void staticQubitsWithDuplicates(QCProgramBuilder& b) { @@ -75,7 +75,7 @@ void staticQubitsWithDuplicates(QCProgramBuilder& b) { b.p(std::numbers::pi / 2., q1a); b.rzz(0.123, q0b, q1b); b.cx(q0b, q1b); - b.inv([&]() { b.t(q0a); }); + b.inv({q0a}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void staticQubitsCanonical(QCProgramBuilder& b) { @@ -86,7 +86,7 @@ void staticQubitsCanonical(QCProgramBuilder& b) { b.p(std::numbers::pi / 2., q1); b.rzz(0.123, q0, q1); b.cx(q0, q1); - b.inv([&]() { b.t(q0); }); + b.inv({q0}, [&](ValueRange qubits) { b.t(qubits[0]); }); } void allocDeallocPair(QCProgramBuilder& b) { @@ -194,7 +194,8 @@ void multipleControlledGlobalPhase(QCProgramBuilder& b) { void nestedControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.ctrl(q[0], [&] { b.cgphase(0.123, q[1]); }); + b.ctrl(q[0], {q[1]}, + [&](ValueRange targets) { b.cgphase(0.123, targets[0]); }); } void trivialControlledGlobalPhase(QCProgramBuilder& b) { @@ -203,12 +204,13 @@ void trivialControlledGlobalPhase(QCProgramBuilder& b) { } void inverseGlobalPhase(QCProgramBuilder& b) { - b.inv([&]() { b.gphase(-0.123); }); + b.inv({}, [&](ValueRange /*qubits*/) { b.gphase(-0.123); }); } void inverseMultipleControlledGlobalPhase(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcgphase(-0.123, {q[0], q[1], q[2]}); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcgphase(-0.123, qubits); }); } void identity(QCProgramBuilder& b) { @@ -228,7 +230,8 @@ void multipleControlledIdentity(QCProgramBuilder& b) { void nestedControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.ctrl(q[2], [&] { b.cid(q[1], q[0]); }); + b.ctrl(q[2], {q[0], q[1]}, + [&](ValueRange targets) { b.cid(targets[1], targets[0]); }); } void trivialControlledIdentity(QCProgramBuilder& b) { @@ -238,12 +241,13 @@ void trivialControlledIdentity(QCProgramBuilder& b) { void inverseIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.id(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.id(qubits[0]); }); } void inverseMultipleControlledIdentity(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcid({q[2], q[1]}, q[0]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcid({qubits[0], qubits[1]}, qubits[2]); }); } void x(QCProgramBuilder& b) { @@ -263,7 +267,8 @@ void multipleControlledX(QCProgramBuilder& b) { void nestedControlledX(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cx(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cx(targets[0], targets[1]); }); } void trivialControlledX(QCProgramBuilder& b) { @@ -282,12 +287,13 @@ void repeatedControlledX(QCProgramBuilder& b) { void inverseX(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.x(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.x(qubits[0]); }); } void inverseMultipleControlledX(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcx({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcx({qubits[0], qubits[1]}, qubits[2]); }); } void y(QCProgramBuilder& b) { @@ -307,7 +313,8 @@ void multipleControlledY(QCProgramBuilder& b) { void nestedControlledY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cy(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cy(targets[0], targets[1]); }); } void trivialControlledY(QCProgramBuilder& b) { @@ -317,12 +324,13 @@ void trivialControlledY(QCProgramBuilder& b) { void inverseY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.y(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.y(qubits[0]); }); } void inverseMultipleControlledY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcy({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcy({qubits[0], qubits[1]}, qubits[2]); }); } void z(QCProgramBuilder& b) { @@ -342,7 +350,8 @@ void multipleControlledZ(QCProgramBuilder& b) { void nestedControlledZ(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cz(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cz(targets[0], targets[1]); }); } void trivialControlledZ(QCProgramBuilder& b) { @@ -352,12 +361,13 @@ void trivialControlledZ(QCProgramBuilder& b) { void inverseZ(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.z(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.z(qubits[0]); }); } void inverseMultipleControlledZ(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcz({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcz({qubits[0], qubits[1]}, qubits[2]); }); } void h(QCProgramBuilder& b) { @@ -377,7 +387,8 @@ void multipleControlledH(QCProgramBuilder& b) { void nestedControlledH(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ch(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ch(targets[0], targets[1]); }); } void trivialControlledH(QCProgramBuilder& b) { @@ -387,12 +398,13 @@ void trivialControlledH(QCProgramBuilder& b) { void inverseH(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.h(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.h(qubits[0]); }); } void inverseMultipleControlledH(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mch({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mch({qubits[0], qubits[1]}, qubits[2]); }); } void hWithoutRegister(QCProgramBuilder& b) { @@ -417,7 +429,8 @@ void multipleControlledS(QCProgramBuilder& b) { void nestedControlledS(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cs(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cs(targets[0], targets[1]); }); } void trivialControlledS(QCProgramBuilder& b) { @@ -427,12 +440,13 @@ void trivialControlledS(QCProgramBuilder& b) { void inverseS(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.s(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.s(qubits[0]); }); } void inverseMultipleControlledS(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcs({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcs({qubits[0], qubits[1]}, qubits[2]); }); } void sdg(QCProgramBuilder& b) { @@ -452,7 +466,8 @@ void multipleControlledSdg(QCProgramBuilder& b) { void nestedControlledSdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csdg(targets[0], targets[1]); }); } void trivialControlledSdg(QCProgramBuilder& b) { @@ -462,12 +477,13 @@ void trivialControlledSdg(QCProgramBuilder& b) { void inverseSdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sdg(qubits[0]); }); } void inverseMultipleControlledSdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcsdg({qubits[0], qubits[1]}, qubits[2]); }); } void t_(QCProgramBuilder& b) { @@ -487,7 +503,8 @@ void multipleControlledT(QCProgramBuilder& b) { void nestedControlledT(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ct(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ct(targets[0], targets[1]); }); } void trivialControlledT(QCProgramBuilder& b) { @@ -497,12 +514,13 @@ void trivialControlledT(QCProgramBuilder& b) { void inverseT(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.t(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.t(qubits[0]); }); } void inverseMultipleControlledT(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mct({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mct({qubits[0], qubits[1]}, qubits[2]); }); } void tdg(QCProgramBuilder& b) { @@ -522,7 +540,8 @@ void multipleControlledTdg(QCProgramBuilder& b) { void nestedControlledTdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.ctdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.ctdg(targets[0], targets[1]); }); } void trivialControlledTdg(QCProgramBuilder& b) { @@ -532,12 +551,13 @@ void trivialControlledTdg(QCProgramBuilder& b) { void inverseTdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.tdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.tdg(qubits[0]); }); } void inverseMultipleControlledTdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mctdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mctdg({qubits[0], qubits[1]}, qubits[2]); }); } void sx(QCProgramBuilder& b) { @@ -557,7 +577,8 @@ void multipleControlledSx(QCProgramBuilder& b) { void nestedControlledSx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csx(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csx(targets[0], targets[1]); }); } void trivialControlledSx(QCProgramBuilder& b) { @@ -567,12 +588,13 @@ void trivialControlledSx(QCProgramBuilder& b) { void inverseSx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sx(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sx(qubits[0]); }); } void inverseMultipleControlledSx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsx({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, + [&](ValueRange qubits) { b.mcsx({qubits[0], qubits[1]}, qubits[2]); }); } void sxdg(QCProgramBuilder& b) { @@ -592,7 +614,8 @@ void multipleControlledSxdg(QCProgramBuilder& b) { void nestedControlledSxdg(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.csxdg(reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.csxdg(targets[0], targets[1]); }); } void trivialControlledSxdg(QCProgramBuilder& b) { @@ -602,12 +625,14 @@ void trivialControlledSxdg(QCProgramBuilder& b) { void inverseSxdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.sxdg(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.sxdg(qubits[0]); }); } void inverseMultipleControlledSxdg(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcsxdg({q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcsxdg({qubits[0], qubits[1]}, qubits[2]); + }); } void rx(QCProgramBuilder& b) { @@ -627,7 +652,8 @@ void multipleControlledRx(QCProgramBuilder& b) { void nestedControlledRx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.crx(0.123, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.crx(0.123, targets[0], targets[1]); }); } void trivialControlledRx(QCProgramBuilder& b) { @@ -637,12 +663,14 @@ void trivialControlledRx(QCProgramBuilder& b) { void inverseRx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.rx(-0.123, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.rx(-0.123, qubits[0]); }); } void inverseMultipleControlledRx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcrx(-0.123, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcrx(-0.123, {qubits[0], qubits[1]}, qubits[2]); + }); } void ry(QCProgramBuilder& b) { @@ -662,7 +690,8 @@ void multipleControlledRy(QCProgramBuilder& b) { void nestedControlledRy(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cry(0.456, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cry(0.456, targets[0], targets[1]); }); } void trivialControlledRy(QCProgramBuilder& b) { @@ -672,12 +701,14 @@ void trivialControlledRy(QCProgramBuilder& b) { void inverseRy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.ry(-0.456, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.ry(-0.456, qubits[0]); }); } void inverseMultipleControlledRy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcry(-0.456, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcry(-0.456, {qubits[0], qubits[1]}, qubits[2]); + }); } void rz(QCProgramBuilder& b) { @@ -697,7 +728,8 @@ void multipleControlledRz(QCProgramBuilder& b) { void nestedControlledRz(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.crz(0.789, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.crz(0.789, targets[0], targets[1]); }); } void trivialControlledRz(QCProgramBuilder& b) { @@ -707,12 +739,14 @@ void trivialControlledRz(QCProgramBuilder& b) { void inverseRz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.rz(-0.789, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.rz(-0.789, qubits[0]); }); } void inverseMultipleControlledRz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcrz(-0.789, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcrz(-0.789, {qubits[0], qubits[1]}, qubits[2]); + }); } void p(QCProgramBuilder& b) { @@ -732,7 +766,8 @@ void multipleControlledP(QCProgramBuilder& b) { void nestedControlledP(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cp(0.123, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, + [&](ValueRange targets) { b.cp(0.123, targets[0], targets[1]); }); } void trivialControlledP(QCProgramBuilder& b) { @@ -742,12 +777,14 @@ void trivialControlledP(QCProgramBuilder& b) { void inverseP(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.p(-0.123, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.p(-0.123, qubits[0]); }); } void inverseMultipleControlledP(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcp(-0.123, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcp(-0.123, {qubits[0], qubits[1]}, qubits[2]); + }); } void r(QCProgramBuilder& b) { @@ -767,7 +804,9 @@ void multipleControlledR(QCProgramBuilder& b) { void nestedControlledR(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cr(0.123, 0.456, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cr(0.123, 0.456, targets[0], targets[1]); + }); } void trivialControlledR(QCProgramBuilder& b) { @@ -777,12 +816,14 @@ void trivialControlledR(QCProgramBuilder& b) { void inverseR(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.r(-0.123, 0.456, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.r(-0.123, 0.456, qubits[0]); }); } void inverseMultipleControlledR(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcr(-0.123, 0.456, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcr(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2]); + }); } void u2(QCProgramBuilder& b) { @@ -802,7 +843,9 @@ void multipleControlledU2(QCProgramBuilder& b) { void nestedControlledU2(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cu2(0.234, 0.567, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cu2(0.234, 0.567, targets[0], targets[1]); + }); } void trivialControlledU2(QCProgramBuilder& b) { @@ -813,13 +856,16 @@ void trivialControlledU2(QCProgramBuilder& b) { void inverseU2(QCProgramBuilder& b) { constexpr double pi = std::numbers::pi; auto q = b.allocQubitRegister(1); - b.inv([&]() { b.u2(-0.567 + pi, -0.234 - pi, q[0]); }); + b.inv(q[0], + [&](ValueRange qubits) { b.u2(-0.567 + pi, -0.234 - pi, qubits[0]); }); } void inverseMultipleControlledU2(QCProgramBuilder& b) { constexpr double pi = std::numbers::pi; auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcu2(-0.567 + pi, -0.234 - pi, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcu2(-0.567 + pi, -0.234 - pi, {qubits[0], qubits[1]}, qubits[2]); + }); } void u(QCProgramBuilder& b) { @@ -839,7 +885,9 @@ void multipleControlledU(QCProgramBuilder& b) { void nestedControlledU(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); - b.ctrl(reg[0], [&] { b.cu(0.1, 0.2, 0.3, reg[1], reg[2]); }); + b.ctrl(reg[0], {reg[1], reg[2]}, [&](ValueRange targets) { + b.cu(0.1, 0.2, 0.3, targets[0], targets[1]); + }); } void trivialControlledU(QCProgramBuilder& b) { @@ -849,12 +897,14 @@ void trivialControlledU(QCProgramBuilder& b) { void inverseU(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.u(-0.1, -0.3, -0.2, q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.u(-0.1, -0.3, -0.2, qubits[0]); }); } void inverseMultipleControlledU(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { b.mcu(-0.1, -0.3, -0.2, {q[0], q[1]}, q[2]); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.mcu(-0.1, -0.3, -0.2, {qubits[0], qubits[1]}, qubits[2]); + }); } void swap(QCProgramBuilder& b) { @@ -874,7 +924,9 @@ void multipleControlledSwap(QCProgramBuilder& b) { void nestedControlledSwap(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cswap(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cswap(targets[0], targets[1], targets[2]); + }); } void trivialControlledSwap(QCProgramBuilder& b) { @@ -884,12 +936,14 @@ void trivialControlledSwap(QCProgramBuilder& b) { void inverseSwap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.swap(q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.swap(qubits[0], qubits[1]); }); } void inverseMultipleControlledSwap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcswap({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcswap({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void iswap(QCProgramBuilder& b) { @@ -909,7 +963,9 @@ void multipleControlledIswap(QCProgramBuilder& b) { void nestedControlledIswap(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.ciswap(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.ciswap(targets[0], targets[1], targets[2]); + }); } void trivialControlledIswap(QCProgramBuilder& b) { @@ -919,12 +975,15 @@ void trivialControlledIswap(QCProgramBuilder& b) { void inverseIswap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.iswap(q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.iswap(qubits[0], qubits[1]); }); } void inverseMultipleControlledIswap(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mciswap({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mciswap({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void dcx(QCProgramBuilder& b) { @@ -944,7 +1003,9 @@ void multipleControlledDcx(QCProgramBuilder& b) { void nestedControlledDcx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cdcx(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cdcx(targets[0], targets[1], targets[2]); + }); } void trivialControlledDcx(QCProgramBuilder& b) { @@ -954,12 +1015,14 @@ void trivialControlledDcx(QCProgramBuilder& b) { void inverseDcx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.dcx(q[1], q[0]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.dcx(qubits[1], qubits[0]); }); } void inverseMultipleControlledDcx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcdcx({q[0], q[1]}, q[3], q[2]); }); + b.inv({q[0], q[1], q[3], q[2]}, [&](ValueRange qubits) { + b.mcdcx({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void ecr(QCProgramBuilder& b) { @@ -979,7 +1042,9 @@ void multipleControlledEcr(QCProgramBuilder& b) { void nestedControlledEcr(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cecr(reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cecr(targets[0], targets[1], targets[2]); + }); } void trivialControlledEcr(QCProgramBuilder& b) { @@ -989,12 +1054,14 @@ void trivialControlledEcr(QCProgramBuilder& b) { void inverseEcr(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.ecr(q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { b.ecr(qubits[0], qubits[1]); }); } void inverseMultipleControlledEcr(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcecr({q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcecr({qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rxx(QCProgramBuilder& b) { @@ -1014,7 +1081,9 @@ void multipleControlledRxx(QCProgramBuilder& b) { void nestedControlledRxx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crxx(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crxx(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRxx(QCProgramBuilder& b) { @@ -1024,18 +1093,22 @@ void trivialControlledRxx(QCProgramBuilder& b) { void inverseRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rxx(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rxx(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrxx(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrxx(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void tripleControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(5); b.mcrxx(0.123, {q[0], q[1], q[2]}, q[3], q[4]); } + void fourControlledRxx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(6); b.mcrxx(0.123, {q[0], q[1], q[2], q[3]}, q[4], q[5]); @@ -1058,7 +1131,9 @@ void multipleControlledRyy(QCProgramBuilder& b) { void nestedControlledRyy(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cryy(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cryy(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRyy(QCProgramBuilder& b) { @@ -1068,12 +1143,15 @@ void trivialControlledRyy(QCProgramBuilder& b) { void inverseRyy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.ryy(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.ryy(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRyy(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcryy(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcryy(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rzx(QCProgramBuilder& b) { @@ -1093,7 +1171,9 @@ void multipleControlledRzx(QCProgramBuilder& b) { void nestedControlledRzx(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crzx(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crzx(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRzx(QCProgramBuilder& b) { @@ -1103,12 +1183,15 @@ void trivialControlledRzx(QCProgramBuilder& b) { void inverseRzx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rzx(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rzx(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRzx(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrzx(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrzx(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void rzz(QCProgramBuilder& b) { @@ -1128,7 +1211,9 @@ void multipleControlledRzz(QCProgramBuilder& b) { void nestedControlledRzz(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.crzz(0.123, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.crzz(0.123, targets[0], targets[1], targets[2]); + }); } void trivialControlledRzz(QCProgramBuilder& b) { @@ -1138,12 +1223,15 @@ void trivialControlledRzz(QCProgramBuilder& b) { void inverseRzz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.rzz(-0.123, q[0], q[1]); }); + b.inv({q[0], q[1]}, + [&](ValueRange qubits) { b.rzz(-0.123, qubits[0], qubits[1]); }); } void inverseMultipleControlledRzz(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcrzz(-0.123, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcrzz(-0.123, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void xxPlusYY(QCProgramBuilder& b) { @@ -1163,7 +1251,9 @@ void multipleControlledXxPlusYY(QCProgramBuilder& b) { void nestedControlledXxPlusYY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cxx_plus_yy(0.123, 0.456, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cxx_plus_yy(0.123, 0.456, targets[0], targets[1], targets[2]); + }); } void trivialControlledXxPlusYY(QCProgramBuilder& b) { @@ -1173,12 +1263,16 @@ void trivialControlledXxPlusYY(QCProgramBuilder& b) { void inverseXxPlusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.xx_plus_yy(-0.123, 0.456, q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.xx_plus_yy(-0.123, 0.456, qubits[0], qubits[1]); + }); } void inverseMultipleControlledXxPlusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcxx_plus_yy(-0.123, 0.456, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcxx_plus_yy(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2], qubits[3]); + }); } void xxMinusYY(QCProgramBuilder& b) { @@ -1198,7 +1292,9 @@ void multipleControlledXxMinusYY(QCProgramBuilder& b) { void nestedControlledXxMinusYY(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(4); - b.ctrl(reg[0], [&] { b.cxx_minus_yy(0.123, 0.456, reg[1], reg[2], reg[3]); }); + b.ctrl(reg[0], {reg[1], reg[2], reg[3]}, [&](ValueRange targets) { + b.cxx_minus_yy(0.123, 0.456, targets[0], targets[1], targets[2]); + }); } void trivialControlledXxMinusYY(QCProgramBuilder& b) { @@ -1208,12 +1304,17 @@ void trivialControlledXxMinusYY(QCProgramBuilder& b) { void inverseXxMinusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.xx_minus_yy(-0.123, 0.456, q[0], q[1]); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.xx_minus_yy(-0.123, 0.456, qubits[0], qubits[1]); + }); } void inverseMultipleControlledXxMinusYY(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.inv([&]() { b.mcxx_minus_yy(-0.123, 0.456, {q[0], q[1]}, q[2], q[3]); }); + b.inv({q[0], q[1], q[2], q[3]}, [&](ValueRange qubits) { + b.mcxx_minus_yy(-0.123, 0.456, {qubits[0], qubits[1]}, qubits[2], + qubits[3]); + }); } void barrier(QCProgramBuilder& b) { @@ -1233,74 +1334,165 @@ void barrierMultipleQubits(QCProgramBuilder& b) { void singleControlledBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.ctrl(q[1], [&] { b.barrier(q[0]); }); + b.ctrl(q[1], q[0], [&](ValueRange targets) { b.barrier(targets[0]); }); } void inverseBarrier(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); - b.inv([&]() { b.barrier(q[0]); }); + b.inv(q[0], [&](ValueRange qubits) { b.barrier(qubits[0]); }); } void trivialCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.ctrl({}, [&]() { b.rxx(0.123, q[0], q[1]); }); + b.ctrl({}, {q[0], q[1]}, + [&](ValueRange targets) { b.rxx(0.123, targets[0], targets[1]); }); +} + +void emptyCtrl(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.rxx(0.123, q[0], q[1]); + b.ctrl({q[0]}, {q[1]}, [&](ValueRange /*targets*/) {}); } void nestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.ctrl(q[0], [&]() { b.ctrl(q[1], [&]() { b.rxx(0.123, q[2], q[3]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); } void tripleNestedCtrl(QCProgramBuilder& b) { auto q = b.allocQubitRegister(5); - b.ctrl(q[0], [&]() { - b.ctrl(q[1], [&]() { b.ctrl(q[2], [&]() { b.rxx(0.123, q[3], q[4]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3], q[4]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2], targets[3]}, + [&](ValueRange innerTargets) { + b.ctrl(innerTargets[0], {innerTargets[1], innerTargets[2]}, + [&](ValueRange innerInnerTargets) { + b.rxx(0.123, innerInnerTargets[0], innerInnerTargets[1]); + }); + }); }); } void doubleNestedCtrlTwoQubits(QCProgramBuilder& b) { auto q = b.allocQubitRegister(6); - b.ctrl({q[0], q[1]}, - [&]() { b.ctrl({q[2], q[3]}, [&]() { b.rxx(0.123, q[4], q[5]); }); }); + b.ctrl({q[0], q[1]}, {q[2], q[3], q[4], q[5]}, [&](ValueRange targets) { + b.ctrl({targets[0], targets[1]}, {targets[2], targets[3]}, + [&](ValueRange innerTargets) { + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); } void ctrlInvSandwich(QCProgramBuilder& b) { auto q = b.allocQubitRegister(4); - b.ctrl(q[0], [&]() { - b.inv([&]() { b.ctrl(q[1], [&]() { b.rxx(-0.123, q[2], q[3]); }); }); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.inv(targets, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange innerTargets) { + b.rxx(-0.123, innerTargets[0], innerTargets[1]); + }); + }); + }); +} + +void ctrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + b.x(targets[0]); + b.rxx(0.123, targets[0], targets[1]); }); } +void ctrlTwoMixed(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + b.cx(targets[0], targets[1]); + b.rxx(0.123, targets[0], targets[1]); + }); +} + +void nestedCtrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + b.ctrl(targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + b.x(innerTargets[0]); + b.rxx(0.123, innerTargets[0], innerTargets[1]); + }); + }); +} + +void ctrlInvTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], {q[1], q[2]}, [&](ValueRange targets) { + b.inv(targets, [&](ValueRange qubits) { + b.x(qubits[0]); + b.rxx(0.123, qubits[0], qubits[1]); + }); + }); +} + +void emptyInv(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.rxx(0.123, q[0], q[1]); + b.inv({q[0], q[1]}, [&](ValueRange /*targets*/) {}); +} + void nestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv([&]() { b.inv([&]() { b.rxx(0.123, q[0], q[1]); }); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.inv(qubits, [&](ValueRange innerQubits) { + b.rxx(0.123, innerQubits[0], innerQubits[1]); + }); + }); } void tripleNestedInv(QCProgramBuilder& b) { auto q = b.allocQubitRegister(2); - b.inv( - [&]() { b.inv([&]() { b.inv([&]() { b.rxx(-0.123, q[0], q[1]); }); }); }); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.inv(qubits, [&](ValueRange innerQubits) { + b.inv(innerQubits, [&](ValueRange innerInnerQubits) { + b.rxx(-0.123, innerInnerQubits[0], innerInnerQubits[1]); + }); + }); + }); } void invCtrlSandwich(QCProgramBuilder& b) { auto q = b.allocQubitRegister(3); - b.inv([&]() { - b.ctrl(q[0], [&]() { b.inv([&]() { b.rxx(0.123, q[1], q[2]); }); }); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange targets) { + b.inv({targets[0], targets[1]}, [&](ValueRange innerQubits) { + b.rxx(0.123, innerQubits[0], innerQubits[1]); + }); + }); }); } -void simpleIf(QCProgramBuilder& b) { - auto q = b.allocQubitRegister(1); - b.h(q[0]); - auto cond = b.measure(q[0]); - b.scfIf(cond, [&] { b.x(q[0]); }); +void invTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + b.x(qubits[0]); + b.rxx(0.123, qubits[0], qubits[1]); + }); } -void ifElse(QCProgramBuilder& b) { +void invCtrlTwo(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + b.ctrl(qubits[0], {qubits[1], qubits[2]}, [&](ValueRange targets) { + b.x(targets[0]); + b.rxx(0.123, targets[0], targets[1]); + }); + }); +} + +void simpleIf(QCProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.h(q[0]); auto cond = b.measure(q[0]); - b.scfIf(cond, [&] { b.x(q[0]); }, [&] { b.z(q[0]); }); + b.scfIf(cond, [&] { b.x(q[0]); }); } void ifTwoQubits(QCProgramBuilder& b) { @@ -1313,6 +1505,13 @@ void ifTwoQubits(QCProgramBuilder& b) { }); } +void ifElse(QCProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.h(q[0]); + auto cond = b.measure(q[0]); + b.scfIf(cond, [&] { b.x(q[0]); }, [&] { b.z(q[0]); }); +} + void nestedIfOpForLoop(QCProgramBuilder& b) { auto reg = b.allocQubitRegister(3); auto q0 = b.allocQubit(); @@ -1395,7 +1594,7 @@ void nestedForLoopCtrlOpWithSeparateQubit(QCProgramBuilder& b) { b.scfFor(0, 3, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(control, [&] { b.x(q0); }); + b.ctrl(control, q0, [&](ValueRange targets) { b.x(targets[0]); }); }); } @@ -1405,7 +1604,7 @@ void nestedForLoopCtrlOpWithExtractedQubit(QCProgramBuilder& b) { b.scfFor(1, 4, 1, [&](Value iv) { auto q0 = b.memrefLoad(reg.value, iv); b.h(q0); - b.ctrl(reg[0], [&] { b.x(q0); }); + b.ctrl(reg[0], q0, [&](ValueRange targets) { b.x(targets[0]); }); }); } diff --git a/mlir/unittests/programs/qc_programs.h b/mlir/unittests/programs/qc_programs.h index e6569f7648..dbf855b982 100644 --- a/mlir/unittests/programs/qc_programs.h +++ b/mlir/unittests/programs/qc_programs.h @@ -814,6 +814,9 @@ void inverseBarrier(QCProgramBuilder& b); /// Creates a circuit with a trivial ctrl modifier. void trivialCtrl(QCProgramBuilder& b); +/// Creates a circuit with an empty ctrl modifier. +void emptyCtrl(QCProgramBuilder& b); + /// Creates a circuit with nested ctrl modifiers. void nestedCtrl(QCProgramBuilder& b); @@ -826,8 +829,25 @@ void doubleNestedCtrlTwoQubits(QCProgramBuilder& b); /// Creates a circuit with control modifiers interleaved by an inverse modifier. void ctrlInvSandwich(QCProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QCProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QCProgramBuilder& b); + +/// Creates a circuit with nested control modifiers applied to two gates. +void nestedCtrlTwo(QCProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to a inverse modifier +/// applied to two gates. +void ctrlInvTwo(QCProgramBuilder& b); + // --- InvOp ---------------------------------------------------------------- // +/// Creates a circuit with an empty inverse modifier. +void emptyInv(QCProgramBuilder& b); + /// Creates a circuit with nested inverse modifiers. void nestedInv(QCProgramBuilder& b); @@ -837,17 +857,24 @@ void tripleNestedInv(QCProgramBuilder& b); /// Creates a circuit with inverse modifiers interleaved by a control modifier. void invCtrlSandwich(QCProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to two gates. +void invTwo(QCProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a control modifier +/// applied to two gates. +void invCtrlTwo(QCProgramBuilder& b); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. void simpleIf(QCProgramBuilder& b); -/// Creates a circuit with an if operation with an else branch. -void ifElse(QCProgramBuilder& b); - /// Creates a circuit with an if operation with two qubits. void ifTwoQubits(QCProgramBuilder& b); +/// Creates a circuit with an if operation with an else branch. +void ifElse(QCProgramBuilder& b); + /// Creates a circuit with an if operation with a nested for operation with /// a register. void nestedIfOpForLoop(QCProgramBuilder& b); diff --git a/mlir/unittests/programs/qco_programs.cpp b/mlir/unittests/programs/qco_programs.cpp index 0ad96fbb10..523f071f8a 100644 --- a/mlir/unittests/programs/qco_programs.cpp +++ b/mlir/unittests/programs/qco_programs.cpp @@ -301,6 +301,24 @@ void twoX(QCOProgramBuilder& b) { q[0] = b.x(q[0]); } +void controlledTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.ctrl(q[0], q[1], [&](ValueRange targets) { + auto q = b.x(targets[0]); + q = b.x(q); + return SmallVector{q}; + }); +} + +void inverseTwoX(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(1); + b.inv(q[0], [&](ValueRange qubits) { + auto q = b.x(qubits[0]); + q = b.x(q); + return SmallVector{q}; + }); +} + void y(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); b.y(q[0]); @@ -1936,6 +1954,12 @@ void trivialCtrl(QCOProgramBuilder& b) { }); } +void emptyCtrl(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); + b.ctrl(q[0], q[1], [&](ValueRange targets) { return targets; }); +} + void nestedCtrl(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(4); b.ctrl({q[0]}, {q[1], q[2], q[3]}, [&](ValueRange targets) { @@ -2003,6 +2027,63 @@ void ctrlInvSandwich(QCOProgramBuilder& b) { }); } +void ctrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void ctrlTwoMixed(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl({q[0], q[1]}, {q[2], q[3]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + std::tie(i0, i1) = b.cx(i0, i1); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void nestedCtrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.ctrl(q[0], {q[1], q[2], q[3]}, [&](ValueRange targets) { + const auto& [controlsOut, targetsOut] = b.ctrl( + targets[0], {targets[1], targets[2]}, [&](ValueRange innerTargets) { + auto i0 = innerTargets[0]; + auto i1 = innerTargets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(llvm::concat(controlsOut, targetsOut)); + }); +} + +void ctrlInvTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.ctrl(q[0], {q[1], q[2]}, [&](ValueRange targets) { + auto inner = b.inv(targets, [&](ValueRange qubits) { + auto i0 = qubits[0]; + auto i1 = qubits[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(inner); + }); +} + +void emptyInv(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + std::tie(q[0], q[1]) = b.rxx(0.123, q[0], q[1]); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { return qubits; }); +} + void nestedInv(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(2); b.inv({q[0], q[1]}, [&](ValueRange qubits) { @@ -2046,6 +2127,32 @@ void invCtrlSandwich(QCOProgramBuilder& b) { }); } +void invTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(2); + b.inv({q[0], q[1]}, [&](ValueRange qubits) { + auto i0 = qubits[0]; + auto i1 = qubits[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); +} + +void invCtrlTwo(QCOProgramBuilder& b) { + auto q = b.allocQubitRegister(3); + b.inv({q[0], q[1], q[2]}, [&](ValueRange qubits) { + const auto& [controlsOut, targetsOut] = + b.ctrl({qubits[0]}, {qubits[1], qubits[2]}, [&](ValueRange targets) { + auto i0 = targets[0]; + auto i1 = targets[1]; + i0 = b.x(i0); + std::tie(i0, i1) = b.rxx(0.123, i0, i1); + return SmallVector{i0, i1}; + }); + return llvm::to_vector(llvm::concat(controlsOut, targetsOut)); + }); +} + void simpleIf(QCOProgramBuilder& b) { auto q = b.allocQubitRegister(1); auto q0 = b.h(q[0]); diff --git a/mlir/unittests/programs/qco_programs.h b/mlir/unittests/programs/qco_programs.h index b4197c5a7f..f562cfff8a 100644 --- a/mlir/unittests/programs/qco_programs.h +++ b/mlir/unittests/programs/qco_programs.h @@ -167,9 +167,16 @@ void inverseX(QCOProgramBuilder& b); /// Creates a circuit with an inverse modifier applied to a controlled X gate. void inverseMultipleControlledX(QCOProgramBuilder& b); -/// Creates a circuit with two X gates in a row. +/// Creates a circuit with two subsequent X gates. void twoX(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two subsequent X gates. +void controlledTwoX(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to two subsequent X +/// gates. +void inverseTwoX(QCOProgramBuilder& b); + // --- YOp ------------------------------------------------------------------ // /// Creates a circuit with just a Y gate. @@ -960,6 +967,9 @@ void twoBarrier(QCOProgramBuilder& b); /// Creates a circuit with a trivial ctrl modifier. void trivialCtrl(QCOProgramBuilder& b); +/// Creates a circuit with an empty ctrl modifier. +void emptyCtrl(QCOProgramBuilder& b); + /// Creates a circuit with nested ctrl modifiers. void nestedCtrl(QCOProgramBuilder& b); @@ -972,8 +982,25 @@ void doubleNestedCtrlTwoQubits(QCOProgramBuilder& b); /// Creates a circuit with control modifiers interleaved by an inverse modifier. void ctrlInvSandwich(QCOProgramBuilder& b); +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QCOProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QCOProgramBuilder& b); + +/// Creates a circuit with nested control modifiers applied to two gates. +void nestedCtrlTwo(QCOProgramBuilder& b); + +/// Creates a circuit with a control modifier applied to an inverse modifier +/// applied to two gates. +void ctrlInvTwo(QCOProgramBuilder& b); + // --- InvOp ---------------------------------------------------------------- // +/// Creates a circuit with an empty inverse modifier. +void emptyInv(QCOProgramBuilder& b); + /// Creates a circuit with nested inverse modifiers. void nestedInv(QCOProgramBuilder& b); @@ -983,6 +1010,13 @@ void tripleNestedInv(QCOProgramBuilder& b); /// Creates a circuit with inverse modifiers interleaved by a control modifier. void invCtrlSandwich(QCOProgramBuilder& b); +/// Creates a circuit with an inverse modifier applied to two gates. +void invTwo(QCOProgramBuilder& b); + +/// Creates a circuit with an inverse modifier applied to a control modifier +/// applied to two gates. +void invCtrlTwo(QCOProgramBuilder& b); + // --- IfOp ---------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. diff --git a/mlir/unittests/programs/qir_programs.cpp b/mlir/unittests/programs/qir_programs.cpp index 6ae5023a09..68f209943f 100644 --- a/mlir/unittests/programs/qir_programs.cpp +++ b/mlir/unittests/programs/qir_programs.cpp @@ -605,4 +605,10 @@ void multipleControlledXxMinusYY(QIRProgramBuilder& b) { b.mcxx_minus_yy(0.123, 0.456, {q[0], q[1]}, q[2], q[3]); } +void ctrlTwo(QIRProgramBuilder& b) { + auto q = b.allocQubitRegister(4); + b.mcx({q[0], q[1]}, q[2]); + b.mcrxx(0.123, {q[0], q[1]}, q[2], q[3]); +} + } // namespace mlir::qir diff --git a/mlir/unittests/programs/qir_programs.h b/mlir/unittests/programs/qir_programs.h index 92f6c54078..86a7f7c807 100644 --- a/mlir/unittests/programs/qir_programs.h +++ b/mlir/unittests/programs/qir_programs.h @@ -422,4 +422,9 @@ void singleControlledXxMinusYY(QIRProgramBuilder& b); /// Creates a circuit with a multi-controlled XXMinusYY gate. void multipleControlledXxMinusYY(QIRProgramBuilder& b); +// --- CtrlOp --------------------------------------------------------------- // + +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QIRProgramBuilder& b); + } // namespace mlir::qir diff --git a/mlir/unittests/programs/quantum_computation_programs.cpp b/mlir/unittests/programs/quantum_computation_programs.cpp index 719fd50b17..f0b9b305cd 100644 --- a/mlir/unittests/programs/quantum_computation_programs.cpp +++ b/mlir/unittests/programs/quantum_computation_programs.cpp @@ -11,10 +11,14 @@ #include "quantum_computation_programs.h" #include "ir/QuantumComputation.hpp" +#include "ir/operations/CompoundOperation.hpp" +#include "ir/operations/IfElseOperation.hpp" #include "ir/operations/OpType.hpp" #include "ir/operations/StandardOperation.hpp" #include +#include +#include namespace qc { @@ -538,6 +542,28 @@ void barrierMultipleQubits(QuantumComputation& comp) { comp.barrier({0, 1, 2}); } +void ctrlTwo(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(4, "q"); + CompoundOperation compound; + compound.emplace_back(2, X); + compound.emplace_back(Targets{2, 3}, RXX, + std::vector{0.123}); + compound.addControl(0); + compound.addControl(1); + comp.emplace_back(std::move(compound)); +} + +void ctrlTwoMixed(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(4, "q"); + CompoundOperation compound; + compound.emplace_back(2, 3, X); + compound.emplace_back(Targets{2, 3}, RXX, + std::vector{0.123}); + compound.addControl(0); + compound.addControl(1); + comp.emplace_back(std::move(compound)); +} + void simpleIf(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); @@ -546,6 +572,19 @@ void simpleIf(QuantumComputation& comp) { comp.if_(X, q[0], c[0]); } +void ifTwoQubits(QuantumComputation& comp) { + const auto& q = comp.addQubitRegister(2, "q"); + const auto& c = comp.addClassicalRegister(1, "c"); + comp.h(q[0]); + comp.measure(q[0], c[0]); + CompoundOperation compound; + compound.emplace_back(0, X); + compound.emplace_back(1, X); + IfElseOperation ifElse( + std::make_unique(std::move(compound)), nullptr, c[0]); + comp.emplace_back(std::move(ifElse)); +} + void ifElse(QuantumComputation& comp) { const auto& q = comp.addQubitRegister(1, "q"); const auto& c = comp.addClassicalRegister(1, "c"); diff --git a/mlir/unittests/programs/quantum_computation_programs.h b/mlir/unittests/programs/quantum_computation_programs.h index b21bba30d5..f6dab6e1c2 100644 --- a/mlir/unittests/programs/quantum_computation_programs.h +++ b/mlir/unittests/programs/quantum_computation_programs.h @@ -385,11 +385,23 @@ void barrierTwoQubits(QuantumComputation& comp); /// Creates a circuit with a barrier on multiple qubits. void barrierMultipleQubits(QuantumComputation& comp); +// --- CtrlOp --------------------------------------------------------------- // + +/// Creates a circuit with a control modifier applied to two gates. +void ctrlTwo(QuantumComputation& comp); + +/// Creates a circuit with a control modifier applied to a controlled and a +/// non-controlled gate. +void ctrlTwoMixed(QuantumComputation& comp); + // --- IfOp ----------------------------------------------------------------- // /// Creates a circuit with a simple if operation with one qubit. void simpleIf(QuantumComputation& comp); +/// Creates a circuit with an if operation with two qubits. +void ifTwoQubits(QuantumComputation& comp); + /// Creates a circuit with an if operation with an else branch. void ifElse(QuantumComputation& comp);