diff --git a/examples/fp8_gemm.py b/examples/fp8_gemm.py index 73bd94308..2ed74b297 100644 --- a/examples/fp8_gemm.py +++ b/examples/fp8_gemm.py @@ -10,6 +10,7 @@ # %% from __future__ import annotations +import functools import os from typing import Callable @@ -62,22 +63,24 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # %% def reference_fp8_gemm_pytorch( - x_fp8: torch.Tensor, y_fp8: torch.Tensor + x_fp8: torch.Tensor, + y_fp8: torch.Tensor, + scale_a: torch.Tensor, + scale_b: torch.Tensor, ) -> torch.Tensor: """ Reference implementation using torch._scaled_mm. Args: x_fp8 (torch.Tensor): Input tensor in FP8 format. y_fp8 (torch.Tensor): Input tensor in FP8 format. + scale_a (torch.Tensor): Scale factor for x_fp8. + scale_b (torch.Tensor): Scale factor for y_fp8. Returns: torch.Tensor: Output tensor in half-precision format. """ # torch._scaled_mm requires column-major for second operand - y_fp8_t = y_fp8.T.contiguous().T - scale_a = torch.tensor(1.0, device=x_fp8.device) - scale_b = torch.tensor(1.0, device=x_fp8.device) return torch._scaled_mm( - x_fp8, y_fp8_t, scale_a, scale_b, use_fast_accum=False, out_dtype=HALF_DTYPE + x_fp8, y_fp8, scale_a, scale_b, use_fast_accum=False, out_dtype=HALF_DTYPE ) @@ -117,8 +120,16 @@ def check(m: int, k: int, n: int) -> None: y = torch.randn([k, n], device=DEVICE, dtype=torch.float32) # Convert to FP8 format (e4m3fn is commonly used for forward pass) x_fp8 = x.to(torch.float8_e4m3fn) - y_fp8 = y.to(torch.float8_e4m3fn) - run_example(fp8_gemm, reference_fp8_gemm_pytorch, (x_fp8, y_fp8)) + y_fp8 = y.to(torch.float8_e4m3fn).T.contiguous().T + + scale_a = torch.tensor(1.0, device=x_fp8.device) + scale_b = torch.tensor(1.0, device=x_fp8.device) + + run_example( + fp8_gemm, + functools.partial(reference_fp8_gemm_pytorch, scale_a=scale_a, scale_b=scale_b), + (x_fp8, y_fp8), + ) # %% diff --git a/test/test_examples.py b/test/test_examples.py index 1017bc59b..fa5f6414b 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -301,13 +301,15 @@ def test_fp8_gemm(self): # Convert to FP8 format x_fp8 = x.to(torch.float8_e4m3fn) - y_fp8 = y.to(torch.float8_e4m3fn) + y_fp8 = y.to(torch.float8_e4m3fn).T.contiguous().T args = (x_fp8, y_fp8) # Import the reference implementation mod = import_path(EXAMPLES_DIR / "fp8_gemm.py") - expected = mod.reference_fp8_gemm_pytorch(x_fp8, y_fp8) + scale_a = torch.tensor(1.0, device=DEVICE) + scale_b = torch.tensor(1.0, device=DEVICE) + expected = mod.reference_fp8_gemm_pytorch(x_fp8, y_fp8, scale_a, scale_b) check_example( "fp8_gemm",