Thank you for the nice implementation! It seems that dim=192 is not in supports_dim. Why is it the case here? Could you add dim=192?
import torch
from lightning_attn.ops import lightning_attn_func
from lightning_attn.utils import _build_slope_tensor
dtype = torch.bfloat16
device = torch.device("cuda")
b, h, n, d, e = 2, 12, 2048, 192, 192
q = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
k = torch.randn((b, h, n, d), dtype=dtype, device=device).requires_grad_()
v = torch.randn((b, h, n, e), dtype=dtype, device=device).requires_grad_()
s = _build_slope_tensor(h).to(q.device).to(torch.float32)
o = lightning_attn_func(q, k, v, s)
print(o.shape)
o = lightning_attn_func(q, k, v, s)
File "/opt/tiger/mariana/lightning-attention-main/lightning_attn/ops/lightning_attn_interface.py", line 10, in lightning_attn_func
assert d in supports_dim and e in supports_dim
AssertionError
Thank you for the nice implementation! It seems that dim=192 is not in supports_dim. Why is it the case here? Could you add dim=192?
I tried this script
and got this error