Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion helion/_compiler/compile_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import typing
from typing import TYPE_CHECKING
from typing import Protocol
import warnings

import sympy
import torch
Expand All @@ -22,6 +23,7 @@
from torch._inductor.runtime.runtime_utils import next_power_of_2
from torch._subclasses import FakeTensor
from torch._subclasses import FakeTensorMode
import torch.distributed as dist
from torch.fx.experimental.symbolic_shapes import DimDynamic
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
Expand Down Expand Up @@ -162,9 +164,30 @@ def __init__(
self.device_load_count = (
0 # Track number of loads in all device code for eviction policy tuning
)
if settings.autotune_force_persistent:
if settings.autotune_force_persistent or dist.is_initialized():
for pid_type in ("flat", "xyz"):
self.config_spec.disallow_pid_type(pid_type)

if dist.is_initialized():
from torch._C._distributed_c10d import _SymmetricMemory

from .._dist_utils import max_num_blocks_for_symm_mem
from ..runtime import get_num_sm

num_sms = get_num_sm(device, reserved_sms=settings.persistent_reserved_sms)
# Floor to previous power of two since PowerOfTwoFragment requires pow2 bounds
raw_max = min(
max_num_blocks_for_symm_mem() // num_sms,
self.config_spec.max_num_sm_multiplier,
)
newmax = max(1, 1 << (raw_max.bit_length() - 1))
if newmax < self.config_spec.max_num_sm_multiplier:
warnings.warn(
f"max_num_sm_multipler is reduced from {self.config_spec.max_num_sm_multiplier} to {newmax} due to the restriction of _SymmetricMemory.signal_pad_size={_SymmetricMemory.signal_pad_size}. Increase the signal pad size to allow autotuner to choose among all possible values in the range.",
stacklevel=1,
)
self.config_spec.max_num_sm_multiplier = newmax

self.has_barrier: bool = False

def specialize_expr(self, expr: sympy.Expr) -> sympy.Expr:
Expand Down
15 changes: 15 additions & 0 deletions helion/_dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from __future__ import annotations
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious should this be merged into helion/runtime/dist_utils.py added in #1771 ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd better split compile time and runtime utils?

btw, I also plan to move other compiled time dist utils from _utils to this file


import torch
from torch._C._distributed_c10d import _SymmetricMemory
import torch.distributed as dist


def max_num_blocks_for_symm_mem() -> int:
"""
Return the max number of blocks allowed due to the restriction of
signal pad size in symm memory.
"""
assert dist.is_initialized()
signal_pad_size = _SymmetricMemory.signal_pad_size
return signal_pad_size // torch.int32.itemsize // dist.get_world_size()
3 changes: 2 additions & 1 deletion helion/autotuner/config_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def __init__(
self.static_ranges: BlockIdSequence[StaticRangeSpec] = BlockIdSequence()

self.allowed_pid_types: tuple[PidTypeLiteral, ...] = tuple(VALID_PID_TYPES)
self.max_num_sm_multiplier: int = MAX_NUM_SM_MULTIPLIER
self.grid_block_ids: list[int] = []
self.load_eviction_policies = ListOf(
EnumFragment(choices=get_valid_eviction_policies(self.backend_name)),
Expand Down Expand Up @@ -629,7 +630,7 @@ def _flat_fields(
if self.supports_config_key("num_sm_multiplier"):
fields["num_sm_multiplier"] = PowerOfTwoFragment(
MIN_NUM_SM_MULTIPLIER,
MAX_NUM_SM_MULTIPLIER,
self.max_num_sm_multiplier,
DEFAULT_NUM_SM_MULTIPLIER,
)
if self.supports_config_key("load_eviction_policies"):
Expand Down
23 changes: 23 additions & 0 deletions test/test_config_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,29 @@ def test_autotune_force_persistent_limits_config_spec(self) -> None:
("persistent_blocked", "persistent_interleaved"),
)

def test_distributed_limits_pid_types_to_persistent(self) -> None:
settings = helion.Settings()
with (
patch("torch.distributed.is_initialized", return_value=True),
patch("helion._dist_utils.max_num_blocks_for_symm_mem", return_value=10000),
):
env = CompileEnvironment(torch.device("cuda", 0), settings)
self.assertEqual(
env.config_spec.allowed_pid_types,
("persistent_blocked", "persistent_interleaved"),
)

def test_persistent_block_limit_caps_num_sm_multiplier(self) -> None:
# max_blocks=10000, 200 SMs -> 10000 // 200 = 50 -> floor pow2 = 32
settings = helion.Settings()
with (
patch("torch.distributed.is_initialized", return_value=True),
patch("helion._dist_utils.max_num_blocks_for_symm_mem", return_value=10000),
patch("helion.runtime.get_num_sm", return_value=200),
):
env = CompileEnvironment(torch.device("cuda", 0), settings)
self.assertEqual(env.config_spec.max_num_sm_multiplier, 32)

def test_backend_env_var_accepts_cute(self) -> None:
with patch.dict(
os.environ,
Expand Down
Loading