Skip to content
1 change: 0 additions & 1 deletion backends/cuda/benchmarks/benchmark_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import torch
import torch.nn.functional as F

from executorch.backends.cuda.triton.kernels.sdpa import (
sdpa as triton_sdpa,
sdpa_decode_splitk as triton_splitk,
Expand Down
90 changes: 33 additions & 57 deletions backends/cuda/triton/kernels/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,7 +1080,6 @@ def _sdpa_decode_splitk_kernel(
K_ptr,
V_ptr,
O_partial_ptr,
M_partial_ptr,
L_partial_ptr,
Mask_ptr,
B,
Expand All @@ -1102,13 +1101,14 @@ def _sdpa_decode_splitk_kernel(
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_mb,
stride_mq,
stride_mk,
sm_scale: tl.float32,
phi: tl.float32,
chunk_size,
HAS_MASK: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -1138,7 +1138,7 @@ def _sdpa_decode_splitk_kernel(
)
q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0).to(tl.bfloat16)

m_i = tl.full([BLOCK_G], -float("inf"), dtype=tl.float32)
# FlashDecoding++ async softmax: use unified max phi instead of tracking m_i
l_i = tl.zeros([BLOCK_G], dtype=tl.float32)
acc = tl.zeros([BLOCK_G, HEAD_DIM], dtype=tl.float32)

Expand Down Expand Up @@ -1175,15 +1175,10 @@ def _sdpa_decode_splitk_kernel(
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
)

# Online softmax update
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
safe_diff = tl.where(
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
)
# FlashDecoding++ async softmax: subtract unified phi instead of local max
safe_diff = tl.where(qk > -float("inf"), qk - phi, -float("inf"))
p_f32 = tl.exp(safe_diff).to(tl.float32)
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
alpha = tl.exp(safe_alpha_diff).to(tl.float32)

v_ptrs = V_ptr + (
b * stride_vb
Expand All @@ -1194,9 +1189,8 @@ def _sdpa_decode_splitk_kernel(
v = tl.load(v_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16)

p_bf16 = p_f32.to(tl.bfloat16)
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i * alpha + l_ij).to(tl.float32)
m_i = m_ij
acc = (acc + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i + l_ij).to(tl.float32)

# Store partial results for valid groups only
h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G]
Expand All @@ -1208,31 +1202,25 @@ def _sdpa_decode_splitk_kernel(
)
tl.store(o_ptrs, acc, mask=g_valid[:, None])

ml_ptrs = M_partial_ptr + (
split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
)
tl.store(ml_ptrs, m_i, mask=g_valid)

ll_ptrs = L_partial_ptr + (
split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h
split_id * stride_lp_s + b * stride_lp_b + h_q_all * stride_lp_h
)
tl.store(ll_ptrs, l_i, mask=g_valid)


@triton.jit
def _sdpa_decode_reduce_kernel(
O_partial_ptr,
M_partial_ptr,
L_partial_ptr,
O_ptr,
num_splits,
stride_op_s,
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_ob,
stride_oh,
stride_om,
Expand All @@ -1242,40 +1230,25 @@ def _sdpa_decode_reduce_kernel(
pid = tl.program_id(axis=0)
offs_d = tl.arange(0, HEAD_DIM)

# pid indexes into flattened (B, H_q). Partial buffers are allocated
# contiguous in _launch_decode_splitk, so pid * stride_*_h is valid.
# Find global max across all splits
m_global = tl.full([1], -float("inf"), dtype=tl.float32)
for s in tl.range(0, num_splits):
m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
m_s = tl.load(m_ptr)
m_global = tl.maximum(m_global, m_s)

# Accumulate rescaled outputs
# FlashDecoding++ async softmax: no rescaling needed, just sum partials
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
l_global = tl.zeros([1], dtype=tl.float32)

for s in tl.range(0, num_splits):
m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h
l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h
l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h
o_ptrs = O_partial_ptr + (
s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d
)

m_s = tl.load(m_ptr)
l_s = tl.load(l_ptr)
o_s = tl.load(o_ptrs)

safe_diff = tl.where(m_global > -float("inf"), m_s - m_global, 0.0)
scale = tl.exp(safe_diff).to(tl.float32)
acc += o_s * scale
l_global += l_s * scale
acc += o_s
l_global += l_s

inv_l = tl.where(l_global > 0, 1.0 / l_global, 0.0)
acc = acc * inv_l

# pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1,
# stride_ob == H_q * stride_oh, so pid * stride_oh is correct.
# This relies on `out` being freshly allocated and contiguous.
o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od
tl.store(o_out_ptrs, acc.to(tl.bfloat16))

Expand All @@ -1297,16 +1270,14 @@ def _launch_decode_splitk(
stride_mq: int,
stride_mk: int,
num_groups: int,
phi: float,
) -> None:
num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128)
chunk_size = triton.cdiv(L_kv, num_splits)

O_partial = torch.empty(
(num_splits, B, H_q, D), device=query.device, dtype=torch.float32
)
M_partial = torch.full(
(num_splits, B, H_q), -float("inf"), device=query.device, dtype=torch.float32
)
L_partial = torch.zeros(
(num_splits, B, H_q), device=query.device, dtype=torch.float32
)
Expand All @@ -1316,15 +1287,14 @@ def _launch_decode_splitk(
stride_vb, stride_vh, stride_vn, stride_vd = value.stride()
stride_ob, stride_oh, stride_om, stride_od = out.stride()
stride_op_s, stride_op_b, stride_op_h, stride_op_d = O_partial.stride()
stride_mp_s, stride_mp_b, stride_mp_h = M_partial.stride()
stride_lp_s, stride_lp_b, stride_lp_h = L_partial.stride()

grid_split = (num_splits, B * H_kv)
wrap_triton(_sdpa_decode_splitk_kernel)[grid_split](
query,
key,
value,
O_partial,
M_partial,
L_partial,
Mask_ptr if HAS_MASK else 0,
B,
Expand All @@ -1346,13 +1316,14 @@ def _launch_decode_splitk(
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_mb,
stride_mq,
stride_mk,
sm_scale,
phi,
chunk_size,
HAS_MASK=HAS_MASK,
HEAD_DIM=D,
Expand All @@ -1363,17 +1334,16 @@ def _launch_decode_splitk(
grid_reduce = (B * H_q,)
wrap_triton(_sdpa_decode_reduce_kernel)[grid_reduce](
O_partial,
M_partial,
L_partial,
out,
num_splits,
stride_op_s,
stride_op_b,
stride_op_h,
stride_op_d,
stride_mp_s,
stride_mp_b,
stride_mp_h,
stride_lp_s,
stride_lp_b,
stride_lp_h,
stride_ob,
stride_oh,
stride_om,
Expand All @@ -1394,9 +1364,13 @@ def sdpa_decode_splitk(
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
phi: float = 5.0,
) -> torch.Tensor:
"""Split-K flash-decoding SDPA for L_q=1 (decode step).

Uses FlashDecoding++ async softmax with unified maximum value (phi)
to eliminate per-split max tracking and cross-split rescaling.

Signature mirrors sdpa() for drop-in use with torch.cond dispatch.
enable_gqa is accepted but ignored — GQA is handled natively via
H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1.
Expand Down Expand Up @@ -1452,6 +1426,7 @@ def sdpa_decode_splitk(
stride_mq,
stride_mk,
num_groups,
phi,
)
return out

Expand All @@ -1466,6 +1441,7 @@ def _sdpa_decode_splitk_abstract(
is_causal: bool = False,
scale: float = 0.0,
enable_gqa: bool = False,
phi: float = 5.0,
) -> torch.Tensor:
assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype"
B, H_q, L_q, D = query.shape
Expand Down
1 change: 0 additions & 1 deletion examples/models/qwen3_5_moe/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import torch
import torch.nn as nn

from torch.nn import functional as F


Expand Down
Loading