Skip to content

[Issue]: While loop counter update issue #211

@ruanjm

Description

@ruanjm

Problem Description

Looks the counter which should be updated in while loop is not updated correctly and LLVM reports error:

Traceback (most recent call last):
  File "/jruan/ws/aiter/op_tests/test_mla_persistent.py", line 996, in <module>
    ret = test_mla(
  File "/jruan/ws/aiter/aiter/test_common.py", line 128, in wrapper
    ret = func(*args, **kwargs)
  File "/jruan/ws/aiter/op_tests/test_mla_persistent.py", line 833, in test_mla
    err, us_asm_decode = test_absorb_decode_fp8()
  File "/jruan/ws/aiter/op_tests/test_mla_persistent.py", line 708, in test_absorb_decode_fp8
    attn_logits, attn_lse = aiter.mla.mla_decode_fwd(
  File "/jruan/ws/aiter/aiter/mla.py", line 379, in mla_decode_fwd
    flydsl_attn_reduce_v1(
  File "/jruan/ws/aiter/aiter/ops/flydsl/attn_reduce.py", line 155, in flydsl_attn_reduce_v1
    launch_attn_reduce_ps(
  File "/jruan/ws/FlyDSL/python/flydsl/compiler/jit_function.py", line 544, in __call__
    compiled_module = MlirCompiler.compile(module, chip=chip, func_name=self.func.__name__)
  File "/jruan/ws/FlyDSL/python/flydsl/compiler/jit_function.py", line 312, in compile
    module.operation.verify()
flydsl._mlir._mlir_libs._site_initialize.<locals>.MLIRError: Verification failed:
error: unknown: 'scf.while' op expects the 'after' region to terminate with 'scf.yield'
 note: unknown: see current operation:
  %69:2 = "scf.while"(%68, %36) ({
  ^bb0(%arg19: i32, %arg20: i32):
    %213 = "arith.cmpi"(%68, %36) <{predicate = 2 : i64}> : (i32, i32) -> i1
    "scf.condition"(%213, %arg19, %arg20) : (i1, i32, i32) -> ()
  }, {
  ^bb0(%arg11: i32, %arg12: i32):
    "gpu.barrier"() : () -> ()
    %70 = "arith.constant"() <{value = 128 : i32}> : () -> i32
    %71 = "arith.remsi"(%68, %70) : (i32, i32) -> i32
    %72 = "arith.constant"() <{value = 128 : i32}> : () -> i32
    %73 = "arith.floordivsi"(%68, %72) : (i32, i32) -> i32
    %74 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %75 = "arith.remsi"(%73, %74) : (i32, i32) -> i32
    %76 = "arith.constant"() <{value = 128 : i32}> : () -> i32
    %77 = "arith.floordivsi"(%68, %76) : (i32, i32) -> i32
    %78 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %79 = "arith.floordivsi"(%77, %78) : (i32, i32) -> i32
    %80 = "arith.constant"() <{value = 4 : i32}> : () -> i32
    %81 = "arith.muli"(%79, %80) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
    %82 = "arith.constant"() <{value = 0 : i32}> : () -> i32
    %83 = "arith.constant"() <{value = 0 : i32}> : () -> i32
    %84 = "rocdl.raw.ptr.buffer.load"(%4, %81, %82, %83) : (!llvm.ptr<8>, i32, i32, i32) -> vector<2xi32>
    %85 = "vector.extract"(%84) <{static_position = array<i64: 0>}> : (vector<2xi32>) -> i32
    %86 = "vector.extract"(%84) <{static_position = array<i64: 1>}> : (vector<2xi32>) -> i32
    %87 = "arith.cmpi"(%85, %41) <{predicate = 1 : i64}> : (i32, i32) -> i1
    %88 = "arith.subi"(%86, %85) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
    %89 = "arith.constant"() <{value = 1 : i32}> : () -> i32
    %90 = "arith.cmpi"(%88, %89) <{predicate = 4 : i64}> : (i32, i32) -> i1
    "scf.if"(%90) ({
      %94 = "memref.get_global"() <{name = @smem_storage}> : () -> memref<8192xi8, #gpu.address_space<workgroup>>
      %95 = "gpu.thread_id"() <{dimension = #gpu<dim x>}> : () -> index
      %96 = "arith.subi"(%86, %85) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      %97 = "arith.index_cast"(%96) : (i32) -> index
      %98 = "arith.index_cast"(%85) : (i32) -> index
      %99 = "arith.constant"() <{value = 128 : index}> : () -> index
      "scf.for"(%95, %97, %99) ({
      ^bb0(%arg18: index):
        %203 = "arith.cmpi"(%arg18, %97) <{predicate = 6 : i64}> : (index, index) -> i1
        "scf.if"(%203) ({
          %204 = "arith.addi"(%98, %arg18) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
          %205 = "arith.index_cast"(%204) : (index) -> i32
          %206 = "arith.constant"() <{value = 4 : i32}> : () -> i32
          %207 = "arith.muli"(%205, %206) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %208 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %209 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %210 = "rocdl.raw.ptr.buffer.load"(%14, %207, %208, %209) : (!llvm.ptr<8>, i32, i32, i32) -> i32
          %211 = "arith.constant"() <{value = 0 : index}> : () -> index
          %212 = "memref.view"(%94, %211) : (memref<8192xi8, #gpu.address_space<workgroup>>, index) -> memref<2048xi32, #gpu.address_space<workgroup>>
          "memref.store"(%210, %212, %arg18) : (i32, memref<2048xi32, #gpu.address_space<workgroup>>, index) -> ()
          "scf.yield"() : () -> ()
        }, {
        }) : (i1) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "gpu.barrier"() : () -> ()
      %100 = "arith.constant"() <{value = 0 : index}> : () -> index
      %101 = "arith.constant"() <{value = 0 : index}> : () -> index
      %102 = "memref.view"(%94, %101) : (memref<8192xi8, #gpu.address_space<workgroup>>, index) -> memref<2048xi32, #gpu.address_space<workgroup>>
      %103 = "memref.load"(%102, %100) : (memref<2048xi32, #gpu.address_space<workgroup>>, index) -> i32
      %104 = "arith.constant"() <{value = 1 : index}> : () -> index
      %105 = "memref.load"(%102, %104) : (memref<2048xi32, #gpu.address_space<workgroup>>, index) -> i32
      %106 = "arith.constant"() <{value = 2 : i32}> : () -> i32
      %107 = "arith.muli"(%79, %106) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      %108 = "arith.constant"() <{value = 4 : i32}> : () -> i32
      %109 = "arith.muli"(%107, %108) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      %110 = "arith.constant"() <{value = 0 : i32}> : () -> i32
      %111 = "arith.constant"() <{value = 0 : i32}> : () -> i32
      %112 = "rocdl.raw.ptr.buffer.load"(%9, %109, %110, %111) : (!llvm.ptr<8>, i32, i32, i32) -> vector<2xi32>
      %113 = "vector.extract"(%112) <{static_position = array<i64: 0>}> : (vector<2xi32>) -> i32
      %114 = "vector.extract"(%112) <{static_position = array<i64: 1>}> : (vector<2xi32>) -> i32
      %115 = "arith.index_cast"(%95) : (index) -> i32
      %116 = "arith.addi"(%113, %75) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
      %117 = "arith.index_cast"(%116) : (i32) -> index
      %118 = "arith.index_cast"(%114) : (i32) -> index
      %119 = "arith.constant"() <{value = 1 : index}> : () -> index
      %120 = "arith.constant"() <{value = 1 : index}> : () -> index
      "scf.for"(%117, %118, %119) ({
      ^bb0(%arg13: index):
        %121 = "arith.index_cast"(%arg13) : (index) -> i32
        %122 = "arith.subi"(%121, %113) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %123 = "arith.addi"(%103, %122) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %124 = "arith.constant"() <{value = 128 : i32}> : () -> i32
        %125 = "arith.muli"(%123, %124) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %126 = "arith.addi"(%125, %71) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %127 = "arith.constant"() <{value = 512 : i32}> : () -> i32
        %128 = "arith.muli"(%125, %127) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %129 = "arith.constant"() <{value = 512 : i32}> : () -> i32
        %130 = "arith.muli"(%71, %129) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %131 = "arith.addi"(%128, %130) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %132 = "arith.constant"() <{value = 4 : i32}> : () -> i32
        %133 = "arith.muli"(%115, %132) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %134 = "arith.addi"(%131, %133) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %135 = "arith.constant"() <{value = 4 : i32}> : () -> i32
        %136 = "arith.muli"(%134, %135) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %137 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        %138 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        %139 = "rocdl.raw.ptr.buffer.load"(%34, %136, %137, %138) : (!llvm.ptr<8>, i32, i32, i32) -> vector<4xf32>
        %140 = "arith.constant"() <{value = 4 : i32}> : () -> i32
        %141 = "arith.muli"(%126, %140) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %142 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        %143 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        %144 = "rocdl.raw.ptr.buffer.load"(%29, %141, %142, %143) : (!llvm.ptr<8>, i32, i32, i32) -> f32
        %145 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
        %146 = "arith.index_cast"(%85) : (i32) -> index
        %147 = "arith.index_cast"(%86) : (i32) -> index
        %148 = "arith.addi"(%146, %120) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
        %149:3 = "scf.for"(%148, %147, %120, %139, %144, %145) ({
        ^bb0(%arg14: index, %arg15: vector<4xf32>, %arg16: f32, %arg17: f32):
          %165 = "arith.subi"(%arg14, %146) <{overflowFlags = #arith.overflow<none>}> : (index, index) -> index
          %166 = "arith.constant"() <{value = 0 : index}> : () -> index
          %167 = "memref.view"(%94, %166) : (memref<8192xi8, #gpu.address_space<workgroup>>, index) -> memref<2048xi32, #gpu.address_space<workgroup>>
          %168 = "memref.load"(%167, %165) : (memref<2048xi32, #gpu.address_space<workgroup>>, index) -> i32
          %169 = "arith.addi"(%168, %122) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %170 = "arith.constant"() <{value = 128 : i32}> : () -> i32
          %171 = "arith.muli"(%169, %170) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %172 = "arith.addi"(%171, %71) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %173 = "arith.constant"() <{value = 512 : i32}> : () -> i32
          %174 = "arith.muli"(%171, %173) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %175 = "arith.constant"() <{value = 512 : i32}> : () -> i32
          %176 = "arith.muli"(%71, %175) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %177 = "arith.addi"(%174, %176) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %178 = "arith.constant"() <{value = 4 : i32}> : () -> i32
          %179 = "arith.muli"(%115, %178) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %180 = "arith.addi"(%177, %179) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %181 = "arith.constant"() <{value = 4 : i32}> : () -> i32
          %182 = "arith.muli"(%180, %181) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %183 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %184 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %185 = "rocdl.raw.ptr.buffer.load"(%34, %182, %183, %184) : (!llvm.ptr<8>, i32, i32, i32) -> vector<4xf32>
          %186 = "arith.constant"() <{value = 4 : i32}> : () -> i32
          %187 = "arith.muli"(%172, %186) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
          %188 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %189 = "arith.constant"() <{value = 0 : i32}> : () -> i32
          %190 = "rocdl.raw.ptr.buffer.load"(%29, %187, %188, %189) : (!llvm.ptr<8>, i32, i32, i32) -> f32
          %191 = "arith.maximumf"(%arg16, %190) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %192 = "arith.subf"(%arg16, %191) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %193 = "math.exp"(%192) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
          %194 = "arith.subf"(%190, %191) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %195 = "math.exp"(%194) <{fastmath = #arith.fastmath<none>}> : (f32) -> f32
          %196 = "vector.broadcast"(%193) : (f32) -> vector<4xf32>
          %197 = "arith.mulf"(%196, %arg15) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
          %198 = "vector.broadcast"(%195) : (f32) -> vector<4xf32>
          %199 = "arith.mulf"(%198, %185) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
          %200 = "arith.addf"(%197, %199) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
          %201 = "arith.mulf"(%arg17, %193) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          %202 = "arith.addf"(%201, %195) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
          "scf.yield"(%200, %191, %202) : (vector<4xf32>, f32, f32) -> ()
        }) : (index, index, index, vector<4xf32>, f32, f32) -> (vector<4xf32>, f32, f32)
        %150 = "arith.constant"() <{value = 1.000000e+00 : f32}> : () -> f32
        %151 = "arith.divf"(%150, %149#2) <{fastmath = #arith.fastmath<none>}> : (f32, f32) -> f32
        %152 = "vector.broadcast"(%151) : (f32) -> vector<4xf32>
        %153 = "arith.mulf"(%152, %149#0) <{fastmath = #arith.fastmath<none>}> : (vector<4xf32>, vector<4xf32>) -> vector<4xf32>
        %154 = "arith.muli"(%121, %arg7) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %155 = "arith.muli"(%71, %arg8) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %156 = "arith.addi"(%154, %155) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %157 = "arith.constant"() <{value = 4 : i32}> : () -> i32
        %158 = "arith.muli"(%115, %157) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %159 = "arith.addi"(%156, %158) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %160 = "arith.truncf"(%153) : (vector<4xf32>) -> vector<4xbf16>
        %161 = "arith.constant"() <{value = 2 : i32}> : () -> i32
        %162 = "arith.muli"(%159, %161) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
        %163 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        %164 = "arith.constant"() <{value = 0 : i32}> : () -> i32
        "rocdl.raw.ptr.buffer.store"(%160, %24, %162, %163, %164) : (vector<4xbf16>, !llvm.ptr<8>, i32, i32, i32) -> ()
        "scf.yield"() : () -> ()
      }) : (index, index, index) -> ()
      "scf.yield"() : () -> ()
    }, {
    }) : (i1) -> ()
    %91 = "gpu.grid_dim"() <{dimension = #gpu<dim x>}> : () -> index
    %92 = "arith.index_cast"(%91) : (index) -> i32
    %93 = "arith.addi"(%68, %92) <{overflowFlags = #arith.overflow<none>}> : (i32, i32) -> i32
  }) : (i32, i32) -> (i32, i32)
 note: "/jruan/ws/aiter/aiter/ops/flydsl/kernels/attn_reduce.py":360:0: terminator here

While loop code:
https://github.com/ROCm/aiter/blob/11d0b63e1412c6f771d0d9d0225cdc07fb6ecd16/aiter/ops/flydsl/kernels/attn_reduce.py#L439

Operating System

all

CPU

all

GPU

all

ROCm Version

all

ROCm Component

No response

Steps to Reproduce

  1. Checkout aiter to jruan/fdsl_issue_211
  2. python op_tests/test_mla_persistent.py -b 33 -c 2333 -n 128,1 -d fp8 -kvd fp8

(Optional for Linux users) Output of /opt/rocm/bin/rocminfo --support

No response

Additional Information

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions