Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions docs/api/language.md
Original file line number Diff line number Diff line change
Expand Up @@ -315,21 +315,6 @@ See {func}`~helion.language.cumprod` for details.
.. autofunction:: tile_id
```

## Synchronization


### signal()

```{eval-rst}
.. autofunction:: signal
```

### wait()

```{eval-rst}
.. autofunction:: wait
```

## Utilities

### device_print()
Expand Down
19 changes: 14 additions & 5 deletions examples/distributed/all_gather_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@
from torch._C._distributed_c10d import _SymmetricMemory
import torch.distributed as dist
import torch.distributed._symmetric_memory as symm_mem
import triton
import triton.language as tl

import helion
from helion._testing import DEVICE
from helion._testing import run_example
import helion.language as hl
from helion.runtime.triton_helpers import triton_wait_signal


@triton.jit
def _wait_progress_at_idx(progress: tl.tensor, idx: int) -> None:
triton_wait_signal(progress + idx, 1, 0, "acquire", "gpu", "ld", False)


# %%
Expand Down Expand Up @@ -119,12 +127,13 @@ def helion_matmul_w_progress(
M_per_rank = a_shared.size(0)
for tile_m, tile_n in hl.tile([M, N]):
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
hl.wait(
progress,
[
hl.triton_kernel(
_wait_progress_at_idx,
args=(
progress,
tile_m.begin // (M_per_rank // SPLITS_PER_RANK),
],
signal=1,
),
output_like=None,
)
for tile_k in hl.tile(K):
acc = torch.addmm(acc, a[tile_m, tile_k], b[tile_k, tile_n])
Expand Down
4 changes: 0 additions & 4 deletions helion/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@
from .scan_ops import associative_scan as associative_scan
from .scan_ops import cumprod as cumprod
from .scan_ops import cumsum as cumsum
from .signal_wait import signal as signal
from .signal_wait import wait as wait
from .stack_tensor import StackTensor as StackTensor
from .stack_tensor import stacktensor_like as stacktensor_like
from .tile_ops import tile_begin as tile_begin
Expand All @@ -62,6 +60,4 @@
atomic_or,
atomic_xchg,
atomic_xor,
wait,
signal,
)
Loading
Loading