We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent edf27e1 commit 45b6e68Copy full SHA for 45b6e68
1 file changed
src/flag_gems/experimental_ops/rmsnorm.py
@@ -10,7 +10,7 @@
10
11
12
@triton.jit
13
-def rmsnorm(
+def rmsnorm_kernel(
14
input_ptr, # *Pointer* to the input tensor flattened to 2D [M, N]
15
weight_ptr, # *Pointer* to the weight tensor [N]
16
output_ptr, # *Pointer* to the output tensor flattened to 2D [M, N]
@@ -48,10 +48,6 @@ def rmsnorm(
48
col_start += BLOCK_SIZE
49
50
51
-# Keep a handle to the Triton kernel before defining the Python wrapper with the same name
52
-rmsnorm_kernel = rmsnorm
53
-
54
55
def rmsnorm(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
56
assert input_tensor.is_cuda and weight.is_cuda, "Tensors must be on CUDA device."
57
assert (
0 commit comments