Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions mlir/include/mlir/Dialect/QC/Builder/QCProgramBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -917,7 +917,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder {
* } : !qc.qubit
* ```
*/
QCProgramBuilder& ctrl(ValueRange controls, const function_ref<void()>& body);
QCProgramBuilder& ctrl(ValueRange controls, ValueRange targets,
const function_ref<void(ValueRange)>& body);

/**
* @brief Apply an inverse (i.e., adjoint) operation.
Expand All @@ -936,7 +937,8 @@ class QCProgramBuilder final : public ImplicitLocOpBuilder {
* }
* ```
*/
QCProgramBuilder& inv(const function_ref<void()>& body);
QCProgramBuilder& inv(ValueRange qubits,
const function_ref<void(ValueRange)>& body);

//===--------------------------------------------------------------------===//
// Deallocation
Expand Down
77 changes: 46 additions & 31 deletions mlir/include/mlir/Dialect/QC/IR/QCOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -916,7 +916,7 @@ def YieldOp : QCOp<"yield", traits = [Terminator]> {

def CtrlOp
: QCOp<"ctrl",
traits = [UnitaryOpInterface,
traits = [UnitaryOpInterface, AttrSizedOperandSegments,
SingleBlockImplicitTerminator<"::mlir::qc::YieldOp">,
RecursiveMemoryEffects]> {
let summary = "Add control qubits to a unitary operation";
Expand All @@ -937,30 +937,36 @@ def CtrlOp
```
}];

let arguments =
(ins Arg<Variadic<QubitType>,
"the control qubits", [MemRead, MemWrite]>:$controls);
let arguments = (ins Arg<Variadic<QubitType>,
"the control qubits", [MemRead, MemWrite]>:$controls,
Arg<Variadic<QubitType>,
"the target qubits", [MemRead, MemWrite]>:$targets);
let regions = (region SizedRegion<1>:$region);
let assemblyFormat =
"`(` $controls `)` $region attr-dict `:` type($controls)";
let assemblyFormat = [{
`(` $controls `)`
`targets`
custom<TargetAliasing>($region, $targets)
attr-dict `:`
`{` type($controls) `}` ( `,` `{` type($targets)^ `}` )?
}];

let extraClassDeclaration = [{
[[nodiscard]] UnitaryOpInterface getBodyUnitary();
size_t getNumBodyUnitaries();
[[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i);
size_t getNumQubits() { return getNumTargets() + getNumControls(); }
size_t getNumTargets() { return getBodyUnitary().getNumTargets(); }
size_t getNumTargets() { return getTargets().size(); }
size_t getNumControls() { return getControls().size(); }
Value getQubit(size_t i);
Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); }
ValueRange getTargets() { return getBodyUnitary().getTargets(); }
Value getTarget(size_t i) { return getTargets()[i]; }
Value getControl(size_t i);
size_t getNumParams() { return getBodyUnitary().getNumParams(); }
Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); }
ValueRange getParameters() { return getBodyUnitary().getParameters(); }
size_t getNumParams() { return 0; }
Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); }
ValueRange getParameters() { llvm::reportFatalUsageError("CtrlOp does not have parameters"); }
static StringRef getBaseSymbol() { return "ctrl"; }
}];

let builders = [OpBuilder<(ins "ValueRange":$controls,
"const function_ref<void()>&":$bodyBuilder)>];
let builders = [OpBuilder<(ins "ValueRange":$controls, "ValueRange":$targets,
"const function_ref<void(ValueRange)>&":$body)>];

let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand All @@ -983,26 +989,35 @@ def InvOp : QCOp<"inv",
```
}];

let arguments = (ins Arg<
Variadic<QubitType>,
"the qubits involved in the operation", [MemRead, MemWrite]>:$qubits);
let regions = (region SizedRegion<1>:$region);
let assemblyFormat = "$region attr-dict";

let extraClassDeclaration = [{
[[nodiscard]] UnitaryOpInterface getBodyUnitary();
size_t getNumQubits() { return getBodyUnitary().getNumQubits(); }
size_t getNumTargets() { return getBodyUnitary().getNumTargets(); }
size_t getNumControls() { return getBodyUnitary().getNumControls(); }
Value getQubit(size_t i) { return getBodyUnitary().getQubit(i); }
Value getTarget(size_t i) { return getBodyUnitary().getTarget(i); }
ValueRange getTargets() { return getBodyUnitary().getTargets(); }
Value getControl(size_t i) { return getBodyUnitary().getControl(i); }
ValueRange getControls() { return getBodyUnitary().getControls(); }
size_t getNumParams() { return getBodyUnitary().getNumParams(); }
Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); }
ValueRange getParameters() { return getBodyUnitary().getParameters(); }
let assemblyFormat = [{
custom<TargetAliasing>($region, $qubits)
attr-dict `:`
type($qubits)
}];

let extraClassDeclaration = [{
size_t getNumBodyUnitaries();
[[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i);
size_t getNumQubits() { return getNumTargets(); }
size_t getNumTargets() { return getQubits().size(); }
size_t getNumControls() { return 0; }
Value getQubit(size_t i) { return getTarget(i); }
Value getTarget(size_t i) { return getQubits()[i]; }
ValueRange getTargets() { return getQubits(); }
Value getControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); }
ValueRange getControls() { return {nullptr, 0}; }
size_t getNumParams() { return 0; }
Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); }
ValueRange getParameters() { return {nullptr, 0}; }
static StringRef getBaseSymbol() { return "inv"; }
}];

let builders = [OpBuilder<(ins "const function_ref<void()>&":$bodyBuilder)>];
let builders = [OpBuilder<(ins "ValueRange":$qubits,
"const function_ref<void(ValueRange)>&":$body)>];

let hasCanonicalizer = 1;
let hasVerifier = 1;
Expand Down
22 changes: 12 additions & 10 deletions mlir/include/mlir/Dialect/QCO/IR/QCOOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,8 @@ def CtrlOp
}];

let extraClassDeclaration = [{
UnitaryOpInterface getBodyUnitary();
size_t getNumBodyUnitaries();
[[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i);
size_t getNumQubits() { return getNumControls() + getNumTargets(); }
size_t getNumTargets() { return getTargetsIn().size(); }
size_t getNumControls() { return getControlsIn().size(); }
Expand All @@ -1120,9 +1121,9 @@ def CtrlOp
ResultRange getOutputControls() { return getControlsOut(); }
Value getInputForOutput(Value output);
Value getOutputForInput(Value input);
size_t getNumParams() { return getBodyUnitary().getNumParams(); }
Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); }
ValueRange getParameters() { return getBodyUnitary().getParameters(); }
size_t getNumParams() { return 0; }
Value getParameter(size_t i) { llvm::reportFatalUsageError("CtrlOp does not have parameters"); }
ValueRange getParameters() { return {nullptr, 0}; }
[[nodiscard]] static StringRef getBaseSymbol() { return "ctrl"; }
[[nodiscard]] std::optional<Eigen::MatrixXcd> getUnitaryMatrix();
}];
Expand Down Expand Up @@ -1173,7 +1174,8 @@ def InvOp
}];

let extraClassDeclaration = [{
UnitaryOpInterface getBodyUnitary();
size_t getNumBodyUnitaries();
[[nodiscard]] UnitaryOpInterface getBodyUnitary(size_t i);
size_t getNumQubits() { return getNumTargets(); }
size_t getNumTargets() { return getQubitsIn().size(); }
static size_t getNumControls() { return 0; }
Expand All @@ -1184,16 +1186,16 @@ def InvOp
ResultRange getOutputQubits() { return getQubitsOut(); }
Value getInputTarget(size_t i) { return getInputQubit(i); }
Value getOutputTarget(size_t i) { return getOutputQubit(i); }
static Value getInputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); }
static Value getInputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); }
static OperandRange getInputControls() { return {nullptr, 0}; }
static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("Operation does not have controls"); }
static Value getOutputControl(size_t i) { llvm::reportFatalUsageError("InvOp does not have controls"); }
static ResultRange getOutputControls() { return {nullptr, 0}; }
ResultRange getOutputTargets() { return getOutputQubits(); }
Value getInputForOutput(Value output);
Value getOutputForInput(Value input);
size_t getNumParams() { return getBodyUnitary().getNumParams(); }
Value getParameter(size_t i) { return getBodyUnitary().getParameter(i); }
ValueRange getParameters() { return getBodyUnitary().getParameters(); }
size_t getNumParams() { return 0; }
Value getParameter(size_t i) { llvm::reportFatalUsageError("InvOp does not have parameters"); }
ValueRange getParameters() { return {nullptr, 0}; }
[[nodiscard]] static StringRef getBaseSymbol() { return "inv"; }
[[nodiscard]] std::optional<Eigen::MatrixXcd> getUnitaryMatrix();
}];
Expand Down
10 changes: 6 additions & 4 deletions mlir/include/mlir/Dialect/QCO/QCOUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ removeInversePairOneTargetZeroParameter(OpType op, PatternRewriter& rewriter) {
}

// Unlink both operations
rewriter.replaceAllUsesWith(nextOp->getResult(0), op.getInputQubit(0));
rewriter.replaceOp(op, op.getInputQubits());
rewriter.replaceOp(nextOp, nextOp.getInputQubits());

return success();
}
Expand Down Expand Up @@ -64,7 +65,8 @@ removeInversePairTwoTargetZeroParameter(OpType op, PatternRewriter& rewriter) {
}

// Unlink both operations
rewriter.replaceAllUsesWith(nextOp->getResults(), op.getOperands());
rewriter.replaceOp(op, op.getInputQubits());
rewriter.replaceOp(nextOp, nextOp.getInputQubits());

return success();
}
Expand Down Expand Up @@ -95,8 +97,8 @@ removeTwoTargetZeroParameterPairWithSwappedTargets(OpType op,
}

// Unlink both operations
rewriter.replaceAllUsesWith(nextOp->getResults(),
{op.getInputQubit(1), op.getInputQubit(0)});
rewriter.replaceOp(op, op.getInputQubits());
rewriter.replaceOp(nextOp, nextOp.getInputQubits());

return success();
}
Expand Down
118 changes: 118 additions & 0 deletions mlir/include/mlir/Dialect/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/IR/Builders.h>
#include <mlir/IR/IRMapping.h>
#include <mlir/IR/Location.h>
#include <mlir/IR/Value.h>

#include <cassert>
#include <variant>

namespace mlir::utils {
Expand Down Expand Up @@ -78,4 +80,120 @@ template <typename T>
return std::nullopt;
}

template <typename QubitType>
[[nodiscard]]
static ParseResult
parseTargetAliasing(OpAsmParser& parser, Region& region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand>& operands) {
// 1. Parse the opening parenthesis
if (parser.parseLParen()) {
return failure();
}

// Temporary storage for block arguments we are about to create
SmallVector<OpAsmParser::Argument> blockArgs;

// 2. Prepare to parse the list
if (failed(parser.parseOptionalRParen())) {
do {
OpAsmParser::Argument newArg; // The "new" variable name
OpAsmParser::UnresolvedOperand oldOperand; // The "old" input variable

// Parse "%new"
if (parser.parseArgument(newArg)) {
return failure();
}

// Parse "="
if (parser.parseEqual()) {
return failure();
}

// Parse "%old"
if (parser.parseOperand(oldOperand)) {
return failure();
}
operands.push_back(oldOperand);

// Hard-code QubitType since targets in qco.ctrl are always qubits.
// This avoids double-binding type($targets_in) in the assembly format
// while keeping the parser simple and the assembly format clean.
newArg.type = QubitType::get(parser.getBuilder().getContext());
blockArgs.push_back(newArg);

} while (succeeded(parser.parseOptionalComma()));

if (parser.parseRParen()) {
return failure();
}
}

// 4. Parse the Region
// We explicitly pass the blockArgs we just parsed so they become the entry
// block!
if (parser.parseRegion(region, blockArgs)) {
return failure();
}

return success();
}

static void printTargetAliasing(OpAsmPrinter& printer, Region& region,
OperandRange targetsIn) {
printer << "(";
if (region.empty()) {
printer << ") ";
printer.printRegion(region, false);
return;
}
Block& entryBlock = region.front();

const auto numTargets = targetsIn.size();
for (unsigned i = 0; i < numTargets; ++i) {
if (i > 0) {
printer << ", ";
}
printer.printOperand(entryBlock.getArgument(i));
printer << " = ";
printer.printOperand(targetsIn[i]);
}
printer << ") ";

printer.printRegion(region, false);
}

/**
* @brief Get the value corresponding to @p qubit from the block arguments @p
* qubits if @p qubit is a block argument, otherwise return @p qubit itself.
*/
static Value getValueFromBlockArgument(Value qubit, ValueRange qubits) {
if (auto blockArg = dyn_cast<BlockArgument>(qubit)) {
return qubits[blockArg.getArgNumber()];
}
return qubit;
}

/**
* @brief Create a mapping between block arguments and qubit values.
*
* @details This helper function is used to resolve block arguments for nested
* modifiers.
*/
static void populateMapping(IRMapping& mapping, Block& block,
ValueRange innerQubits, ValueRange outerQubits,
ValueRange newQubits, ValueRange qubitArgs) {
assert(innerQubits.size() == block.getNumArguments() &&
"Size of innerQubits must match number of block arguments");
for (auto arg : block.getArguments()) {
auto innerQubit = innerQubits[arg.getArgNumber()];
auto outerQubit = getValueFromBlockArgument(innerQubit, outerQubits);
if (auto it = llvm::find(newQubits, outerQubit); it != newQubits.end()) {
auto index = std::distance(newQubits.begin(), it);
mapping.map(arg, qubitArgs[index]);
} else {
llvm::reportFatalInternalError("Outer qubit not found in new qubits");
}
}
}

} // namespace mlir::utils
Loading
Loading