diff --git a/backends/cuda/benchmarks/benchmark_sdpa.py b/backends/cuda/benchmarks/benchmark_sdpa.py index 3c117f4574f..ce11d78ce6d 100644 --- a/backends/cuda/benchmarks/benchmark_sdpa.py +++ b/backends/cuda/benchmarks/benchmark_sdpa.py @@ -21,7 +21,6 @@ import torch import torch.nn.functional as F - from executorch.backends.cuda.triton.kernels.sdpa import ( sdpa as triton_sdpa, sdpa_decode_splitk as triton_splitk, diff --git a/backends/cuda/triton/kernels/sdpa.py b/backends/cuda/triton/kernels/sdpa.py index 4ee169405b5..9f42a474b36 100644 --- a/backends/cuda/triton/kernels/sdpa.py +++ b/backends/cuda/triton/kernels/sdpa.py @@ -1080,7 +1080,6 @@ def _sdpa_decode_splitk_kernel( K_ptr, V_ptr, O_partial_ptr, - M_partial_ptr, L_partial_ptr, Mask_ptr, B, @@ -1102,13 +1101,14 @@ def _sdpa_decode_splitk_kernel( stride_op_b, stride_op_h, stride_op_d, - stride_mp_s, - stride_mp_b, - stride_mp_h, + stride_lp_s, + stride_lp_b, + stride_lp_h, stride_mb, stride_mq, stride_mk, sm_scale: tl.float32, + phi: tl.float32, chunk_size, HAS_MASK: tl.constexpr, BLOCK_N: tl.constexpr, @@ -1138,7 +1138,7 @@ def _sdpa_decode_splitk_kernel( ) q = tl.load(q_ptrs, mask=g_valid[:, None], other=0.0).to(tl.bfloat16) - m_i = tl.full([BLOCK_G], -float("inf"), dtype=tl.float32) + # FlashDecoding++ async softmax: use unified max phi instead of tracking m_i l_i = tl.zeros([BLOCK_G], dtype=tl.float32) acc = tl.zeros([BLOCK_G, HEAD_DIM], dtype=tl.float32) @@ -1175,15 +1175,10 @@ def _sdpa_decode_splitk_kernel( mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32) ) - # Online softmax update - m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32)) - safe_diff = tl.where( - m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf") - ) + # FlashDecoding++ async softmax: subtract unified phi instead of local max + safe_diff = tl.where(qk > -float("inf"), qk - phi, -float("inf")) p_f32 = tl.exp(safe_diff).to(tl.float32) l_ij = tl.sum(p_f32, axis=1).to(tl.float32) - safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0) - alpha = tl.exp(safe_alpha_diff).to(tl.float32) v_ptrs = V_ptr + ( b * stride_vb @@ -1194,9 +1189,8 @@ def _sdpa_decode_splitk_kernel( v = tl.load(v_ptrs, mask=n_valid[:, None], other=0.0).to(tl.bfloat16) p_bf16 = p_f32.to(tl.bfloat16) - acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32) - l_i = (l_i * alpha + l_ij).to(tl.float32) - m_i = m_ij + acc = (acc + tl.dot(p_bf16, v)).to(tl.float32) + l_i = (l_i + l_ij).to(tl.float32) # Store partial results for valid groups only h_q_all = h_kv * NUM_GROUPS + offs_g # [BLOCK_G] @@ -1208,13 +1202,8 @@ def _sdpa_decode_splitk_kernel( ) tl.store(o_ptrs, acc, mask=g_valid[:, None]) - ml_ptrs = M_partial_ptr + ( - split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h - ) - tl.store(ml_ptrs, m_i, mask=g_valid) - ll_ptrs = L_partial_ptr + ( - split_id * stride_mp_s + b * stride_mp_b + h_q_all * stride_mp_h + split_id * stride_lp_s + b * stride_lp_b + h_q_all * stride_lp_h ) tl.store(ll_ptrs, l_i, mask=g_valid) @@ -1222,7 +1211,6 @@ def _sdpa_decode_splitk_kernel( @triton.jit def _sdpa_decode_reduce_kernel( O_partial_ptr, - M_partial_ptr, L_partial_ptr, O_ptr, num_splits, @@ -1230,9 +1218,9 @@ def _sdpa_decode_reduce_kernel( stride_op_b, stride_op_h, stride_op_d, - stride_mp_s, - stride_mp_b, - stride_mp_h, + stride_lp_s, + stride_lp_b, + stride_lp_h, stride_ob, stride_oh, stride_om, @@ -1242,40 +1230,25 @@ def _sdpa_decode_reduce_kernel( pid = tl.program_id(axis=0) offs_d = tl.arange(0, HEAD_DIM) - # pid indexes into flattened (B, H_q). Partial buffers are allocated - # contiguous in _launch_decode_splitk, so pid * stride_*_h is valid. - # Find global max across all splits - m_global = tl.full([1], -float("inf"), dtype=tl.float32) - for s in tl.range(0, num_splits): - m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h - m_s = tl.load(m_ptr) - m_global = tl.maximum(m_global, m_s) - - # Accumulate rescaled outputs + # FlashDecoding++ async softmax: no rescaling needed, just sum partials acc = tl.zeros([HEAD_DIM], dtype=tl.float32) l_global = tl.zeros([1], dtype=tl.float32) + for s in tl.range(0, num_splits): - m_ptr = M_partial_ptr + s * stride_mp_s + pid * stride_mp_h - l_ptr = L_partial_ptr + s * stride_mp_s + pid * stride_mp_h + l_ptr = L_partial_ptr + s * stride_lp_s + pid * stride_lp_h o_ptrs = O_partial_ptr + ( s * stride_op_s + pid * stride_op_h + offs_d * stride_op_d ) - m_s = tl.load(m_ptr) l_s = tl.load(l_ptr) o_s = tl.load(o_ptrs) - safe_diff = tl.where(m_global > -float("inf"), m_s - m_global, 0.0) - scale = tl.exp(safe_diff).to(tl.float32) - acc += o_s * scale - l_global += l_s * scale + acc += o_s + l_global += l_s inv_l = tl.where(l_global > 0, 1.0 / l_global, 0.0) acc = acc * inv_l - # pid = b*H_q + h_q. For contiguous output [B, H_q, 1, D] with L_q=1, - # stride_ob == H_q * stride_oh, so pid * stride_oh is correct. - # This relies on `out` being freshly allocated and contiguous. o_out_ptrs = O_ptr + pid * stride_oh + offs_d * stride_od tl.store(o_out_ptrs, acc.to(tl.bfloat16)) @@ -1297,6 +1270,7 @@ def _launch_decode_splitk( stride_mq: int, stride_mk: int, num_groups: int, + phi: float, ) -> None: num_splits = min(max(triton.cdiv(L_kv, 256), 1), 128) chunk_size = triton.cdiv(L_kv, num_splits) @@ -1304,9 +1278,6 @@ def _launch_decode_splitk( O_partial = torch.empty( (num_splits, B, H_q, D), device=query.device, dtype=torch.float32 ) - M_partial = torch.full( - (num_splits, B, H_q), -float("inf"), device=query.device, dtype=torch.float32 - ) L_partial = torch.zeros( (num_splits, B, H_q), device=query.device, dtype=torch.float32 ) @@ -1316,7 +1287,7 @@ def _launch_decode_splitk( stride_vb, stride_vh, stride_vn, stride_vd = value.stride() stride_ob, stride_oh, stride_om, stride_od = out.stride() stride_op_s, stride_op_b, stride_op_h, stride_op_d = O_partial.stride() - stride_mp_s, stride_mp_b, stride_mp_h = M_partial.stride() + stride_lp_s, stride_lp_b, stride_lp_h = L_partial.stride() grid_split = (num_splits, B * H_kv) wrap_triton(_sdpa_decode_splitk_kernel)[grid_split]( @@ -1324,7 +1295,6 @@ def _launch_decode_splitk( key, value, O_partial, - M_partial, L_partial, Mask_ptr if HAS_MASK else 0, B, @@ -1346,13 +1316,14 @@ def _launch_decode_splitk( stride_op_b, stride_op_h, stride_op_d, - stride_mp_s, - stride_mp_b, - stride_mp_h, + stride_lp_s, + stride_lp_b, + stride_lp_h, stride_mb, stride_mq, stride_mk, sm_scale, + phi, chunk_size, HAS_MASK=HAS_MASK, HEAD_DIM=D, @@ -1363,7 +1334,6 @@ def _launch_decode_splitk( grid_reduce = (B * H_q,) wrap_triton(_sdpa_decode_reduce_kernel)[grid_reduce]( O_partial, - M_partial, L_partial, out, num_splits, @@ -1371,9 +1341,9 @@ def _launch_decode_splitk( stride_op_b, stride_op_h, stride_op_d, - stride_mp_s, - stride_mp_b, - stride_mp_h, + stride_lp_s, + stride_lp_b, + stride_lp_h, stride_ob, stride_oh, stride_om, @@ -1394,9 +1364,13 @@ def sdpa_decode_splitk( is_causal: bool = False, scale: float = 0.0, enable_gqa: bool = False, + phi: float = 5.0, ) -> torch.Tensor: """Split-K flash-decoding SDPA for L_q=1 (decode step). + Uses FlashDecoding++ async softmax with unified maximum value (phi) + to eliminate per-split max tracking and cross-split rescaling. + Signature mirrors sdpa() for drop-in use with torch.cond dispatch. enable_gqa is accepted but ignored — GQA is handled natively via H_q // H_kv grouping; no packed-GQA tradeoff exists at L_q=1. @@ -1452,6 +1426,7 @@ def sdpa_decode_splitk( stride_mq, stride_mk, num_groups, + phi, ) return out @@ -1466,6 +1441,7 @@ def _sdpa_decode_splitk_abstract( is_causal: bool = False, scale: float = 0.0, enable_gqa: bool = False, + phi: float = 5.0, ) -> torch.Tensor: assert query.dtype == key.dtype == value.dtype, "Q, K, V must have the same dtype" B, H_q, L_q, D = query.shape diff --git a/examples/models/qwen3_5_moe/model.py b/examples/models/qwen3_5_moe/model.py index a801c2908bd..fd0f7060a1b 100644 --- a/examples/models/qwen3_5_moe/model.py +++ b/examples/models/qwen3_5_moe/model.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn - from torch.nn import functional as F