Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
e760329
Add RemoveExtractAfterInsert canonicalization pattern
MatthiasReumann May 20, 2026
dd935d7
Add TensorIterator
MatthiasReumann May 20, 2026
872eb4d
Add qtensor-utils unit test
MatthiasReumann May 20, 2026
554cca7
Use TensorIterator
MatthiasReumann May 20, 2026
46ed449
Add BubbleDownInsertPattern pattern
MatthiasReumann May 20, 2026
18338b5
Update InsertOp.cpp
MatthiasReumann May 21, 2026
bdd6b70
Minor update to comments
MatthiasReumann May 21, 2026
12fea23
🎨 pre-commit fixes
pre-commit-ci[bot] May 21, 2026
845f15e
Fix CMake
MatthiasReumann May 21, 2026
4aae1ab
Merge branch 'enh/additional-tensor-canonicalization' of https://gith…
MatthiasReumann May 21, 2026
0cbd687
Add missing library in CMakeLists.txt
MatthiasReumann May 21, 2026
dad7fac
Update library dependencies
MatthiasReumann May 21, 2026
4c56de8
🎨 pre-commit fixes
pre-commit-ci[bot] May 21, 2026
cd92343
Move to ExtractOp
MatthiasReumann May 22, 2026
b52fa37
Merge branch 'enh/additional-tensor-canonicalization' of https://gith…
MatthiasReumann May 22, 2026
05a3d11
🎨 pre-commit fixes
pre-commit-ci[bot] May 22, 2026
4a7d7c0
Merge branch 'main' into enh/additional-tensor-canonicalization
MatthiasReumann May 27, 2026
6d6f419
Merge branch 'enh/additional-tensor-canonicalization' of https://gith…
MatthiasReumann May 27, 2026
2ee8b3f
Fix build
MatthiasReumann May 27, 2026
91de0bd
🎨 pre-commit fixes
pre-commit-ci[bot] May 27, 2026
fb619f4
Merge branch 'main' into enh/additional-tensor-canonicalization
MatthiasReumann May 27, 2026
26e3b19
🎨 pre-commit fixes
pre-commit-ci[bot] May 27, 2026
00f6d08
Remove unused includes
MatthiasReumann May 27, 2026
2f7f915
Merge branch 'enh/additional-tensor-canonicalization' of https://gith…
MatthiasReumann May 27, 2026
38fcf2e
Update CHANGELOG.md
MatthiasReumann May 27, 2026
51002a0
Update mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp
MatthiasReumann May 28, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel
- ✨ Add conversions between `jeff` and QCO ([#1479], [#1548], [#1565], [#1637], [#1676], [#1706]) ([**@denialhaag**], [**@burgholzer**])
- ✨ Add a `place-and-route` pass for mapping circuits to architectures with restricted topologies ([#1537], [#1547], [#1568], [#1581], [#1583], [#1588], [#1600], [#1664], [#1709], [#1716]) ([**@MatthiasReumann**], [**@burgholzer**])
- ✨ Add initial infrastructure for new QC and QCO MLIR dialects
([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1730])
([#1264], [#1330], [#1402], [#1428], [#1430], [#1436], [#1443], [#1446], [#1464], [#1465], [#1470], [#1471], [#1472], [#1474], [#1475], [#1506], [#1510], [#1513], [#1521], [#1542], [#1548], [#1550], [#1554], [#1567], [#1569], [#1570], [#1572], [#1573], [#1580], [#1602], [#1620], [#1623], [#1624], [#1626], [#1627], [#1635], [#1638], [#1673], [#1675], [#1700], [#1717], [#1728], [#1730])
([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**], [**@li-mingbao**], [**@Ectras**], [**@MatthiasReumann**], [**@simon1hofmann**])

### Changed
Expand Down Expand Up @@ -404,6 +404,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool

[#1737]: https://github.com/munich-quantum-toolkit/core/pull/1737
[#1730]: https://github.com/munich-quantum-toolkit/core/pull/1730
[#1728]: https://github.com/munich-quantum-toolkit/core/pull/1728
[#1720]: https://github.com/munich-quantum-toolkit/core/pull/1720
[#1719]: https://github.com/munich-quantum-toolkit/core/pull/1719
[#1718]: https://github.com/munich-quantum-toolkit/core/pull/1718
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/QTensor/IR/QTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def ExtractOp
let results = (outs 1DTensorOf<[QubitType]>:$out_tensor, QubitType:$result);
let assemblyFormat = "$tensor `[` $index `]` attr-dict `:` type($tensor)";

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/QTensor/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ add_mlir_dialect_library(
MLIRQTensorOpsIncGen
LINK_LIBS
PRIVATE
MLIRQTensorUtils
MLIRIR
MLIRDialectUtils
MLIRArithDialect
Expand Down
68 changes: 68 additions & 0 deletions mlir/lib/Dialect/QTensor/IR/Operations/ExtractOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,78 @@
*/

#include "mlir/Dialect/QTensor/IR/QTensorOps.h"
#include "mlir/Dialect/QTensor/IR/QTensorUtils.h"
#include "mlir/Dialect/QTensor/Utils/TensorIterator.h"

#include <mlir/Dialect/Utils/StaticValueUtils.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
#include <mlir/IR/OpDefinition.h>
#include <mlir/IR/PatternMatch.h>
#include <mlir/Support/LLVM.h>

#include <iterator>

using namespace mlir;
using namespace mlir::qtensor;

namespace {
/**
* @brief Remove an (extract, insert) pair when the extracted qubit is
* reinserted unchanged at the same constant index.
*/
struct RemoveExtractInsertPairPattern final : OpRewritePattern<ExtractOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(ExtractOp extract,
PatternRewriter& rewriter) const override {
// Check: Extract has constant index.
if (!getConstantIntValue(extract.getIndex())) {
return failure();
}

// Search for an insert operation on the tensor-chain with the same constant
// index as the matched extract operation.
TensorIterator it(extract.getOutTensor());
for (; it != std::default_sentinel; ++it) {
if (!isa<InsertOp>(it.operation())) {
continue;
}

auto insert = cast<InsertOp>(it.operation());

// Check: Insert has constant index.
if (!getConstantIntValue(insert.getIndex())) {
return failure();
}

// Check: Same constant index.
if (!areEquivalentIndices(insert.getIndex(), extract.getIndex())) {
continue;
}

// Check: The inserted qubit value is the extracted one. If so, the
// qubit has not been used and both operations can be safely removed.

if (extract.getResult() == insert.getScalar()) {

// ┌──────────┐ ┌─────────┐
// ... ─tensor─▶│extract(i)│─▶ ... ─▶│insert(i)│─▶result─▶ ...
// └────┬─────┘ └────▲────┘
// └──result = scalar───┘
// ------------------- ⬇ (transformed) ⬇ -------------------
// ... ─tensor = result─▶ ...

rewriter.replaceOp(insert, insert.getDest());
rewriter.replaceOp(extract, {extract.getTensor(), nullptr});
return success();
}
}

return failure();
}
};
} // namespace

LogicalResult ExtractOp::verify() {
auto tensorDim = getTensor().getType().getDimSize(0);
auto index = getConstantIntValue(getIndex());
Expand All @@ -32,3 +95,8 @@ LogicalResult ExtractOp::verify() {
}
return success();
}

void ExtractOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
results.add<RemoveExtractInsertPairPattern>(context);
}
151 changes: 96 additions & 55 deletions mlir/lib/Dialect/QTensor/IR/Operations/InsertOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "mlir/Dialect/QTensor/IR/QTensorOps.h"
#include "mlir/Dialect/QTensor/IR/QTensorUtils.h"
#include "mlir/Dialect/QTensor/Utils/TensorIterator.h"

#include <mlir/Dialect/Utils/StaticValueUtils.h>
#include <mlir/IR/BuiltinTypeInterfaces.h>
Expand All @@ -19,108 +20,148 @@
#include <mlir/IR/Value.h>
#include <mlir/Support/LLVM.h>

#include <iterator>

using namespace mlir;
using namespace mlir::qtensor;

/**
* @brief Checks whether removing an extract-insert pair is linearity-safe.
*/
static bool isRemovableExtractInsertPair(InsertOp insertOp,
ExtractOp extractOp) {
return insertOp.getScalar() == extractOp.getResult() &&
areEquivalentIndices(insertOp.getIndex(), extractOp.getIndex());
static bool isRemovableExtractInsertPair(InsertOp insert, ExtractOp extract) {
return insert.getScalar() == extract.getResult() &&
areEquivalentIndices(insert.getIndex(), extract.getIndex());
}

/**
* @brief Folds an insert operation after a matching extract operation into the
* original tensor.
*/
static Value foldInsertAfterExtract(InsertOp insertOp) {
auto extractOp = insertOp.getScalar().getDefiningOp<ExtractOp>();
if (!extractOp) {
static Value foldInsertAfterExtract(InsertOp insert) {
auto extract = insert.getScalar().getDefiningOp<ExtractOp>();
if (!extract) {
return nullptr;
}

if (insertOp.getDest() != extractOp.getOutTensor()) {
if (insert.getDest() != extract.getOutTensor()) {
return nullptr;
}

if (!isRemovableExtractInsertPair(insertOp, extractOp)) {
if (!isRemovableExtractInsertPair(insert, extract)) {
return nullptr;
}

return extractOp.getTensor();
return extract.getTensor();
}

namespace {
/**
* @brief Finds the extract operation corresponding to a given insert operation.
*
* @details The function traverses the tensor chain of the insert operation
* until it finds the matching extract operation.
* @brief Remove an (insert, extract) pair when the inserted qubit has been
* extracted previously with the same constant index.
* @pre Assumes each qubit is extracted and inserted with the same index.
*/
static ExtractOp findMatchingExtractInTensorChain(InsertOp insertOp) {
auto current = insertOp.getDest();
auto insertIndex = insertOp.getIndex();
struct RemoveInsertExtractPairPattern final : OpRewritePattern<InsertOp> {
using OpRewritePattern::OpRewritePattern;

if (!getConstantIntValue(insertIndex)) {
return nullptr;
}
LogicalResult matchAndRewrite(InsertOp insert,
PatternRewriter& rewriter) const override {
// Check: Insert has constant index.
if (!getConstantIntValue(insert.getIndex())) {
return failure();
}

while (auto* definingOp = current.getDefiningOp()) {
if (auto nestedInsertOp = dyn_cast<InsertOp>(definingOp)) {
auto nestedInsertIndex = nestedInsertOp.getIndex();
if (!getConstantIntValue(nestedInsertIndex)) {
return nullptr;
}
// A more recent write to the same index shadows all older extracts
if (areEquivalentIndices(nestedInsertIndex, insertIndex)) {
return nullptr;
// Search for an extract operation on the tensor-chain with the same
// constant index as the matched insert operation.
TensorIterator it(insert.getResult());
for (; it != std::default_sentinel; ++it) {
if (!isa<ExtractOp>(it.operation())) {
continue;
}
current = nestedInsertOp.getDest();
continue;
}
if (auto extractOp = dyn_cast<ExtractOp>(definingOp)) {
auto extractIndex = extractOp.getIndex();
if (!getConstantIntValue(extractIndex)) {
return nullptr;

auto extract = cast<ExtractOp>(it.operation());

// Check: Extract has constant index.
if (!getConstantIntValue(extract.getIndex())) {
return failure();
}
if (areEquivalentIndices(extractIndex, insertIndex)) {
return extractOp;

// Check: Same constant index.
if (!areEquivalentIndices(extract.getIndex(), insert.getIndex())) {
continue;
}
current = extractOp.getTensor();
continue;

// ┌─────────┐ ┌──────────┐
// ... ─t = dest──▶│insert(i)│─▶ ... ─▶tensor─▶│extract(i)│─outTensor─▶...
// └────▲────┘ └────┬─────┘
// ... ─scalar─┘ └result─▶ ...
// ------------------------- ⬇ (transformed) ⬇ -------------------------
// ... ─t = outTensor─▶ ...
// ... ─scalar = result─▶ ... (Assumption applied.)

rewriter.replaceOp(extract, {extract.getTensor(), insert.getScalar()});
rewriter.replaceOp(insert, insert.getDest());

return success();
}
break;
}
return nullptr;
}

namespace {
return failure();
}
};

/**
* @brief Remove matching extract-insert pairs.
* @brief If possible, move insert after extract in tensor chain.
* @pre Assumes that the extract and insertion index of any qubit is equivalent.
*/
struct RemoveExtractInsertPair final : OpRewritePattern<InsertOp> {
struct BubbleDownInsertPattern final : OpRewritePattern<InsertOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(InsertOp op,
LogicalResult matchAndRewrite(InsertOp insert,
PatternRewriter& rewriter) const override {
auto extractOp = findMatchingExtractInTensorChain(op);
if (!extractOp) {
if (!getConstantIntValue(insert.getIndex())) {
return failure();
}

auto next = std::next(TensorIterator(insert.getResult()));
if (next == std::default_sentinel) {
return failure();
}

if (!isa<ExtractOp>(next.operation())) {
return failure();
}

if (!isRemovableExtractInsertPair(op, extractOp)) {
auto extract = cast<ExtractOp>(next.operation());
if (!getConstantIntValue(extract.getIndex())) {
return failure();
}

rewriter.replaceOp(op, op.getDest());
rewriter.replaceOp(extractOp, {extractOp.getTensor(), nullptr});
if (areEquivalentIndices(extract.getIndex(), insert.getIndex())) {
return failure();
}

// i != j
// ┌─────────┐ ┌──────────┐
// ... ─t = dest─▶│insert(i)│─result = tensor─▶│extract(j)│─outTensor─▶ ...
// └─────────┘ └──────────┘
// -------------------------- ⬇ (transformed) ⬇ --------------------------
// ┌──────────┐ ┌─────────┐
// ... ─t = tensor─▶│extract(j)│─outTensor = dest─▶│insert(i)│─result─▶ ...
// └──────────┘ └─────────┘

const Value t = insert.getDest();
const Value outTensor = extract.getOutTensor();
const Value result = insert.getResult();

rewriter.moveOpAfter(insert, extract);
rewriter.modifyOpInPlace(extract,
[&] { extract.getTensorMutable().assign(t); });
rewriter.modifyOpInPlace(
insert, [&] { insert.getDestMutable().assign(outTensor); });
rewriter.replaceAllUsesExcept(outTensor, result, insert);

return success();
}
};

} // namespace

LogicalResult InsertOp::verify() {
Expand Down Expand Up @@ -148,5 +189,5 @@ OpFoldResult InsertOp::fold(FoldAdaptor /*adaptor*/) {

void InsertOp::getCanonicalizationPatterns(RewritePatternSet& results,
MLIRContext* context) {
results.add<RemoveExtractInsertPair>(context);
results.add<RemoveInsertExtractPairPattern, BubbleDownInsertPattern>(context);
}
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/QTensor/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ add_mlir_dialect_library(
MLIRQTensorOpsIncGen
LINK_LIBS
PUBLIC
MLIRQTensorDialect)
MLIRQCODialect)

mqt_mlir_target_use_project_options(MLIRQTensorUtils)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ set(target_name mqt-core-mlir-unittest-mapping)
add_executable(${target_name} test_mapping.cpp)

target_link_libraries(${target_name} PRIVATE GTest::gtest_main MLIRParser MLIRQCOProgramBuilder
MLIRQCOUtils MLIRQCOTransforms)
MLIRQTensorUtils MLIRQCOTransforms)

mqt_mlir_configure_unittest_target(${target_name})

Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/Dialect/QCO/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

set(qco_utils_target mqt-core-mlir-unittest-qco-utils)
add_executable(${qco_utils_target} test_drivers.cpp test_wireiterator.cpp)
target_link_libraries(${qco_utils_target} PRIVATE GTest::gtest_main MLIRQCODialect MLIRQCOUtils
target_link_libraries(${qco_utils_target} PRIVATE GTest::gtest_main MLIRQCOUtils
MLIRQCOProgramBuilder)
mqt_mlir_configure_unittest_target(${qco_utils_target})

Expand Down
5 changes: 3 additions & 2 deletions mlir/unittests/Dialect/QIR/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
set(target_name mqt-core-mlir-unittest-qir-ir)
add_executable(${target_name} test_qir_ir.cpp)

target_link_libraries(${target_name} PRIVATE MLIRParser MLIRSupportMQT GTest::gtest_main
MLIRQIRProgramBuilder MLIRQIRPrograms)
target_link_libraries(
${target_name} PRIVATE MLIRParser MLIRSupportMQT MLIRSCFDialect GTest::gtest_main
MLIRQIRProgramBuilder MLIRQIRPrograms)

mqt_mlir_configure_unittest_target(${target_name})

Expand Down
Loading