From b820a58053ed900d700196e82bb2b235f147d640 Mon Sep 17 00:00:00 2001 From: "xijun.gong" Date: Tue, 28 Apr 2026 12:36:10 +0000 Subject: [PATCH] Add the new vendor backend ENFLAME MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Description Add the new vendor backend ENFLAME ## Type of change - [ √ ] New feature (non-breaking change which adds functionality) ## Changes Please list the changes introduced in this PR: - Add enflame ops register - Add enflame backend implementation - Register enflame ops in builtin_ops.py ## Requirements - The module migraiton is needed, to use this module, need to install package migration whl # Checklist: - [x] I have read and followed the [contributing guidelines](https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst) - [x] The functionality is complete - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have made corresponding changes to the documentation - [x] My changes generate no new warnings - [x] I have added tests that prove my fix is effective or that my feature works - [x] New and existing unit tests pass locally with my changes --- .../core/backends/vendor/enflame/__init__.py | 7 + .../core/backends/vendor/enflame/enflame.py | 1582 +++++++++++++++++ .../vendor/enflame/flash_attention.py | 127 ++ .../backends/vendor/enflame/register_ops.py | 956 ++++++++++ transformer_engine/plugin/core/builtin_ops.py | 9 + 5 files changed, 2681 insertions(+) create mode 100755 transformer_engine/plugin/core/backends/vendor/enflame/__init__.py create mode 100755 transformer_engine/plugin/core/backends/vendor/enflame/enflame.py create mode 100755 transformer_engine/plugin/core/backends/vendor/enflame/flash_attention.py create mode 100755 transformer_engine/plugin/core/backends/vendor/enflame/register_ops.py diff --git a/transformer_engine/plugin/core/backends/vendor/enflame/__init__.py b/transformer_engine/plugin/core/backends/vendor/enflame/__init__.py new file mode 100755 index 0000000000..34f7ccaf26 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/enflame/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from .enflame import EnflameBackend + +__all__ = ["EnflameBackend"] \ No newline at end of file diff --git a/transformer_engine/plugin/core/backends/vendor/enflame/enflame.py b/transformer_engine/plugin/core/backends/vendor/enflame/enflame.py new file mode 100755 index 0000000000..73896bde93 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/enflame/enflame.py @@ -0,0 +1,1582 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from typing import Any, Dict, List, Optional, Tuple, Union + +import ctypes +from pathlib import Path +import importlib.util +import platform +import os, sys +import functools +import inspect + +import torch + + +from ....ops import * + +def _get_tex(): + from migration.patches.transformer_engine import v2_9_0 + return v2_9_0 + +def _check_enflame_available() -> bool: + from torch_gcu import transfer_to_gcu + if not torch.cuda.is_available(): + return False + return True + +class EnflameBackend(TEFLBackendBase): + @staticmethod + def check_available() -> bool: + return _check_enflame_available() + + def __init__(self): + self._tex = None + + def _get_tex(self): + if self._tex is None: + self._tex = _get_tex() + return self._tex + + def is_available(self) -> bool: + return _check_enflame_available() + + def get_attention_backend(self, attention_params=None): + # Import the enflame get_attention_backend function + try: + from transformer_engine_torch.pytorch.attention.dot_product_attention import utils + + return utils.get_attention_backend(attention_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import enflame FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_torch is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to get_attention_backend: {e}. Attention_params: {self.attention_params}" + ) + + def quantize( + self, + tensor: torch.Tensor, + quantizer: Any, + output: Optional[torch.Tensor] = None, + noop: Optional[torch.Tensor] = None, + ) -> Any: + tex = self._get_tex() + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + return tex.quantize(tensor, quantizer, output, noop) + + def dequantize( + self, + input: Any, + otype: DType, + ) -> Any: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.dequantize(input, otype) + + def bgrad_quantize( + self, + input: torch.Tensor, + quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + # Normalize quantizer.dtype to this backend's `tex.DType`. + try: + if quantizer is not None and hasattr(quantizer, "dtype") and hasattr(tex, "DType"): + qdtype = quantizer.dtype + if qdtype is not None: + quantizer.dtype = tex.DType(int(qdtype)) + except Exception: + pass + + return tex.bgrad_quantize(input, quantizer) + + def generic_gemm( + self, + A: Any, + transA: bool, + B: Any, + transB: bool, + D: Any, + quantizer: Any, + output_dtype: Optional[DType], + bias: Optional[torch.Tensor], + bias_type: DType, + gelu: bool, + gelu_in: Optional[torch.Tensor], + grad: bool, + workspace: torch.Tensor, + workspace_size: int, + accumulate: bool, + use_split_accumulator: bool, + comm_overlap: Optional[Any] = None, + comm_type: Optional[CommOverlapType] = None, + extra_output: Optional[torch.Tensor] = None, + bulk_overlap: bool = False, + alpha: float = 1.0, + beta: Optional[float] = None, + ) -> List[Any]: + tex = self._get_tex() + + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + comm_type = tex.CommOverlapType(int(comm_type)) if comm_type is not None else None + output_dtype = tex.DType(int(output_dtype)) if output_dtype is not None else None + return tex.generic_gemm( + A, + transA, + B, + transB, + D, + quantizer, + output_dtype, + bias, + bias_type, + gelu, + gelu_in, + grad, + workspace, + workspace_size, + accumulate, + use_split_accumulator, + comm_overlap, + comm_type, + extra_output, + bulk_overlap, + alpha, + beta, + ) + + # GELU and variants # + def gelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.gelu(input, quantizer) + + def geglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.geglu(input, quantizer) + + def qgelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgelu(input, quantizer) + + def qgeglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.qgeglu(input, quantizer) + + # ReLU and variants # + def relu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.relu(input, quantizer) + + def reglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.reglu(input, quantizer) + + def srelu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.srelu(input, quantizer) + + def sreglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.sreglu(input, quantizer) + + # SwiGLU and variants # + def silu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.silu(input, quantizer) + + def swiglu(self, input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.swiglu(input, quantizer) + + def clamped_swiglu( + self, + input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_swiglu(input, quantizer, limit, alpha) + + # Backward of GELU and variants # + def dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgelu(grad, fwd_input, quantizer) + + def dgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dgeglu(grad, fwd_input, quantizer) + + def dqgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgelu(grad, fwd_input, quantizer) + + def dqgeglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dqgeglu(grad, fwd_input, quantizer) + + # Backward of ReLU and variants # + def drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.drelu(grad, fwd_input, quantizer) + + def dreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dreglu(grad, fwd_input, quantizer) + + def dsrelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsrelu(grad, fwd_input, quantizer) + + def dsreglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsreglu(grad, fwd_input, quantizer) + + # Backward of SiLU and variants # + def dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dsilu(grad, fwd_input, quantizer) + + def dswiglu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> Any: + tex = self._get_tex() + return tex.dswiglu(grad, fwd_input, quantizer) + + def clamped_dswiglu( + self, + grad: torch.Tensor, + fwd_input: torch.Tensor, + quantizer: Any, + limit: float = 7.0, + alpha: float = 1.702, + ) -> Any: + tex = self._get_tex() + return tex.clamped_dswiglu(grad, fwd_input, quantizer, limit, alpha) + + # DBias + DAct fusions # + def dbias_dgelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dgelu(grad, fwd_input, quantizer) + + def dbias_dsilu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsilu(grad, fwd_input, quantizer) + + def dbias_drelu(self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any) -> List[Any]: + tex = self._get_tex() + return tex.dbias_drelu(grad, fwd_input, quantizer) + + def dbias_dqgelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dqgelu(grad, fwd_input, quantizer) + + def dbias_dsrelu( + self, grad: torch.Tensor, fwd_input: torch.Tensor, quantizer: Any + ) -> List[Any]: + tex = self._get_tex() + return tex.dbias_dsrelu(grad, fwd_input, quantizer) + + # Permutation functions + def moe_permute_fwd( + self, + input: torch.Tensor, + dtype: DType, + indices: torch.Tensor, + num_out_tokens: int, + workspace: List[torch.Tensor], + max_expanded_token_num: int, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_fwd( + input, dtype, indices, num_out_tokens, workspace, max_expanded_token_num + ) + + def moe_permute_bwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_permute_bwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_fwd( + self, + input: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + num_tokens: int, + topK: int, + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_fwd(input, dtype, row_id_map, prob, num_tokens, topK) + + def moe_unpermute_bwd( + self, + input_bwd: torch.Tensor, + input_fwd: torch.Tensor, + dtype: DType, + row_id_map: torch.Tensor, + prob: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.moe_unpermute_bwd(input_bwd, input_fwd, dtype, row_id_map, prob) + + # Softmax functions + def scaled_softmax_forward( + self, + input: torch.Tensor, + scale: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_forward(input, scale) + + def scaled_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_masked_softmax_forward( + self, + input: torch.Tensor, + mask: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_forward(input, mask, scale_factor) + + def scaled_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_masked_softmax_backward(output_grad_, softmax_results_, scale_factor) + + def scaled_upper_triang_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_forward(input, scale_factor) + + def scaled_upper_triang_masked_softmax_backward( + self, + output_grads_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_upper_triang_masked_softmax_backward( + output_grads_, softmax_results_, scale_factor + ) + + def scaled_aligned_causal_masked_softmax_forward( + self, + input: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_forward(input, scale_factor) + + def scaled_aligned_causal_masked_softmax_backward( + self, + output_grad_: torch.Tensor, + softmax_results_: torch.Tensor, + scale_factor: float, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.scaled_aligned_causal_masked_softmax_backward( + output_grad_, softmax_results_, scale_factor + ) + + # Other granular functions + def layernorm_fwd( + self, + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.layernorm_fwd( + input, weight, bias, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def layernorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.layernorm_bwd(dz, x, mu, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_fwd( + self, + input: Any, + weight: Any, + eps: float, + ln_out: Any, + quantizer: Any, + otype: DType, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + otype = tex.DType(int(otype)) if otype is not None else None + return tex.rmsnorm_fwd( + input, weight, eps, ln_out, quantizer, otype, sm_margin, zero_centered_gamma + ) + + def rmsnorm_bwd( + self, + dz: torch.Tensor, + x: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma) + + def rmsnorm_bwd_add( + self, + dz: torch.Tensor, + x: torch.Tensor, + add: torch.Tensor, + rsigma: torch.Tensor, + gamma: torch.Tensor, + sm_margin: int, + zero_centered_gamma: bool, + ) -> List[Any]: + tex = self._get_tex() + return tex.rmsnorm_bwd_add(dz, x, add, rsigma, gamma, sm_margin, zero_centered_gamma) + + def multi_tensor_quantize( + self, + tensor_list: List[torch.Tensor], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.multi_tensor_quantize(tensor_list, quantizer_list) + + def split_quantize( + self, + tensor: torch.Tensor, + split_sections: List[int], + quantizer_list: List[Any], + ) -> List[Any]: + tex = self._get_tex() + return tex.split_quantize(tensor, split_sections, quantizer_list) + + def te_general_grouped_gemm( + self, + A: List[Any], + transa: bool, + B: List[Any], + transb: bool, + D: Optional[List[torch.Tensor]], + D_type: DType, + m_splits: List[int], + bias: List[torch.Tensor], + bias_type: DType, + single_output: bool, + pre_gelu_out: List[torch.Tensor], + grad: bool, + workspace: List[torch.Tensor], + workspaceSizes: int, + accumulate: bool, + use_split_accumulator: bool, + math_sm_count: int, + ) -> Optional[List[torch.Tensor]]: + tex = self._get_tex() + D_type = tex.DType(int(D_type)) if D_type is not None else None + bias_type = tex.DType(int(bias_type)) if bias_type is not None else None + return tex.te_general_grouped_gemm( + A, + transa, + B, + transb, + D, + D_type, + m_splits, + bias, + bias_type, + single_output, + pre_gelu_out, + grad, + workspace, + workspaceSizes, + accumulate, + use_split_accumulator, + math_sm_count, + ) + + def fp8_transpose( + self, + input: torch.Tensor, + dtype: DType, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + dtype = tex.DType(int(dtype)) if dtype is not None else None + return tex.fp8_transpose(input, dtype, out) + + def swap_first_dims( + self, + tensor: torch.Tensor, + out: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.swap_first_dims(tensor, out) + + def get_fused_attn_backend( + self, + is_training: bool, + q_dtype: DType, + kv_dtype: DType, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + p_dropout: float, + num_attn_heads: int, + num_gqa_groups: int, + max_seqlen_q: int, + max_seqlen_kv: int, + head_dim_qk: int, + head_dim_v: int, + window_size_left: int, + window_size_right: int, + return_max_logit: bool, + ) -> NVTE_Fused_Attn_Backend: + tex = self._get_tex() + + q_dtype = tex.DType(int(q_dtype)) if q_dtype is not None else None + kv_dtype = tex.DType(int(kv_dtype)) if kv_dtype is not None else None + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + result = tex.get_fused_attn_backend( + is_training, + q_dtype, + kv_dtype, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + p_dropout, + num_attn_heads, + num_gqa_groups, + max_seqlen_q, + max_seqlen_kv, + head_dim_qk, + head_dim_v, + window_size_left, + window_size_right, + return_max_logit, + ) + return NVTE_Fused_Attn_Backend(result) + + def compute_amax( + self, + input: torch.Tensor, + amax: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.compute_amax(input, amax) + + def fused_amax_and_scale_update_after_reduction( + self, + amax_reduction_buffer: torch.Tensor, + amax_histories: List[torch.Tensor], + scales: List[torch.Tensor], + amax_compute_algo: str, + fp8_dtype: DType, + margin: float, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.fused_amax_and_scale_update_after_reduction( + amax_reduction_buffer, amax_histories, scales, amax_compute_algo, fp8_dtype, margin + ) + + def fp8_block_scaling_compute_partial_amax( + self, + tensor: torch.Tensor, + amax: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + ) -> None: + tex = self._get_tex() + return tex.fp8_block_scaling_compute_partial_amax( + tensor, amax, h, w, start_offset, block_len + ) + + def fp8_block_scaling_partial_cast( + self, + inp: torch.Tensor, + out: torch.Tensor, + scale: torch.Tensor, + h: int, + w: int, + start_offset: int, + block_len: int, + out_dtype: DType, + ) -> None: + tex = self._get_tex() + out_dtype = tex.DType(int(out_dtype)) if out_dtype is not None else None + return tex.fp8_block_scaling_partial_cast( + inp, out, scale, h, w, start_offset, block_len, out_dtype + ) + + def fused_multi_row_padding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + padded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_padding(input, output, input_row_list, padded_input_row_list) + + def fused_multi_row_unpadding( + self, + input: torch.Tensor, + output: torch.Tensor, + input_row_list: List[int], + unpadded_input_row_list: List[int], + ) -> None: + tex = self._get_tex() + return tex.fused_multi_row_unpadding(input, output, input_row_list, unpadded_input_row_list) + + # attention kernels + def fa_prepare_fwd( + self, + qkvi: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_fwd(qkvi) + + def fa_prepare_bwd( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fa_prepare_bwd(q, k, v) + + def fused_attn_fwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + is_training: bool, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + fake_dtype: torch.dtype, + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + page_table_k: Optional[torch.Tensor], + page_table_v: Optional[torch.Tensor], + s_quantizer: Any, + o_quantizer: Any, + Bias: Optional[torch.Tensor], + SoftmaxOffset: Optional[torch.Tensor], + rng_gen: Optional[torch.Generator], + rng_elts_per_thread: int, + return_max_logit: bool, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + + return tex.fused_attn_fwd( + max_seqlen_q, + max_seqlen_kv, + is_training, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + fake_dtype, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + page_table_k, + page_table_v, + s_quantizer, + o_quantizer, + Bias, + SoftmaxOffset, + rng_gen, + rng_elts_per_thread, + return_max_logit, + ) + + def fused_attn_bwd( + self, + max_seqlen_q: int, + max_seqlen_kv: int, + attn_scale: float, + p_dropout: float, + set_zero: bool, + qkv_layout: NVTE_QKV_Layout, + bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, + softmax_type: NVTE_Softmax_Type, + window_size: List[int], + deterministic: bool, + cu_seqlens_q: torch.Tensor, + cu_seqlens_kv: torch.Tensor, + Q: Any, + K: Any, + V: Any, + O: Any, + dO: Any, + fake_dtype: torch.dtype, + dqkv_type: DType, + Aux_CTX_Tensors: List[torch.Tensor], + cu_seqlens_q_padded: Optional[torch.Tensor], + cu_seqlens_kv_padded: Optional[torch.Tensor], + s_quantizer: Any, + dp_quantizer: Any, + dqkv_quantizer: Any, + ) -> List[Any]: + tex = self._get_tex() + + qkv_layout = tex.NVTE_QKV_Layout(int(qkv_layout)) if qkv_layout is not None else None + bias_type = tex.NVTE_Bias_Type(int(bias_type)) if bias_type is not None else None + attn_mask_type = ( + tex.NVTE_Mask_Type(int(attn_mask_type)) if attn_mask_type is not None else None + ) + softmax_type = ( + tex.NVTE_Softmax_Type(int(softmax_type)) if softmax_type is not None else None + ) + dqkv_type = tex.DType(int(dqkv_type)) if dqkv_type is not None else None + + return tex.fused_attn_bwd( + max_seqlen_q, + max_seqlen_kv, + attn_scale, + p_dropout, + set_zero, + qkv_layout, + bias_type, + attn_mask_type, + softmax_type, + window_size, + deterministic, + cu_seqlens_q, + cu_seqlens_kv, + Q, + K, + V, + O, + dO, + fake_dtype, + dqkv_type, + Aux_CTX_Tensors, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + s_quantizer, + dp_quantizer, + dqkv_quantizer, + ) + + def copy_to_kv_cache( + self, + new_k: torch.Tensor, + new_v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + page_table: torch.Tensor, + cu_new_lens: torch.Tensor, + cu_cached_lens: torch.Tensor, + qkv_format: NVTE_QKV_Format, + b: int, + max_ctx_len: int, + max_seq_len: int, + max_pages_per_seq: int, + is_non_paged: bool, + ) -> None: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.copy_to_kv_cache( + new_k, + new_v, + k_cache, + v_cache, + page_table, + cu_new_lens, + cu_cached_lens, + qkv_format, + b, + max_ctx_len, + max_seq_len, + max_pages_per_seq, + is_non_paged, + ) + + def convert_thd_to_bshd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + b: int, + max_seq_len: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_thd_to_bshd(tensor, cu_seqlens, b, max_seq_len) + + def convert_bshd_to_thd( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + t: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.convert_bshd_to_thd(tensor, cu_seqlens, t) + + # fused apply rope + def fused_rope_forward( + self, + input: torch.Tensor, + freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_forward( + input, freqs, start_positions, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank + ) + + def fused_rope_backward( + self, + output_grads: torch.Tensor, + freqs: torch.Tensor, + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cu_seqlens: Optional[torch.Tensor], + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_rope_backward( + output_grads, freqs, qkv_format, interleaved, cu_seqlens, cp_size, cp_rank + ) + + def fused_qkv_rope_forward( + self, + qkv_input: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + start_positions: Optional[torch.Tensor], + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_forward( + qkv_input, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + def fused_qkv_rope_backward( + self, + q_grad_out: torch.Tensor, + k_grad_out: torch.Tensor, + v_grad_out: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + qkv_format: NVTE_QKV_Format, + interleaved: bool, + cp_size: int, + cp_rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + qkv_format = tex.NVTE_QKV_Format(int(qkv_format)) if qkv_format is not None else None + return tex.fused_qkv_rope_backward( + q_grad_out, + k_grad_out, + v_grad_out, + q_freqs, + k_freqs, + qkv_split_arg_list, + qkv_format, + interleaved, + cp_size, + cp_rank, + ) + + # fused router + def fused_topk_with_score_function_fwd( + self, + logits: torch.Tensor, + topk: int, + use_pre_softmax: bool, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], + score_function: str, + expert_bias: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_topk_with_score_function_fwd( + logits, + topk, + use_pre_softmax, + num_groups, + group_topk, + scaling_factor, + score_function, + expert_bias, + ) + + def fused_topk_with_score_function_bwd( + self, + num_tokens: int, + num_experts: int, + routing_map: torch.Tensor, + intermediate_output: torch.Tensor, + grad_probs: torch.Tensor, + topk: int, + use_pre_softmax: bool, + scaling_factor: Optional[float], + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_topk_with_score_function_bwd( + num_tokens, + num_experts, + routing_map, + intermediate_output, + grad_probs, + topk, + use_pre_softmax, + scaling_factor, + score_function, + ) + + def fused_score_for_moe_aux_loss_fwd( + self, + logits: torch.Tensor, + topk: int, + score_function: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_fwd( + logits, + topk, + score_function, + ) + + def fused_score_for_moe_aux_loss_bwd( + self, + num_tokens: int, + num_experts: int, + intermediate_output: torch.Tensor, + grad_scores: torch.Tensor, + topk: int, + score_function: str, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_score_for_moe_aux_loss_bwd( + num_tokens, + num_experts, + intermediate_output, + grad_scores, + topk, + score_function, + ) + + def fused_moe_aux_loss_fwd( + self, + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + total_num_tokens: int, + num_experts: int, + num_rows: int, + num_cols: int, + topk: int, + coeff: float, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.fused_moe_aux_loss_fwd( + probs, + tokens_per_expert, + total_num_tokens, + num_experts, + num_rows, + num_cols, + topk, + coeff, + ) + + def fused_moe_aux_loss_bwd( + self, + Const_buf: torch.Tensor, + tokens_per_expert: torch.Tensor, + num_rows: int, + num_cols: int, + grad_aux_loss: torch.Tensor, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.fused_moe_aux_loss_bwd( + Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss + ) + + # Dropout + def dropout_fwd( + self, + input: torch.Tensor, + dropout_probability: float, + out: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.dropout_fwd(input, dropout_probability, out) + + def dropout_bwd( + self, + grad_output: torch.Tensor, + mask: torch.Tensor, + dropout_probability: float, + grad_input: Optional[torch.Tensor], + ) -> torch.Tensor: + tex = self._get_tex() + return tex.dropout_bwd(grad_output, mask, dropout_probability, grad_input) + + # Misc + def get_cublasLt_version(self) -> int: + tex = self._get_tex() + return tex.get_cublasLt_version() + + def get_cudnn_version(self) -> int: + tex = self._get_tex() + return tex.get_cudnn_version() + + def get_num_cublas_streams(self) -> int: + tex = self._get_tex() + return tex.get_num_cublas_streams() + + # Support THD format for Context Parallel + def thd_read_half_tensor( + self, + tensor: torch.Tensor, + cu_seqlens: torch.Tensor, + half_idx: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_half_tensor(tensor, cu_seqlens, half_idx) + + def thd_second_half_lse_correction( + self, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_second_half_lse_correction(lse, lse_per_step, cu_seqlens, lse_packed) + + def thd_read_second_half_lse( + self, + lse: torch.Tensor, + cu_seqlens: torch.Tensor, + lse_packed: bool, + second_half_lse_seqlen: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_read_second_half_lse(lse, cu_seqlens, lse_packed, second_half_lse_seqlen) + + def thd_out_correction( + self, + out: torch.Tensor, + out_per_step: torch.Tensor, + lse: torch.Tensor, + lse_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + only_second_half: bool, + lse_packed: bool, + ) -> None: + tex = self._get_tex() + return tex.thd_out_correction( + out, out_per_step, lse, lse_per_step, cu_seqlens, only_second_half, lse_packed + ) + + def thd_grad_correction( + self, + grad: torch.Tensor, + grad_per_step: torch.Tensor, + cu_seqlens: torch.Tensor, + first_half: str, + second_half: str, + ) -> None: + tex = self._get_tex() + return tex.thd_grad_correction(grad, grad_per_step, cu_seqlens, first_half, second_half) + + def thd_get_partitioned_indices( + self, + cu_seqlens: torch.Tensor, + total_tokens: int, + world_size: int, + rank: int, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.thd_get_partitioned_indices(cu_seqlens, total_tokens, world_size, rank) + + # nvshmem functions + def init_nvshmem_backend( + self, + process_group: Any, + ) -> None: + tex = self._get_tex() + return tex.init_nvshmem_backend(process_group) + + def create_nvshmem_tensor( + self, + shape: List[int], + dtype: torch.dtype, + ) -> torch.Tensor: + tex = self._get_tex() + return tex.create_nvshmem_tensor(shape, dtype) + + def nvshmem_send_on_current_stream( + self, + src: torch.Tensor, + dst: torch.Tensor, + peer: int, + signal: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_send_on_current_stream(src, dst, peer, signal) + + def nvshmem_wait_on_current_stream( + self, + signal: torch.Tensor, + wait_kind: str, + ) -> None: + tex = self._get_tex() + return tex.nvshmem_wait_on_current_stream(signal, wait_kind) + + def nvshmem_finalize(self) -> None: + tex = self._get_tex() + return tex.nvshmem_finalize() + + # multi-tensor functions + def multi_tensor_scale( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_scale(chunk_size, noop_flag, tensor_lists, scale) + + def multi_tensor_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_l2norm(chunk_size, noop_flag, tensor_lists, per_tensor) + + def multi_tensor_unscale_l2norm( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + inv_scale: torch.Tensor, + per_tensor: Optional[bool] = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + tex = self._get_tex() + return tex.multi_tensor_unscale_l2norm( + chunk_size, noop_flag, tensor_lists, inv_scale, per_tensor + ) + + def multi_tensor_adam( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_param_remainder( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_param_remainder( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + ) + + def multi_tensor_adam_fp8( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: float, + beta1: float, + beta2: float, + epsilon: float, + step: int, + mode: int, + bias_correction: int, + weight_decay: float, + fp8_dtype: DType, + ) -> None: + tex = self._get_tex() + fp8_dtype = tex.DType(int(fp8_dtype)) if fp8_dtype is not None else None + return tex.multi_tensor_adam_fp8( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + fp8_dtype, + ) + + def multi_tensor_adam_capturable( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_adam_capturable_master( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + lr: torch.Tensor, + beta1: float, + beta2: float, + epsilon: float, + step: torch.Tensor, + mode: int, + bias_correction: int, + weight_decay: float, + inv_scale: torch.Tensor, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_adam_capturable_master( + chunk_size, + noop_flag, + tensor_lists, + lr, + beta1, + beta2, + epsilon, + step, + mode, + bias_correction, + weight_decay, + inv_scale, + ) + + def multi_tensor_sgd( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + wd: float, + momentum: float, + dampening: float, + lr: float, + nesterov: bool, + first_run: bool, + wd_after_momentum: bool, + scale: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_sgd( + chunk_size, + noop_flag, + tensor_lists, + wd, + momentum, + dampening, + lr, + nesterov, + first_run, + wd_after_momentum, + scale, + ) + + def multi_tensor_compute_scale_and_scale_inv( + self, + chunk_size: int, + noop_flag: torch.Tensor, + tensor_lists: List[List[torch.Tensor]], + max_fp8: float, + force_pow_2_scales: bool, + epsilon: float, + ) -> None: + tex = self._get_tex() + return tex.multi_tensor_compute_scale_and_scale_inv( + chunk_size, noop_flag, tensor_lists, max_fp8, force_pow_2_scales, epsilon + ) + + # Comm+GEMM Overlap + def bulk_overlap_ag_with_external_gemm( + self, + allgather_communicator: CommOverlap, + send_stream: Any, + recv_stream: Any, + ) -> Any: + tex = self._get_tex() + return tex.bulk_overlap_ag_with_external_gemm( + allgather_communicator, send_stream, recv_stream + ) + + ############## class func ################################# + def get_flash_attention_class(self): + from .flash_attention import FlashAttentionENFLAME + + return FlashAttentionENFLAME + + def create_fp8_tensor_meta(self) -> FP8TensorMeta: + tex = self._get_tex() + return tex.FP8TensorMeta() + + def create_comm_overlap_helper( + self, + world_group: Optional[Any] = None, + intra_node_group: Optional[Any] = None, + ) -> "CommOverlapHelper": + tex = self._get_tex() + return tex.CommOverlapHelper(world_group, intra_node_group) + + def create_comm_overlap( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + num_splits: int = 3, + num_max_streams: int = 3, + comm_cga_size: int = 2, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 16, + set_sm_margin: bool = True, + atomic_gemm: bool = False, + rs_overlap_first_gemm: bool = False, + ) -> "CommOverlap": + tex = self._get_tex() + return tex.CommOverlap( + buffer_shape, + buffer_dtype, + helper, + tp_size, + num_splits, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + rs_overlap_first_gemm, + ) + + def create_comm_overlap_p2p( + self, + buffer_shape: List[int], + buffer_dtype: torch.dtype, + helper: Any, + tp_size: int, + comm_type: Any, + num_max_streams: int = 3, + comm_cga_size: int = 1, + gemm_priority: int = 0, + comm_priority: int = 0, + num_comm_sm: int = 1, + set_sm_margin: bool = False, + atomic_gemm: bool = False, + use_ce: bool = True, + aggregate: bool = False, + ) -> "CommOverlapP2P": + tex = self._get_tex() + return tex.CommOverlapP2P( + buffer_shape, + buffer_dtype, + helper, + tp_size, + comm_type, + num_max_streams, + comm_cga_size, + gemm_priority, + comm_priority, + num_comm_sm, + set_sm_margin, + atomic_gemm, + use_ce, + aggregate, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/enflame/flash_attention.py b/transformer_engine/plugin/core/backends/vendor/enflame/flash_attention.py new file mode 100755 index 0000000000..c7ef5cbdd2 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/enflame/flash_attention.py @@ -0,0 +1,127 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import nullcontext +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from transformer_engine.plugin.core.ops import FlashAttentionBase + + +class FlashAttentionMETAX(FlashAttentionBase): + def __init__( + self, + softmax_scale: float, + attention_dropout: float = 0.0, + attention_dropout_ctx: Optional[Callable] = None, + attention_type: str = "self", + layer_number: Optional[int] = None, + deterministic: bool = False, + ) -> None: + super().__init__( + softmax_scale=softmax_scale, + attention_dropout=attention_dropout, + attention_dropout_ctx=attention_dropout_ctx, + attention_type=attention_type, + layer_number=layer_number, + deterministic=deterministic, + ) + + # Store initialization parameters for lazy loading + self._init_params = { + "softmax_scale": softmax_scale, + "attention_dropout": attention_dropout, + "attention_dropout_ctx": attention_dropout_ctx or nullcontext, + "attention_type": attention_type, + "layer_number": layer_number, + "deterministic": deterministic, + } + self._enflame_flash_attn = None + + def _ensure_enflame_flash_attn(self): + """Lazy initialization of enflame FlashAttention.""" + if self._enflame_flash_attn is not None: + return + + try: + # Import here to avoid circular dependency issues + # transformer_engine_torch must be registered before this import + from migration.patches.transformer_engine.v2_9_0.pytorch.attention.dot_product_attention.backends import ( + FlashAttention as FlashAttentionEnflame, + ) + + if FlashAttentionEnflame is None: + raise RuntimeError( + "FlashAttention class is None - flash-attn may not be installed correctly" + ) + + self._enflame_flash_attn = FlashAttentionEnflame(**self._init_params) + + except ImportError as e: + raise RuntimeError( + f"Failed to import enflame FlashAttention: {e}. " + "Please ensure flash-attn is installed and transformer_engine_torch is available." + ) + except Exception as e: + raise RuntimeError( + f"Failed to initialize enflame FlashAttention: {e}. Init params: {self._init_params}" + ) + + @property + def backend_name(self) -> str: + return "enflame" + + def _forward_impl( + self, + query_layer: torch.Tensor, + key_layer: torch.Tensor, + value_layer: torch.Tensor, + attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, + qkv_layout: str = "sbh3d", + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_kv: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + max_seqlen_kv: Optional[int] = None, + attn_mask_type: str = "causal", + window_size: Optional[Tuple[int, int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + cp_group: Optional[Any] = None, + cp_global_ranks: Optional[List[int]] = None, + cp_stream: Optional[torch.cuda.Stream] = None, + cp_comm_type: str = "p2p", + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers: Optional[Any] = None, + inference_params: Optional[Any] = None, + flash_attention_backend: Optional[Any] = None, + fp8_output: bool = False, + ) -> torch.Tensor: + # Ensure enflame flash attention is initialized + self._ensure_enflame_flash_attn() + + return self._enflame_flash_attn( + query_layer=query_layer, + key_layer=key_layer, + value_layer=value_layer, + attention_mask=attention_mask, + qkv_layout=qkv_layout, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, + attn_mask_type=attn_mask_type, + window_size=window_size, + alibi_slopes=alibi_slopes, + cp_group=cp_group, + cp_global_ranks=cp_global_ranks, + cp_stream=cp_stream, + cp_comm_type=cp_comm_type, + fp8=fp8, + fp8_meta=fp8_meta, + quantizers=quantizers, + inference_params=inference_params, + flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, + ) diff --git a/transformer_engine/plugin/core/backends/vendor/enflame/register_ops.py b/transformer_engine/plugin/core/backends/vendor/enflame/register_ops.py new file mode 100755 index 0000000000..53744e4d66 --- /dev/null +++ b/transformer_engine/plugin/core/backends/vendor/enflame/register_ops.py @@ -0,0 +1,956 @@ +# Copyright (c) 2025, BAAI. All rights reserved. +# +# See LICENSE for license information. + +""" +Enflame vendor backend operator registrations. + +This module registers all VENDOR (Enflame vendor backend operator registrations. +) implementations from transformer_engine_torch. +""" + +from __future__ import annotations + +import functools + +from ....types import OpImpl, BackendImplKind + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all Enflame (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + # Import Enflame backend to get all the wrapped tex functions + from .enflame import EnflameBackend + + # Create a backend instance to access the methods + backend = EnflameBackend() + + # Check if Enflame is available before registering + if not backend.is_available(): + return + + # Bind is_available to all methods + is_avail = backend.is_available + + impls = [ + # Normalization + OpImpl( + op_name="rmsnorm_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="rmsnorm_bwd_add", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.rmsnorm_bwd_add, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="layernorm_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="layernorm_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.layernorm_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + # GEMM + OpImpl( + op_name="generic_gemm", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.generic_gemm, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="te_general_grouped_gemm", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.te_general_grouped_gemm, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Quantization + OpImpl( + op_name="quantize", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.quantize, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dequantize", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dequantize, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="bgrad_quantize", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bgrad_quantize, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="split_quantize", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.split_quantize, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Activations - Forward + OpImpl( + op_name="gelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.gelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="geglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.geglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="qgelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="qgeglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.qgeglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="relu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.relu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="reglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.reglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="srelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.srelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="sreglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.sreglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="silu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.silu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="swiglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swiglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="clamped_swiglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_swiglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Activations - Backward + OpImpl( + op_name="dgelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dgeglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dgeglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dqgelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dqgeglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dqgeglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="drelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.drelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dreglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dreglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dsrelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsrelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dsreglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsreglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dsilu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dsilu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dswiglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dswiglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="clamped_dswiglu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.clamped_dswiglu, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Activations - Bias + Backward + OpImpl( + op_name="dbias_dgelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dgelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dbias_dsilu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsilu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dbias_drelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_drelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dbias_dqgelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dqgelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dbias_dsrelu", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dbias_dsrelu, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Softmax + OpImpl( + op_name="scaled_softmax_forward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_forward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_softmax_backward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_softmax_backward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_forward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_forward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_masked_softmax_backward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_masked_softmax_backward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_forward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_forward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_upper_triang_masked_softmax_backward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_upper_triang_masked_softmax_backward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_forward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_forward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="scaled_aligned_causal_masked_softmax_backward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.scaled_aligned_causal_masked_softmax_backward, is_avail), + vendor="ENFLAME", + priority=100, + ), + # MOE operations + OpImpl( + op_name="moe_permute_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="moe_permute_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_permute_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="moe_unpermute_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.moe_unpermute_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Fused attention + OpImpl( + op_name="get_fused_attn_backend", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_fused_attn_backend, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_attn_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_attn_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_attn_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fa_prepare_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fa_prepare_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fa_prepare_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + # KV cache + OpImpl( + op_name="copy_to_kv_cache", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.copy_to_kv_cache, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Tensor format conversions + OpImpl( + op_name="convert_thd_to_bshd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_thd_to_bshd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="convert_bshd_to_thd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.convert_bshd_to_thd, is_avail), + vendor="ENFLAME", + priority=100, + ), + # RoPE (Rotary Position Embedding) + OpImpl( + op_name="fused_rope_forward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_forward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_rope_backward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_rope_backward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_forward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_forward, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_qkv_rope_backward", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_qkv_rope_backward, is_avail), + vendor="ENFLAME", + priority=100, + ), + # TopK and MOE aux loss + OpImpl( + op_name="fused_topk_with_score_function_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_topk_with_score_function_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_topk_with_score_function_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_score_for_moe_aux_loss_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_score_for_moe_aux_loss_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_moe_aux_loss_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_moe_aux_loss_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Dropout + OpImpl( + op_name="dropout_fwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_fwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="dropout_bwd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.dropout_bwd, is_avail), + vendor="ENFLAME", + priority=100, + ), + # FP8 operations + OpImpl( + op_name="fp8_transpose", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_transpose, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="swap_first_dims", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.swap_first_dims, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="compute_amax", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.compute_amax, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_amax_and_scale_update_after_reduction", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_amax_and_scale_update_after_reduction, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_compute_partial_amax", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_compute_partial_amax, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fp8_block_scaling_partial_cast", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fp8_block_scaling_partial_cast, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Padding operations + OpImpl( + op_name="fused_multi_row_padding", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_padding, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="fused_multi_row_unpadding", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.fused_multi_row_unpadding, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Library version getters + OpImpl( + op_name="get_cublasLt_version", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cublasLt_version, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="get_cudnn_version", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_cudnn_version, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="get_num_cublas_streams", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_num_cublas_streams, is_avail), + vendor="ENFLAME", + priority=100, + ), + # THD (Tensor, Hidden, Dimension) operations + OpImpl( + op_name="thd_read_half_tensor", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_half_tensor, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="thd_second_half_lse_correction", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_second_half_lse_correction, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="thd_read_second_half_lse", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_read_second_half_lse, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="thd_out_correction", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_out_correction, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="thd_grad_correction", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_grad_correction, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="thd_get_partitioned_indices", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.thd_get_partitioned_indices, is_avail), + vendor="ENFLAME", + priority=100, + ), + # NVSHMEM operations + OpImpl( + op_name="init_nvshmem_backend", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.init_nvshmem_backend, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="create_nvshmem_tensor", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_nvshmem_tensor, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="nvshmem_send_on_current_stream", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_send_on_current_stream, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="nvshmem_wait_on_current_stream", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_wait_on_current_stream, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="nvshmem_finalize", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.nvshmem_finalize, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Multi-tensor operations + OpImpl( + op_name="multi_tensor_quantize", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_quantize, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_scale", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_scale, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_l2norm", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_l2norm, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_unscale_l2norm", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_unscale_l2norm, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_param_remainder", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_param_remainder, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_fp8", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_fp8, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_adam_capturable_master", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_adam_capturable_master, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_sgd", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_sgd, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="multi_tensor_compute_scale_and_scale_inv", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.multi_tensor_compute_scale_and_scale_inv, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Communication overlap operations + OpImpl( + op_name="bulk_overlap_ag_with_external_gemm", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.bulk_overlap_ag_with_external_gemm, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="create_fp8_tensor_meta", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_fp8_tensor_meta, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_helper", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_helper, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap, is_avail), + vendor="ENFLAME", + priority=100, + ), + OpImpl( + op_name="create_comm_overlap_p2p", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.create_comm_overlap_p2p, is_avail), + vendor="ENFLAME", + priority=100, + ), + # FlashAttention class getter + OpImpl( + op_name="get_flash_attention_class", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_flash_attention_class, is_avail), + vendor="ENFLAME", + priority=100, + ), + # Attention backend selection + OpImpl( + op_name="get_attention_backend", + impl_id="vendor.enflame", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.get_attention_backend, is_avail), + vendor="ENFLAME", + priority=100, + ), + ] + + registry.register_many(impls) diff --git a/transformer_engine/plugin/core/builtin_ops.py b/transformer_engine/plugin/core/builtin_ops.py index c991d4fc51..8735958f7d 100644 --- a/transformer_engine/plugin/core/builtin_ops.py +++ b/transformer_engine/plugin/core/builtin_ops.py @@ -95,3 +95,12 @@ def register_builtins(registry: OpRegistry) -> None: except Exception as e: # MUSA may not be available, this is expected pass + # Register enflame (VENDOR) implementations + try: + from .backends.vendor.enflame.register_ops import register_builtins as register_enflame + + register_enflame(registry) + except Exception as e: + # enflame may not be available, this is expected + raise +