We are currently onboarding DCPP kernel to AMD_350x with Helion. However while running tritonbench, we found Helion DCPP kernel has numerical accuracy difference (1.0 > threshold 0.2) compared to inductor's. we don't see the difference on NVDIA H100 and B200.
One guess we had was Triton AMD performs differently for bf16 (DCPP kernel input dtype) -> fp32 type conversion internally.
We are looking for guidance or suggestions on how to close the accuracy difference.
Accuracy comparison
forward
sometimes pass accuracy test. sometimes failed with same difference as bwd mode test.
backward
Mismatched elements: 37 / 4194304 (0.0%)
Greatest absolute difference: 1.0 at index (3395900,) (up to 0.2 allowed)
Greatest relative difference: 21.875 at index (3497852,) (up to 0.01 allowed)
Versions
triton: ovr_config//triton:beta
ROCm: rocm_arch=mi350 -m rcclx_dev -m rocm70
Triton Code
Forward
Helion generated triton code
from __future__ import annotations
import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher
_BLOCK_SIZE_1 = tl.constexpr(32)
_BLOCK_SIZE_4 = tl.constexpr(32)
@triton.jit
def _helion_helion_dcpp_mm_fwd_impl(x, y, xty, out, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr):
# src[dot_compress.py:189]: for tile_b in hl.tile(B, block_size=1):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_9 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
indices_10 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
# src[dot_compress.py:190]: acc = hl.zeros([D, K], dtype=torch.float32)
acc = tl.full([256, 64], 0.0, tl.float32)
# src[dot_compress.py:191]: for tile_m in hl.tile(M):
# src[dot_compress.py:192]: x_tile = x[tile_b.begin, tile_m, :]
# src[dot_compress.py:193]: y_tile = y[tile_b.begin, tile_m, :]
# src[dot_compress.py:191-198]: ...
for offset_7 in tl.range(0, 3219, _BLOCK_SIZE_1):
indices_7 = offset_7 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_7 < 3219
acc_copy = acc
acc_copy_0 = acc_copy
# src[dot_compress.py:192]: x_tile = x[tile_b.begin, tile_m, :]
x_tile = tl.load(x + (offset_0 * 824064 + indices_7[:, None] * 256 + indices_9[None, :] * 1), mask_1[:, None], other=0)
# src[dot_compress.py:193]: y_tile = y[tile_b.begin, tile_m, :]
y_tile = tl.load(y + (offset_0 * 206016 + indices_7[:, None] * 64 + indices_10[None, :] * 1), mask_1[:, None], other=0)
# src[dot_compress.py:196]: x_tile.t(),
permute = tl.permute(x_tile, [1, 0])
# src[dot_compress.py:194]: acc = torch.addmm(
# src[dot_compress.py:195]: acc,
# src[dot_compress.py:196]: x_tile.t(),
# src[dot_compress.py:194-198]: ...
acc = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(y_tile, tl.bfloat16), acc=acc_copy_0, input_precision='ieee', out_dtype=tl.float32)
# src[dot_compress.py:199]: acc2 = acc.to(x.dtype)
v_0 = tl.cast(acc, tl.bfloat16)
# src[dot_compress.py:201]: xty[tile_b.begin, :, :] = acc2
tl.store(xty + (offset_0 * 16384 + indices_9[:, None] * 64 + indices_10[None, :] * 1), v_0, None)
# src[dot_compress.py:204]: for tile_m2 in hl.tile(M):
# src[dot_compress.py:205]: out[tile_b.begin, tile_m2, :] = torch.matmul(
# src[dot_compress.py:206]: x[tile_b.begin, tile_m2, :], xty[tile_b.begin, :, :]
# src[dot_compress.py:204-207]: ...
for offset_8 in tl.range(0, 3219, _BLOCK_SIZE_4):
indices_8 = offset_8 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
mask_4 = indices_8 < 3219
# src[dot_compress.py:206]: x[tile_b.begin, tile_m2, :], xty[tile_b.begin, :, :]
load = tl.load(x + (offset_0 * 824064 + indices_8[:, None] * 256 + indices_9[None, :] * 1), mask_4[:, None], other=0)
load_1 = tl.load(xty + (offset_0 * 16384 + indices_9[:, None] * 64 + indices_10[None, :] * 1), None)
# src[dot_compress.py:205]: out[tile_b.begin, tile_m2, :] = torch.matmul(
# src[dot_compress.py:206]: x[tile_b.begin, tile_m2, :], xty[tile_b.begin, :, :]
# src[dot_compress.py:207]: ).to(out.dtype)
mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
tl.store(out + (offset_0 * 206016 + indices_8[:, None] * 64 + indices_10[None, :] * 1), mm, mask_4[:, None])
def helion_dcpp_mm_fwd_impl(x: torch.Tensor, y: torch.Tensor, z: Optional[torch.Tensor]=None, *, _launcher=_default_launcher):
"""
Helion kernel for the dot compress forward pass.
Computes the fused operation: out = x @ (x^T @ y) or out = x @ (x^T @ y + z) if z is not None
This kernel fuses two batch matrix multiplications into a single optimized
kernel, reducing memory bandwidth requirements and improving performance.
The intermediate result X^T @ Y is computed first and accumulated in
registers before being used for the final matrix multiplication.
The kernel is auto-tuned with multiple configurations optimized for
different input sizes on B200 hardware with bfloat16 precision.
Args:
x: Input tensor of shape (B, M, D) where:
- B is the batch size
- M is the number of embeddings
- D is the embedding dimension
y: Input tensor of shape (B, M, K) where:
- K is the number of compressed embeddings
z: Optional input tensor of shape (B, D, K) to be added to the intermediate
Returns:
A tuple of:
- out: Output tensor of shape (B, M, K)
- xty: Intermediate tensor of shape (B, D, K), the result of x^T @ y
or x^T @ y + z if z is not None (saved for backward pass)
Note:
D and K are expected to be small relative to M for optimal performance.
Reference: https://fburl.com/code/jccltp1y
"""
# src[dot_compress.py:180]: B, M, D = x.shape
B, M, D = x.shape
# src[dot_compress.py:181]: K = y.shape[2]
K = y.shape[2]
# src[dot_compress.py:182]: D = hl.specialize(D)
D = 256
# src[dot_compress.py:183]: K = hl.specialize(K)
K = 64
# src[dot_compress.py:184]: out = torch.empty((B, M, K), device=x.device, dtype=x.dtype)
out = torch.empty((B, M, K), device=x.device, dtype=x.dtype)
# src[dot_compress.py:185]: xty = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
xty = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
# src[dot_compress.py:186]: assert D <= 1024, "DIM is required to be smaller than 1024"
assert D <= 1024, 'DIM is required to be smaller than 1024'
# src[dot_compress.py:187]: assert K <= 64, "NUM_COMPRESS_EMB is required to be smaller than 64"
assert K <= 64, 'NUM_COMPRESS_EMB is required to be smaller than 64'
# src[dot_compress.py:189]: for tile_b in hl.tile(B, block_size=1):
_RDIM_SIZE_2 = 256
_RDIM_SIZE_3 = 64
# src[dot_compress.py:189]: for tile_b in hl.tile(B, block_size=1):
# src[dot_compress.py:190]: acc = hl.zeros([D, K], dtype=torch.float32)
# src[dot_compress.py:191]: for tile_m in hl.tile(M):
# src[dot_compress.py:189-207]: ...
_launcher(_helion_helion_dcpp_mm_fwd_impl, (1152,), x, y, xty, out, _RDIM_SIZE_2, _RDIM_SIZE_3, num_warps=4, num_stages=1, waves_per_eu=1, matrix_instr_nonkdim=0)
# src[dot_compress.py:208]: return out, xty
return (out, xty)
def call():
from torch._dynamo.testing import rand_strided
# src[dot_compress.py:143]: def helion_dcpp_mm_fwd_impl(
# src[dot_compress.py:144]: x: torch.Tensor,
# src[dot_compress.py:145]: y: torch.Tensor,
# src[dot_compress.py:143-208]: ...
x = rand_strided(size=(1152, 3219, 256), stride=(824064, 256, 1), dtype=torch.bfloat16, device='cuda:0')
y = rand_strided(size=(1152, 3219, 64), stride=(206016, 64, 1), dtype=torch.bfloat16, device='cuda:0')
z = 'UNSUPPORTED TYPE - REPLACE'
helion_dcpp_mm_fwd_impl(x, y, z)
if __name__ == '__main__':
call()
Inductor generated triton code
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /var/tmp/torchinductor_mengjiao/i7/ci7y45y2agln2svmop746feanzn46n3e56sa3bgntzvxegecwprg.py
# Topologically Sorted Source Nodes: [out], Original ATen: [aten.bmm]
# Source node to ATen node mapping:
# out => bmm_1
# Graph fragment:
# %primals_1 : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0" = PlaceHolder[target=primals_1]
# %bmm : Tensor "bf16[1152, 256, 64][16384, 64, 1]cuda:0" = PlaceHolder[target=bmm]
# %bmm_1 : Tensor "bf16[1152, 3219, 64][206016, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%primals_1, %bmm), kwargs = {})
# return %bmm_1
triton_tem_fused_bmm_0 = async_compile.triton('triton_tem_fused_bmm_0', '''
import triton
import triton.language as tl
import triton.language.extra.tlx as tlx # noqa: F401
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=2,
num_warps=8,
triton_meta={'signature': {'arg_A': '*bf16', 'arg_B': '*bf16', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'matrix_instr_nonkdim': 16, 'kpack': 2},
inductor_meta={'kernel_name': 'triton_tem_fused_bmm_0', 'backend_hash': 'AF70BC814F3C641B0867B40FECBA1755585B56CBEE96FC8F140050135BD22880', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'matrix_instr_nonkdim': 16, 'waves_per_eu': 0, 'kpack': 2, 'GROUP_M': 4, 'ALLOW_TF32': False}},
)
@triton.jit
def triton_tem_fused_bmm_0(arg_A, arg_B, out_ptr0):
EVEN_K : tl.constexpr = True
USE_FAST_ACCUM : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
BLOCK_M : tl.constexpr = 256
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 64
matrix_instr_nonkdim : tl.constexpr = 16
waves_per_eu : tl.constexpr = 0
kpack : tl.constexpr = 2
GROUP_M : tl.constexpr = 4
ALLOW_TF32 : tl.constexpr = False
INDEX_DTYPE : tl.constexpr = tl.int32
A = arg_A
B = arg_B
M = 3219
N = 64
K = 256
stride_aq = 824064
stride_am = 256
stride_ak = 1
stride_bq = 16384
stride_bk = 64
stride_bn = 1
# based on triton.ops.matmul
pid = tl.program_id(0).to(INDEX_DTYPE)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + 64*idx_m + 206016*idx_q
tl.store(out_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), acc, mask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2 = args
args.clear()
assert_size_stride(primals_1, (1152, 3219, 256), (824064, 256, 1))
assert_size_stride(primals_2, (1152, 3219, 64), (206016, 64, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1152, 256, 64), (16384, 64, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [permute, xty], Original ATen: [aten.permute, aten.bmm]
extern_kernels.bmm(reinterpret_tensor(primals_1, (1152, 256, 3219), (824064, 1, 256), 0), primals_2, out=buf0)
buf1 = empty_strided_cuda((1152, 3219, 64), (206016, 64, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [out], Original ATen: [aten.bmm]
stream0 = get_raw_stream(0)
triton_tem_fused_bmm_0.run(primals_1, buf0, buf1, 13, 1152, 1, stream=stream0)
return (buf1, reinterpret_tensor(primals_1, (1152, 256, 3219), (824064, 1, 256), 0), reinterpret_tensor(buf0, (1152, 64, 256), (16384, 1, 64), 0), reinterpret_tensor(primals_2, (1152, 64, 3219), (206016, 1, 64), 0), )
def get_args():
from torch._dynamo.testing import rand_strided
primals_1 = rand_strided((1152, 3219, 256), (824064, 256, 1), device='cuda:0', dtype=torch.bfloat16)
primals_2 = rand_strided((1152, 3219, 64), (206016, 64, 1), device='cuda:0', dtype=torch.bfloat16)
return [primals_1, primals_2]
def benchmark_compiled_module(args, times=10, repeat=10):
from torch._inductor.utils import print_performance
fn = lambda: call(list(args))
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
args = get_args()
compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
backward
Helion generated triton code
from __future__ import annotations
import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher
_BLOCK_SIZE_1 = tl.constexpr(16)
_BLOCK_SIZE_4 = tl.constexpr(16)
_BLOCK_SIZE_5 = tl.constexpr(16)
@triton.jit
def _helion_helion_dcpp_bwd_impl(x, d_out, dy, b, y, dx, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr):
# src[dot_compress.py:335]: for tile_b in hl.tile(B, block_size=1):
pid_0 = tl.program_id(0)
offset_0 = pid_0
indices_13 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
indices_14 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
# src[dot_compress.py:337]: db = hl.zeros([D, K], dtype=torch.float32)
db = tl.full([256, 64], 0.0, tl.float32)
# src[dot_compress.py:338]: for tile_m in hl.tile(M):
# src[dot_compress.py:339]: db = torch.addmm(
# src[dot_compress.py:340]: db, x[tile_b.begin, tile_m, :].t(), d_out[tile_b.begin, tile_m, :]
# src[dot_compress.py:338-341]: ...
for offset_10 in tl.range(0, 3219, _BLOCK_SIZE_1):
indices_10 = offset_10 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_10 < 3219
db_copy = db
db_copy_0 = db_copy
# src[dot_compress.py:340]: db, x[tile_b.begin, tile_m, :].t(), d_out[tile_b.begin, tile_m, :]
load = tl.load(x + (offset_0 * 824064 + indices_10[:, None] * 256 + indices_13[None, :] * 1), mask_1[:, None], other=0)
permute = tl.permute(load, [1, 0])
load_1 = tl.load(d_out + (offset_0 * 206016 + indices_10[:, None] * 64 + indices_14[None, :] * 1), mask_1[:, None], other=0)
# src[dot_compress.py:339]: db = torch.addmm(
# src[dot_compress.py:340]: db, x[tile_b.begin, tile_m, :].t(), d_out[tile_b.begin, tile_m, :]
# src[dot_compress.py:341]: )
db = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=db_copy_0, input_precision='ieee', out_dtype=tl.float32)
# src[dot_compress.py:342]: db2 = db.to(x.dtype)
v_0 = tl.cast(db, tl.bfloat16)
# src[dot_compress.py:349]: for tile_m2 in hl.tile(M):
# src[dot_compress.py:350]: dy[tile_b.begin, tile_m2, :] = torch.matmul(
# src[dot_compress.py:351]: x[tile_b.begin, tile_m2, :], db2
# src[dot_compress.py:349-352]: ...
for offset_11 in tl.range(0, 3219, _BLOCK_SIZE_4):
indices_11 = offset_11 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
mask_4 = indices_11 < 3219
v_0_copy = v_0
v_0_copy_0 = v_0_copy
# src[dot_compress.py:351]: x[tile_b.begin, tile_m2, :], db2
load_2 = tl.load(x + (offset_0 * 824064 + indices_11[:, None] * 256 + indices_13[None, :] * 1), mask_4[:, None], other=0)
# src[dot_compress.py:350]: dy[tile_b.begin, tile_m2, :] = torch.matmul(
# src[dot_compress.py:351]: x[tile_b.begin, tile_m2, :], db2
# src[dot_compress.py:352]: )
mm = tl.cast(tl.dot(tl.cast(load_2, tl.bfloat16), tl.cast(v_0_copy_0, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
tl.store(dy + (offset_0 * 206016 + indices_11[:, None] * 64 + indices_14[None, :] * 1), mm, mask_4[:, None])
# src[dot_compress.py:355]: bt = b[tile_b.begin, :, :].t()
load_3 = tl.load(b + (offset_0 * 16384 + indices_13[:, None] * 64 + indices_14[None, :] * 1), None)
bt = tl.permute(load_3, [1, 0])
# src[dot_compress.py:356]: for tile_m3 in hl.tile(M):
# src[dot_compress.py:357]: dx1 = torch.matmul(d_out[tile_b.begin, tile_m3, :], bt)
# src[dot_compress.py:358]: dxt = torch.matmul(db2, y[tile_b.begin, tile_m3, :].t())
# src[dot_compress.py:356-359]: ...
for offset_12 in tl.range(0, 3219, _BLOCK_SIZE_5):
indices_12 = offset_12 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32)
mask_5 = indices_12 < 3219
bt_copy = bt
v_0_copy_1 = v_0
bt_copy_0 = bt_copy
v_0_copy_1_0 = v_0_copy_1
# src[dot_compress.py:357]: dx1 = torch.matmul(d_out[tile_b.begin, tile_m3, :], bt)
load_4 = tl.load(d_out + (offset_0 * 206016 + indices_12[:, None] * 64 + indices_14[None, :] * 1), mask_5[:, None], other=0)
dx1 = tl.cast(tl.dot(tl.cast(load_4, tl.bfloat16), tl.cast(bt_copy_0, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
# src[dot_compress.py:358]: dxt = torch.matmul(db2, y[tile_b.begin, tile_m3, :].t())
load_5 = tl.load(y + (offset_0 * 206016 + indices_12[:, None] * 64 + indices_14[None, :] * 1), mask_5[:, None], other=0)
permute_1 = tl.permute(load_5, [1, 0])
dxt = tl.cast(tl.dot(tl.cast(v_0_copy_1_0, tl.bfloat16), tl.cast(permute_1, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
# src[dot_compress.py:359]: dx[tile_b.begin, tile_m3, :] = (dx1 + dxt.t()).to(dx.dtype)
permute_2 = tl.permute(dxt, [1, 0])
v_1 = dx1 + permute_2
tl.store(dx + (offset_0 * 824064 + indices_12[:, None] * 256 + indices_13[None, :] * 1), v_1, mask_5[:, None])
def helion_dcpp_bwd_impl(d_out: torch.Tensor, x: torch.Tensor, y: torch.Tensor, b: torch.Tensor, z: Optional[torch.Tensor]=None, *, _launcher=_default_launcher):
"""
Helion kernel for the dot compress backward pass.
Computes gradients for the dot compress operation with respect to
inputs x, y and z (optional).
The backward pass computes:
- db = x^T @ dout
- dy = x @ db
- dx = dout @ b^T + y @ db^T
- dz = db (if z is not None)
The kernel is auto-tuned with multiple configurations optimized for
different input sizes on B200 hardware with bfloat16 precision.
Args:
d_out: Gradient of loss w.r.t. output, shape (B, M, K).
x: Input tensor from forward pass, shape (B, M, D).
y: Input tensor from forward pass, shape (B, M, K).
b: Intermediate tensor (x^T @ y) or (x^T @ y + z) if z is not None
from forward pass, shape (B, D, K).
z: (Optional) Input tensor from forward pass, shape (B, D, K).
Returns:
A tuple of:
- dx: Gradient w.r.t. x, shape (B, M, D)
- dy: Gradient w.r.t. y, shape (B, M, K)
- dz: Gradient w.r.t. z, shape (B, D, K), None if z was None
"""
# src[dot_compress.py:325]: B, M, D = x.shape
B, M, D = x.shape
# src[dot_compress.py:326]: K = y.shape[2]
K = y.shape[2]
# src[dot_compress.py:327]: D = hl.specialize(D)
D = 256
# src[dot_compress.py:328]: K = hl.specialize(K)
K = 64
# src[dot_compress.py:330]: dx = torch.empty((B, M, D), device=x.device, dtype=x.dtype)
dx = torch.empty((B, M, D), device=x.device, dtype=x.dtype)
# src[dot_compress.py:331]: dy = torch.empty((B, M, K), device=y.device, dtype=y.dtype)
dy = torch.empty((B, M, K), device=y.device, dtype=y.dtype)
# src[dot_compress.py:333]: dz = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
dz = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
# src[dot_compress.py:335]: for tile_b in hl.tile(B, block_size=1):
_RDIM_SIZE_2 = 256
_RDIM_SIZE_3 = 64
# src[dot_compress.py:335]: for tile_b in hl.tile(B, block_size=1):
# src[dot_compress.py:336]: # compute db = x^T @ dout
# src[dot_compress.py:337]: db = hl.zeros([D, K], dtype=torch.float32)
# src[dot_compress.py:335-359]: ...
_launcher(_helion_helion_dcpp_bwd_impl, (1152,), x, d_out, dy, b, y, dx, _RDIM_SIZE_2, _RDIM_SIZE_3, num_warps=4, num_stages=1, waves_per_eu=1, matrix_instr_nonkdim=0)
# src[dot_compress.py:361]: if z is not None:
# src[dot_compress.py:362]: return dx, dy, dz
# src[dot_compress.py:363]: else:
# src[dot_compress.py:361-364]: ...
if z is not None:
# src[dot_compress.py:362]: return dx, dy, dz
return (dx, dy, dz)
else:
# src[dot_compress.py:364]: return dx, dy, None
return (dx, dy, None)
def call():
from torch._dynamo.testing import rand_strided
# src[dot_compress.py:289]: def helion_dcpp_bwd_impl(
# src[dot_compress.py:290]: d_out: torch.Tensor,
# src[dot_compress.py:291]: x: torch.Tensor,
# src[dot_compress.py:289-364]: ...
d_out = rand_strided(size=(1152, 3219, 64), stride=(206016, 64, 1), dtype=torch.bfloat16, device='cuda:0')
x = rand_strided(size=(1152, 3219, 256), stride=(824064, 256, 1), dtype=torch.bfloat16, device='cuda:0')
y = rand_strided(size=(1152, 3219, 64), stride=(206016, 64, 1), dtype=torch.bfloat16, device='cuda:0')
b = rand_strided(size=(1152, 256, 64), stride=(16384, 64, 1), dtype=torch.bfloat16, device='cuda:0')
z = 'UNSUPPORTED TYPE - REPLACE'
helion_dcpp_bwd_impl(d_out, x, y, b, z)
if __name__ == '__main__':
call()
Inductor generated triton code
from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p
# kernel path: /var/tmp/torchinductor_mengjiao/u7/cu7bfpex2bd6ohilcqfmln4beie43j7wm5br6hx7436vegehl7vo.py
# Topologically Sorted Source Nodes: [permute_3, bmm_4], Original ATen: [aten.transpose, aten.bmm]
# Source node to ATen node mapping:
# bmm_4 => bmm_4
# permute_3 => permute_3
# Graph fragment:
# %permute : Tensor "bf16[1152, 256, 3219][824064, 1, 256]cuda:0" = PlaceHolder[target=permute]
# %bmm_2 : Tensor "bf16[1152, 256, 64][16384, 64, 1]cuda:0" = PlaceHolder[target=bmm_2]
# %permute_3 : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute, [0, 2, 1]), kwargs = {})
# %bmm_4 : Tensor "bf16[1152, 3219, 64][206016, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%permute_3, %bmm_2), kwargs = {})
# return %bmm_4
triton_tem_fused_bmm_transpose_0 = async_compile.triton('triton_tem_fused_bmm_transpose_0', '''
import triton
import triton.language as tl
import triton.language.extra.tlx as tlx # noqa: F401
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
@triton_heuristics.template(
num_stages=2,
num_warps=8,
triton_meta={'signature': {'arg_A': '*bf16', 'arg_B': '*bf16', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'matrix_instr_nonkdim': 16, 'kpack': 2},
inductor_meta={'kernel_name': 'triton_tem_fused_bmm_transpose_0', 'backend_hash': 'AF70BC814F3C641B0867B40FECBA1755585B56CBEE96FC8F140050135BD22880', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'matrix_instr_nonkdim': 16, 'waves_per_eu': 0, 'kpack': 2, 'GROUP_M': 4, 'ALLOW_TF32': False}},
)
@triton.jit
def triton_tem_fused_bmm_transpose_0(arg_A, arg_B, out_ptr0):
EVEN_K : tl.constexpr = True
USE_FAST_ACCUM : tl.constexpr = False
ACC_TYPE : tl.constexpr = tl.float32
BLOCK_M : tl.constexpr = 256
BLOCK_N : tl.constexpr = 64
BLOCK_K : tl.constexpr = 64
matrix_instr_nonkdim : tl.constexpr = 16
waves_per_eu : tl.constexpr = 0
kpack : tl.constexpr = 2
GROUP_M : tl.constexpr = 4
ALLOW_TF32 : tl.constexpr = False
INDEX_DTYPE : tl.constexpr = tl.int32
A = arg_A
B = arg_B
M = 3219
N = 64
K = 256
stride_aq = 824064
stride_am = 256
stride_ak = 1
stride_bq = 16384
stride_bk = 64
stride_bn = 1
# based on triton.ops.matmul
pid = tl.program_id(0).to(INDEX_DTYPE)
grid_m = (M + BLOCK_M - 1) // BLOCK_M
grid_n = (N + BLOCK_N - 1) // BLOCK_N
# re-order program ID for better L2 performance
width = GROUP_M * grid_n
group_id = pid // width
group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
pid_m = group_id * GROUP_M + (pid % group_size)
pid_n = (pid % width) // (group_size)
tl.assume(pid_m >= 0)
tl.assume(pid_n >= 0)
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
else:
ram = rm % M
if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
else:
rbn = rn % N
rk = tl.arange(0, BLOCK_K)
idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM
A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
for k in range(K, 0, -BLOCK_K):
if EVEN_K:
a = tl.load(A)
b = tl.load(B)
else:
a = tl.load(A, mask=rk[None, :] < k, other=0.)
b = tl.load(B, mask=rk[:, None] < k, other=0.)
acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
A += BLOCK_K * stride_ak
B += BLOCK_K * stride_bk
# rematerialize rm and rn to save registers
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
idx_q = tl.program_id(1).to(INDEX_DTYPE) # batch dimension for BMM
idx_m = rm[:, None]
idx_n = rn[None, :]
mask = (idx_m < M) & (idx_n < N)
# inductor generates a suffix
xindex = idx_n + 64*idx_m + 206016*idx_q
tl.store(out_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), acc, mask)
''', device_str='cuda')
# kernel path: /var/tmp/torchinductor_mengjiao/c2/cc2tzxsihja7up6uwtuezwism6hsficj26xqb6256iphybq23ko5.py
# Topologically Sorted Source Nodes: [permute_5, add], Original ATen: [aten.permute, aten.add]
# Source node to ATen node mapping:
# add => add
# permute_5 => permute_5
# Graph fragment:
# %bmm_3 : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0" = PlaceHolder[target=bmm_3]
# %bmm_5 : Tensor "bf16[1152, 256, 3219][824064, 3219, 1]cuda:0" = PlaceHolder[target=bmm_5]
# %permute_5 : Tensor "bf16[1152, 3219, 256][824064, 1, 3219]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%bmm_5, [0, 2, 1]), kwargs = {})
# %add : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%bmm_3, %permute_5), kwargs = {})
# return %add
triton_poi_fused_add_permute_1 = async_compile.triton('triton_poi_fused_add_permute_1', '''
import triton
import triton.language as tl
import triton.language.extra.tlx as tlx # noqa: F401
from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()
@triton_heuristics.pointwise(
size_hints={'y': 4194304, 'x': 256}, tile_hint=TileHint.SQUARE,
filename=__file__,
triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
inductor_meta={'grid_type': 'Grid2DWithYZOverflow', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_permute_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'AF70BC814F3C641B0867B40FECBA1755585B56CBEE96FC8F140050135BD22880', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True},
min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_permute_1(in_out_ptr0, in_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
ynumel = 3708288
xnumel = 256
yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
ymask = yindex < ynumel
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
xmask = xindex < xnumel
x2 = xindex
y3 = yindex
y0 = (yindex % 3219)
y1 = yindex // 3219
tmp0 = tl.load(in_out_ptr0 + (x2 + 256*y3), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
tmp1 = tl.load(in_ptr0 + (y0 + 3219*x2 + 824064*y1), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
tmp2 = tmp0 + tmp1
tl.debug_barrier()
tl.store(in_out_ptr0 + (x2 + 256*y3), tmp2, xmask & ymask)
''', device_str='cuda')
async_compile.wait(globals())
del async_compile
def call(args):
permute, permute_2, permute_4, tangents_1 = args
args.clear()
assert_size_stride(permute, (1152, 256, 3219), (824064, 1, 256))
assert_size_stride(permute_2, (1152, 64, 256), (16384, 1, 64))
assert_size_stride(permute_4, (1152, 64, 3219), (206016, 1, 64))
assert_size_stride(tangents_1, (1152, 3219, 64), (206016, 64, 1))
with torch.cuda._DeviceGuard(0):
torch.cuda.set_device(0)
buf0 = empty_strided_cuda((1152, 256, 64), (16384, 64, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [bmm_2], Original ATen: [aten.bmm]
extern_kernels.bmm(permute, tangents_1, out=buf0)
buf1 = empty_strided_cuda((1152, 3219, 256), (824064, 256, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [bmm_3], Original ATen: [aten.bmm]
extern_kernels.bmm(tangents_1, permute_2, out=buf1)
del permute_2
del tangents_1
buf2 = empty_strided_cuda((1152, 3219, 64), (206016, 64, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [permute_3, bmm_4], Original ATen: [aten.transpose, aten.bmm]
stream0 = get_raw_stream(0)
triton_tem_fused_bmm_transpose_0.run(permute, buf0, buf2, 13, 1152, 1, stream=stream0)
del permute
buf3 = empty_strided_cuda((1152, 256, 3219), (824064, 3219, 1), torch.bfloat16)
# Topologically Sorted Source Nodes: [bmm_5], Original ATen: [aten.bmm]
extern_kernels.bmm(buf0, permute_4, out=buf3)
del buf0
del permute_4
buf4 = buf1; del buf1 # reuse
# Topologically Sorted Source Nodes: [permute_5, add], Original ATen: [aten.permute, aten.add]
stream0 = get_raw_stream(0)
triton_poi_fused_add_permute_1.run(buf4, buf3, 3708288, 256, stream=stream0)
del buf3
return (buf4, buf2, )
def get_args():
from torch._dynamo.testing import rand_strided
permute = rand_strided((1152, 256, 3219), (824064, 1, 256), device='cuda:0', dtype=torch.bfloat16)
permute_2 = rand_strided((1152, 64, 256), (16384, 1, 64), device='cuda:0', dtype=torch.bfloat16)
permute_4 = rand_strided((1152, 64, 3219), (206016, 1, 64), device='cuda:0', dtype=torch.bfloat16)
tangents_1 = rand_strided((1152, 3219, 64), (206016, 64, 1), device='cuda:0', dtype=torch.bfloat16)
return [permute, permute_2, permute_4, tangents_1]
def benchmark_compiled_module(args, times=10, repeat=10):
from torch._inductor.utils import print_performance
fn = lambda: call(list(args))
return print_performance(fn, times=times, repeat=repeat)
if __name__ == "__main__":
from torch._inductor.wrapper_benchmark import compiled_module_main
args = get_args()
compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))
We are currently onboarding DCPP kernel to AMD_350x with Helion. However while running tritonbench, we found Helion DCPP kernel has numerical accuracy difference (1.0 > threshold 0.2) compared to inductor's. we don't see the difference on NVDIA H100 and B200.
One guess we had was Triton AMD performs differently for bf16 (DCPP kernel input dtype) -> fp32 type conversion internally.
We are looking for guidance or suggestions on how to close the accuracy difference.
Accuracy comparison
forward
sometimes pass accuracy test. sometimes failed with same difference as bwd mode test.
backward
Mismatched elements: 37 / 4194304 (0.0%)
Greatest absolute difference: 1.0 at index (3395900,) (up to 0.2 allowed)
Greatest relative difference: 21.875 at index (3497852,) (up to 0.01 allowed)
Versions
triton: ovr_config//triton:beta
ROCm: rocm_arch=mi350 -m rcclx_dev -m rocm70
Triton Code
Forward
Helion generated triton code
Inductor generated triton code
backward
Helion generated triton code
Inductor generated triton code