Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 137 additions & 0 deletions helion/_dist_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,31 @@
from __future__ import annotations

import contextlib
from dataclasses import dataclass
import os
import random
from typing import Generator
from typing import TypeVar

import torch
from torch import Tensor
from torch._C._distributed_c10d import _SymmetricMemory
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import helion
from helion import exc

T = TypeVar("T")


def all_gather_object(obj: T) -> list[T]:
if not dist.is_initialized():
return [obj]

object_list = [None] * dist.get_world_size()
dist.all_gather_object(object_list, obj)
return object_list # pyrefly: ignore


def max_num_blocks_for_symm_mem() -> int:
Expand All @@ -13,3 +36,117 @@ def max_num_blocks_for_symm_mem() -> int:
assert dist.is_initialized()
signal_pad_size = _SymmetricMemory.signal_pad_size
return signal_pad_size // torch.int32.itemsize // dist.get_world_size()


def is_master_rank() -> bool:
"""
Either return True for rank 0 in a distributed workload or
always return true for non-distributed workload.
"""
return not dist.is_initialized() or dist.get_rank() == 0


def is_symm_mem_tensor(t: Tensor) -> bool:
if not isinstance(t, Tensor) or not dist.is_initialized():
return False

# TODO(shunting): support group other than WORLD?
try:
assert dist.group.WORLD is not None
return symm_mem.rendezvous(t, group=dist.group.WORLD.group_name) is not None
except RuntimeError:
# PyTorch right now throws a RuntimeError if the tensor passed
# to rendezvious is not from symm-mem
return False


def get_signal_pad_ptrs_dev(t: Tensor) -> int:
assert dist.group.WORLD is not None
hdl = symm_mem.rendezvous(t, group=dist.group.WORLD.group_name)
return hdl.signal_pad_ptrs_dev


def check_config_consistancy(config: helion.Config, print_config: bool = False) -> None:
"""
Check the consistency of configs across ranks.
"""
if (
os.getenv("HELION_DIST_CHECK_CONFIG_CONSISTANCY") != "1"
or not dist.is_initialized()
):
return

all_configs = [None] * dist.get_world_size()
dist.all_gather_object(all_configs, config)
if dist.get_rank() == 0:
# do the check on rank 0
if all_configs != all_configs[:1] * len(all_configs):
if print_config:
for idx, c in enumerate(all_configs):
print("FAIL", idx, c)
raise exc.InconsistantConfigsAcrossRanks
if print_config:
for idx, c in enumerate(all_configs):
print("PASS", idx, c)


def print_with_rank(*args: object, **kwargs: object) -> None:
if dist.is_initialized():
print(f"Rank{dist.get_rank()}: ", end="")
print(*args, **kwargs) # pyrefly: ignore[no-matching-overload]


@dataclass
class SeedEnsemble:
torch_seed: int
py_random_seed: int

@staticmethod
def get_seeds() -> SeedEnsemble:
"""
There is no way to get current seed in PyTorch. We can only get
the initial seed.

This method instead re-initialize the seed by incrementing the
initial seed by 1
"""
seed = torch.initial_seed()
return SeedEnsemble(
seed + 1,
seed + 1,
)

@staticmethod
def set_seeds(seeds: SeedEnsemble) -> None:
torch.manual_seed(seeds.torch_seed)
random.seed(seeds.py_random_seed)

@classmethod
def update_seeds_with_rank(cls) -> None:
seed = torch.initial_seed() + 1 + dist.get_rank()
cls.set_seeds(SeedEnsemble(seed, seed))


@contextlib.contextmanager
def sync_seed(need_diverse_seeds_after: bool = True) -> Generator[None, None, None]:
"""
Sync seeds across ranks.

If need_diverse_seeds_after is True, we make sure different
ranks have different seeds after the call. This ensures different
rank can generate independent random tensors.
"""
if not dist.is_initialized():
yield
return

from helion._testing import sync_object

seeds = sync_object(SeedEnsemble.get_seeds())

try:
SeedEnsemble.set_seeds(seeds)
yield
finally:
if need_diverse_seeds_after:
SeedEnsemble.update_seeds_with_rank()
2 changes: 1 addition & 1 deletion helion/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
from ._compat import requires_torch_version
from ._compat import supports_amd_cdna_tunables
from ._compat import supports_tensor_descriptor
from ._dist_utils import is_master_rank
from ._utils import counters
from ._utils import is_master_rank
from .autotuner.benchmarking import sync_object as sync_object
from .runtime.settings import _get_backend
from helion.autotuner.base_search import _clone_args
Expand Down
138 changes: 1 addition & 137 deletions helion/_utils.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,16 @@
from __future__ import annotations

import collections
import contextlib
from dataclasses import dataclass
import functools
import os
import random
from typing import Generator
from typing import Sequence
from typing import TypeVar

import torch
from torch import Tensor
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem

import helion
from helion import exc
T = TypeVar("T")

counters: collections.defaultdict[str, collections.Counter[str]] = (
collections.defaultdict(collections.Counter)
)

T = TypeVar("T")


def cdiv(a: int, b: int) -> int:
"""Ceiling division: returns ceil(a / b)."""
Expand Down Expand Up @@ -102,126 +89,3 @@ def _extract_slice(obj: object) -> object:
if isinstance(index, tuple):
return tuple(_extract_slice(idx) for idx in index)
return _extract_slice(index)


def is_master_rank() -> bool:
"""
Either return True for rank 0 in a distributed workload or
always return true for non-distributed workload.
"""
return not dist.is_initialized() or dist.get_rank() == 0


def is_symm_mem_tensor(t: Tensor) -> bool:
if not isinstance(t, Tensor) or not dist.is_initialized():
return False

# TODO(shunting): support group other than WORLD?
try:
assert dist.group.WORLD is not None
return symm_mem.rendezvous(t, group=dist.group.WORLD.group_name) is not None
except RuntimeError:
# PyTorch right now throws a RuntimeError if the tensor passed
# to rendezvious is not from symm-mem
return False


def get_signal_pad_ptrs_dev(t: Tensor) -> int:
assert dist.group.WORLD is not None
hdl = symm_mem.rendezvous(t, group=dist.group.WORLD.group_name)
return hdl.signal_pad_ptrs_dev


def check_config_consistancy(config: helion.Config, print_config: bool = False) -> None:
"""
Check the consistency of configs across ranks.
"""
if (
os.getenv("HELION_DIST_CHECK_CONFIG_CONSISTANCY") != "1"
or not dist.is_initialized()
):
return

all_configs = [None] * dist.get_world_size()
dist.all_gather_object(all_configs, config)
if dist.get_rank() == 0:
# do the check on rank 0
if all_configs != all_configs[:1] * len(all_configs):
if print_config:
for idx, c in enumerate(all_configs):
print("FAIL", idx, c)
raise exc.InconsistantConfigsAcrossRanks
if print_config:
for idx, c in enumerate(all_configs):
print("PASS", idx, c)


def print_with_rank(*args: object, **kwargs: object) -> None:
if dist.is_initialized():
print(f"Rank{dist.get_rank()}: ", end="")
print(*args, **kwargs) # pyrefly: ignore[no-matching-overload]


@dataclass
class SeedEnsemble:
torch_seed: int
py_random_seed: int

@staticmethod
def get_seeds() -> SeedEnsemble:
"""
There is no way to get current seed in PyTorch. We can only get
the initial seed.

This method instead re-initialize the seed by incrementing the
initial seed by 1
"""
seed = torch.initial_seed()
return SeedEnsemble(
seed + 1,
seed + 1,
)

@staticmethod
def set_seeds(seeds: SeedEnsemble) -> None:
torch.manual_seed(seeds.torch_seed)
random.seed(seeds.py_random_seed)

@classmethod
def update_seeds_with_rank(cls) -> None:
seed = torch.initial_seed() + 1 + dist.get_rank()
cls.set_seeds(SeedEnsemble(seed, seed))


@contextlib.contextmanager
def sync_seed(need_diverse_seeds_after: bool = True) -> Generator[None, None, None]:
"""
Sync seeds across ranks.

If need_diverse_seeds_after is True, we make sure different
ranks have different seeds after the call. This ensures different
rank can generate independent random tensors.
"""
if not dist.is_initialized():
yield
return

from helion._testing import sync_object

seeds = sync_object(SeedEnsemble.get_seeds())

try:
SeedEnsemble.set_seeds(seeds)
yield
finally:
if need_diverse_seeds_after:
SeedEnsemble.update_seeds_with_rank()


def all_gather_object(obj: T) -> list[T]:
if not dist.is_initialized():
return [obj]

object_list = [None] * dist.get_world_size()
dist.all_gather_object(object_list, obj)
return object_list # pyrefly: ignore
8 changes: 4 additions & 4 deletions helion/autotuner/base_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@
from .metrics import AutotuneMetrics
from .metrics import _run_post_autotune_hooks
from .progress_bar import iter_with_progress
from helion._utils import all_gather_object
from helion._utils import get_signal_pad_ptrs_dev
from helion._utils import is_master_rank
from helion._utils import is_symm_mem_tensor
from helion._dist_utils import all_gather_object
from helion._dist_utils import get_signal_pad_ptrs_dev
from helion._dist_utils import is_master_rank
from helion._dist_utils import is_symm_mem_tensor

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/config_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .config_fragment import Category
from .config_fragment import ConfigSpecFragment
from .config_fragment import PowerOfTwoFragment
from helion._utils import sync_seed
from helion._dist_utils import sync_seed

if TYPE_CHECKING:
from collections.abc import Mapping
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/de_surrogate_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from .differential_evolution import DifferentialEvolutionSearch
from .effort_profile import DIFFERENTIAL_EVOLUTION_DEFAULTS
from helion._utils import sync_seed
from helion._dist_utils import sync_seed

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/differential_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .base_search import population_statistics
from .effort_profile import DIFFERENTIAL_EVOLUTION_DEFAULTS
from .pattern_search import InitialPopulationStrategy
from helion._utils import sync_seed
from helion._dist_utils import sync_seed

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from torch._inductor.runtime.triton_compat import OutOfResources
from torch._inductor.runtime.triton_compat import PTXASError

from helion._utils import is_master_rank
from helion._dist_utils import is_master_rank

if TYPE_CHECKING:
from _csv import _writer as CsvWriter
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rich.text import Text
import torch

from helion._utils import is_master_rank
from helion._dist_utils import is_master_rank

if TYPE_CHECKING:
from collections.abc import Iterable
Expand Down
2 changes: 1 addition & 1 deletion helion/autotuner/surrogate_pattern_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .effort_profile import PATTERN_SEARCH_DEFAULTS
from .pattern_search import InitialPopulationStrategy
from .pattern_search import PatternSearch
from helion._utils import sync_seed
from helion._dist_utils import sync_seed

if TYPE_CHECKING:
from collections.abc import Iterator
Expand Down
2 changes: 1 addition & 1 deletion helion/runtime/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
from .._compiler.output_header import assert_no_conflicts
from .._compiler.output_header import get_needed_imports
from .._compiler.variable_origin import ArgumentOrigin
from .._dist_utils import check_config_consistancy as dist_check_config_consistancy
from .._logging import LazyString
from .._utils import check_config_consistancy as dist_check_config_consistancy
from .._utils import counters
from ..autotuner.base_search import _AutotunableKernel
from ..language.constexpr import ConstExpr
Expand Down
Loading
Loading