Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 44 additions & 20 deletions finegrained-fp8/tests/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def id(self):


PROBLEMS = [
# ── Small problems (correctness only, no speedup expectations) ──
Problem(
S=8,
E=4,
Expand All @@ -56,7 +57,6 @@ def id(self):
TOP_K=1,
scale_layout="block",
block_size=(128, 128),
expectation=Expectations(batched_ms=0.0346, grouped_ms=0.1291),
),
Problem(
S=32,
Expand All @@ -75,43 +75,66 @@ def id(self):
TOP_K=2,
scale_layout="block",
block_size=(128, 128),
expectation=Expectations(batched_ms=0.0344, grouped_ms=0.1283),
),
Problem(
S=128,
E=16,
N=1024,
K=2048,
TOP_K=2,
scale_layout="per_tensor_1d",
block_size=None,
),
Problem(
S=64,
E=8,
N=512,
K=1024,
TOP_K=4,
scale_layout="per_tensor_111",
block_size=None,
),
# ── Qwen3-30B-A3B (E=128, H=2048, I=768, top_k=8) ──
# gate_up: N=1536, K=2048 — down: N=2048, K=768
Problem(
S=256,
E=128,
N=1536,
K=2048,
TOP_K=8,
scale_layout="block",
block_size=(128, 128),
expectation=Expectations(batched_ms=0.0348, grouped_ms=0.1272),
expectation=Expectations(batched_ms=0.1641, grouped_ms=0.1596),
),
Problem(
S=128,
E=16,
N=1024,
K=2048,
TOP_K=2,
S=256,
E=128,
N=2048,
K=768,
TOP_K=8,
scale_layout="block",
block_size=(128, 128),
expectation=Expectations(batched_ms=0.0558, grouped_ms=0.1306),
expectation=Expectations(batched_ms=0.0956, grouped_ms=0.1582),
),
Problem(
S=128,
E=16,
N=1024,
S=1024,
E=128,
N=1536,
K=2048,
TOP_K=2,
scale_layout="per_tensor_1d",
TOP_K=8,
scale_layout="block",
block_size=(128, 128),
expectation=Expectations(batched_ms=0.5731, grouped_ms=0.1904),
),
Problem(
S=64,
E=8,
N=512,
K=1024,
TOP_K=4,
scale_layout="per_tensor_111",
S=1024,
E=128,
N=2048,
K=768,
TOP_K=8,
scale_layout="block",
block_size=(128, 128),
expectation=Expectations(batched_ms=0.3151, grouped_ms=0.1571),
),
]

Expand Down Expand Up @@ -347,6 +370,7 @@ def fn(A_sorted, B_fp8, Bs, offsets, tokens_per_expert):
offsets,
tokens_per_expert,
COMPILE_PROBLEM.block_size,
allow_sync=False,
)

compiled = torch.compile(fn, mode="max-autotune", fullgraph=True)
Expand Down
16 changes: 16 additions & 0 deletions finegrained-fp8/torch-ext/finegrained_fp8/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from torch.library import triton_op, wrap_triton


@triton.autotune(
configs=[
triton.Config({}, num_warps=w, num_stages=s)
for w in [2, 4, 8, 16]
for s in [2, 3, 4, 5]
],
key=["N", "K"],
)
@triton.jit
def w8a8_block_fp8_matmul_batched_kernel(
A, # (S, K) raw BF16/FP16 activations
Expand Down Expand Up @@ -103,6 +111,14 @@ def w8a8_block_fp8_matmul_batched_kernel(
tl.store(c_ptrs, c)


@triton.autotune(
configs=[
triton.Config({}, num_warps=w, num_stages=s)
for w in [2, 4, 8, 16]
for s in [2, 3, 4, 5]
],
key=["N", "K"],
)
@triton.jit
def w8a8_tensor_fp8_matmul_batched_kernel(
A, # (S, K) pre-quantized FP8 activations
Expand Down
65 changes: 51 additions & 14 deletions finegrained-fp8/torch-ext/finegrained_fp8/grouped.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
from torch.library import triton_op, wrap_triton


@triton.autotune(
configs=[
triton.Config({}, num_warps=w, num_stages=s)
for w in [2, 4, 8, 16]
for s in [2, 3, 4, 5]
],
key=["N", "K"],
)
@triton.jit
def w8a8_block_fp8_matmul_grouped_kernel(
A, # (S, K) raw BF16/FP16 activations, sorted/grouped by expert id
Expand Down Expand Up @@ -118,9 +126,7 @@ def w8a8_block_fp8_matmul_grouped_kernel(
a = (a_raw / tl.maximum(a_s[:, None], 1e-12)).to(tl.float8e4nv)
# ---- same as baseline from here ----
b = tl.load(b_ptrs)
k_start = k * block_k
offs_ks = k_start // block_k
b_s = tl.load(Bs_ptrs + offs_ks * stride_Bsk)
b_s = tl.load(Bs_ptrs + k * stride_Bsk)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += block_k * stride_ak
b_ptrs += block_k * stride_bk
Expand All @@ -137,6 +143,14 @@ def w8a8_block_fp8_matmul_grouped_kernel(
tl.store(c_ptrs, c, mask=c_mask)


@triton.autotune(
configs=[
triton.Config({}, num_warps=w, num_stages=s)
for w in [2, 4, 8, 16]
for s in [2, 3, 4, 5]
],
key=["N", "K"],
)
@triton.jit
def w8a8_tensor_fp8_matmul_grouped_kernel(
A, # (S, K) pre-quantized FP8 activations
Expand All @@ -162,6 +176,7 @@ def w8a8_tensor_fp8_matmul_grouped_kernel(
block_k: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
NUM_EXPERTS: tl.constexpr,
NUM_EXPERTS_BIT_LENGTH: tl.constexpr,
):
"""Tensor-scale grouped FP8 expert matmul kernel.

Expand All @@ -177,7 +192,7 @@ def w8a8_tensor_fp8_matmul_grouped_kernel(

lo = 0
hi = NUM_EXPERTS
for _ in tl.static_range(NUM_EXPERTS.bit_length()):
for _ in tl.static_range(NUM_EXPERTS_BIT_LENGTH):
mid = (lo + hi) >> 1
mid_val = tl.load(TileOffsets + mid)
is_left = mid_val <= pid_m
Expand Down Expand Up @@ -241,6 +256,7 @@ def _w8a8_block_fp8_matmul_grouped(
offsets: torch.Tensor,
tokens_per_expert: torch.Tensor,
block_size: list[int] | None,
allow_sync: bool = False,
) -> torch.Tensor:
"""Internal block-scale grouped FP8 matmul op.

Expand Down Expand Up @@ -299,11 +315,15 @@ def _w8a8_block_fp8_matmul_grouped(
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
# Using a static upper bound keeps the grid size data-independent, which is
# required for cuda-graph compatibility. Programs beyond the real tile count
# exit immediately via the early-return guard inside the kernel.
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
if allow_sync:
# Exact tile count via CPU/GPU sync — no wasted programs.
max_M_tiles = int(tile_offsets[-1].item())
else:
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
# Using a static upper bound keeps the grid size data-independent, which is
# required for cuda-graph compatibility. Programs beyond the real tile count
# exit immediately via the early-return guard inside the kernel.
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E

grid = (max_M_tiles, triton.cdiv(N, block_n))
wrap_triton(w8a8_block_fp8_matmul_grouped_kernel)[grid](
Expand Down Expand Up @@ -345,6 +365,7 @@ def _w8a8_tensor_fp8_matmul_grouped(
offsets: torch.Tensor,
tokens_per_expert: torch.Tensor,
block_size: list[int] | None,
allow_sync: bool = False,
) -> torch.Tensor:
"""Tensor-scale grouped FP8 matmul for sorted routed experts.

Expand Down Expand Up @@ -387,7 +408,15 @@ def _w8a8_tensor_fp8_matmul_grouped(
BLOCK_SIZE_M = min(max(triton.next_power_of_2((S + E - 1) // E), 16), 128)
tiles_per_expert = (tokens_per_expert + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
tile_offsets = torch.cumsum(tiles_per_expert, dim=0).to(torch.int32)
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E
if allow_sync:
# Exact tile count via CPU/GPU sync — no wasted programs.
max_M_tiles = int(tile_offsets[-1].item())
else:
# Upper bound on M-tiles: sum_e ceil(M_e / BLOCK_M) <= ceil(S / BLOCK_M) + E.
# Using a static upper bound keeps the grid size data-independent, which is
# required for cuda-graph compatibility. Programs beyond the real tile count
# exit immediately via the early-return guard inside the kernel.
max_M_tiles = triton.cdiv(S, BLOCK_SIZE_M) + E

qA, As = fp8_act_quant(A, K)
grid = (max_M_tiles, triton.cdiv(N, block_n))
Expand Down Expand Up @@ -415,6 +444,7 @@ def _w8a8_tensor_fp8_matmul_grouped(
block_k=block_k,
BLOCK_SIZE_M=BLOCK_SIZE_M,
NUM_EXPERTS=E,
NUM_EXPERTS_BIT_LENGTH=E.bit_length(),
)

return C
Expand All @@ -427,6 +457,7 @@ def w8a8_block_fp8_matmul_grouped(
offsets: torch.Tensor,
tokens_per_expert: torch.Tensor,
block_size: list[int] | None,
allow_sync: bool = False,
) -> torch.Tensor:
"""Grouped W8A8 FP8 matmul for MoE expert dispatch with fused activation quantization.

Expand All @@ -443,12 +474,15 @@ def w8a8_block_fp8_matmul_grouped(
offsets: Cumulative token counts per expert ``[E]`` (i.e. ``cumsum(tokens_per_expert)``).
tokens_per_expert: Number of tokens routed to each expert ``[E]``.
block_size: ``[block_n, block_k]`` quantization block dimensions, e.g. ``[128, 128]``.
allow_sync: If True (default), read back the exact tile count from the GPU
(avoids wasted programs). If False, use a data-independent upper bound
(required for CUDA-graph / torch.compile compatibility).

Returns:
Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
"""
return torch.ops.finegrained_fp8.w8a8_block_fp8_matmul_grouped(
A, B, Bs, offsets, tokens_per_expert, block_size
A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
)


Expand All @@ -459,6 +493,7 @@ def w8a8_tensor_fp8_matmul_grouped(
offsets: torch.Tensor,
tokens_per_expert: torch.Tensor,
block_size: list[int] | None,
allow_sync: bool = False,
) -> torch.Tensor:
"""Tensor-scale grouped W8A8 FP8 matmul for MoE expert dispatch.

Expand All @@ -469,12 +504,13 @@ def w8a8_tensor_fp8_matmul_grouped(
offsets: Cumulative token counts per expert ``[E]``.
tokens_per_expert: Number of tokens routed to each expert ``[E]``.
block_size: Kept for API consistency; tensor path derives tile sizes from ``N`` and ``K``.
allow_sync: If True, sync for exact grid; if False, use upper bound.

Returns:
Output tensor ``[S, N]`` in the same dtype as ``A``, in expert-sorted order.
"""
return torch.ops.finegrained_fp8.w8a8_tensor_fp8_matmul_grouped(
A, B, Bs, offsets, tokens_per_expert, block_size
A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
)


Expand All @@ -485,6 +521,7 @@ def w8a8_fp8_matmul_grouped(
offsets: torch.Tensor,
tokens_per_expert: torch.Tensor,
block_size: list[int] | None,
allow_sync: bool = False,
) -> torch.Tensor:
"""Unified grouped W8A8 FP8 matmul dispatcher.

Expand All @@ -500,9 +537,9 @@ def w8a8_fp8_matmul_grouped(
block_size[0] == B.size(1) and block_size[1] == B.size(2)
):
return w8a8_tensor_fp8_matmul_grouped(
A, B, Bs, offsets, tokens_per_expert, block_size
A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
)

return w8a8_block_fp8_matmul_grouped(
A, B, Bs, offsets, tokens_per_expert, block_size
A, B, Bs, offsets, tokens_per_expert, block_size, allow_sync
)
Loading