Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions test/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
from datetime import timedelta
import functools
import os
import unittest

Expand All @@ -23,6 +24,7 @@
from helion._dist_utils import sync_seed
from helion._testing import EXAMPLES_DIR
from helion._testing import TestCase
from helion._testing import assert_close_with_mismatch_tolerance
from helion._testing import import_path
from helion._testing import onlyBackends
from helion._testing import skipIfRocm
Expand Down Expand Up @@ -429,14 +431,44 @@ def do_test_matmul_reduce_scatter(self, kernel, ref_kernel):
@skipIfRocm("Distributed example requires CUDA/NCCL")
@skipIfXPU("Distributed operations require CCL, not yet fully integrated")
@skip_if_lt_x_gpu(4)
def test_fp8_matmul_reduce_scatter(self):
@parametrize("autotuner", ["fixed", "LFBOTreeSearch"])
def test_fp8_matmul_reduce_scatter(self, autotuner):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
self.skipTest("FP8 requires CUDA compute capability >= 9.0")
self._init_process()

mod = import_path(EXAMPLES_DIR / "distributed" / "fp8_matmul_reduce_scatter.py")

kernel = mod.fp8_matmul_reduce_scatter_kernel.fn
_SymmetricMemory.signal_pad_size = 1024 * 1024 * 16

accuracy_check_fn = functools.partial(
assert_close_with_mismatch_tolerance, **mod.tolerance
)

if autotuner == "fixed":
kernel = helion.kernel(
config=helion.Config(
block_sizes=[64, 64, 32],
num_warps=8,
num_stages=3,
),
static_shapes=True,
ignore_warnings=[helion.exc.TensorOperationInWrapper],
autotune_baseline_accuracy_check_fn=accuracy_check_fn,
)(kernel)
context = contextlib.nullcontext()
else:
kernel = helion.kernel(
kernel,
static_shapes=True,
ignore_warnings=[helion.exc.TensorOperationInWrapper],
autotune_baseline_accuracy_check_fn=accuracy_check_fn,
)
context = unittest.mock.patch.dict(
os.environ, {"HELION_AUTOTUNER": autotuner}
)

M, N, K = 512, 768, 1024

torch.manual_seed(42 + self.rank)
Expand All @@ -457,17 +489,18 @@ def test_fp8_matmul_reduce_scatter(self):
symm_mem_buffer = symm_mem.empty(M, N, dtype=torch.bfloat16, device=self.device)
symm_mem_hdl = symm_mem.rendezvous(symm_mem_buffer, dist.group.WORLD.group_name)

result = mod.fp8_matmul_reduce_scatter_kernel(
a,
b,
scale_a,
scale_b,
symm_mem_buffer,
symm_mem_hdl.signal_pad_ptrs_dev,
RANK=symm_mem_hdl.rank,
WORLD_SIZE=symm_mem_hdl.world_size,
GROUP_NAME=dist.group.WORLD.group_name,
)
with context:
result = kernel(
a,
b,
scale_a,
scale_b,
symm_mem_buffer,
symm_mem_hdl.signal_pad_ptrs_dev,
RANK=symm_mem_hdl.rank,
WORLD_SIZE=symm_mem_hdl.world_size,
GROUP_NAME=dist.group.WORLD.group_name,
)

expected = mod.reference_fp8_matmul_reduce_scatter(a, b, scale_a, scale_b)

Expand Down
Loading