Skip to content
Open
77 changes: 75 additions & 2 deletions amd_triton_npu/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,63 @@ def _replace_include(m):
return result


def _get_transform_ir_string():
def _detect_element_type(ir_str):
"""Detect the primary element type from the provided Linalg IR string.

Searches the IR text for the first ``memref<...xTYPE>`` occurrence and
returns the captured MLIR element type string (for example, ``"bf16"``,
``"f32"``, ``"i8"``, or ``"i16"``).
Falls back to "bf16" if detection fails.
"""
import re

# Match the first memref<...xTYPE> occurrence in the provided IR text.
match = re.search(r"memref<[^>]*x(\w+)>", ir_str)
if match:
return match.group(1)
return "bf16"


# Dtype-aware placeholder info: padding value and default vector size per NPU.
_DTYPE_PLACEHOLDER_INFO = {
"bf16": {"pad_val": "0.0 : bf16", "vector_size": {"npu1": 16, "npu2": 32}},
"f32": {"pad_val": "0.0 : f32", "vector_size": {"npu1": 16, "npu2": 16}},
"i8": {"pad_val": "0 : i8", "vector_size": {"npu1": 32, "npu2": 32}},
"i16": {"pad_val": "0 : i16", "vector_size": {"npu1": 32, "npu2": 32}},
"i32": {"pad_val": "0 : i32", "vector_size": {"npu1": 16, "npu2": 16}},
}


def _substitute_dtype_placeholders(script, dtype, npu_version):
"""Substitute dtype-aware placeholders in a transform script.

Replaces @DTYPE@, @PAD_VAL@, and @VECTOR_SIZE@ with values derived
from the detected element type and target NPU version.
No-op if the script contains no placeholders (backward compatible).
"""
if (
"@DTYPE@" not in script
and "@PAD_VAL@" not in script
and "@VECTOR_SIZE@" not in script
):
return script
info = _DTYPE_PLACEHOLDER_INFO.get(dtype)
if info is None:
raise ValueError(
f"Unsupported element type '{dtype}' for transform script placeholder "
f"substitution. Supported types: {list(_DTYPE_PLACEHOLDER_INFO.keys())}. "
f"The script contains @DTYPE@/@PAD_VAL@/@VECTOR_SIZE@ placeholders that "
f"require a supported element type."
)
script = script.replace("@DTYPE@", dtype)
script = script.replace("@PAD_VAL@", info["pad_val"])
script = script.replace(
"@VECTOR_SIZE@", str(info["vector_size"].get(npu_version, 16))
)
return script


def _get_transform_ir_string(ir_str=None):
"""
Get the transform IR string for tiling operations.

Expand All @@ -421,6 +477,12 @@ def _get_transform_ir_string():
If the script uses `transform.include`, the shared transform library
(transform_library.mlir) is automatically injected.

If ir_str is provided, dtype-aware placeholders (@DTYPE@, @PAD_VAL@,
@VECTOR_SIZE@) are substituted before library injection.

Args:
ir_str: Optional Linalg IR string for dtype detection.

Returns:
str: The transform IR string to use for tiling
"""
Expand All @@ -436,6 +498,17 @@ def _get_transform_ir_string():
with open(custom_script_path, "r") as f:
print(f"Using custom tiling script from: {custom_script_path}")
user_script = f.read()
_PLACEHOLDERS = ("@DTYPE@", "@PAD_VAL@", "@VECTOR_SIZE@")
if ir_str is not None and any(p in user_script for p in _PLACEHOLDERS):
dtype = _detect_element_type(
ir_str if isinstance(ir_str, str) else str(ir_str)
)
npu_version = (
detect_npu_version() if "@VECTOR_SIZE@" in user_script else None
)
user_script = _substitute_dtype_placeholders(
user_script, dtype, npu_version
)
return _inject_transform_library(user_script)

# Default hardcoded transform IR string
Expand Down Expand Up @@ -493,7 +566,7 @@ def _ttshared_to_air(mod, gridX, gridY, gridZ, actual_sizes=None):
pm = air.passmanager.PassManager.parse(pipeline, context=air_context)
pm.run(air_module.operation)
# MLIR-AIR compilation step 2: tiling the launch body
transform_ir_string = _get_transform_ir_string()
transform_ir_string = _get_transform_ir_string(ir_str=mod)
transform_ir = Module.parse(transform_ir_string, context=air_context)
run_transform(transform_ir, air_module)
# MLIR-AIR compilation step 3: converting to AIR
Expand Down
173 changes: 173 additions & 0 deletions amd_triton_npu/backend/transform_library/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,85 @@ transform.named_sequence @pad_and_promote_unary_bf16(
transform.yield
}

// Unary variant for f32: 1 input + 1 output = 2 operands.
// Used with bf16-emulation (f32 data, bf16 compute on AIE cores).
transform.named_sequence @pad_and_promote_unary_f32(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1],
nofold_flags=[1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_input = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_input_buffer, %padded_input_new =
transform.structured.bufferize_to_allocation %padded_input
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Unary variant for i8: 1 input + 1 output = 2 operands.
transform.named_sequence @pad_and_promote_unary_i8(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0 : i8, 0 : i8],
padding_dimensions=[0, 1],
nofold_flags=[1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_input = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_input_buffer, %padded_input_new =
transform.structured.bufferize_to_allocation %padded_input
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Unary variant for i16: 1 input + 1 output = 2 operands.
transform.named_sequence @pad_and_promote_unary_i16(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0 : i16, 0 : i16],
padding_dimensions=[0, 1],
nofold_flags=[1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_input = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_input_buffer, %padded_input_new =
transform.structured.bufferize_to_allocation %padded_input
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant: 2 inputs + 1 output = 3 operands (vec-add, axpy, swiglu).
transform.named_sequence @pad_and_promote_binary_bf16(
%module: !transform.any_op {transform.readonly}) {
Expand Down Expand Up @@ -96,3 +175,97 @@ transform.named_sequence @pad_and_promote_binary_bf16(
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant for f32: 2 inputs + 1 output = 3 operands.
// Used with bf16-emulation (f32 data, bf16 compute on AIE cores).
transform.named_sequence @pad_and_promote_binary_f32(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0.0 : f32, 0.0 : f32, 0.0 : f32],
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_lhs = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_lhs_buffer, %padded_lhs_new =
transform.structured.bufferize_to_allocation %padded_lhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_rhs = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_rhs_buffer, %padded_rhs_new =
transform.structured.bufferize_to_allocation %padded_rhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[2]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant for i8: 2 inputs + 1 output = 3 operands.
transform.named_sequence @pad_and_promote_binary_i8(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0 : i8, 0 : i8, 0 : i8],
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_lhs = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_lhs_buffer, %padded_lhs_new =
transform.structured.bufferize_to_allocation %padded_lhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_rhs = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_rhs_buffer, %padded_rhs_new =
transform.structured.bufferize_to_allocation %padded_rhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[2]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}

// Binary variant for i16: 2 inputs + 1 output = 3 operands.
transform.named_sequence @pad_and_promote_binary_i16(
%module: !transform.any_op {transform.readonly}) {
%op = transform.structured.match ops{["linalg.generic"]} in %module
: (!transform.any_op) -> !transform.any_op
%padded_op, %pad_op, %__ = transform.structured.pad %op {
padding_values=[0 : i16, 0 : i16, 0 : i16],
padding_dimensions=[0, 1, 2],
nofold_flags=[1, 1, 1],
copy_back_op="linalg.copy"
} : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
%pad_dps = transform.structured.rewrite_in_destination_passing_style %pad_op
: (!transform.any_op) -> !transform.any_op
%padded_lhs = transform.get_producer_of_operand %padded_op[0]
: (!transform.any_op) -> (!transform.any_op)
%padded_lhs_buffer, %padded_lhs_new =
transform.structured.bufferize_to_allocation %padded_lhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_rhs = transform.get_producer_of_operand %padded_op[1]
: (!transform.any_op) -> (!transform.any_op)
%padded_rhs_buffer, %padded_rhs_new =
transform.structured.bufferize_to_allocation %padded_rhs
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
%padded_result = transform.get_producer_of_operand %padded_op[2]
: (!transform.any_op) -> (!transform.any_op)
%padded_result_buffer, %padded_result_new =
transform.structured.bufferize_to_allocation %padded_result
{memory_space = 2, bufferize_destination_only, emit_dealloc} : !transform.any_op
transform.yield
}
Loading
Loading