Skip to content

add WhileCyclicTranposePropagate transform pattern#2450

Open
mofeing wants to merge 9 commits into
mainfrom
ss/transpose-in-while-boundaries
Open

add WhileCyclicTranposePropagate transform pattern#2450
mofeing wants to merge 9 commits into
mainfrom
ss/transpose-in-while-boundaries

Conversation

@mofeing
Copy link
Copy Markdown
Collaborator

@mofeing mofeing commented Apr 27, 2026

Co-authored-by: Copilot <copilot@github.com>
@mofeing mofeing marked this pull request as draft April 27, 2026 11:34
mofeing and others added 6 commits April 27, 2026 07:15
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
Co-authored-by: Copilot <copilot@github.com>
@mofeing mofeing changed the title add WhileBodyBoundaryTranposePropagate add WhileCyclicTranposePropagate Apr 27, 2026
@mofeing mofeing changed the title add WhileCyclicTranposePropagate add WhileCyclicTranposePropagate transform pattern Apr 27, 2026
@mofeing mofeing marked this pull request as ready for review April 27, 2026 16:10
@mofeing mofeing requested a review from wsmoses April 27, 2026 16:10
@wsmoses
Copy link
Copy Markdown
Member

wsmoses commented Apr 28, 2026

@avik-pal can you review?

@mofeing
Copy link
Copy Markdown
Collaborator Author

mofeing commented Apr 28, 2026

just detected a couple of bugs that make transpose_transpose go on an infinite loop. I'm finishing the fixes.

@mofeing
Copy link
Copy Markdown
Collaborator Author

mofeing commented Apr 28, 2026

okay, so the problem I'm running into is that if I add it to the same transform sequence, it goes on an infinite loop on the test/lit_tests/while-transpose-test.mlir test because it conflicts with while_transpose or transpose_transpose.

but if I run the same test on different steps, it works perfectly.

MWE

this gets stuck

./bazel-bin/enzymexlamlir-opt \
    --enzyme-hlo-generate-td="patterns=transpose_while;transpose_transpose;while_cyclic_transpose_propagate(1);transpose_transpose;transpose_is_reshape" --transform-interpreter --enzyme-hlo-remove-transform \
    --allow-unregistered-dialect test/lit_tests/while-transpose-test.mlir --verify-each=1

but this works and generates the correct code

./bazel-bin/enzymexlamlir-opt \
    --enzyme-hlo-generate-td="patterns=transpose_while;transpose_transpose" --transform-interpreter --enzyme-hlo-remove-transform \
    --enzyme-hlo-generate-td="patterns=while_cyclic_transpose_propagate(1)" --transform-interpreter --enzyme-hlo-remove-transform \
    --enzyme-hlo-generate-td="patterns=transpose_transpose;transpose_is_reshape" --transform-interpreter --enzyme-hlo-remove-transform \
    --allow-unregistered-dialect test/lit_tests/while-transpose-test.mlir --verify-each=1
module {
  func.func @test_while_transpose_elimination(%arg0: tensor<2x3xf32>, %arg1: tensor<i64>) -> tensor<3x2xf32> {
    %c = stablehlo.constant dense<1> : tensor<i64>
    %c_0 = stablehlo.constant dense<true> : tensor<i1>
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32>
    %1:2 = stablehlo.while(%iterArg = %0, %iterArg_1 = %arg1) : tensor<3x2xf32>, tensor<i64>
    cond {
      stablehlo.return %c_0 : tensor<i1>
    } do {
      %2 = stablehlo.transpose %iterArg, dims = [1, 0] : (tensor<3x2xf32>) -> tensor<2x3xf32>
      %3 = "unregistered.custom_transform"(%2) : (tensor<2x3xf32>) -> tensor<3x2xf32>
      %4 = stablehlo.add %iterArg_1, %c : tensor<i64>
      stablehlo.return %3, %4 : tensor<3x2xf32>, tensor<i64>
    }
    return %1#0 : tensor<3x2xf32>
  }
}

@avik-pal
Copy link
Copy Markdown
Collaborator

There is a debug flag in mlir that will dump what each pattern rewriter is doing. Would be worth checking its output

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants