-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsoftmax_naive.py
More file actions
74 lines (61 loc) · 1.99 KB
/
softmax_naive.py
File metadata and controls
74 lines (61 loc) · 1.99 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import triton
from triton import language as tl
import torch
@triton.jit
def softmax_kernel(
output_ptr,
input_ptr,
M,
N,
TILE_N: tl.constexpr,
):
pid_m = tl.program_id(0)
n_offsets = tl.arange(0, TILE_N)
offset = pid_m * N + n_offsets
input_ptrs = input_ptr + offset
mask = n_offsets < N
inp = tl.load(input_ptrs, mask=mask, other=-float("inf")).to(output_ptr.dtype.element_ty)
m = tl.max(inp, 0)
e = tl.exp(inp - m)
z = tl.sum(e, 0)
out = e / z
output_ptrs = output_ptr + offset
tl.store(output_ptrs, out, mask=mask)
def softmax(x):
M, N = x.shape
out = torch.empty_like(x)
TILE_N = triton.next_power_of_2(N)
grid = (M, 1, 1)
softmax_kernel[grid](out, x, M, N, TILE_N)
return out
import pytest
@pytest.mark.parametrize("n", [512, 1024, 32 * 1024, 128 * 1024])
@pytest.mark.parametrize("m", [10, 128, 1024])
def test_softmax(m, n):
x = torch.randn((m, n), device="cuda")
hyp = softmax(x)
ref = torch.softmax(x, dim=-1)
torch.testing.assert_close(hyp, ref)
def benchmark_softmax(m, n):
x = torch.randn((m, n), device="cuda")
t1 = triton.testing.do_bench(lambda: softmax(x), return_mode="median")
t2 = triton.testing.do_bench(lambda: torch.softmax(x, dim=-1), return_mode="median")
def throughput(t):
return x.numel() * x.element_size() * 2 * 1e-9 / (t * 1e-3)
return throughput(t1), throughput(t2)
import pandas as pd
def run_benchmark():
records = []
for m in [10, 128, 1024, 4096]:
for n in [512, 1024, 2048, 4096, 8192, 16 * 1024, 32* 1024, 64* 1024, 128 * 1024]:
t1, t2 = benchmark_softmax(m, n)
record = (m, n, t1, t2)
records.append(record)
df = pd.DataFrame.from_records(records, columns=["pre_size", "reduce_size", "naive", "torch"])
print(df)
df.to_excel("naive.xlsx")
def run_an_example(m, n):
x = torch.randn((m, n), device="cuda")
y = softmax(x)
if __name__ == "__main__":
run_an_example(4096, 4 * 1024)