Skip to content

Commit 45b6e68

Browse files
authored
Tweak rmsnorm operator regarding kernel name (flagos-ai#2056)
1 parent edf27e1 commit 45b6e68

1 file changed

Lines changed: 1 addition & 5 deletions

File tree

src/flag_gems/experimental_ops/rmsnorm.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
@triton.jit
13-
def rmsnorm(
13+
def rmsnorm_kernel(
1414
input_ptr, # *Pointer* to the input tensor flattened to 2D [M, N]
1515
weight_ptr, # *Pointer* to the weight tensor [N]
1616
output_ptr, # *Pointer* to the output tensor flattened to 2D [M, N]
@@ -48,10 +48,6 @@ def rmsnorm(
4848
col_start += BLOCK_SIZE
4949

5050

51-
# Keep a handle to the Triton kernel before defining the Python wrapper with the same name
52-
rmsnorm_kernel = rmsnorm
53-
54-
5551
def rmsnorm(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
5652
assert input_tensor.is_cuda and weight.is_cuda, "Tensors must be on CUDA device."
5753
assert (

0 commit comments

Comments
 (0)