Skip to content

[mlir][shard] Small fixes to partition pass#185050

Merged
fschlimb merged 2 commits intollvm:mainfrom
fschlimb:fix-partition
Mar 9, 2026
Merged

[mlir][shard] Small fixes to partition pass#185050
fschlimb merged 2 commits intollvm:mainfrom
fschlimb:fix-partition

Conversation

@fschlimb
Copy link
Contributor

@fschlimb fschlimb commented Mar 6, 2026

  • Empty functions (with no blocks) should be skipped by partition pass, not error-flagged
  • fixed ShardingInterfaceImpl of bufferization.materialize_in_destination

Enables llvm/lighthouse#65

@fschlimb fschlimb requested review from Copilot and tkarna March 6, 2026 17:01
@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Mar 6, 2026
@llvmbot
Copy link
Member

llvmbot commented Mar 6, 2026

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Frank Schlimbach (fschlimb)

Changes
  • Empty functions (with no blocks) should be skipped by partition pass, not error-flagged
  • fixed ShardingInterfaceImpl of bufferization.materialize_in_destination

Enables llvm/lighthouse#65


Full diff: https://github.com/llvm/llvm-project/pull/185050.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp (+11-13)
  • (modified) mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp (+10-6)
  • (modified) mlir/test/Dialect/Shard/sharding-propagation-failed.mlir (+8-1)
diff --git a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
index 7d6d2a8378813..40a26cf6334a2 100644
--- a/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
+++ b/mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp
@@ -12,22 +12,20 @@
 #include "mlir/IR/DialectRegistry.h"
 
 using namespace mlir;
-
-/// Variadic helper function.
-template <typename... OpTypes>
-static void registerAll(MLIRContext *ctx) {
-  (OpTypes::template attachInterface<
-       shard::IndependentParallelIteratorDomainShardingInterface<OpTypes>>(
-       *ctx),
-   ...);
-}
+using namespace mlir::bufferization;
+using namespace mlir::shard;
 
 void mlir::bufferization::shard_ext::registerShardingInterfaceExternalModels(
     DialectRegistry &registry) {
 
-  registry.addExtension(+[](MLIRContext *ctx,
-                            bufferization::BufferizationDialect *dialect) {
-    registerAll<bufferization::AllocTensorOp, bufferization::DeallocTensorOp,
-                bufferization::MaterializeInDestinationOp>(ctx);
+  registry.addExtension(+[](MLIRContext *ctx, BufferizationDialect *dialect) {
+    AllocTensorOp::attachInterface<
+        IndependentParallelIteratorDomainShardingInterface<AllocTensorOp>>(
+        *ctx);
+    DeallocTensorOp::attachInterface<
+        IndependentParallelIteratorDomainShardingInterface<DeallocTensorOp>>(
+        *ctx);
+    MaterializeInDestinationOp::attachInterface<
+        ElementwiseShardingInterface<MaterializeInDestinationOp>>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
index cff02d4f03143..1e7deda5c6377 100644
--- a/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp
@@ -364,12 +364,19 @@ struct ShardingPropagation
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
     Region &region = funcOp.getFunctionBody();
-    OpBuilder builder(ctx);
+
+    if (region.empty())
+      return;
+
+    Block &block = region.front();
+    // Nothing to propagate if there is no sharding annotation in the block.
+    if (block.getOps<shard::ShardOp>().empty())
+      return;
+
     if (!region.hasOneBlock()) {
       funcOp.emitOpError() << "only one block is supported!";
       return signalPassFailure();
     }
-    Block &block = region.front();
 
     LLVM_DEBUG(
         DBGS() << "print all the ops' iterator types and indexing maps in the "
@@ -379,10 +386,7 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
-    // Nothing to propagate if there is no sharding annotation in the block.
-    if (block.getOps<shard::ShardOp>().empty())
-      return;
-
+    OpBuilder builder(ctx);
     auto traverse = [&](auto &&range, OpBuilder &builder,
                         const char *order) -> bool {
       for (Operation &op : range) {
diff --git a/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
index b5eb98d859c36..3459c1c9f6edc 100644
--- a/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
+++ b/mlir/test/Dialect/Shard/sharding-propagation-failed.mlir
@@ -1,4 +1,11 @@
 // RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s -verify-diagnostics
 
+shard.grid @grid(shape = 1) {sym_visibility = "private"}
 // expected-error @+1 {{'func.func' op only one block is supported!}}
-func.func private @no_block_function(i64)
+func.func @multi_block_function(%arg0 : tensor<6x6xi32>) -> tensor<6x6xi32> {
+    %sharding = shard.sharding @grid split_axes = [[0]] : !shard.sharding
+    %sharded = shard.shard %arg0 to %sharding : tensor<6x6xi32>
+    cf.br ^bb1
+  ^bb1:
+    return %sharded : tensor<6x6xi32>
+}

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adjusts Shard sharding propagation behavior to avoid flagging empty functions as failures and updates Bufferization dialect sharding interface registrations (notably for bufferization.materialize_in_destination) to support downstream sharding behavior.

Changes:

  • Skip sharding propagation on empty function bodies (no blocks) instead of emitting an error.
  • Reorder early-exit logic in sharding propagation based on presence of sharding annotations.
  • Update Bufferization external sharding interface registration, switching MaterializeInDestinationOp to an elementwise sharding model.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
mlir/test/Dialect/Shard/sharding-propagation-failed.mlir Updates negative test to trigger the “multi-block not supported” diagnostic and adds grid/sharding setup.
mlir/lib/Dialect/Shard/Transforms/ShardingPropagation.cpp Skips empty regions and adds early returns when no sharding is present before doing propagation work.
mlir/lib/Dialect/Bufferization/Extensions/ShardingExtensions.cpp Adjusts external sharding interface attachments; uses elementwise sharding for MaterializeInDestinationOp.

Copy link
Contributor

@tkarna tkarna left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me

@fschlimb fschlimb merged commit 05781f4 into llvm:main Mar 9, 2026
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants