diff --git a/amd_triton_npu/backend/driver.py b/amd_triton_npu/backend/driver.py index 7887cff..efe63c7 100644 --- a/amd_triton_npu/backend/driver.py +++ b/amd_triton_npu/backend/driver.py @@ -410,7 +410,63 @@ def _replace_include(m): return result -def _get_transform_ir_string(): +def _detect_element_type(ir_str): + """Detect the primary element type from the provided Linalg IR string. + + Searches the IR text for the first ``memref<...xTYPE>`` occurrence and + returns the captured MLIR element type string (for example, ``"bf16"``, + ``"f32"``, ``"i8"``, or ``"i16"``). + Falls back to "bf16" if detection fails. + """ + import re + + # Match the first memref<...xTYPE> occurrence in the provided IR text. + match = re.search(r"memref<[^>]*x(\w+)>", ir_str) + if match: + return match.group(1) + return "bf16" + + +# Dtype-aware placeholder info: padding value and default vector size per NPU. +_DTYPE_PLACEHOLDER_INFO = { + "bf16": {"pad_val": "0.0 : bf16", "vector_size": {"npu1": 16, "npu2": 32}}, + "f32": {"pad_val": "0.0 : f32", "vector_size": {"npu1": 16, "npu2": 16}}, + "i8": {"pad_val": "0 : i8", "vector_size": {"npu1": 32, "npu2": 32}}, + "i16": {"pad_val": "0 : i16", "vector_size": {"npu1": 32, "npu2": 32}}, + "i32": {"pad_val": "0 : i32", "vector_size": {"npu1": 16, "npu2": 16}}, +} + + +def _substitute_dtype_placeholders(script, dtype, npu_version): + """Substitute dtype-aware placeholders in a transform script. + + Replaces @DTYPE@, @PAD_VAL@, and @VECTOR_SIZE@ with values derived + from the detected element type and target NPU version. + No-op if the script contains no placeholders (backward compatible). + """ + if ( + "@DTYPE@" not in script + and "@PAD_VAL@" not in script + and "@VECTOR_SIZE@" not in script + ): + return script + info = _DTYPE_PLACEHOLDER_INFO.get(dtype) + if info is None: + raise ValueError( + f"Unsupported element type '{dtype}' for transform script placeholder " + f"substitution. Supported types: {list(_DTYPE_PLACEHOLDER_INFO.keys())}. " + f"The script contains @DTYPE@/@PAD_VAL@/@VECTOR_SIZE@ placeholders that " + f"require a supported element type." + ) + script = script.replace("@DTYPE@", dtype) + script = script.replace("@PAD_VAL@", info["pad_val"]) + script = script.replace( + "@VECTOR_SIZE@", str(info["vector_size"].get(npu_version, 16)) + ) + return script + + +def _get_transform_ir_string(ir_str=None): """ Get the transform IR string for tiling operations. @@ -421,6 +477,12 @@ def _get_transform_ir_string(): If the script uses `transform.include`, the shared transform library (transform_library.mlir) is automatically injected. + If ir_str is provided, dtype-aware placeholders (@DTYPE@, @PAD_VAL@, + @VECTOR_SIZE@) are substituted before library injection. + + Args: + ir_str: Optional Linalg IR string for dtype detection. + Returns: str: The transform IR string to use for tiling """ @@ -436,6 +498,17 @@ def _get_transform_ir_string(): with open(custom_script_path, "r") as f: print(f"Using custom tiling script from: {custom_script_path}") user_script = f.read() + _PLACEHOLDERS = ("@DTYPE@", "@PAD_VAL@", "@VECTOR_SIZE@") + if ir_str is not None and any(p in user_script for p in _PLACEHOLDERS): + dtype = _detect_element_type( + ir_str if isinstance(ir_str, str) else str(ir_str) + ) + npu_version = ( + detect_npu_version() if "@VECTOR_SIZE@" in user_script else None + ) + user_script = _substitute_dtype_placeholders( + user_script, dtype, npu_version + ) return _inject_transform_library(user_script) # Default hardcoded transform IR string @@ -493,7 +566,7 @@ def _ttshared_to_air(mod, gridX, gridY, gridZ, actual_sizes=None): pm = air.passmanager.PassManager.parse(pipeline, context=air_context) pm.run(air_module.operation) # MLIR-AIR compilation step 2: tiling the launch body - transform_ir_string = _get_transform_ir_string() + transform_ir_string = _get_transform_ir_string(ir_str=mod) transform_ir = Module.parse(transform_ir_string, context=air_context) run_transform(transform_ir, air_module) # MLIR-AIR compilation step 3: converting to AIR diff --git a/amd_triton_npu/backend/transform_library/elementwise.mlir b/amd_triton_npu/backend/transform_library/elementwise.mlir index 26fda74..30c14f1 100644 --- a/amd_triton_npu/backend/transform_library/elementwise.mlir +++ b/amd_triton_npu/backend/transform_library/elementwise.mlir @@ -66,6 +66,85 @@ transform.named_sequence @pad_and_promote_unary_bf16( transform.yield } +// Unary variant for f32: 1 input + 1 output = 2 operands. +// Used with bf16-emulation (f32 data, bf16 compute on AIE cores). +transform.named_sequence @pad_and_promote_unary_f32( + %module: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["linalg.generic"]} in %module + : (!transform.any_op) -> !transform.any_op + %padded_op, %pad_op, %__ = transform.structured.pad %op { + padding_values=[0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1], + nofold_flags=[1, 1], + copy_back_op="linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op + : (!transform.any_op) -> !transform.any_op + %padded_input = transform.get_producer_of_operand %padded_op[0] + : (!transform.any_op) -> (!transform.any_op) + %padded_input_buffer, %padded_input_new = + transform.structured.bufferize_to_allocation %padded_input + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_result = transform.get_producer_of_operand %padded_op[1] + : (!transform.any_op) -> (!transform.any_op) + %padded_result_buffer, %padded_result_new = + transform.structured.bufferize_to_allocation %padded_result + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.yield +} + +// Unary variant for i8: 1 input + 1 output = 2 operands. +transform.named_sequence @pad_and_promote_unary_i8( + %module: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["linalg.generic"]} in %module + : (!transform.any_op) -> !transform.any_op + %padded_op, %pad_op, %__ = transform.structured.pad %op { + padding_values=[0 : i8, 0 : i8], + padding_dimensions=[0, 1], + nofold_flags=[1, 1], + copy_back_op="linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op + : (!transform.any_op) -> !transform.any_op + %padded_input = transform.get_producer_of_operand %padded_op[0] + : (!transform.any_op) -> (!transform.any_op) + %padded_input_buffer, %padded_input_new = + transform.structured.bufferize_to_allocation %padded_input + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_result = transform.get_producer_of_operand %padded_op[1] + : (!transform.any_op) -> (!transform.any_op) + %padded_result_buffer, %padded_result_new = + transform.structured.bufferize_to_allocation %padded_result + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.yield +} + +// Unary variant for i16: 1 input + 1 output = 2 operands. +transform.named_sequence @pad_and_promote_unary_i16( + %module: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["linalg.generic"]} in %module + : (!transform.any_op) -> !transform.any_op + %padded_op, %pad_op, %__ = transform.structured.pad %op { + padding_values=[0 : i16, 0 : i16], + padding_dimensions=[0, 1], + nofold_flags=[1, 1], + copy_back_op="linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op + : (!transform.any_op) -> !transform.any_op + %padded_input = transform.get_producer_of_operand %padded_op[0] + : (!transform.any_op) -> (!transform.any_op) + %padded_input_buffer, %padded_input_new = + transform.structured.bufferize_to_allocation %padded_input + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_result = transform.get_producer_of_operand %padded_op[1] + : (!transform.any_op) -> (!transform.any_op) + %padded_result_buffer, %padded_result_new = + transform.structured.bufferize_to_allocation %padded_result + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.yield +} + // Binary variant: 2 inputs + 1 output = 3 operands (vec-add, axpy, swiglu). transform.named_sequence @pad_and_promote_binary_bf16( %module: !transform.any_op {transform.readonly}) { @@ -96,3 +175,97 @@ transform.named_sequence @pad_and_promote_binary_bf16( {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op transform.yield } + +// Binary variant for f32: 2 inputs + 1 output = 3 operands. +// Used with bf16-emulation (f32 data, bf16 compute on AIE cores). +transform.named_sequence @pad_and_promote_binary_f32( + %module: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["linalg.generic"]} in %module + : (!transform.any_op) -> !transform.any_op + %padded_op, %pad_op, %__ = transform.structured.pad %op { + padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32], + padding_dimensions=[0, 1, 2], + nofold_flags=[1, 1, 1], + copy_back_op="linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op + : (!transform.any_op) -> !transform.any_op + %padded_lhs = transform.get_producer_of_operand %padded_op[0] + : (!transform.any_op) -> (!transform.any_op) + %padded_lhs_buffer, %padded_lhs_new = + transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_rhs = transform.get_producer_of_operand %padded_op[1] + : (!transform.any_op) -> (!transform.any_op) + %padded_rhs_buffer, %padded_rhs_new = + transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_result = transform.get_producer_of_operand %padded_op[2] + : (!transform.any_op) -> (!transform.any_op) + %padded_result_buffer, %padded_result_new = + transform.structured.bufferize_to_allocation %padded_result + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.yield +} + +// Binary variant for i8: 2 inputs + 1 output = 3 operands. +transform.named_sequence @pad_and_promote_binary_i8( + %module: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["linalg.generic"]} in %module + : (!transform.any_op) -> !transform.any_op + %padded_op, %pad_op, %__ = transform.structured.pad %op { + padding_values=[0 : i8, 0 : i8, 0 : i8], + padding_dimensions=[0, 1, 2], + nofold_flags=[1, 1, 1], + copy_back_op="linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op + : (!transform.any_op) -> !transform.any_op + %padded_lhs = transform.get_producer_of_operand %padded_op[0] + : (!transform.any_op) -> (!transform.any_op) + %padded_lhs_buffer, %padded_lhs_new = + transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_rhs = transform.get_producer_of_operand %padded_op[1] + : (!transform.any_op) -> (!transform.any_op) + %padded_rhs_buffer, %padded_rhs_new = + transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_result = transform.get_producer_of_operand %padded_op[2] + : (!transform.any_op) -> (!transform.any_op) + %padded_result_buffer, %padded_result_new = + transform.structured.bufferize_to_allocation %padded_result + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.yield +} + +// Binary variant for i16: 2 inputs + 1 output = 3 operands. +transform.named_sequence @pad_and_promote_binary_i16( + %module: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["linalg.generic"]} in %module + : (!transform.any_op) -> !transform.any_op + %padded_op, %pad_op, %__ = transform.structured.pad %op { + padding_values=[0 : i16, 0 : i16, 0 : i16], + padding_dimensions=[0, 1, 2], + nofold_flags=[1, 1, 1], + copy_back_op="linalg.copy" + } : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op + : (!transform.any_op) -> !transform.any_op + %padded_lhs = transform.get_producer_of_operand %padded_op[0] + : (!transform.any_op) -> (!transform.any_op) + %padded_lhs_buffer, %padded_lhs_new = + transform.structured.bufferize_to_allocation %padded_lhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_rhs = transform.get_producer_of_operand %padded_op[1] + : (!transform.any_op) -> (!transform.any_op) + %padded_rhs_buffer, %padded_rhs_new = + transform.structured.bufferize_to_allocation %padded_rhs + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + %padded_result = transform.get_producer_of_operand %padded_op[2] + : (!transform.any_op) -> (!transform.any_op) + %padded_result_buffer, %padded_result_new = + transform.structured.bufferize_to_allocation %padded_result + {memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op + transform.yield +} diff --git a/examples/axpy/axpy.py b/examples/axpy/axpy.py index 9eb2738..90bc69d 100644 --- a/examples/axpy/axpy.py +++ b/examples/axpy/axpy.py @@ -1,14 +1,54 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# AXPY benchmark: out = alpha * x + y +# Supports bf16 (default), f32 (via bf16-emulation), i8, and i16. + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "is_float": True, + "alpha": 2.0, + "atol": 1e-2, + "rtol": 1e-2, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "is_float": True, + "alpha": 2.0, + "atol": 1e-1, + "rtol": 5e-2, + "bf16_emulation": True, + }, + "i8": { + "torch_dtype": torch.int8, + "is_float": False, + "alpha": 2, + "atol": 0, + "rtol": 0, + "bf16_emulation": False, + }, + "i16": { + "torch_dtype": torch.int16, + "is_float": False, + "alpha": 2, + "atol": 0, + "rtol": 0, + "bf16_emulation": False, + }, +} + @triton.jit def axpy_kernel( @@ -29,13 +69,23 @@ def axpy_kernel( tl.store(OUT + offsets[:], out) -def bench_axpy(N, provider): +def bench_axpy(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - alpha = 2.0 - x = torch.randn(N, device=device, dtype=dtype) - y = torch.randn(N, device=device, dtype=dtype) - out = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + alpha = cfg["alpha"] + + if cfg["is_float"]: + x = torch.randn(N, device=device, dtype=torch_dtype) + y = torch.randn(N, device=device, dtype=torch_dtype) + else: + iinfo = torch.iinfo(torch_dtype) + # Keep values small enough that alpha*x+y doesn't overflow + quarter_max = iinfo.max // 4 + x = torch.randint(0, quarter_max, (N,), device=device, dtype=torch_dtype) + y = torch.randint(0, quarter_max, (N,), device=device, dtype=torch_dtype) + + out = torch.empty(N, device=device, dtype=torch_dtype) + if provider == "torch" or provider == "test": out_ref = alpha * x + y if provider == "triton" or provider == "test": @@ -51,10 +101,35 @@ def bench_axpy(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(out, out_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(out, out_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="AXPY benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_axpy(N, "test") + bench_axpy(N, "test", cfg) diff --git a/examples/axpy/transform_aie2.mlir b/examples/axpy/transform_aie2.mlir index 31e907d..2bea4be 100644 --- a/examples/axpy/transform_aie2.mlir +++ b/examples/axpy/transform_aie2.mlir @@ -3,8 +3,8 @@ //////////////////////////////////////////////////////////////////////////////// // Transform Script for AXPY (AIE2): out = alpha * x + y -// Binary op (2 inputs: x, y). Cast mulf and addf to bf16. -// No extern_func.o needed (native mulf/addf). +// Binary op (2 inputs: x, y). Cast mulf and addf to bf16 when float. +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_binary_bf16 failures(propagate) + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/axpy/transform_aie2p.mlir b/examples/axpy/transform_aie2p.mlir index 3244ef5..df56af7 100644 --- a/examples/axpy/transform_aie2p.mlir +++ b/examples/axpy/transform_aie2p.mlir @@ -3,8 +3,8 @@ //////////////////////////////////////////////////////////////////////////////// // Transform Script for AXPY (AIE2P): out = alpha * x + y -// Binary op (2 inputs: x, y). Cast mulf and addf to bf16. -// No extern_func.o needed (native mulf/addf). +// Binary op (2 inputs: x, y). Cast mulf and addf to bf16 when float. +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_binary_bf16 failures(propagate) + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/elementwise_arith/elementwise_arith.py b/examples/elementwise_arith/elementwise_arith.py new file mode 100644 index 0000000..cd4c678 --- /dev/null +++ b/examples/elementwise_arith/elementwise_arith.py @@ -0,0 +1,193 @@ +# Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: MIT + +# Elementwise arithmetic benchmark: sub, mul, div, square. +# Supports bf16 (default) and f32 (via bf16-emulation). +# Not all ops support all dtypes: +# sub: bf16, f32 +# mul: bf16, f32 +# div: f32 only (hardware constraint: arith.divf is f32-only on AIE2P) +# square: bf16, f32 (implemented as x * x) + +import argparse +import torch +import triton +import triton.language as tl +import sys +import os + +sys.path.append(os.path.abspath("..")) +import benchmark + +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "atol": 1e-2, + "rtol": 1e-2, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "atol": 1e-1, + "rtol": 5e-2, + "bf16_emulation": True, + }, +} + +# Which dtypes each op supports. +# Integer types (i16) fail at aircc for subi/muli on AIE2P (only addi works). +OP_DTYPES = { + "sub": ["bf16", "f32"], + "mul": ["bf16", "f32"], + "div": ["f32"], # arith.divf is f32-only on AIE2P; bf16 divf not supported + "square": ["bf16", "f32"], +} + + +# --- Triton kernels --- + + +@triton.jit +def sub_kernel(X, Y, OUT, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + offsets[:]) + y = tl.load(Y + offsets[:]) + tl.store(OUT + offsets[:], x - y) + + +@triton.jit +def mul_kernel(X, Y, OUT, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + offsets[:]) + y = tl.load(Y + offsets[:]) + tl.store(OUT + offsets[:], x * y) + + +@triton.jit +def div_kernel(X, Y, OUT, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + offsets[:]) + y = tl.load(Y + offsets[:]) + tl.store(OUT + offsets[:], x / y) + + +@triton.jit +def square_kernel(X, OUT, n_elements: tl.constexpr, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + offsets[:]) + tl.store(OUT + offsets[:], x * x) + + +# --- Kernel dispatch table --- + +KERNELS = { + "sub": sub_kernel, + "mul": mul_kernel, + "div": div_kernel, + "square": square_kernel, +} + +# --- Torch reference functions --- + +TORCH_REF = { + "sub": lambda x, y: x - y, + "mul": lambda x, y: x * y, + "div": lambda x, y: x / y, + "square": lambda x, y: x * x, +} + + +def bench_op(op, N, provider, cfg): + device = "cpu" + torch_dtype = cfg["torch_dtype"] + is_unary = op == "square" + + x = torch.randn(N, device=device, dtype=torch_dtype) + if not is_unary: + if op == "div": + # Avoid division by zero; use values in [0.5, 1.5] + y = 0.5 + torch.rand(N, device=device, dtype=torch_dtype) + else: + y = torch.randn(N, device=device, dtype=torch_dtype) + + out = torch.empty(N, device=device, dtype=torch_dtype) + + if provider == "torch" or provider == "test": + out_ref = TORCH_REF[op](x, y if not is_unary else None) + + if provider == "triton" or provider == "test": + grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) + kernel = KERNELS[op] + if is_unary: + compiled_kernel = kernel[grid](x, out, N, BLOCK_SIZE=1024) + else: + compiled_kernel = kernel[grid](x, y, out, N, BLOCK_SIZE=1024) + with open("tt.shared.mlir", "w") as f: + f.write(str(compiled_kernel.asm["ttsharedir"])) + if provider == "test": + torch.testing.assert_close(out, out_ref, atol=cfg["atol"], rtol=cfg["rtol"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Elementwise arithmetic benchmark for AMD NPU" + ) + parser.add_argument( + "--op", + type=str, + choices=list(KERNELS.keys()), + required=True, + help="Operation to benchmark", + ) + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + # Validate op + dtype combination + if args.dtype not in OP_DTYPES[args.op]: + supported = ", ".join(OP_DTYPES[args.op]) + print(f"Error: --op {args.op} does not support --dtype {args.dtype}.") + print(f"Supported dtypes for {args.op}: {supported}") + sys.exit(1) + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + + # Select the right transform script based on op arity and NPU version. + # If AIR_TRANSFORM_TILING_SCRIPT is already set, respect it. + if not os.environ.get("AIR_TRANSFORM_TILING_SCRIPT"): + from triton.backends.amd_triton_npu.driver import detect_npu_version + + is_unary = args.op == "square" + script_dir = os.path.dirname(os.path.abspath(__file__)) + arity = "unary" if is_unary else "binary" + npu = detect_npu_version() + suffix = "aie2" if npu == "npu1" else "aie2p" + os.environ["AIR_TRANSFORM_TILING_SCRIPT"] = os.path.join( + script_dir, f"transform_{arity}_{suffix}.mlir" + ) + + benchmark.select_npu_backend() + for N in [2**i for i in range(10, 16, 1)]: + bench_op(args.op, N, "test", cfg) diff --git a/examples/elementwise_arith/transform_binary_aie2.mlir b/examples/elementwise_arith/transform_binary_aie2.mlir new file mode 100644 index 0000000..ccec81d --- /dev/null +++ b/examples/elementwise_arith/transform_binary_aie2.mlir @@ -0,0 +1,40 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +//////////////////////////////////////////////////////////////////////////////// +// Transform Script for Binary Elementwise Ops (AIE2): sub, mul, div +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. +// Uses shared library sequences from transform_library.mlir (auto-injected). +//////////////////////////////////////////////////////////////////////////////// + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg1: !transform.any_op {transform.readonly}) { + + transform.include @canonicalize_with_fold_dims failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @fuse_elementwise_and_canonicalize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @flatten_tile_forall failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @one_shot_bufferize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @post_bufferize_cleanup failures(propagate) + (%arg1) : (!transform.any_op) -> () + + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + %vh = transform.include @air_herd_mapping_and_vectorize + failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op + transform.include @cast_bf16_only_ops failures(propagate) + (%vh) : (!transform.any_op) -> () + + transform.yield + } +} diff --git a/examples/elementwise_arith/transform_binary_aie2p.mlir b/examples/elementwise_arith/transform_binary_aie2p.mlir new file mode 100644 index 0000000..2d76c54 --- /dev/null +++ b/examples/elementwise_arith/transform_binary_aie2p.mlir @@ -0,0 +1,40 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +//////////////////////////////////////////////////////////////////////////////// +// Transform Script for Binary Elementwise Ops (AIE2P): sub, mul, div +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. +// Uses shared library sequences from transform_library.mlir (auto-injected). +//////////////////////////////////////////////////////////////////////////////// + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg1: !transform.any_op {transform.readonly}) { + + transform.include @canonicalize_with_fold_dims failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @fuse_elementwise_and_canonicalize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @flatten_tile_forall failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @one_shot_bufferize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @post_bufferize_cleanup failures(propagate) + (%arg1) : (!transform.any_op) -> () + + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + %vh = transform.include @air_herd_mapping_and_vectorize + failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op + transform.include @cast_bf16_only_ops failures(propagate) + (%vh) : (!transform.any_op) -> () + + transform.yield + } +} diff --git a/examples/elementwise_arith/transform_unary_aie2.mlir b/examples/elementwise_arith/transform_unary_aie2.mlir new file mode 100644 index 0000000..2e09a8b --- /dev/null +++ b/examples/elementwise_arith/transform_unary_aie2.mlir @@ -0,0 +1,40 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +//////////////////////////////////////////////////////////////////////////////// +// Transform Script for Unary Elementwise Ops (AIE2): square +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. +// Uses shared library sequences from transform_library.mlir (auto-injected). +//////////////////////////////////////////////////////////////////////////////// + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg1: !transform.any_op {transform.readonly}) { + + transform.include @canonicalize_with_fold_dims failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @fuse_elementwise_and_canonicalize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @flatten_tile_forall failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @one_shot_bufferize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @post_bufferize_cleanup failures(propagate) + (%arg1) : (!transform.any_op) -> () + + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + %vh = transform.include @air_herd_mapping_and_vectorize + failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op + transform.include @cast_bf16_only_ops failures(propagate) + (%vh) : (!transform.any_op) -> () + + transform.yield + } +} diff --git a/examples/elementwise_arith/transform_unary_aie2p.mlir b/examples/elementwise_arith/transform_unary_aie2p.mlir new file mode 100644 index 0000000..14bfd4c --- /dev/null +++ b/examples/elementwise_arith/transform_unary_aie2p.mlir @@ -0,0 +1,40 @@ +// Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: MIT + +//////////////////////////////////////////////////////////////////////////////// +// Transform Script for Unary Elementwise Ops (AIE2P): square +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. +// Uses shared library sequences from transform_library.mlir (auto-injected). +//////////////////////////////////////////////////////////////////////////////// + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main( + %arg1: !transform.any_op {transform.readonly}) { + + transform.include @canonicalize_with_fold_dims failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @fuse_elementwise_and_canonicalize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @flatten_tile_forall failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @canonicalize_with_cse failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @one_shot_bufferize failures(propagate) + (%arg1) : (!transform.any_op) -> () + transform.include @post_bufferize_cleanup failures(propagate) + (%arg1) : (!transform.any_op) -> () + + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) + (%arg1) : (!transform.any_op) -> () + %vh = transform.include @air_herd_mapping_and_vectorize + failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op + transform.include @cast_bf16_only_ops failures(propagate) + (%vh) : (!transform.any_op) -> () + + transform.yield + } +} diff --git a/examples/gelu/gelu.py b/examples/gelu/gelu.py index 304afbb..ceadacb 100644 --- a/examples/gelu/gelu.py +++ b/examples/gelu/gelu.py @@ -1,14 +1,34 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# GELU benchmark: y = x * sigmoid(1.702 * x) +# Supports bf16 (default) and f32 (via bf16-emulation). + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "atol": 1e-1, + "rtol": 1e-1, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "atol": 2e-1, + "rtol": 1e-1, + "bf16_emulation": True, + }, +} + @triton.jit def gelu_kernel( @@ -30,14 +50,14 @@ def gelu_kernel( tl.store(Y + offsets[:], y) -def bench_gelu(N, provider): +def bench_gelu(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - x = torch.randn(N, device=device, dtype=dtype) - y = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + x = torch.randn(N, device=device, dtype=torch_dtype) + y = torch.empty(N, device=device, dtype=torch_dtype) if provider == "torch" or provider == "test": # Reference uses sigmoid approximation: x * sigmoid(1.702 * x) - y_ref = x * torch.sigmoid(1.702 * x.float()).to(dtype) + y_ref = x * torch.sigmoid(1.702 * x.float()).to(torch_dtype) if provider == "triton" or provider == "test": grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE"]),) compiled_kernel = gelu_kernel[grid]( @@ -49,10 +69,35 @@ def bench_gelu(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(y, y_ref, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(y, y_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="GELU benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_gelu(N, "test") + bench_gelu(N, "test", cfg) diff --git a/examples/gelu/transform_aie2p.mlir b/examples/gelu/transform_aie2p.mlir index 2fa1afa..71de302 100644 --- a/examples/gelu/transform_aie2p.mlir +++ b/examples/gelu/transform_aie2p.mlir @@ -22,7 +22,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -31,7 +31,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/generate_readme.py b/examples/generate_readme.py index a8a06ae..b64b9f5 100644 --- a/examples/generate_readme.py +++ b/examples/generate_readme.py @@ -60,49 +60,49 @@ "category": "Element-wise", "name": "ReLU", "path": "relu", - "datatypes": "bf16", + "datatypes": "bf16, f32, i8, i16", }, { "category": "Element-wise", "name": "Sigmoid", "path": "sigmoid", - "datatypes": "bf16", + "datatypes": "bf16, f32", }, { "category": "Element-wise", "name": "SiLU", "path": "silu", - "datatypes": "bf16", + "datatypes": "bf16, f32", }, { "category": "Element-wise", "name": "GELU", "path": "gelu", - "datatypes": "bf16", + "datatypes": "bf16, f32", }, { "category": "Element-wise", "name": "Leaky ReLU", "path": "leaky_relu", - "datatypes": "bf16", + "datatypes": "bf16, f32", }, { "category": "Element-wise", "name": "SwiGLU", "path": "swiglu", - "datatypes": "bf16", + "datatypes": "bf16, f32", }, { "category": "Element-wise", "name": "AXPY", "path": "axpy", - "datatypes": "bf16", + "datatypes": "bf16, f32, i8, i16", }, { "category": "Element-wise", "name": "Vector Add", "path": "vec-add", - "datatypes": "bf16", + "datatypes": "bf16, f32, i8, i16", }, { "category": "Normalization", @@ -146,6 +146,12 @@ "path": "multi_drivers", "datatypes": "bf16", }, + { + "category": "Element-wise", + "name": "Elementwise Arith (sub, mul, div, square)", + "path": "elementwise_arith", + "datatypes": "bf16, f32", + }, ] # Directories to ignore when verifying registry completeness @@ -155,10 +161,13 @@ def get_device_support(example_dir): """Check which device targets have transform files. + Checks for both exact names (transform_aie2.mlir) and prefixed + variants (transform_*_aie2.mlir) used by multi-op examples. + Returns (has_aie2, has_aie2p) as booleans. """ - has_aie2 = (example_dir / "transform_aie2.mlir").exists() - has_aie2p = (example_dir / "transform_aie2p.mlir").exists() + has_aie2 = bool(list(example_dir.glob("transform*_aie2.mlir"))) + has_aie2p = bool(list(example_dir.glob("transform*_aie2p.mlir"))) return has_aie2, has_aie2p diff --git a/examples/leaky_relu/leaky_relu.py b/examples/leaky_relu/leaky_relu.py index 088b6b1..5b9927c 100644 --- a/examples/leaky_relu/leaky_relu.py +++ b/examples/leaky_relu/leaky_relu.py @@ -1,16 +1,36 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# Leaky ReLU benchmark: y = x if x >= 0, else alpha * x +# Supports bf16 (default) and f32 (via bf16-emulation). + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark ALPHA = 0.01 # Standard leaky relu negative slope +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "atol": 1e-2, + "rtol": 1e-2, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "atol": 1e-1, + "rtol": 5e-2, + "bf16_emulation": True, + }, +} + @triton.jit def leaky_relu_kernel( @@ -31,11 +51,11 @@ def leaky_relu_kernel( tl.store(Y + offsets[:], y) -def bench_leaky_relu(N, provider): +def bench_leaky_relu(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - x = torch.randn(N, device=device, dtype=dtype) - y = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + x = torch.randn(N, device=device, dtype=torch_dtype) + y = torch.empty(N, device=device, dtype=torch_dtype) if provider == "torch" or provider == "test": y_ref = torch.nn.functional.leaky_relu(x, negative_slope=ALPHA) if provider == "triton" or provider == "test": @@ -49,10 +69,35 @@ def bench_leaky_relu(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(y, y_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Leaky ReLU benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_leaky_relu(N, "test") + bench_leaky_relu(N, "test", cfg) diff --git a/examples/leaky_relu/transform_aie2.mlir b/examples/leaky_relu/transform_aie2.mlir index e0234a4..f804e9f 100644 --- a/examples/leaky_relu/transform_aie2.mlir +++ b/examples/leaky_relu/transform_aie2.mlir @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/leaky_relu/transform_aie2p.mlir b/examples/leaky_relu/transform_aie2p.mlir index 7ed2de4..bc2d3c9 100644 --- a/examples/leaky_relu/transform_aie2p.mlir +++ b/examples/leaky_relu/transform_aie2p.mlir @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/relu/relu.py b/examples/relu/relu.py index b873aab..66ab642 100644 --- a/examples/relu/relu.py +++ b/examples/relu/relu.py @@ -1,14 +1,50 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# ReLU benchmark: y = max(x, 0) +# Supports bf16 (default), f32 (via bf16-emulation), i8, and i16. + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "is_float": True, + "atol": 1e-2, + "rtol": 1e-2, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "is_float": True, + "atol": 1e-1, + "rtol": 5e-2, + "bf16_emulation": True, + }, + "i8": { + "torch_dtype": torch.int8, + "is_float": False, + "atol": 0, + "rtol": 0, + "bf16_emulation": False, + }, + "i16": { + "torch_dtype": torch.int16, + "is_float": False, + "atol": 0, + "rtol": 0, + "bf16_emulation": False, + }, +} + @triton.jit def relu_kernel( @@ -22,15 +58,23 @@ def relu_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) x = tl.load(X + offsets[:]) - y = tl.maximum(x, 0.0) + # x * 0 produces a dtype-compatible zero for both float and int types. + y = tl.maximum(x, x * 0) tl.store(Y + offsets[:], y) -def bench_relu(N, provider): +def bench_relu(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - x = torch.randn(N, device=device, dtype=dtype) - y = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + + if cfg["is_float"]: + x = torch.randn(N, device=device, dtype=torch_dtype) + else: + iinfo = torch.iinfo(torch_dtype) + x = torch.randint(iinfo.min, iinfo.max, (N,), device=device, dtype=torch_dtype) + + y = torch.empty(N, device=device, dtype=torch_dtype) + if provider == "torch" or provider == "test": y_ref = torch.relu(x) if provider == "triton" or provider == "test": @@ -44,10 +88,35 @@ def bench_relu(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(y, y_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(y, y_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ReLU benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_relu(N, "test") + bench_relu(N, "test", cfg) diff --git a/examples/relu/transform_aie2.mlir b/examples/relu/transform_aie2.mlir index ce5dc86..fbcf1df 100644 --- a/examples/relu/transform_aie2.mlir +++ b/examples/relu/transform_aie2.mlir @@ -4,7 +4,7 @@ //////////////////////////////////////////////////////////////////////////////// // Transform Script for ReLU (AIE2) // relu(x) = max(x, 0) -// No extern_func.o needed (native maxnumf). +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/relu/transform_aie2p.mlir b/examples/relu/transform_aie2p.mlir index eba1f17..7e4ba8e 100644 --- a/examples/relu/transform_aie2p.mlir +++ b/examples/relu/transform_aie2p.mlir @@ -4,8 +4,7 @@ //////////////////////////////////////////////////////////////////////////////// // Transform Script for ReLU (AIE2P) // relu(x) = max(x, 0) -// Strategy: fuse_elementwise_linalg -> unary pad+promote -> vectorize at 16 -// -> cast maxnumf to bf16. +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -21,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -30,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/sigmoid/sigmoid.py b/examples/sigmoid/sigmoid.py index 12b602c..d5922dd 100644 --- a/examples/sigmoid/sigmoid.py +++ b/examples/sigmoid/sigmoid.py @@ -1,14 +1,34 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# Sigmoid benchmark: y = 1 / (1 + exp(-x)) +# Supports bf16 (default) and f32 (via bf16-emulation). + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "atol": 1e-1, + "rtol": 1e-1, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "atol": 2e-1, + "rtol": 1e-1, + "bf16_emulation": True, + }, +} + @triton.jit def sigmoid_kernel( @@ -34,11 +54,11 @@ def sigmoid_kernel( tl.store(Y + offsets[:], y) -def bench_sigmoid(N, provider): +def bench_sigmoid(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - x = torch.randn(N, device=device, dtype=dtype) - y = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + x = torch.randn(N, device=device, dtype=torch_dtype) + y = torch.empty(N, device=device, dtype=torch_dtype) if provider == "torch" or provider == "test": y_ref = torch.sigmoid(x) if provider == "triton" or provider == "test": @@ -52,10 +72,35 @@ def bench_sigmoid(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(y, y_ref, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(y, y_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Sigmoid benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_sigmoid(N, "test") + bench_sigmoid(N, "test", cfg) diff --git a/examples/sigmoid/transform_aie2p.mlir b/examples/sigmoid/transform_aie2p.mlir index 2494c2b..8fe2d8e 100644 --- a/examples/sigmoid/transform_aie2p.mlir +++ b/examples/sigmoid/transform_aie2p.mlir @@ -6,8 +6,9 @@ // // sigmoid(x) = 1 / (1 + exp(-x)) // -// Strategy: fuse_elementwise_linalg -> unary pad+promote -> vectorize at 16 +// Strategy: fuse_elementwise_linalg -> unary pad+promote -> vectorize // -> cast exp, subf, addf, mulf to bf16; divf stays f32. +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders. // // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -16,43 +17,25 @@ module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main( %arg1: !transform.any_op {transform.readonly}) { - // Phase 1: Initial canonicalization transform.include @canonicalize_with_fold_dims failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 2: Fuse elementwise chain (extf + subf + exp + addf + divf + truncf) transform.include @fuse_elementwise_and_canonicalize failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 3: Flatten + tile forall [256] transform.include @flatten_tile_forall failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 4: Canonicalization transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 5: Pad and promote to L1 (unary: 1 input + 1 output) - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 6: Canonicalization transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 7: Bufferization transform.include @one_shot_bufferize failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 8: Post-bufferization cleanup transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - // Phase 9: Vectorization tiling (16-lane for bf16) - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () - - // Phase 10: AIR herd mapping + vectorization %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op transform.include @cast_bf16_only_ops failures(propagate) diff --git a/examples/silu/silu.py b/examples/silu/silu.py index 59b0aa0..05d55df 100644 --- a/examples/silu/silu.py +++ b/examples/silu/silu.py @@ -1,14 +1,34 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# SiLU benchmark: y = x * sigmoid(x) +# Supports bf16 (default) and f32 (via bf16-emulation). + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "atol": 1e-1, + "rtol": 1e-1, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "atol": 2e-1, + "rtol": 1e-1, + "bf16_emulation": True, + }, +} + @triton.jit def silu_kernel( @@ -30,11 +50,11 @@ def silu_kernel( tl.store(Y + offsets[:], y) -def bench_silu(N, provider): +def bench_silu(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - x = torch.randn(N, device=device, dtype=dtype) - y = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + x = torch.randn(N, device=device, dtype=torch_dtype) + y = torch.empty(N, device=device, dtype=torch_dtype) if provider == "torch" or provider == "test": y_ref = torch.nn.functional.silu(x) if provider == "triton" or provider == "test": @@ -48,10 +68,35 @@ def bench_silu(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(y, y_ref, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(y, y_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="SiLU benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_silu(N, "test") + bench_silu(N, "test", cfg) diff --git a/examples/silu/transform_aie2.mlir b/examples/silu/transform_aie2.mlir index 3f16514..78784f9 100644 --- a/examples/silu/transform_aie2.mlir +++ b/examples/silu/transform_aie2.mlir @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_with_extern_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/silu/transform_aie2p.mlir b/examples/silu/transform_aie2p.mlir index 53de42f..acc0aea 100644 --- a/examples/silu/transform_aie2p.mlir +++ b/examples/silu/transform_aie2p.mlir @@ -21,7 +21,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_unary_bf16 failures(propagate) + transform.include @pad_and_promote_unary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -30,7 +30,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/swiglu/swiglu.py b/examples/swiglu/swiglu.py index 180e856..65157fe 100644 --- a/examples/swiglu/swiglu.py +++ b/examples/swiglu/swiglu.py @@ -1,14 +1,34 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT +# SwiGLU benchmark: out = SiLU(gate) * up = gate * sigmoid(gate) * up +# Supports bf16 (default) and f32 (via bf16-emulation). + +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "atol": 1e-1, + "rtol": 1e-1, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "atol": 2e-1, + "rtol": 1e-1, + "bf16_emulation": True, + }, +} + @triton.jit def swiglu_kernel( @@ -33,12 +53,12 @@ def swiglu_kernel( tl.store(OUT + offsets[:], out) -def bench_swiglu(N, provider): +def bench_swiglu(N, provider, cfg): device = "cpu" - dtype = torch.bfloat16 - gate = torch.randn(N, device=device, dtype=dtype) - up = torch.randn(N, device=device, dtype=dtype) - out = torch.empty(N, device=device, dtype=dtype) + torch_dtype = cfg["torch_dtype"] + gate = torch.randn(N, device=device, dtype=torch_dtype) + up = torch.randn(N, device=device, dtype=torch_dtype) + out = torch.empty(N, device=device, dtype=torch_dtype) if provider == "torch" or provider == "test": out_ref = torch.nn.functional.silu(gate) * up if provider == "triton" or provider == "test": @@ -53,10 +73,35 @@ def bench_swiglu(N, provider): with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(out, out_ref, atol=1e-1, rtol=1e-1) + torch.testing.assert_close(out, out_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser(description="SwiGLU benchmark for AMD NPU") + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_swiglu(N, "test") + bench_swiglu(N, "test", cfg) diff --git a/examples/swiglu/transform_aie2.mlir b/examples/swiglu/transform_aie2.mlir index 0de74b4..94c07ff 100644 --- a/examples/swiglu/transform_aie2.mlir +++ b/examples/swiglu/transform_aie2.mlir @@ -21,7 +21,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_binary_bf16 failures(propagate) + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -30,7 +30,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_with_extern_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/swiglu/transform_aie2p.mlir b/examples/swiglu/transform_aie2p.mlir index ee1c6b2..7d799d3 100644 --- a/examples/swiglu/transform_aie2p.mlir +++ b/examples/swiglu/transform_aie2p.mlir @@ -20,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_binary_bf16 failures(propagate) + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -29,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/vec-add/transform_aie2.mlir b/examples/vec-add/transform_aie2.mlir index b192305..5fdcf4f 100644 --- a/examples/vec-add/transform_aie2.mlir +++ b/examples/vec-add/transform_aie2.mlir @@ -4,8 +4,10 @@ //////////////////////////////////////////////////////////////////////////////// // Transform Script for Vector Addition (AIE2) // Simple elementwise add: out = a + b -// Binary op (2 inputs + 1 output). No fusion needed. Vec tile = 16 (AIE2). -// No type casts needed (bf16 add is native). +// Binary op (2 inputs + 1 output). No fusion needed. +// No type casts needed (bf16/i8/i16 add is native; f32 uses bf16-emulation). +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders substituted +// by the driver based on the IR element type and NPU version. // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -18,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_binary_bf16 failures(propagate) + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -27,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_16 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/vec-add/transform_aie2p.mlir b/examples/vec-add/transform_aie2p.mlir index c9bae4f..9dad749 100644 --- a/examples/vec-add/transform_aie2p.mlir +++ b/examples/vec-add/transform_aie2p.mlir @@ -4,8 +4,10 @@ //////////////////////////////////////////////////////////////////////////////// // Transform Script for Vector Addition (AIE2P) // Simple elementwise add: out = a + b -// Binary op (2 inputs + 1 output). No fusion needed. Vec tile = 32 (AIE2P). -// No type casts needed (bf16 add is native). +// Binary op (2 inputs + 1 output). No fusion needed. +// No type casts needed (bf16/i8/i16 add is native; f32 uses bf16-emulation). +// Dtype-generic: uses @DTYPE@ and @VECTOR_SIZE@ placeholders substituted +// by the driver based on the IR element type and NPU version. // Uses shared library sequences from transform_library.mlir (auto-injected). //////////////////////////////////////////////////////////////////////////////// @@ -18,7 +20,7 @@ module attributes {transform.with_named_sequence} { (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @pad_and_promote_binary_bf16 failures(propagate) + transform.include @pad_and_promote_binary_@DTYPE@ failures(propagate) (%arg1) : (!transform.any_op) -> () transform.include @canonicalize_with_cse failures(propagate) (%arg1) : (!transform.any_op) -> () @@ -27,7 +29,7 @@ module attributes {transform.with_named_sequence} { transform.include @post_bufferize_cleanup failures(propagate) (%arg1) : (!transform.any_op) -> () - transform.include @vectorize_generics_at_32 failures(propagate) + transform.include @vectorize_generics_at_@VECTOR_SIZE@ failures(propagate) (%arg1) : (!transform.any_op) -> () %vh = transform.include @air_herd_mapping_and_vectorize failures(propagate) (%arg1) : (!transform.any_op) -> !transform.any_op diff --git a/examples/vec-add/vec-add.py b/examples/vec-add/vec-add.py index c5452dd..fafb087 100644 --- a/examples/vec-add/vec-add.py +++ b/examples/vec-add/vec-add.py @@ -1,17 +1,51 @@ # Copyright (C) 2026, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: MIT -# this is a benchmark for adding vectors with maximum block size -# to check the performance of tl.dot operation +# Vector addition benchmark supporting multiple data types. +# Supports bf16 (default), f32 (via bf16-emulation), i8, and i16. +import argparse import torch import triton import triton.language as tl -import sys, os +import sys +import os sys.path.append(os.path.abspath("..")) import benchmark +# Dtype configuration: torch type, whether it's a float, tolerances. +DTYPE_CONFIG = { + "bf16": { + "torch_dtype": torch.bfloat16, + "is_float": True, + "atol": 1e-2, + "rtol": 1e-2, + "bf16_emulation": False, + }, + "f32": { + "torch_dtype": torch.float32, + "is_float": True, + "atol": 1e-1, + "rtol": 5e-2, + "bf16_emulation": True, # f32 addf not native on AIE; requires bf16-emulation + }, + "i8": { + "torch_dtype": torch.int8, + "is_float": False, + "atol": 0, + "rtol": 0, + "bf16_emulation": False, + }, + "i16": { + "torch_dtype": torch.int16, + "is_float": False, + "atol": 0, + "rtol": 0, + "bf16_emulation": False, + }, +} + @triton.jit def vecadd( @@ -25,8 +59,6 @@ def vecadd( block_start = pid * BLOCK_SIZE_N offsets = block_start + tl.arange(0, BLOCK_SIZE_N) - # mask = offsets < n_elements #AMK - in triton example, do we need? - a_block = tl.load(A + offsets[:]) b_block = tl.load(B + offsets[:]) @@ -35,35 +67,69 @@ def vecadd( tl.store(C + offsets[:], c_block) -# @benchmark.measure() -def bench_vecadd(N, provider): +def bench_vecadd(N, provider, cfg): device = "cpu" - dtype_in = torch.bfloat16 - dtype_out = ( - torch.bfloat16 - ) # torch.float32 won't work due to unsupported `%33 = fpext <8 x bfloat> %32 to <8 x float>` - a = torch.randn(N, device=device, dtype=dtype_in) - b = torch.randn(N, device=device, dtype=dtype_in) - c = torch.empty(N, device=device, dtype=dtype_out) + torch_dtype = cfg["torch_dtype"] + + if cfg["is_float"]: + a = torch.randn(N, device=device, dtype=torch_dtype) + b = torch.randn(N, device=device, dtype=torch_dtype) + else: + # Clamp to half-max to avoid overflow on addition + iinfo = torch.iinfo(torch_dtype) + half_max = iinfo.max // 2 + a = torch.randint(0, half_max, (N,), device=device, dtype=torch_dtype) + b = torch.randint(0, half_max, (N,), device=device, dtype=torch_dtype) + + c = torch.empty(N, device=device, dtype=torch_dtype) + if provider == "torch" or provider == "test": c_ref = torch.add(a, b) if provider == "triton" or provider == "test": - # 2D launch kernel where each block gets its own program. grid = lambda META: (triton.cdiv(N, META["BLOCK_SIZE_N"]),) compiled_kernel = vecadd[grid]( a, b, c, N, - BLOCK_SIZE_N=1024, # TODO: small tile sizes currently face errors due to lock race condition at memtiles + BLOCK_SIZE_N=1024, ) with open("tt.shared.mlir", "w") as f: f.write(str(compiled_kernel.asm["ttsharedir"])) if provider == "test": - torch.testing.assert_close(c, c_ref, atol=1e-2, rtol=1e-2) + torch.testing.assert_close(c, c_ref, atol=cfg["atol"], rtol=cfg["rtol"]) if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Vector addition benchmark for AMD NPU" + ) + parser.add_argument( + "--dtype", + type=str, + choices=list(DTYPE_CONFIG.keys()), + default="bf16", + help="Element data type (default: bf16)", + ) + parser.add_argument( + "--bf16-emulation", + dest="bf16_emulation", + default=False, + action="store_true", + help="Use f32 data type with bf16 emulation on AIE cores", + ) + args = parser.parse_args() + + # --bf16-emulation is shorthand for --dtype f32 + if args.bf16_emulation: + args.dtype = "f32" + + cfg = DTYPE_CONFIG[args.dtype] + + # Enable bf16 emulation env var when needed + if cfg["bf16_emulation"]: + os.environ["AMD_TRITON_NPU_BF16_EMULATION"] = "1" + benchmark.select_npu_backend() for N in [2**i for i in range(10, 16, 1)]: - bench_vecadd(N, "test") + bench_vecadd(N, "test", cfg)