Skip to content

Commit 3cecf1e

Browse files
committed
properly codegen hl.triton_kernel
stack-info: PR: #1797, branch: shunting314/stack/18
1 parent 67e2803 commit 3cecf1e

5 files changed

Lines changed: 3 additions & 16 deletions

File tree

examples/distributed/allreduce_bias_rmsnorm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
config=helion.Config(
2929
block_sizes=[8],
3030
num_warps=8,
31+
reduction_loops=[1024],
3132
),
3233
static_shapes=True,
3334
)

helion/_compiler/roll_reduction.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,7 @@ def process(self, graph: torch.fx.Graph) -> torch.fx.Graph:
399399
if (
400400
not all((n in self.available) for n in node.all_input_nodes)
401401
or node.op == "output"
402+
or (node.is_impure() and self.inner_count > 0)
402403
):
403404
self.start_new_graph()
404405
new_node = self.outer_graph.create_node(

helion/_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,11 +225,3 @@ def all_gather_object(obj: T) -> list[T]:
225225
object_list = [None] * dist.get_world_size()
226226
dist.all_gather_object(object_list, obj)
227227
return object_list # pyrefly: ignore
228-
229-
230-
def autotune_for_distributed_kernel() -> bool:
231-
"""
232-
Remove this once these issues regarding distributed kernels are fixed:
233-
- https://github.com/pytorch/helion/issues/1642
234-
"""
235-
return os.getenv("HELION_AUTOTUNE_FOR_DISTRIBUTED_KERNEL") == "1"

helion/autotuner/config_spec.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
from .config_fragment import PowerOfTwoFragment
3232
from .config_fragment import assert_integer_power_of_two
3333
import helion
34-
from helion._utils import autotune_for_distributed_kernel
3534

3635
if TYPE_CHECKING:
3736
from collections.abc import Callable
@@ -931,10 +930,6 @@ def _flat_config(
931930
default = min(default, base.max_reduction_threads)
932931
value = fn(BlockSizeFragment(low, high, default))
933932
assert isinstance(value, int)
934-
if autotune_for_distributed_kernel():
935-
# workaround https://github.com/pytorch/helion/issues/1642
936-
return None
937-
938933
if not (low <= value <= high):
939934
raise InvalidConfig(
940935
f"Invalid value for reduction loop {low} <= {value} <= {high}"

test/test_distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ def setUpClass(cls) -> None:
9494
"HELION_DIST_CHECK_CONFIG_CONSISTANCY": "1",
9595
"HELION_CAP_AUTOTUNE_NUM_NEIGHBORS": "50",
9696
"HELION_CAP_REBENCHMARK_REPEAT": "50",
97-
"HELION_AUTOTUNE_FOR_DISTRIBUTED_KERNEL": "1",
9897
},
9998
)
10099
)
@@ -264,8 +263,7 @@ def test_allreduce_bias_rmsnorm(self, kernel_name, autotuner):
264263
kernel = getattr(mod, kernel_name).fn
265264
if autotuner == "fixed":
266265
fixed_config = helion.Config(
267-
block_sizes=[8],
268-
num_warps=8,
266+
block_sizes=[8], num_warps=8, reduction_loops=[1024]
269267
)
270268

271269
kernel = helion.kernel(

0 commit comments

Comments
 (0)