Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions lightning_attn/ops/triton/lightning_attn2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,22 @@
import triton
import triton.language as tl

import os

if "BLOCK_SIZE" not in os.environ:
os.environ["BLOCK_SIZE"] = "64"
else:
BLOCK_SIZE = int(os.environ["BLOCK_SIZE"])

if "CBLOCK_SIZE" not in os.environ:
os.environ["CBLOCK_SIZE"] = "32"
else:
CBLOCK_SIZE = int(os.environ["CBLOCK_SIZE"])


assert BLOCK_SIZE % 16 == 0, "BLOCK_SIZE must be multiple of 16"
assert CBLOCK_SIZE % 16 == 0, "CBLOCK_SIZE must be multiple of 16"


@triton.jit
def _fwd_kernel(
Expand Down Expand Up @@ -407,7 +423,7 @@ def forward(ctx, q, k, v, s):
e = v.shape[-1]
o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)

BLOCK = 64
BLOCK = BLOCK_SIZE
NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)
# parallel over channel
BLOCK_MODEL = min(triton.next_power_of_2(e), 32)
Expand Down Expand Up @@ -451,10 +467,10 @@ def backward(ctx, do):
e = v.shape[-1]

# block size
BLOCK = 64
BLOCK = BLOCK_SIZE
NUM_BLOCK = triton.cdiv(n, BLOCK)
# compute block size
CBLOCK = 32
CBLOCK = CBLOCK_SIZE
NUM_CBLOCK = BLOCK // CBLOCK

# for intra part, compute in parallel
Expand Down