Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/nki/matmul.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# CODEX NOTE: IGNORE THIS FILE, IT IS THE DEPRECATED NKI INTERPRETER
from neuronxcc import nki
import neuronxcc.nki.language as nl

Expand Down
124 changes: 124 additions & 0 deletions examples/nki/matmul_beta2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import nki
import nki.isa as nisa
import nki.language as nl

import numpy as np
import triton_viz

TRITON_VIZ_ENABLED = True


def matmul_kernel(lhsT, rhs, result):
"""Compute tiled matrix multiplication ``result = lhsT.T @ rhs``.

Args:
lhsT: Input matrix with shape ``[K, M]``.
rhs: Input matrix with shape ``[K, N]``.
result: Output matrix with shape ``[M, N]`` written in place.
"""

# Verify that the lhsT and rhs have the same contraction dimension.
K, M = lhsT.shape
K_, N = rhs.shape
assert K == K_, "lhsT and rhs must have the same contraction dimension"

# Lookup the device matrix multiply dimensions.
TILE_M = nl.tile_size.gemm_stationary_fmax # 128
TILE_K = nl.tile_size.pmax # 128
TILE_N = nl.tile_size.gemm_moving_fmax # 512

# Verify that the input matrices are a multiple of the tile dimensions.
assert (
M % TILE_M == 0
), f"Expected M, {M}, to be a multiple of stationary free-dimension max, {TILE_M}"
assert (
N % TILE_N == 0
), f"Expected N, {N}, to be a multiple of moving free-dimension max, {TILE_N}"
assert (
K % TILE_K == 0
), f"Expected K, {K}, to be a multiple of the partition dimension max, {TILE_K}"

# Use affine_range to loop over tiles
for m in nl.affine_range(M // TILE_M):
for n in nl.affine_range(N // TILE_N):
# Allocate a tensor in PSUM
res_psum = nl.ndarray((TILE_M, TILE_N), nl.float32, buffer=nl.psum)

for k in nl.affine_range(K // TILE_K):
# Declare the tiles on SBUF
lhsT_tile = nl.ndarray(
(TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf
)
rhs_tile = nl.ndarray((TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)

# Load tiles from lhsT and rhs
nisa.dma_copy(
dst=lhsT_tile,
src=lhsT[
k * TILE_K : (k + 1) * TILE_K, m * TILE_M : (m + 1) * TILE_M
],
)
nisa.dma_copy(
dst=rhs_tile,
src=rhs[
k * TILE_K : (k + 1) * TILE_K, n * TILE_N : (n + 1) * TILE_N
],
)

# Accumulate partial-sums into PSUM
nisa.nc_matmul(dst=res_psum, stationary=lhsT_tile, moving=rhs_tile)

# Copy the result from PSUM back to SBUF, and cast to expected output data-type
res_sb = nl.ndarray(res_psum.shape, dtype=result.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=res_sb, src=res_psum)

# Copy the result from SBUF to HBM.
nisa.dma_copy(
dst=result[
m * TILE_M : (m + 1) * TILE_M, n * TILE_N : (n + 1) * TILE_N
],
src=res_sb,
)


def _run_with_xla(kernel, kernel_grid, *arrays):
"""Run one beta2 kernel invocation on an XLA device."""
import torch
from torch_xla.core import xla_model as xm

device = xm.xla_device()
tensors = [torch.as_tensor(array, device=device) for array in arrays]
compiled_kernel = nki.jit(kernel, kernel_return=False)
compiled_kernel[kernel_grid](*tensors)
xm.mark_step()
return [tensor.cpu().numpy() for tensor in tensors]


def _run_demo():
"""Run the matmul example with lhsT ``[128, 128]`` and rhs ``[128, 512]``."""
kernel_grid = (1, 1, 1)
m_dim = 128
k_dim = 128
n_dim = 512
lhs_small = np.arange(k_dim * m_dim, dtype=np.float32).reshape(k_dim, m_dim)
rhs_small = np.arange(k_dim * n_dim, dtype=np.float32).reshape(k_dim, n_dim)
kernel = matmul_kernel

result = np.empty((m_dim, n_dim), dtype=lhs_small.dtype)
kernel_args = (lhs_small, rhs_small, result)
expected = lhs_small.T @ rhs_small

if TRITON_VIZ_ENABLED:
traced_kernel = triton_viz.trace("tracer", backend="nki")(kernel)
traced_kernel[kernel_grid](*kernel_args)
assert np.allclose(expected, result)
print("actual equals expected")
triton_viz.launch(share=False)
else:
_, _, result = _run_with_xla(kernel, kernel_grid, *kernel_args)
assert np.allclose(expected, result)
print("actual equals expected")


if __name__ == "__main__":
_run_demo()
78 changes: 78 additions & 0 deletions examples/nki/nki2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import nki
import nki.language as nl
import nki.isa as nisa
import numpy as np
import torch
import triton_viz

TRITON_VIZ_ENABLED = True


def nki_tensor_add_kernel(a_input, b_input, result):
"""
NKI kernel to compute element-wise addition of two input tensors.
"""

# Check both input tensor shapes are the same for element-wise operation.
assert a_input.shape == b_input.shape

# Check the first dimension's size to ensure it does not exceed on-chip
# memory tile size, since this simple kernel does not tile inputs.

assert a_input.shape[0] <= nl.tile_size.pmax

# Allocate space for the input tensors in SBUF and copy the inputs from HBM
# to SBUF with DMA copy. Note: 'sbuf' is a keyword in NKI.
a_tile = sbuf.view(dtype=a_input.dtype, shape=a_input.shape) # noqa: F821
nisa.dma_copy(dst=a_tile, src=a_input)

b_tile = sbuf.view(dtype=b_input.dtype, shape=b_input.shape) # noqa: F821
nisa.dma_copy(dst=b_tile, src=b_input)

# Allocate space for the result and use tensor_tensor to perform
# element-wise addition. Note: the first argument of 'tensor_tensor'
# is the destination tensor.
c_tile = sbuf.view(dtype=a_input.dtype, shape=a_input.shape) # noqa: F821
nisa.tensor_tensor(dst=c_tile, data1=a_tile, data2=b_tile, op=nl.add)

# Create a tensor in HBM and copy the result into HBM. Note: Similar to
# 'sbuf', 'hbm' is a keyword in NKI.
# c_output = hbm.view(dtype=a_input.dtype, shape=a_input.shape)
# nisa.dma_copy(dst=c_output, src=c_tile)
nisa.dma_copy(dst=result, src=c_tile)


def _run_demo():
kernel_grid = (1, 1, 1)
a = torch.ones((4, 3))
b = torch.ones((4, 3))
result = torch.empty((4, 3))
kernel_args = (a, b, result)

if TRITON_VIZ_ENABLED:
print("Executing matmul_kernel with NKI interpreter...")
nl.tile_size.pmax = 128 # nl.tile_size.pmax = None if no Trn device found?
traced_kernel = triton_viz.trace("tracer", backend="nki")(nki_tensor_add_kernel)
kernel_instance = traced_kernel[kernel_grid]
kernel_instance(*kernel_args)
triton_viz.launch(share=False)
else: # NOTE: we must have trainium device to run this (no CPU interpreter for NKI Beta 2 yet)
print("Executing NKI JIT-ed matmul_kernel...")
from torch_xla.core import xla_model as xm

device = xm.xla_device()
a = a.to(device)
b = b.to(device)
# Invoke the kernel to add the results.
compiled_kernel = nki.jit(nki_tensor_add_kernel)
c = compiled_kernel(*kernel_args).cpu()
assert np.allclose(a + b, c)

z2 = result
z1 = a + b
print(np.max(np.abs(z1 - z2)))
assert np.allclose(z1, z2)


if __name__ == "__main__":
_run_demo()
1 change: 1 addition & 0 deletions examples/nki/rmsnorm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# CODEX NOTE: IGNORE THIS FILE, IT IS THE DEPRECATED NKI INTERPRETER
import neuronxcc.nki as nki
import neuronxcc.nki.language as nl
from triton_viz.clients import Tracer
Expand Down
105 changes: 105 additions & 0 deletions examples/nki/rmsnorm_beta2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import nki
import nki.isa as nisa
import nki.language as nl

import numpy as np
import triton_viz

TRITON_VIZ_ENABLED = True


def rmsnorm_kernel(x, gamma, out, eps=1e-6):
"""Compute RMSNorm on rows of ``x``.

Args:
x: Input tensor with shape ``[batch, dim]``.
gamma: Scale tensor with shape ``[dim]``.
out: Output tensor with shape ``[batch, dim]`` written in place.
eps: Stability epsilon added before rsqrt.
"""
batch, dim = x.shape
tile_p = nl.tile_size.pmax
assert batch % tile_p == 0, f"Expected batch ({batch}) to be a multiple of {tile_p}"
assert gamma.shape == (dim,), f"Expected gamma shape ({dim},), got {gamma.shape}"

gamma_tile = nl.ndarray((1, dim), dtype=gamma.dtype, buffer=nl.sbuf)
nisa.dma_copy(dst=gamma_tile, src=gamma.reshape((1, dim)))

for tile_idx in nl.affine_range(batch // tile_p):
row_start = tile_idx * tile_p
x_tile = nl.ndarray((tile_p, dim), dtype=x.dtype, buffer=nl.sbuf)
nisa.dma_copy(dst=x_tile, src=x[nl.ds(row_start, tile_p), :])

sq_tile = nl.ndarray((tile_p, dim), dtype=nl.float32, buffer=nl.sbuf)
nisa.activation(dst=sq_tile, op=np.square, data=x_tile)

sq_mean = nl.ndarray((tile_p, 1), dtype=nl.float32, buffer=nl.sbuf)
nisa.tensor_reduce(dst=sq_mean, op=nl.add, data=sq_tile, axis=1, keepdims=True)
nisa.tensor_scalar(
dst=sq_mean, data=sq_mean, op0=nl.multiply, operand0=1.0 / float(dim)
)

inv_rms = nl.ndarray((tile_p, 1), dtype=nl.float32, buffer=nl.sbuf)
nisa.activation(dst=inv_rms, op=nl.rsqrt, data=sq_mean, bias=eps)

norm_tile = nl.ndarray((tile_p, dim), dtype=nl.float32, buffer=nl.sbuf)
nisa.tensor_tensor(
dst=norm_tile,
data1=x_tile,
data2=inv_rms.broadcast_to((tile_p, dim)),
op=nl.multiply,
)

out_tile = nl.ndarray((tile_p, dim), dtype=out.dtype, buffer=nl.sbuf)
nisa.tensor_tensor(
dst=out_tile,
data1=norm_tile,
data2=gamma_tile.broadcast_to((tile_p, dim)),
op=nl.multiply,
)
nisa.dma_copy(dst=out[nl.ds(row_start, tile_p), :], src=out_tile)


def _numpy_rmsnorm(x, gamma, eps=1e-6):
mean_sq = np.mean(x * x, axis=1, keepdims=True)
return (x / np.sqrt(mean_sq + eps)) * gamma.reshape((1, -1))


def _run_with_xla(kernel, kernel_grid, *arrays):
"""Run one beta2 kernel invocation on an XLA device."""
import torch
from torch_xla.core import xla_model as xm

device = xm.xla_device()
tensors = [torch.as_tensor(array, device=device) for array in arrays]
compiled_kernel = nki.jit(kernel, kernel_return=False)
compiled_kernel[kernel_grid](*tensors)
xm.mark_step()
return [tensor.cpu().numpy() for tensor in tensors]


def _run_demo():
"""Run the RMSNorm example with x ``[256, 128]`` and gamma ``[128]``."""
kernel_grid = (1, 1, 1)
batch = 256
dim = 128
x = np.linspace(-2.0, 2.0, batch * dim, dtype=np.float32).reshape(batch, dim)
gamma = np.linspace(0.5, 1.5, dim, dtype=np.float32)
out = np.empty_like(x)
kernel_args = (x, gamma, out)
expected = _numpy_rmsnorm(x, gamma)

if TRITON_VIZ_ENABLED:
traced_kernel = triton_viz.trace("tracer", backend="nki")(rmsnorm_kernel)
traced_kernel[kernel_grid](*kernel_args)
assert np.allclose(expected, out)
print("actual equals expected")
triton_viz.launch(share=False)
else:
_, _, out = _run_with_xla(rmsnorm_kernel, kernel_grid, *kernel_args)
assert np.allclose(expected, out)
print("actual equals expected")


if __name__ == "__main__":
_run_demo()
1 change: 1 addition & 0 deletions examples/nki/rope.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# CODEX NOTE: IGNORE THIS FILE, IT IS THE DEPRECATED NKI INTERPRETER
import os
from typing import Tuple

Expand Down
Loading
Loading