Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b00d65f
initial port
latentCall145 Mar 2, 2026
c486d4c
trim down interpreter to just what nki examples needs
latentCall145 Mar 2, 2026
e26809a
get rid of non-essential tests for now
latentCall145 Mar 2, 2026
20ea47b
push NKIBeta2Trace into NKITrace
latentCall145 Mar 2, 2026
f0bfb58
verify kernels against trainium
latentCall145 Mar 6, 2026
2e16f00
Merge branch 'main' into nki-beta-2-pt2
mark14wu Mar 6, 2026
11843c7
add new tests
latentCall145 Mar 8, 2026
875dc48
make examples work
latentCall145 Mar 8, 2026
5e02059
debloat
latentCall145 Mar 8, 2026
c65b6a8
Merge branch 'nki-beta-2-pt2' of github.com:Deep-Learning-Profiling-T…
latentCall145 Mar 8, 2026
df6dbc8
recapitalize
latentCall145 Mar 8, 2026
f5d36c1
more tests; debloat
latentCall145 Mar 8, 2026
a5c096b
more tests
latentCall145 Mar 8, 2026
e12f667
type hints
latentCall145 Mar 8, 2026
a452a39
pin nki version
latentCall145 Mar 8, 2026
b5195a2
conditional import
latentCall145 Mar 8, 2026
1d5919a
comments
latentCall145 Mar 8, 2026
e3fb949
nki.trace before running kernel to catch some frontend errors
latentCall145 Mar 10, 2026
0db44a0
fix beta 2 examples
latentCall145 Mar 10, 2026
93d86ea
stuff
latentCall145 Mar 10, 2026
8fe631d
better match NKI compiler
latentCall145 Mar 10, 2026
5021a33
fixes to NKI compiler edge cases
latentCall145 Mar 10, 2026
fb8fc79
lint
latentCall145 Mar 10, 2026
badaf3b
tighten tolerance
latentCall145 Mar 10, 2026
9e7048c
edge case testing
latentCall145 Mar 10, 2026
99ae69a
debloat
latentCall145 Mar 10, 2026
aaa1caa
pre-trace arg
latentCall145 Mar 10, 2026
79e8b55
BHND tiled attn
latentCall145 Mar 10, 2026
0357158
import triton_viz only when needed
latentCall145 Mar 10, 2026
490f04d
move nki beta 2 examples to its own folder
latentCall145 Mar 10, 2026
c155694
BHLD rope
latentCall145 Mar 10, 2026
735d267
Merge branch 'main' into nki-beta-2-pt2
latentCall145 Mar 10, 2026
fa31bed
don't change legacy nki examples
latentCall145 Mar 10, 2026
288530e
Merge branch 'nki-beta-2-pt2' of github.com:Deep-Learning-Profiling-T…
latentCall145 Mar 10, 2026
23f7e21
more descriptive err msgs
latentCall145 Mar 10, 2026
7a4d6f7
vibe-coded mlp kernel (works on trainium) first try
latentCall145 Mar 10, 2026
86541b9
appease the llm
latentCall145 Mar 10, 2026
00b492c
Merge branch 'main' into nki-beta-2-pt2
latentCall145 Mar 12, 2026
62a8b44
Merge branch 'main' into nki-beta-2-pt2
latentCall145 Mar 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/nki_beta2/add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import nki
import nki.isa as nisa
import nki.language as nl
import numpy as np
import torch

TRITON_VIZ_ENABLED = True
PRE_TRACE = True # if True, run the NKI Beta 2 tracer before running interpreter. Can be set to false, though has less guarantees with matching NKI compiler behavior.


def nki_tensor_add_kernel(a_input, b_input, result):
"""
NKI kernel to compute element-wise addition of two input tensors.
"""

# Check both input tensor shapes are the same for element-wise operation.
assert a_input.shape == b_input.shape

# Check the first dimension's size to ensure it does not exceed on-chip
# memory tile size, since this simple kernel does not tile inputs.

assert a_input.shape[0] <= nl.tile_size.pmax

# Allocate space for the input tensors in SBUF and copy the inputs from HBM
# to SBUF with DMA copy. Note: 'sbuf' is a keyword in NKI.
a_tile = sbuf.view(dtype=a_input.dtype, shape=a_input.shape) # noqa: F821
nisa.dma_copy(dst=a_tile, src=a_input)

b_tile = sbuf.view(dtype=b_input.dtype, shape=b_input.shape) # noqa: F821
nisa.dma_copy(dst=b_tile, src=b_input)

# Allocate space for the result and use tensor_tensor to perform
# element-wise addition. Note: the first argument of 'tensor_tensor'
# is the destination tensor.
c_tile = sbuf.view(dtype=a_input.dtype, shape=a_input.shape) # noqa: F821
nisa.tensor_tensor(dst=c_tile, data1=a_tile, data2=b_tile, op=nl.add)

# Create a tensor in HBM and copy the result into HBM. Note: Similar to
# 'sbuf', 'hbm' is a keyword in NKI.
nisa.dma_copy(dst=result, src=c_tile)
return result


def _run_with_xla(kernel, kernel_grid, *arrays):
"""Run one beta2 kernel invocation on an XLA (Trainium) device."""
import torch
import torch_xla

device = torch_xla.device()
tensors = [torch.as_tensor(array, device=device) for array in arrays]
compiled_kernel = nki.jit(kernel, platform_target="trn1")
result = compiled_kernel[kernel_grid](*tensors)
torch_xla.sync()
return result.cpu().numpy()


def _run_demo():
kernel_grid = (1,)
a = torch.ones((4, 3), dtype=torch.float32)
b = torch.ones((4, 3), dtype=torch.float32)
result = torch.empty((4, 3), dtype=torch.float32)
kernel = nki_tensor_add_kernel
kernel_args = (a, b, result)
expected = a + b

if TRITON_VIZ_ENABLED:
import triton_viz

traced_kernel = triton_viz.trace("tracer", backend="nki_beta2")(kernel)
traced_kernel[kernel_grid](*kernel_args, pre_trace=PRE_TRACE)
assert np.allclose(expected, result)
print("☑️ Actual equals expected!")
triton_viz.launch(share=False)
else: # Note: no official NKI Beta 2 CPU interpreter exists so run on XLA
result = _run_with_xla(kernel, kernel_grid, *kernel_args)
assert np.allclose(expected, result)
print("☑️ Actual equals expected!")


if __name__ == "__main__":
_run_demo()
127 changes: 127 additions & 0 deletions examples/nki_beta2/matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import nki
import nki.isa as nisa
import nki.language as nl

import numpy as np

TRITON_VIZ_ENABLED = True
PRE_TRACE = True # if True, run the NKI Beta 2 tracer before running interpreter. Can be set to false, though has less guarantees with matching NKI compiler behavior.


def matmul_kernel(lhsT, rhs, result):
"""Compute tiled matrix multiplication ``result = lhsT.T @ rhs``.

Args:
lhsT: Input matrix with shape ``[K, M]``.
rhs: Input matrix with shape ``[K, N]``.
result: Output matrix with shape ``[M, N]`` written in place.
"""

# Verify that the lhsT and rhs have the same contraction dimension.
K, M = lhsT.shape
K_, N = rhs.shape
assert K == K_, "lhsT and rhs must have the same contraction dimension"

# Lookup the device matrix multiply dimensions.
TILE_M = nl.tile_size.gemm_stationary_fmax # 128
TILE_K = nl.tile_size.pmax # 128
TILE_N = nl.tile_size.gemm_moving_fmax # 512

# Verify that the input matrices are a multiple of the tile dimensions.
assert (
M % TILE_M == 0
), f"Expected M, {M}, to be a multiple of stationary free-dimension max, {TILE_M}"
assert (
N % TILE_N == 0
), f"Expected N, {N}, to be a multiple of moving free-dimension max, {TILE_N}"
assert (
K % TILE_K == 0
), f"Expected K, {K}, to be a multiple of the partition dimension max, {TILE_K}"

# Use affine_range to loop over tiles
for m in nl.affine_range(M // TILE_M):
for n in nl.affine_range(N // TILE_N):
# Allocate a tensor in PSUM
res_psum = nl.ndarray((TILE_M, TILE_N), nl.float32, buffer=nl.psum)

for k in nl.affine_range(K // TILE_K):
# Declare the tiles on SBUF
lhsT_tile = nl.ndarray(
(TILE_K, TILE_M), dtype=lhsT.dtype, buffer=nl.sbuf
)
rhs_tile = nl.ndarray((TILE_K, TILE_N), dtype=rhs.dtype, buffer=nl.sbuf)

# Load tiles from lhsT and rhs
nisa.dma_copy(
dst=lhsT_tile,
src=lhsT[
k * TILE_K : (k + 1) * TILE_K, m * TILE_M : (m + 1) * TILE_M
],
)
nisa.dma_copy(
dst=rhs_tile,
src=rhs[
k * TILE_K : (k + 1) * TILE_K, n * TILE_N : (n + 1) * TILE_N
],
)

# Accumulate partial-sums into PSUM
nisa.nc_matmul(dst=res_psum, stationary=lhsT_tile, moving=rhs_tile)

# Copy the result from PSUM back to SBUF, and cast to expected output data-type
res_sb = nl.ndarray(res_psum.shape, dtype=result.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=res_sb, src=res_psum)

# Copy the result from SBUF to HBM.
nisa.dma_copy(
dst=result[
m * TILE_M : (m + 1) * TILE_M, n * TILE_N : (n + 1) * TILE_N
],
src=res_sb,
)
return result


def _run_with_xla(kernel, kernel_grid, *arrays):
"""Run one beta2 kernel invocation on an XLA device."""
import torch
import torch_xla

device = torch_xla.device()
tensors = [torch.as_tensor(array, device=device) for array in arrays]
compiled_kernel = nki.jit(kernel, platform_target="trn1")
result = compiled_kernel[kernel_grid](*tensors)
torch_xla.sync()
return result.cpu().numpy()


def _run_demo():
"""Run the matmul example with lhsT ``[128, 128]`` and rhs ``[128, 512]``."""
kernel_grid = (1,)
m_dim = 128
k_dim = 128
n_dim = 512
lhs_small = np.arange(k_dim * m_dim, dtype=np.float32).reshape(k_dim, m_dim)
rhs_small = np.arange(k_dim * n_dim, dtype=np.float32).reshape(k_dim, n_dim)
kernel = matmul_kernel

result = np.empty((m_dim, n_dim), dtype=lhs_small.dtype)
kernel_args = (lhs_small, rhs_small, result)
expected = lhs_small.T @ rhs_small

if TRITON_VIZ_ENABLED:
import triton_viz

traced_kernel = triton_viz.trace("tracer", backend="nki_beta2")(kernel)
traced_kernel[kernel_grid](*kernel_args, pre_trace=PRE_TRACE)
assert np.allclose(expected, result)
print("☑️ Actual equals expected!")
triton_viz.launch(share=False)
else:
result = _run_with_xla(kernel, kernel_grid, *kernel_args)
assert np.allclose(expected, result)
print("☑️ Actual equals expected!")


if __name__ == "__main__":
_run_demo()
188 changes: 188 additions & 0 deletions examples/nki_beta2/mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
import nki
import nki.isa as nisa
import nki.language as nl

import numpy as np

TRITON_VIZ_ENABLED = True
PRE_TRACE = True # if True, run the NKI Beta 2 tracer before running interpreter. Can be set to false, though has less guarantees with matching NKI compiler behavior.
TILE_M = 128
TILE_K = 128
TILE_H = 128
TILE_N = 512


def mlp_kernel(x_t, w1, w2, out):
"""Compute tiled ``relu(x_t.T @ w1).T @ w2`` into ``out``."""
k_dim, batch = x_t.shape
k_dim_w1, hidden = w1.shape
hidden_w2, out_dim = w2.shape
assert k_dim == k_dim_w1, "x_t and w1 must share the input dimension"
assert hidden == hidden_w2, "w1 and w2 must share the hidden dimension"

tile_m = TILE_M
tile_k = TILE_K
tile_h = TILE_H
tile_n = TILE_N
assert batch % tile_m == 0, f"Expected batch ({batch}) to be a multiple of {tile_m}"
assert (
k_dim % tile_k == 0
), f"Expected input dim ({k_dim}) to be a multiple of {tile_k}"
assert (
hidden % tile_h == 0
), f"Expected hidden dim ({hidden}) to be a multiple of {tile_h}"
assert (
out_dim % tile_n == 0
), f"Expected output dim ({out_dim}) to be a multiple of {tile_n}"

for batch_idx in nl.affine_range(batch // tile_m):
batch_start = batch_idx * tile_m
for out_idx in nl.affine_range(out_dim // tile_n):
out_start = out_idx * tile_n
out_psum = nl.ndarray((tile_m, tile_n), dtype=nl.float32, buffer=nl.psum)

for hidden_idx in nl.affine_range(hidden // tile_h):
hidden_start = hidden_idx * tile_h
hidden_psum = nl.ndarray(
(tile_m, tile_h), dtype=nl.float32, buffer=nl.psum
)

for k_idx in nl.affine_range(k_dim // tile_k):
k_start = k_idx * tile_k
x_tile = nl.ndarray(
(tile_k, tile_m), dtype=x_t.dtype, buffer=nl.sbuf
)
w1_tile = nl.ndarray(
(tile_k, tile_h), dtype=w1.dtype, buffer=nl.sbuf
)
nisa.dma_copy(
dst=x_tile,
src=x_t[
nl.ds(k_start, tile_k),
nl.ds(batch_start, tile_m),
],
)
nisa.dma_copy(
dst=w1_tile,
src=w1[
nl.ds(k_start, tile_k),
nl.ds(hidden_start, tile_h),
],
)
nisa.nc_matmul(dst=hidden_psum, stationary=x_tile, moving=w1_tile)

hidden_tile = nl.ndarray(
(tile_m, tile_h), dtype=out.dtype, buffer=nl.sbuf
)
hidden_t_psum = nl.ndarray(
(tile_h, tile_m), dtype=out.dtype, buffer=nl.psum
)
hidden_t = nl.ndarray((tile_h, tile_m), dtype=out.dtype, buffer=nl.sbuf)
w2_tile = nl.ndarray((tile_h, tile_n), dtype=w2.dtype, buffer=nl.sbuf)

nisa.tensor_copy(dst=hidden_tile, src=hidden_psum)
nisa.tensor_scalar(
dst=hidden_tile,
data=hidden_tile,
op0=nl.maximum,
operand0=0.0,
)
nisa.nc_transpose(dst=hidden_t_psum, data=hidden_tile)
nisa.tensor_copy(dst=hidden_t, src=hidden_t_psum)
nisa.dma_copy(
dst=w2_tile,
src=w2[
nl.ds(hidden_start, tile_h),
nl.ds(out_start, tile_n),
],
)
nisa.nc_matmul(dst=out_psum, stationary=hidden_t, moving=w2_tile)

out_tile = nl.ndarray((tile_m, tile_n), dtype=out.dtype, buffer=nl.sbuf)
nisa.tensor_copy(dst=out_tile, src=out_psum)
nisa.dma_copy(
dst=out[
nl.ds(batch_start, tile_m),
nl.ds(out_start, tile_n),
],
src=out_tile,
)
return out


def _round_up(size, tile):
"""Round ``size`` up to the nearest ``tile`` multiple."""
return ((size + tile - 1) // tile) * tile


def _pad_matrix(x, shape):
"""Zero-pad a 2D matrix into ``shape``."""
padded = np.zeros(shape, dtype=x.dtype)
padded[: x.shape[0], : x.shape[1]] = x
return padded


def _run_with_xla(kernel, kernel_grid, *arrays):
"""Run one beta2 kernel invocation on an XLA device."""
import torch
import torch_xla

device = torch_xla.device()
tensors = [torch.as_tensor(array, device=device) for array in arrays]
compiled_kernel = nki.jit(kernel, platform_target="trn1")
result = compiled_kernel[kernel_grid](*tensors)
torch_xla.sync()
return result.cpu().numpy()


def _run_demo():
"""Run the tiled MLP example on non-tile-aligned 2D matrices."""
kernel_grid = (1,)
batch = 190
in_dim = 170
hidden = 222
out_dim = 333
tile_m = TILE_M
tile_k = TILE_K
tile_h = TILE_H
tile_n = TILE_N
batch_pad = _round_up(batch, tile_m)
in_dim_pad = _round_up(in_dim, tile_k)
hidden_pad = _round_up(hidden, tile_h)
out_dim_pad = _round_up(out_dim, tile_n)

x = np.linspace(-1.0, 1.0, batch * in_dim, dtype=np.float32).reshape(batch, in_dim)
w1 = np.linspace(-0.5, 0.5, in_dim * hidden, dtype=np.float32).reshape(
in_dim, hidden
)
w2 = np.linspace(-0.25, 0.75, hidden * out_dim, dtype=np.float32).reshape(
hidden, out_dim
)
# add 100s to make the distribution a bit weird to prevent math effects (e.g. dot(randn, randn) ~ 0)
x[0, 0] = 100
w1[0, 0] = 100
w2[0, 0] = 100

x_t = _pad_matrix(x.T, (in_dim_pad, batch_pad))
w1_padded = _pad_matrix(w1, (in_dim_pad, hidden_pad))
w2_padded = _pad_matrix(w2, (hidden_pad, out_dim_pad))
out = np.empty((batch_pad, out_dim_pad), dtype=np.float32)
kernel_args = (x_t, w1_padded, w2_padded, out)
expected = np.maximum(x @ w1, 0.0) @ w2

if TRITON_VIZ_ENABLED:
import triton_viz

traced_kernel = triton_viz.trace("tracer", backend="nki_beta2")(mlp_kernel)
traced_kernel[kernel_grid](*kernel_args, pre_trace=PRE_TRACE)
assert np.allclose(expected, out[:batch, :out_dim], atol=1e-4, rtol=1e-4)
print("☑️ Actual equals expected!")
triton_viz.launch(share=False)
else:
out = _run_with_xla(mlp_kernel, kernel_grid, *kernel_args)
assert np.allclose(expected, out[:batch, :out_dim], atol=1e-4, rtol=1e-4)
print("☑️ Actual equals expected!")


if __name__ == "__main__":
_run_demo()
Loading
Loading