diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c3b5c34c8..bdc629138c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel - ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637], [#1676], [#1706]) ([**@denialhaag**], [**@burgholzer**]) - ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588], [#1600], [#1664], [#1709], [#1716]) ([**@MatthiasReumann**], [**@burgholzer**]) - ✨ Add initial infrastructure for new QC and QCO MLIR dialects - ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1730]) + ([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1728], [#1730]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**]) ### Changed @@ -404,6 +404,7 @@ _πŸ“š Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1737]: https://github.com/munich-quantum-toolkit/core/pull/1737 [#1730]: https://github.com/munich-quantum-toolkit/core/pull/1730 +[#1728]: https://github.com/munich-quantum-toolkit/core/pull/1728 [#1720]: https://github.com/munich-quantum-toolkit/core/pull/1720 [#1719]: https://github.com/munich-quantum-toolkit/core/pull/1719 [#1718]: https://github.com/munich-quantum-toolkit/core/pull/1718 diff --git a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td index 7a0e26ea27..98745e1255 100644 --- a/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td +++ b/mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td @@ -133,6 +133,7 @@ def ExtractOp let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, QubitType:$result); let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor)"; + let hasCanonicalizer = 1; let hasVerifier = 1; } diff --git a/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt b/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt index f19fa71111..97d7ec241c 100644 --- a/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/QTensor/IR/CMakeLists.txt @@ -18,6 +18,7 @@ add_mlir_dialect_library( MLIRQTensorOpsIncGen LINK_LIBS PRIVATE + MLIRQTensorUtils MLIRIR MLIRDialectUtils MLIRArithDialect diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp index 04d6b834c7..7a4ec3b9f0 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp @@ -9,15 +9,78 @@ */ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" +#include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include "mlir/Dialect/QTensor/Utils/TensorIterator.h" #include #include #include +#include #include +#include + using namespace mlir; using namespace mlir::qtensor; +namespace { +/** + * @brief Remove an (extract, insert) pair when the extracted qubit is + * reinserted unchanged at the same constant index. + */ +struct RemoveExtractInsertPairPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ExtractOp extract, + PatternRewriter& rewriter) const override { + // Check: Extract has constant index. + if (!getConstantIntValue(extract.getIndex())) { + return failure(); + } + + // Search for an insert operation on the tensor-chain with the same constant + // index as the matched extract operation. + TensorIterator it(extract.getOutTensor()); + for (; it != std::default_sentinel; ++it) { + if (!isa(it.operation())) { + continue; + } + + auto insert = cast(it.operation()); + + // Check: Insert has constant index. + if (!getConstantIntValue(insert.getIndex())) { + return failure(); + } + + // Check: Same constant index. + if (!areEquivalentIndices(insert.getIndex(), extract.getIndex())) { + continue; + } + + // Check: The inserted qubit value is the extracted one. If so, the + // qubit has not been used and both operations can be safely removed. + + if (extract.getResult() == insert.getScalar()) { + + // β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // ... ─tensor─▢│extract(i)│─▢ ... ─▢│insert(i)│─▢result─▢ ... + // β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β–²β”€β”€β”€β”€β”˜ + // └──result = scalarβ”€β”€β”€β”˜ + // ------------------- ⬇ (transformed) ⬇ ------------------- + // ... ─tensor = result─▢ ... + + rewriter.replaceOp(insert, insert.getDest()); + rewriter.replaceOp(extract, {extract.getTensor(), nullptr}); + return success(); + } + } + + return failure(); + } +}; +} // namespace + LogicalResult ExtractOp::verify() { auto tensorDim = getTensor().getType().getDimSize(0); auto index = getConstantIntValue(getIndex()); @@ -32,3 +95,8 @@ LogicalResult ExtractOp::verify() { } return success(); } + +void ExtractOp::getCanonicalizationPatterns(RewritePatternSet& results, + MLIRContext* context) { + results.add(context); +} diff --git a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp index 9ec4bd5d07..68b6dafed9 100644 --- a/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp +++ b/mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp @@ -10,6 +10,7 @@ #include "mlir/Dialect/QTensor/IR/QTensorOps.h" #include "mlir/Dialect/QTensor/IR/QTensorUtils.h" +#include "mlir/Dialect/QTensor/Utils/TensorIterator.h" #include #include @@ -19,108 +20,148 @@ #include #include +#include + using namespace mlir; using namespace mlir::qtensor; /** * @brief Checks whether removing an extract-insert pair is linearity-safe. */ -static bool isRemovableExtractInsertPair(InsertOp insertOp, - ExtractOp extractOp) { - return insertOp.getScalar() == extractOp.getResult() && - areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex()); +static bool isRemovableExtractInsertPair(InsertOp insert, ExtractOp extract) { + return insert.getScalar() == extract.getResult() && + areEquivalentIndices(insert.getIndex(), extract.getIndex()); } /** * @brief Folds an insert operation after a matching extract operation into the * original tensor. */ -static Value foldInsertAfterExtract(InsertOp insertOp) { - auto extractOp = insertOp.getScalar().getDefiningOp(); - if (!extractOp) { +static Value foldInsertAfterExtract(InsertOp insert) { + auto extract = insert.getScalar().getDefiningOp(); + if (!extract) { return nullptr; } - if (insertOp.getDest() != extractOp.getOutTensor()) { + if (insert.getDest() != extract.getOutTensor()) { return nullptr; } - if (!isRemovableExtractInsertPair(insertOp, extractOp)) { + if (!isRemovableExtractInsertPair(insert, extract)) { return nullptr; } - return extractOp.getTensor(); + return extract.getTensor(); } +namespace { /** - * @brief Finds the extract operation corresponding to a given insert operation. - * - * @details The function traverses the tensor chain of the insert operation - * until it finds the matching extract operation. + * @brief Remove an (insert, extract) pair when the inserted qubit has been + * extracted previously with the same constant index. + * @pre Assumes each qubit is extracted and inserted with the same index. */ -static ExtractOp findMatchingExtractInTensorChain(InsertOp insertOp) { - auto current = insertOp.getDest(); - auto insertIndex = insertOp.getIndex(); +struct RemoveInsertExtractPairPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - if (!getConstantIntValue(insertIndex)) { - return nullptr; - } + LogicalResult matchAndRewrite(InsertOp insert, + PatternRewriter& rewriter) const override { + // Check: Insert has constant index. + if (!getConstantIntValue(insert.getIndex())) { + return failure(); + } - while (auto* definingOp = current.getDefiningOp()) { - if (auto nestedInsertOp = dyn_cast(definingOp)) { - auto nestedInsertIndex = nestedInsertOp.getIndex(); - if (!getConstantIntValue(nestedInsertIndex)) { - return nullptr; - } - // A more recent write to the same index shadows all older extracts - if (areEquivalentIndices(nestedInsertIndex, insertIndex)) { - return nullptr; + // Search for an extract operation on the tensor-chain with the same + // constant index as the matched insert operation. + TensorIterator it(insert.getResult()); + for (; it != std::default_sentinel; ++it) { + if (!isa(it.operation())) { + continue; } - current = nestedInsertOp.getDest(); - continue; - } - if (auto extractOp = dyn_cast(definingOp)) { - auto extractIndex = extractOp.getIndex(); - if (!getConstantIntValue(extractIndex)) { - return nullptr; + + auto extract = cast(it.operation()); + + // Check: Extract has constant index. + if (!getConstantIntValue(extract.getIndex())) { + return failure(); } - if (areEquivalentIndices(extractIndex, insertIndex)) { - return extractOp; + + // Check: Same constant index. + if (!areEquivalentIndices(extract.getIndex(), insert.getIndex())) { + continue; } - current = extractOp.getTensor(); - continue; + + // β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // ... ─t = dest──▢│insert(i)│─▢ ... ─▢tensor─▢│extract(i)│─outTensor─▢... + // β””β”€β”€β”€β”€β–²β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”˜ + // ... ─scalarβ”€β”˜ β””result─▢ ... + // ------------------------- ⬇ (transformed) ⬇ ------------------------- + // ... ─t = outTensor─▢ ... + // ... ─scalar = result─▢ ... (Assumption applied.) + + rewriter.replaceOp(extract, {extract.getTensor(), insert.getScalar()}); + rewriter.replaceOp(insert, insert.getDest()); + + return success(); } - break; - } - return nullptr; -} -namespace { + return failure(); + } +}; /** - * @brief Remove matching extract-insert pairs. + * @brief If possible, move insert after extract in tensor chain. + * @pre Assumes that the extract and insertion index of any qubit is equivalent. */ -struct RemoveExtractInsertPair final : OpRewritePattern { +struct BubbleDownInsertPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(InsertOp op, + LogicalResult matchAndRewrite(InsertOp insert, PatternRewriter& rewriter) const override { - auto extractOp = findMatchingExtractInTensorChain(op); - if (!extractOp) { + if (!getConstantIntValue(insert.getIndex())) { + return failure(); + } + + auto next = std::next(TensorIterator(insert.getResult())); + if (next == std::default_sentinel) { + return failure(); + } + + if (!isa(next.operation())) { return failure(); } - if (!isRemovableExtractInsertPair(op, extractOp)) { + auto extract = cast(next.operation()); + if (!getConstantIntValue(extract.getIndex())) { return failure(); } - rewriter.replaceOp(op, op.getDest()); - rewriter.replaceOp(extractOp, {extractOp.getTensor(), nullptr}); + if (areEquivalentIndices(extract.getIndex(), insert.getIndex())) { + return failure(); + } + + // i != j + // β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // ... ─t = dest─▢│insert(i)│─result = tensor─▢│extract(j)│─outTensor─▢ ... + // β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + // -------------------------- ⬇ (transformed) ⬇ -------------------------- + // β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β” + // ... ─t = tensor─▢│extract(j)│─outTensor = dest─▢│insert(i)│─result─▢ ... + // β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + + const Value t = insert.getDest(); + const Value outTensor = extract.getOutTensor(); + const Value result = insert.getResult(); + + rewriter.moveOpAfter(insert, extract); + rewriter.modifyOpInPlace(extract, + [&] { extract.getTensorMutable().assign(t); }); + rewriter.modifyOpInPlace( + insert, [&] { insert.getDestMutable().assign(outTensor); }); + rewriter.replaceAllUsesExcept(outTensor, result, insert); return success(); } }; - } // namespace LogicalResult InsertOp::verify() { @@ -148,5 +189,5 @@ OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) { void InsertOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } diff --git a/mlir/lib/Dialect/QTensor/Utils/CMakeLists.txt b/mlir/lib/Dialect/QTensor/Utils/CMakeLists.txt index 6b94307613..59cbc69bec 100644 --- a/mlir/lib/Dialect/QTensor/Utils/CMakeLists.txt +++ b/mlir/lib/Dialect/QTensor/Utils/CMakeLists.txt @@ -17,7 +17,7 @@ add_mlir_dialect_library( MLIRQTensorOpsIncGen LINK_LIBS PUBLIC - MLIRQTensorDialect) + MLIRQCODialect) mqt_mlir_target_use_project_options(MLIRQTensorUtils) diff --git a/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt index 248d9ec9f5..35abaebfcf 100644 --- a/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Transforms/Mapping/CMakeLists.txt @@ -10,7 +10,7 @@ set(target_name mqt-core-mlir-unittest-mapping) add_executable(${target_name} test_mapping.cpp) target_link_libraries(${target_name} PRIVATE GTest::gtest_main MLIRParser MLIRQCOProgramBuilder - MLIRQCOUtils MLIRQCOTransforms) + MLIRQTensorUtils MLIRQCOTransforms) mqt_mlir_configure_unittest_target(${target_name}) diff --git a/mlir/unittests/Dialect/QCO/Utils/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Utils/CMakeLists.txt index a8070a6984..e5815170b6 100644 --- a/mlir/unittests/Dialect/QCO/Utils/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/Utils/CMakeLists.txt @@ -8,7 +8,7 @@ set(qco_utils_target mqt-core-mlir-unittest-qco-utils) add_executable(${qco_utils_target} test_drivers.cpp test_wireiterator.cpp) -target_link_libraries(${qco_utils_target} PRIVATE GTest::gtest_main MLIRQCODialect MLIRQCOUtils +target_link_libraries(${qco_utils_target} PRIVATE GTest::gtest_main MLIRQCOUtils MLIRQCOProgramBuilder) mqt_mlir_configure_unittest_target(${qco_utils_target}) diff --git a/mlir/unittests/Dialect/QIR/IR/CMakeLists.txt b/mlir/unittests/Dialect/QIR/IR/CMakeLists.txt index 2d5ed5d601..ef5a487b6e 100644 --- a/mlir/unittests/Dialect/QIR/IR/CMakeLists.txt +++ b/mlir/unittests/Dialect/QIR/IR/CMakeLists.txt @@ -9,8 +9,9 @@ set(target_name mqt-core-mlir-unittest-qir-ir) add_executable(${target_name} test_qir_ir.cpp) -target_link_libraries(${target_name} PRIVATE MLIRParser MLIRSupportMQT GTest::gtest_main - MLIRQIRProgramBuilder MLIRQIRPrograms) +target_link_libraries( + ${target_name} PRIVATE MLIRParser MLIRSupportMQT MLIRSCFDialect GTest::gtest_main + MLIRQIRProgramBuilder MLIRQIRPrograms) mqt_mlir_configure_unittest_target(${target_name})