Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions examples/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# %%
from __future__ import annotations

import functools
import os
from typing import Callable

Expand Down Expand Up @@ -62,22 +63,24 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

# %%
def reference_fp8_gemm_pytorch(
x_fp8: torch.Tensor, y_fp8: torch.Tensor
x_fp8: torch.Tensor,
y_fp8: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
) -> torch.Tensor:
"""
Reference implementation using torch._scaled_mm.
Args:
x_fp8 (torch.Tensor): Input tensor in FP8 format.
y_fp8 (torch.Tensor): Input tensor in FP8 format.
scale_a (torch.Tensor): Scale factor for x_fp8.
scale_b (torch.Tensor): Scale factor for y_fp8.
Returns:
torch.Tensor: Output tensor in half-precision format.
"""
# torch._scaled_mm requires column-major for second operand
y_fp8_t = y_fp8.T.contiguous().T
scale_a = torch.tensor(1.0, device=x_fp8.device)
scale_b = torch.tensor(1.0, device=x_fp8.device)
return torch._scaled_mm(
x_fp8, y_fp8_t, scale_a, scale_b, use_fast_accum=False, out_dtype=HALF_DTYPE
x_fp8, y_fp8, scale_a, scale_b, use_fast_accum=False, out_dtype=HALF_DTYPE
)


Expand Down Expand Up @@ -117,8 +120,16 @@ def check(m: int, k: int, n: int) -> None:
y = torch.randn([k, n], device=DEVICE, dtype=torch.float32)
# Convert to FP8 format (e4m3fn is commonly used for forward pass)
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn)
run_example(fp8_gemm, reference_fp8_gemm_pytorch, (x_fp8, y_fp8))
y_fp8 = y.to(torch.float8_e4m3fn).T.contiguous().T
Copy link
Copy Markdown
Contributor

@jansel jansel Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if you remove this line? This patch is changing the layout used by the non-reference versions as well, so it is not apples-to-apples with the prior version.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing this line will cause torch._scaled_mm fail. The kernel requires matrix B to be column major.

I think the transpose call should usually being fused with proceeding ops in practice? But I can check how it looks like in vllm

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are two different kernels:

  1. fp8 gemm with both args contiguous
  2. fp8 gemm with second arg transposed

The issue is eager mode has a kernel for 2, but no kernel for 1. If we are measuring 1, then you don't get to pre-compute anything -- it requires two kernels to do in eager and we should measure both.


scale_a = torch.tensor(1.0, device=x_fp8.device)
scale_b = torch.tensor(1.0, device=x_fp8.device)

run_example(
fp8_gemm,
functools.partial(reference_fp8_gemm_pytorch, scale_a=scale_a, scale_b=scale_b),
(x_fp8, y_fp8),
)


# %%
Expand Down
6 changes: 4 additions & 2 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,13 +301,15 @@ def test_fp8_gemm(self):

# Convert to FP8 format
x_fp8 = x.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn)
y_fp8 = y.to(torch.float8_e4m3fn).T.contiguous().T

args = (x_fp8, y_fp8)

# Import the reference implementation
mod = import_path(EXAMPLES_DIR / "fp8_gemm.py")
expected = mod.reference_fp8_gemm_pytorch(x_fp8, y_fp8)
scale_a = torch.tensor(1.0, device=DEVICE)
scale_b = torch.tensor(1.0, device=DEVICE)
expected = mod.reference_fp8_gemm_pytorch(x_fp8, y_fp8, scale_a, scale_b)

check_example(
"fp8_gemm",
Expand Down
Loading