Skip to content

[AMD 350X] Helion DCPP kernel has numerical accuracy discrepancy compared to inductor's #1661

@MengjiaoZhou

Description

@MengjiaoZhou

We are currently onboarding DCPP kernel to AMD_350x with Helion. However while running tritonbench, we found Helion DCPP kernel has numerical accuracy difference (1.0 > threshold 0.2) compared to inductor's. we don't see the difference on NVDIA H100 and B200.
One guess we had was Triton AMD performs differently for bf16 (DCPP kernel input dtype) -> fp32 type conversion internally.
We are looking for guidance or suggestions on how to close the accuracy difference.

Accuracy comparison

forward
sometimes pass accuracy test. sometimes failed with same difference as bwd mode test.

backward
Mismatched elements: 37 / 4194304 (0.0%)
Greatest absolute difference: 1.0 at index (3395900,) (up to 0.2 allowed)
Greatest relative difference: 21.875 at index (3497852,) (up to 0.01 allowed)

Versions

triton: ovr_config//triton:beta
ROCm: rocm_arch=mi350 -m rcclx_dev -m rocm70

Triton Code

Forward

Helion generated triton code

from __future__ import annotations
import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher
_BLOCK_SIZE_1 = tl.constexpr(32)
_BLOCK_SIZE_4 = tl.constexpr(32)
@triton.jit
def _helion_helion_dcpp_mm_fwd_impl(x, y, xty, out, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr):
    # src[dot_compress.py:189]: for tile_b in hl.tile(B, block_size=1):
    pid_0 = tl.program_id(0)
    offset_0 = pid_0
    indices_9 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
    indices_10 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
    # src[dot_compress.py:190]: acc = hl.zeros([D, K], dtype=torch.float32)
    acc = tl.full([256, 64], 0.0, tl.float32)
    # src[dot_compress.py:191]: for tile_m in hl.tile(M):
    # src[dot_compress.py:192]:     x_tile = x[tile_b.begin, tile_m, :]
    # src[dot_compress.py:193]:     y_tile = y[tile_b.begin, tile_m, :]
    # src[dot_compress.py:191-198]: ...
    for offset_7 in tl.range(0, 3219, _BLOCK_SIZE_1):
        indices_7 = offset_7 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
        mask_1 = indices_7 < 3219
        acc_copy = acc
        acc_copy_0 = acc_copy
        # src[dot_compress.py:192]: x_tile = x[tile_b.begin, tile_m, :]
        x_tile = tl.load(x + (offset_0 * 824064 + indices_7[:, None] * 256 + indices_9[None, :] * 1), mask_1[:, None], other=0)
        # src[dot_compress.py:193]: y_tile = y[tile_b.begin, tile_m, :]
        y_tile = tl.load(y + (offset_0 * 206016 + indices_7[:, None] * 64 + indices_10[None, :] * 1), mask_1[:, None], other=0)
        # src[dot_compress.py:196]: x_tile.t(),
        permute = tl.permute(x_tile, [1, 0])
        # src[dot_compress.py:194]: acc = torch.addmm(
        # src[dot_compress.py:195]:     acc,
        # src[dot_compress.py:196]:     x_tile.t(),
        # src[dot_compress.py:194-198]: ...
        acc = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(y_tile, tl.bfloat16), acc=acc_copy_0, input_precision='ieee', out_dtype=tl.float32)
    # src[dot_compress.py:199]: acc2 = acc.to(x.dtype)
    v_0 = tl.cast(acc, tl.bfloat16)
    # src[dot_compress.py:201]: xty[tile_b.begin, :, :] = acc2
    tl.store(xty + (offset_0 * 16384 + indices_9[:, None] * 64 + indices_10[None, :] * 1), v_0, None)
    # src[dot_compress.py:204]: for tile_m2 in hl.tile(M):
    # src[dot_compress.py:205]:     out[tile_b.begin, tile_m2, :] = torch.matmul(
    # src[dot_compress.py:206]:         x[tile_b.begin, tile_m2, :], xty[tile_b.begin, :, :]
    # src[dot_compress.py:204-207]: ...
    for offset_8 in tl.range(0, 3219, _BLOCK_SIZE_4):
        indices_8 = offset_8 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
        mask_4 = indices_8 < 3219
        # src[dot_compress.py:206]: x[tile_b.begin, tile_m2, :], xty[tile_b.begin, :, :]
        load = tl.load(x + (offset_0 * 824064 + indices_8[:, None] * 256 + indices_9[None, :] * 1), mask_4[:, None], other=0)
        load_1 = tl.load(xty + (offset_0 * 16384 + indices_9[:, None] * 64 + indices_10[None, :] * 1), None)
        # src[dot_compress.py:205]: out[tile_b.begin, tile_m2, :] = torch.matmul(
        # src[dot_compress.py:206]:     x[tile_b.begin, tile_m2, :], xty[tile_b.begin, :, :]
        # src[dot_compress.py:207]: ).to(out.dtype)
        mm = tl.cast(tl.dot(tl.cast(load, tl.bfloat16), tl.cast(load_1, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
        tl.store(out + (offset_0 * 206016 + indices_8[:, None] * 64 + indices_10[None, :] * 1), mm, mask_4[:, None])
def helion_dcpp_mm_fwd_impl(x: torch.Tensor, y: torch.Tensor, z: Optional[torch.Tensor]=None, *, _launcher=_default_launcher):
    """
    Helion kernel for the dot compress forward pass.
    Computes the fused operation: out = x @ (x^T @ y) or out = x @ (x^T @ y + z) if z is not None
    This kernel fuses two batch matrix multiplications into a single optimized
    kernel, reducing memory bandwidth requirements and improving performance.
    The intermediate result X^T @ Y is computed first and accumulated in
    registers before being used for the final matrix multiplication.
    The kernel is auto-tuned with multiple configurations optimized for
    different input sizes on B200 hardware with bfloat16 precision.
    Args:
        x: Input tensor of shape (B, M, D) where:
           - B is the batch size
           - M is the number of embeddings
           - D is the embedding dimension
        y: Input tensor of shape (B, M, K) where:
           - K is the number of compressed embeddings
        z: Optional input tensor of shape (B, D, K) to be added to the intermediate
    Returns:
        A tuple of:
            - out: Output tensor of shape (B, M, K)
            - xty: Intermediate tensor of shape (B, D, K), the result of x^T @ y
                   or x^T @ y + z if z is not None (saved for backward pass)
    Note:
        D and K are expected to be small relative to M for optimal performance.
        Reference: https://fburl.com/code/jccltp1y
    """
    # src[dot_compress.py:180]: B, M, D = x.shape
    B, M, D = x.shape
    # src[dot_compress.py:181]: K = y.shape[2]
    K = y.shape[2]
    # src[dot_compress.py:182]: D = hl.specialize(D)
    D = 256
    # src[dot_compress.py:183]: K = hl.specialize(K)
    K = 64
    # src[dot_compress.py:184]: out = torch.empty((B, M, K), device=x.device, dtype=x.dtype)
    out = torch.empty((B, M, K), device=x.device, dtype=x.dtype)
    # src[dot_compress.py:185]: xty = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
    xty = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
    # src[dot_compress.py:186]: assert D <= 1024, "DIM is required to be smaller than 1024"
    assert D <= 1024, 'DIM is required to be smaller than 1024'
    # src[dot_compress.py:187]: assert K <= 64, "NUM_COMPRESS_EMB is required to be smaller than 64"
    assert K <= 64, 'NUM_COMPRESS_EMB is required to be smaller than 64'
    # src[dot_compress.py:189]: for tile_b in hl.tile(B, block_size=1):
    _RDIM_SIZE_2 = 256
    _RDIM_SIZE_3 = 64
    # src[dot_compress.py:189]: for tile_b in hl.tile(B, block_size=1):
    # src[dot_compress.py:190]:     acc = hl.zeros([D, K], dtype=torch.float32)
    # src[dot_compress.py:191]:     for tile_m in hl.tile(M):
    # src[dot_compress.py:189-207]: ...
    _launcher(_helion_helion_dcpp_mm_fwd_impl, (1152,), x, y, xty, out, _RDIM_SIZE_2, _RDIM_SIZE_3, num_warps=4, num_stages=1, waves_per_eu=1, matrix_instr_nonkdim=0)
    # src[dot_compress.py:208]: return out, xty
    return (out, xty)
def call():
    from torch._dynamo.testing import rand_strided
    # src[dot_compress.py:143]: def helion_dcpp_mm_fwd_impl(
    # src[dot_compress.py:144]:     x: torch.Tensor,
    # src[dot_compress.py:145]:     y: torch.Tensor,
    # src[dot_compress.py:143-208]: ...
    x = rand_strided(size=(1152, 3219, 256), stride=(824064, 256, 1), dtype=torch.bfloat16, device='cuda:0')
    y = rand_strided(size=(1152, 3219, 64), stride=(206016, 64, 1), dtype=torch.bfloat16, device='cuda:0')
    z = 'UNSUPPORTED TYPE - REPLACE'
    helion_dcpp_mm_fwd_impl(x, y, z)
if __name__ == '__main__':
    call()

Inductor generated triton code

from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
from torch._C import _cuda_getCurrentRawStream as get_raw_stream
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p


# kernel path: /var/tmp/torchinductor_mengjiao/i7/ci7y45y2agln2svmop746feanzn46n3e56sa3bgntzvxegecwprg.py
# Topologically Sorted Source Nodes: [out], Original ATen: [aten.bmm]
# Source node to ATen node mapping:
#   out => bmm_1
# Graph fragment:
#   %primals_1 : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0" = PlaceHolder[target=primals_1]
#   %bmm : Tensor "bf16[1152, 256, 64][16384, 64, 1]cuda:0" = PlaceHolder[target=bmm]
#   %bmm_1 : Tensor "bf16[1152, 3219, 64][206016, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%primals_1, %bmm), kwargs = {})
#   return %bmm_1
triton_tem_fused_bmm_0 = async_compile.triton('triton_tem_fused_bmm_0', '''
import triton
import triton.language as tl

import triton.language.extra.tlx as tlx  # noqa: F401

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.template(

num_stages=2,
num_warps=8,
triton_meta={'signature': {'arg_A': '*bf16', 'arg_B': '*bf16', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'matrix_instr_nonkdim': 16, 'kpack': 2},
inductor_meta={'kernel_name': 'triton_tem_fused_bmm_0', 'backend_hash': 'AF70BC814F3C641B0867B40FECBA1755585B56CBEE96FC8F140050135BD22880', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'matrix_instr_nonkdim': 16, 'waves_per_eu': 0, 'kpack': 2, 'GROUP_M': 4, 'ALLOW_TF32': False}},

)
@triton.jit
def triton_tem_fused_bmm_0(arg_A, arg_B, out_ptr0):
    EVEN_K : tl.constexpr = True
    USE_FAST_ACCUM : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    BLOCK_M : tl.constexpr = 256
    BLOCK_N : tl.constexpr = 64
    BLOCK_K : tl.constexpr = 64
    matrix_instr_nonkdim : tl.constexpr = 16
    waves_per_eu : tl.constexpr = 0
    kpack : tl.constexpr = 2
    GROUP_M : tl.constexpr = 4
    ALLOW_TF32 : tl.constexpr = False
    INDEX_DTYPE : tl.constexpr = tl.int32
    A = arg_A
    B = arg_B

    M = 3219
    N = 64
    K = 256

    stride_aq = 824064
    stride_am = 256
    stride_ak = 1

    stride_bq = 16384
    stride_bk = 64
    stride_bn = 1

    # based on triton.ops.matmul
    pid = tl.program_id(0).to(INDEX_DTYPE)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
        ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    else:
        ram = rm % M
    if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
        rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    else:
        rbn = rn % N

    rk = tl.arange(0, BLOCK_K)

    idx_q = tl.program_id(1).to(INDEX_DTYPE)  # batch dimension for BMM
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_q = tl.program_id(1).to(INDEX_DTYPE)  # batch dimension for BMM
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + 64*idx_m + 206016*idx_q
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), acc, mask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    primals_1, primals_2 = args
    args.clear()
    assert_size_stride(primals_1, (1152, 3219, 256), (824064, 256, 1))
    assert_size_stride(primals_2, (1152, 3219, 64), (206016, 64, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1152, 256, 64), (16384, 64, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [permute, xty], Original ATen: [aten.permute, aten.bmm]
        extern_kernels.bmm(reinterpret_tensor(primals_1, (1152, 256, 3219), (824064, 1, 256), 0), primals_2, out=buf0)
        buf1 = empty_strided_cuda((1152, 3219, 64), (206016, 64, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [out], Original ATen: [aten.bmm]
        stream0 = get_raw_stream(0)
        triton_tem_fused_bmm_0.run(primals_1, buf0, buf1, 13, 1152, 1, stream=stream0)
    return (buf1, reinterpret_tensor(primals_1, (1152, 256, 3219), (824064, 1, 256), 0), reinterpret_tensor(buf0, (1152, 64, 256), (16384, 1, 64), 0), reinterpret_tensor(primals_2, (1152, 64, 3219), (206016, 1, 64), 0), )


def get_args():
    from torch._dynamo.testing import rand_strided
    primals_1 = rand_strided((1152, 3219, 256), (824064, 256, 1), device='cuda:0', dtype=torch.bfloat16)
    primals_2 = rand_strided((1152, 3219, 64), (206016, 64, 1), device='cuda:0', dtype=torch.bfloat16)
    return [primals_1, primals_2]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))

backward
Helion generated triton code

from __future__ import annotations
import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher
_BLOCK_SIZE_1 = tl.constexpr(16)
_BLOCK_SIZE_4 = tl.constexpr(16)
_BLOCK_SIZE_5 = tl.constexpr(16)
@triton.jit
def _helion_helion_dcpp_bwd_impl(x, d_out, dy, b, y, dx, _RDIM_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr):
    # src[dot_compress.py:335]: for tile_b in hl.tile(B, block_size=1):
    pid_0 = tl.program_id(0)
    offset_0 = pid_0
    indices_13 = tl.arange(0, _RDIM_SIZE_2).to(tl.int32)
    indices_14 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
    # src[dot_compress.py:337]: db = hl.zeros([D, K], dtype=torch.float32)
    db = tl.full([256, 64], 0.0, tl.float32)
    # src[dot_compress.py:338]: for tile_m in hl.tile(M):
    # src[dot_compress.py:339]:     db = torch.addmm(
    # src[dot_compress.py:340]:         db, x[tile_b.begin, tile_m, :].t(), d_out[tile_b.begin, tile_m, :]
    # src[dot_compress.py:338-341]: ...
    for offset_10 in tl.range(0, 3219, _BLOCK_SIZE_1):
        indices_10 = offset_10 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
        mask_1 = indices_10 < 3219
        db_copy = db
        db_copy_0 = db_copy
        # src[dot_compress.py:340]: db, x[tile_b.begin, tile_m, :].t(), d_out[tile_b.begin, tile_m, :]
        load = tl.load(x + (offset_0 * 824064 + indices_10[:, None] * 256 + indices_13[None, :] * 1), mask_1[:, None], other=0)
        permute = tl.permute(load, [1, 0])
        load_1 = tl.load(d_out + (offset_0 * 206016 + indices_10[:, None] * 64 + indices_14[None, :] * 1), mask_1[:, None], other=0)
        # src[dot_compress.py:339]: db = torch.addmm(
        # src[dot_compress.py:340]:     db, x[tile_b.begin, tile_m, :].t(), d_out[tile_b.begin, tile_m, :]
        # src[dot_compress.py:341]: )
        db = tl.dot(tl.cast(permute, tl.bfloat16), tl.cast(load_1, tl.bfloat16), acc=db_copy_0, input_precision='ieee', out_dtype=tl.float32)
    # src[dot_compress.py:342]: db2 = db.to(x.dtype)
    v_0 = tl.cast(db, tl.bfloat16)
    # src[dot_compress.py:349]: for tile_m2 in hl.tile(M):
    # src[dot_compress.py:350]:     dy[tile_b.begin, tile_m2, :] = torch.matmul(
    # src[dot_compress.py:351]:         x[tile_b.begin, tile_m2, :], db2
    # src[dot_compress.py:349-352]: ...
    for offset_11 in tl.range(0, 3219, _BLOCK_SIZE_4):
        indices_11 = offset_11 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
        mask_4 = indices_11 < 3219
        v_0_copy = v_0
        v_0_copy_0 = v_0_copy
        # src[dot_compress.py:351]: x[tile_b.begin, tile_m2, :], db2
        load_2 = tl.load(x + (offset_0 * 824064 + indices_11[:, None] * 256 + indices_13[None, :] * 1), mask_4[:, None], other=0)
        # src[dot_compress.py:350]: dy[tile_b.begin, tile_m2, :] = torch.matmul(
        # src[dot_compress.py:351]:     x[tile_b.begin, tile_m2, :], db2
        # src[dot_compress.py:352]: )
        mm = tl.cast(tl.dot(tl.cast(load_2, tl.bfloat16), tl.cast(v_0_copy_0, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
        tl.store(dy + (offset_0 * 206016 + indices_11[:, None] * 64 + indices_14[None, :] * 1), mm, mask_4[:, None])
    # src[dot_compress.py:355]: bt = b[tile_b.begin, :, :].t()
    load_3 = tl.load(b + (offset_0 * 16384 + indices_13[:, None] * 64 + indices_14[None, :] * 1), None)
    bt = tl.permute(load_3, [1, 0])
    # src[dot_compress.py:356]: for tile_m3 in hl.tile(M):
    # src[dot_compress.py:357]:     dx1 = torch.matmul(d_out[tile_b.begin, tile_m3, :], bt)
    # src[dot_compress.py:358]:     dxt = torch.matmul(db2, y[tile_b.begin, tile_m3, :].t())
    # src[dot_compress.py:356-359]: ...
    for offset_12 in tl.range(0, 3219, _BLOCK_SIZE_5):
        indices_12 = offset_12 + tl.arange(0, _BLOCK_SIZE_5).to(tl.int32)
        mask_5 = indices_12 < 3219
        bt_copy = bt
        v_0_copy_1 = v_0
        bt_copy_0 = bt_copy
        v_0_copy_1_0 = v_0_copy_1
        # src[dot_compress.py:357]: dx1 = torch.matmul(d_out[tile_b.begin, tile_m3, :], bt)
        load_4 = tl.load(d_out + (offset_0 * 206016 + indices_12[:, None] * 64 + indices_14[None, :] * 1), mask_5[:, None], other=0)
        dx1 = tl.cast(tl.dot(tl.cast(load_4, tl.bfloat16), tl.cast(bt_copy_0, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
        # src[dot_compress.py:358]: dxt = torch.matmul(db2, y[tile_b.begin, tile_m3, :].t())
        load_5 = tl.load(y + (offset_0 * 206016 + indices_12[:, None] * 64 + indices_14[None, :] * 1), mask_5[:, None], other=0)
        permute_1 = tl.permute(load_5, [1, 0])
        dxt = tl.cast(tl.dot(tl.cast(v_0_copy_1_0, tl.bfloat16), tl.cast(permute_1, tl.bfloat16), input_precision='ieee', out_dtype=tl.float32), tl.bfloat16)
        # src[dot_compress.py:359]: dx[tile_b.begin, tile_m3, :] = (dx1 + dxt.t()).to(dx.dtype)
        permute_2 = tl.permute(dxt, [1, 0])
        v_1 = dx1 + permute_2
        tl.store(dx + (offset_0 * 824064 + indices_12[:, None] * 256 + indices_13[None, :] * 1), v_1, mask_5[:, None])
def helion_dcpp_bwd_impl(d_out: torch.Tensor, x: torch.Tensor, y: torch.Tensor, b: torch.Tensor, z: Optional[torch.Tensor]=None, *, _launcher=_default_launcher):
    """
    Helion kernel for the dot compress backward pass.
    Computes gradients for the dot compress operation with respect to
    inputs x, y and z (optional).
    The backward pass computes:
        - db = x^T @ dout
        - dy = x @ db
        - dx = dout @ b^T + y @ db^T
        - dz = db (if z is not None)
    The kernel is auto-tuned with multiple configurations optimized for
    different input sizes on B200 hardware with bfloat16 precision.
    Args:
        d_out: Gradient of loss w.r.t. output, shape (B, M, K).
        x: Input tensor from forward pass, shape (B, M, D).
        y: Input tensor from forward pass, shape (B, M, K).
        b: Intermediate tensor (x^T @ y) or (x^T @ y + z) if z is not None
           from forward pass, shape (B, D, K).
        z: (Optional) Input tensor from forward pass, shape (B, D, K).
    Returns:
        A tuple of:
            - dx: Gradient w.r.t. x, shape (B, M, D)
            - dy: Gradient w.r.t. y, shape (B, M, K)
            - dz: Gradient w.r.t. z, shape (B, D, K), None if z was None
    """
    # src[dot_compress.py:325]: B, M, D = x.shape
    B, M, D = x.shape
    # src[dot_compress.py:326]: K = y.shape[2]
    K = y.shape[2]
    # src[dot_compress.py:327]: D = hl.specialize(D)
    D = 256
    # src[dot_compress.py:328]: K = hl.specialize(K)
    K = 64
    # src[dot_compress.py:330]: dx = torch.empty((B, M, D), device=x.device, dtype=x.dtype)
    dx = torch.empty((B, M, D), device=x.device, dtype=x.dtype)
    # src[dot_compress.py:331]: dy = torch.empty((B, M, K), device=y.device, dtype=y.dtype)
    dy = torch.empty((B, M, K), device=y.device, dtype=y.dtype)
    # src[dot_compress.py:333]: dz = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
    dz = torch.empty((B, D, K), device=x.device, dtype=x.dtype)
    # src[dot_compress.py:335]: for tile_b in hl.tile(B, block_size=1):
    _RDIM_SIZE_2 = 256
    _RDIM_SIZE_3 = 64
    # src[dot_compress.py:335]: for tile_b in hl.tile(B, block_size=1):
    # src[dot_compress.py:336]:     # compute db = x^T @ dout
    # src[dot_compress.py:337]:     db = hl.zeros([D, K], dtype=torch.float32)
    # src[dot_compress.py:335-359]: ...
    _launcher(_helion_helion_dcpp_bwd_impl, (1152,), x, d_out, dy, b, y, dx, _RDIM_SIZE_2, _RDIM_SIZE_3, num_warps=4, num_stages=1, waves_per_eu=1, matrix_instr_nonkdim=0)
    # src[dot_compress.py:361]: if z is not None:
    # src[dot_compress.py:362]:     return dx, dy, dz
    # src[dot_compress.py:363]: else:
    # src[dot_compress.py:361-364]: ...
    if z is not None:
        # src[dot_compress.py:362]: return dx, dy, dz
        return (dx, dy, dz)
    else:
        # src[dot_compress.py:364]: return dx, dy, None
        return (dx, dy, None)
def call():
    from torch._dynamo.testing import rand_strided
    # src[dot_compress.py:289]: def helion_dcpp_bwd_impl(
    # src[dot_compress.py:290]:     d_out: torch.Tensor,
    # src[dot_compress.py:291]:     x: torch.Tensor,
    # src[dot_compress.py:289-364]: ...
    d_out = rand_strided(size=(1152, 3219, 64), stride=(206016, 64, 1), dtype=torch.bfloat16, device='cuda:0')
    x = rand_strided(size=(1152, 3219, 256), stride=(824064, 256, 1), dtype=torch.bfloat16, device='cuda:0')
    y = rand_strided(size=(1152, 3219, 64), stride=(206016, 64, 1), dtype=torch.bfloat16, device='cuda:0')
    b = rand_strided(size=(1152, 256, 64), stride=(16384, 64, 1), dtype=torch.bfloat16, device='cuda:0')
    z = 'UNSUPPORTED TYPE - REPLACE'
    helion_dcpp_bwd_impl(d_out, x, y, b, z)
if __name__ == '__main__':
    call()

Inductor generated triton code

from ctypes import c_void_p, c_long, c_int
import torch
import math
import random
import os
import tempfile
from math import inf, nan
from cmath import nanj
from torch._inductor.hooks import run_intermediate_hooks
from torch._inductor.utils import maybe_profile
from torch._inductor.codegen.memory_planning import _align as align
from torch import device, empty_strided
from torch._inductor.async_compile import AsyncCompile
from torch._inductor.select_algorithm import extern_kernels
import triton
import triton.language as tl
from torch._inductor.runtime.triton_heuristics import start_graph, end_graph
from torch._C import _cuda_getCurrentRawStream as get_raw_stream

aten = torch.ops.aten
inductor_ops = torch.ops.inductor
_quantized = torch.ops._quantized
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
assert_alignment = torch._C._dynamo.guards.assert_alignment
empty_strided_cpu = torch._C._dynamo.guards._empty_strided_cpu
empty_strided_cpu_pinned = torch._C._dynamo.guards._empty_strided_cpu_pinned
empty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda
empty_strided_xpu = torch._C._dynamo.guards._empty_strided_xpu
empty_strided_mtia = torch._C._dynamo.guards._empty_strided_mtia
reinterpret_tensor = torch._C._dynamo.guards._reinterpret_tensor
alloc_from_pool = torch.ops.inductor._alloc_from_pool
async_compile = AsyncCompile()
empty_strided_p2p = torch._C._distributed_c10d._SymmetricMemory.empty_strided_p2p


# kernel path: /var/tmp/torchinductor_mengjiao/u7/cu7bfpex2bd6ohilcqfmln4beie43j7wm5br6hx7436vegehl7vo.py
# Topologically Sorted Source Nodes: [permute_3, bmm_4], Original ATen: [aten.transpose, aten.bmm]
# Source node to ATen node mapping:
#   bmm_4 => bmm_4
#   permute_3 => permute_3
# Graph fragment:
#   %permute : Tensor "bf16[1152, 256, 3219][824064, 1, 256]cuda:0" = PlaceHolder[target=permute]
#   %bmm_2 : Tensor "bf16[1152, 256, 64][16384, 64, 1]cuda:0" = PlaceHolder[target=bmm_2]
#   %permute_3 : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%permute, [0, 2, 1]), kwargs = {})
#   %bmm_4 : Tensor "bf16[1152, 3219, 64][206016, 64, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.bmm.default](args = (%permute_3, %bmm_2), kwargs = {})
#   return %bmm_4
triton_tem_fused_bmm_transpose_0 = async_compile.triton('triton_tem_fused_bmm_transpose_0', '''
import triton
import triton.language as tl

import triton.language.extra.tlx as tlx  # noqa: F401

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties

@triton_heuristics.template(

num_stages=2,
num_warps=8,
triton_meta={'signature': {'arg_A': '*bf16', 'arg_B': '*bf16', 'out_ptr0': '*bf16'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]]}], 'matrix_instr_nonkdim': 16, 'kpack': 2},
inductor_meta={'kernel_name': 'triton_tem_fused_bmm_transpose_0', 'backend_hash': 'AF70BC814F3C641B0867B40FECBA1755585B56CBEE96FC8F140050135BD22880', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True, 'grid_type': 'FixedGrid', 'fixed_grid': ['_grid_0', '_grid_1', '_grid_2'], 'extra_launcher_args': ['_grid_0', '_grid_1', '_grid_2'], 'config_args': {'EVEN_K': True, 'USE_FAST_ACCUM': False, 'ACC_TYPE': 'tl.float32', 'BLOCK_M': 256, 'BLOCK_N': 64, 'BLOCK_K': 64, 'matrix_instr_nonkdim': 16, 'waves_per_eu': 0, 'kpack': 2, 'GROUP_M': 4, 'ALLOW_TF32': False}},

)
@triton.jit
def triton_tem_fused_bmm_transpose_0(arg_A, arg_B, out_ptr0):
    EVEN_K : tl.constexpr = True
    USE_FAST_ACCUM : tl.constexpr = False
    ACC_TYPE : tl.constexpr = tl.float32
    BLOCK_M : tl.constexpr = 256
    BLOCK_N : tl.constexpr = 64
    BLOCK_K : tl.constexpr = 64
    matrix_instr_nonkdim : tl.constexpr = 16
    waves_per_eu : tl.constexpr = 0
    kpack : tl.constexpr = 2
    GROUP_M : tl.constexpr = 4
    ALLOW_TF32 : tl.constexpr = False
    INDEX_DTYPE : tl.constexpr = tl.int32
    A = arg_A
    B = arg_B

    M = 3219
    N = 64
    K = 256

    stride_aq = 824064
    stride_am = 256
    stride_ak = 1

    stride_bq = 16384
    stride_bk = 64
    stride_bn = 1

    # based on triton.ops.matmul
    pid = tl.program_id(0).to(INDEX_DTYPE)
    grid_m = (M + BLOCK_M - 1) // BLOCK_M
    grid_n = (N + BLOCK_N - 1) // BLOCK_N

    # re-order program ID for better L2 performance
    width = GROUP_M * grid_n
    group_id = pid // width
    group_size = min(grid_m - group_id * GROUP_M, GROUP_M)
    pid_m = group_id * GROUP_M + (pid % group_size)
    pid_n = (pid % width) // (group_size)
    tl.assume(pid_m >= 0)
    tl.assume(pid_n >= 0)

    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    if (stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1):
        ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)
    else:
        ram = rm % M
    if (stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1):
        rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)
    else:
        rbn = rn % N

    rk = tl.arange(0, BLOCK_K)

    idx_q = tl.program_id(1).to(INDEX_DTYPE)  # batch dimension for BMM
    A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak + idx_q*stride_aq)
    B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn + idx_q*stride_bq)

    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)
    for k in range(K, 0, -BLOCK_K):
        if EVEN_K:
            a = tl.load(A)
            b = tl.load(B)
        else:
            a = tl.load(A, mask=rk[None, :] < k, other=0.)
            b = tl.load(B, mask=rk[:, None] < k, other=0.)
        acc += tl.dot(a, b, allow_tf32=ALLOW_TF32)
        A += BLOCK_K * stride_ak
        B += BLOCK_K * stride_bk

    # rematerialize rm and rn to save registers
    rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    idx_q = tl.program_id(1).to(INDEX_DTYPE)  # batch dimension for BMM
    idx_m = rm[:, None]
    idx_n = rn[None, :]
    mask = (idx_m < M) & (idx_n < N)

    # inductor generates a suffix
    xindex = idx_n + 64*idx_m + 206016*idx_q
    tl.store(out_ptr0 + (tl.broadcast_to(xindex, [BLOCK_M, BLOCK_N])), acc, mask)
''', device_str='cuda')


# kernel path: /var/tmp/torchinductor_mengjiao/c2/cc2tzxsihja7up6uwtuezwism6hsficj26xqb6256iphybq23ko5.py
# Topologically Sorted Source Nodes: [permute_5, add], Original ATen: [aten.permute, aten.add]
# Source node to ATen node mapping:
#   add => add
#   permute_5 => permute_5
# Graph fragment:
#   %bmm_3 : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0" = PlaceHolder[target=bmm_3]
#   %bmm_5 : Tensor "bf16[1152, 256, 3219][824064, 3219, 1]cuda:0" = PlaceHolder[target=bmm_5]
#   %permute_5 : Tensor "bf16[1152, 3219, 256][824064, 1, 3219]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%bmm_5, [0, 2, 1]), kwargs = {})
#   %add : Tensor "bf16[1152, 3219, 256][824064, 256, 1]cuda:0"[num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%bmm_3, %permute_5), kwargs = {})
#   return %add
triton_poi_fused_add_permute_1 = async_compile.triton('triton_poi_fused_add_permute_1', '''
import triton
import triton.language as tl

import triton.language.extra.tlx as tlx  # noqa: F401

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

@triton_heuristics.pointwise(
    size_hints={'y': 4194304, 'x': 256}, tile_hint=TileHint.SQUARE,
    filename=__file__,
    triton_meta={'signature': {'in_out_ptr0': '*bf16', 'in_ptr0': '*bf16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='hip', index=0, multi_processor_count=256, cc='gfx950', major=9, regs_per_multiprocessor=65536, max_threads_per_multi_processor=2048, max_threads_per_block=1024, warp_size=64), 'constants': {}, 'native_matmul': False, 'enable_fp_fusion': True, 'launch_pdl': False, 'disable_ftz': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}]},
    inductor_meta={'grid_type': 'Grid2DWithYZOverflow', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused_add_permute_1', 'mutated_arg_names': ['in_out_ptr0'], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 2, 'num_store': 1, 'num_reduction': 0, 'backend_hash': 'AF70BC814F3C641B0867B40FECBA1755585B56CBEE96FC8F140050135BD22880', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': True, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 32, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'is_hip': True, 'is_fbcode': True},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused_add_permute_1(in_out_ptr0, in_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 3708288
    xnumel = 256
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = xindex < xnumel
    x2 = xindex
    y3 = yindex
    y0 = (yindex % 3219)
    y1 = yindex // 3219
    tmp0 = tl.load(in_out_ptr0 + (x2 + 256*y3), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
    tmp1 = tl.load(in_ptr0 + (y0 + 3219*x2 + 824064*y1), xmask & ymask, eviction_policy='evict_last').to(tl.float32)
    tmp2 = tmp0 + tmp1
    tl.debug_barrier()
    tl.store(in_out_ptr0 + (x2 + 256*y3), tmp2, xmask & ymask)
''', device_str='cuda')


async_compile.wait(globals())
del async_compile

def call(args):
    permute, permute_2, permute_4, tangents_1 = args
    args.clear()
    assert_size_stride(permute, (1152, 256, 3219), (824064, 1, 256))
    assert_size_stride(permute_2, (1152, 64, 256), (16384, 1, 64))
    assert_size_stride(permute_4, (1152, 64, 3219), (206016, 1, 64))
    assert_size_stride(tangents_1, (1152, 3219, 64), (206016, 64, 1))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0)
        buf0 = empty_strided_cuda((1152, 256, 64), (16384, 64, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [bmm_2], Original ATen: [aten.bmm]
        extern_kernels.bmm(permute, tangents_1, out=buf0)
        buf1 = empty_strided_cuda((1152, 3219, 256), (824064, 256, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [bmm_3], Original ATen: [aten.bmm]
        extern_kernels.bmm(tangents_1, permute_2, out=buf1)
        del permute_2
        del tangents_1
        buf2 = empty_strided_cuda((1152, 3219, 64), (206016, 64, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [permute_3, bmm_4], Original ATen: [aten.transpose, aten.bmm]
        stream0 = get_raw_stream(0)
        triton_tem_fused_bmm_transpose_0.run(permute, buf0, buf2, 13, 1152, 1, stream=stream0)
        del permute
        buf3 = empty_strided_cuda((1152, 256, 3219), (824064, 3219, 1), torch.bfloat16)
        # Topologically Sorted Source Nodes: [bmm_5], Original ATen: [aten.bmm]
        extern_kernels.bmm(buf0, permute_4, out=buf3)
        del buf0
        del permute_4
        buf4 = buf1; del buf1  # reuse
        # Topologically Sorted Source Nodes: [permute_5, add], Original ATen: [aten.permute, aten.add]
        stream0 = get_raw_stream(0)
        triton_poi_fused_add_permute_1.run(buf4, buf3, 3708288, 256, stream=stream0)
        del buf3
    return (buf4, buf2, )


def get_args():
    from torch._dynamo.testing import rand_strided
    permute = rand_strided((1152, 256, 3219), (824064, 1, 256), device='cuda:0', dtype=torch.bfloat16)
    permute_2 = rand_strided((1152, 64, 256), (16384, 1, 64), device='cuda:0', dtype=torch.bfloat16)
    permute_4 = rand_strided((1152, 64, 3219), (206016, 1, 64), device='cuda:0', dtype=torch.bfloat16)
    tangents_1 = rand_strided((1152, 3219, 64), (206016, 64, 1), device='cuda:0', dtype=torch.bfloat16)
    return [permute, permute_2, permute_4, tangents_1]


def benchmark_compiled_module(args, times=10, repeat=10):
    from torch._inductor.utils import print_performance
    fn = lambda: call(list(args))
    return print_performance(fn, times=times, repeat=repeat)


if __name__ == "__main__":
    from torch._inductor.wrapper_benchmark import compiled_module_main
    args = get_args()
    compiled_module_main('None', lambda times, repeat: benchmark_compiled_module(args, times=times, repeat=repeat))

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions