Skip to content

Commit 6011775

Browse files
committed
[Test] Wrap softmax module
1 parent 3616431 commit 6011775

1 file changed

Lines changed: 11 additions & 2 deletions

File tree

tests/test_softmax.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,17 @@ def test_softmax(device, size=(128, 128), dim=1):
4242
#cpu_y = softmax3(x2, cpu_max, cpu_sum)
4343
#test_result("Softmax", y, cpu_y)
4444

45-
opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax)
46-
y = opt_fn(x1, dim=dim)
45+
class SoftmaxModule(torch.nn.Module):
46+
def __init__(self, dim):
47+
super().__init__()
48+
self.dim = dim
49+
50+
def forward(self, x):
51+
return torch.nn.functional.softmax(x, dim=self.dim)
52+
53+
softmax_module = SoftmaxModule(dim=dim).to(device)
54+
opt_fn = torch.compile(dynamic=False)(softmax_module)
55+
y = opt_fn(x1)
4756
cpu_y = torch.nn.functional.softmax(x2, dim=dim)
4857
test_result("Softmax", y, cpu_y)
4958

0 commit comments

Comments
 (0)