Skip to content

Commit 30e89b2

Browse files
committed
remove hl.signal/wait
stack-info: PR: #1791, branch: shunting314/stack/17
1 parent cdf4e12 commit 30e89b2

9 files changed

Lines changed: 17 additions & 943 deletions

File tree

docs/api/language.md

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -315,21 +315,6 @@ See {func}`~helion.language.cumprod` for details.
315315
.. autofunction:: tile_id
316316
```
317317

318-
## Synchronization
319-
320-
321-
### signal()
322-
323-
```{eval-rst}
324-
.. autofunction:: signal
325-
```
326-
327-
### wait()
328-
329-
```{eval-rst}
330-
.. autofunction:: wait
331-
```
332-
333318
## Utilities
334319

335320
### device_print()

examples/distributed/all_gather_matmul.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,19 @@
1919
from torch._C._distributed_c10d import _SymmetricMemory
2020
import torch.distributed as dist
2121
import torch.distributed._symmetric_memory as symm_mem
22+
import triton
23+
import triton.language as tl
2224

2325
import helion
2426
from helion._testing import DEVICE
2527
from helion._testing import run_example
2628
import helion.language as hl
29+
from helion.runtime.triton_helpers import triton_wait_signal
30+
31+
32+
@triton.jit
33+
def _wait_progress_at_idx(progress: tl.tensor, idx: int) -> None:
34+
triton_wait_signal(progress + idx, 1, 0, "acquire", "gpu", "ld", False)
2735

2836

2937
# %%
@@ -119,12 +127,13 @@ def helion_matmul_w_progress(
119127
M_per_rank = a_shared.size(0)
120128
for tile_m, tile_n in hl.tile([M, N]):
121129
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
122-
hl.wait(
123-
progress,
124-
[
130+
hl.triton_kernel(
131+
_wait_progress_at_idx,
132+
args=(
133+
progress,
125134
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
126-
],
127-
signal=1,
135+
),
136+
output_like=None,
128137
)
129138
for tile_k in hl.tile(K):
130139
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])

helion/language/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@
3535
from .scan_ops import associative_scan as associative_scan
3636
from .scan_ops import cumprod as cumprod
3737
from .scan_ops import cumsum as cumsum
38-
from .signal_wait import signal as signal
39-
from .signal_wait import wait as wait
4038
from .stack_tensor import StackTensor as StackTensor
4139
from .stack_tensor import stacktensor_like as stacktensor_like
4240
from .tile_ops import tile_begin as tile_begin
@@ -62,6 +60,4 @@
6260
atomic_or,
6361
atomic_xchg,
6462
atomic_xor,
65-
wait,
66-
signal,
6763
)

0 commit comments

Comments
 (0)