From 8998280c7206850438796d03b1ef6c17addcdee9 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Fri, 12 Jun 2026 22:13:08 +0000 Subject: [PATCH 01/22] start draft --- build_tools/pytorch.py | 28 +- pyproject.toml | 5 +- setup.py | 3 + transformer_engine/common/CMakeLists.txt | 34 + transformer_engine/common/CuTeDSL/__init__.py | 19 + .../common/CuTeDSL/cast/__init__.py | 8 + .../common/CuTeDSL/cast/mxfp8/__init__.py | 8 + .../common/CuTeDSL/cast/mxfp8/mxfp8_utils.py | 991 ++++++++++++++++ .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 1021 +++++++++++++++++ transformer_engine/common/CuTeDSL/utils.py | 16 + .../common/cast/dispatch/quantize.cuh | 27 +- .../cast/mxfp8/quantize_mxfp8_cutedsl.cuh | 251 ++++ transformer_engine/common/tvm_ffi_bridge.h | 259 +++++ transformer_engine/pytorch/__init__.py | 10 + 14 files changed, 2673 insertions(+), 7 deletions(-) create mode 100644 transformer_engine/common/CuTeDSL/__init__.py create mode 100644 transformer_engine/common/CuTeDSL/cast/__init__.py create mode 100644 transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py create mode 100644 transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py create mode 100644 transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py create mode 100644 transformer_engine/common/CuTeDSL/utils.py create mode 100644 transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh create mode 100644 transformer_engine/common/tvm_ffi_bridge.h diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index e2e6d09c29..ee35e70df0 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -4,6 +4,7 @@ """PyTorch related extensions.""" +import importlib.util import os from pathlib import Path from importlib import metadata @@ -58,6 +59,25 @@ def setup_pytorch_extension( ] ) + # apache-tvm-ffi: headers for the C++ API (Module / Function / TensorView) + # and libtvm_ffi.so for symbol resolution. Used by tvm_ffi_bridge.h / + # applyTVMFunction. Python registers AOT-compiled CuTeDSL kernels into + # the global registry; TE C++ looks them up via Function::GetGlobalRequired. + tvm_ffi_spec = importlib.util.find_spec("tvm_ffi") + if tvm_ffi_spec is None or not tvm_ffi_spec.submodule_search_locations: + raise RuntimeError( + "apache-tvm-ffi package not found; install it (e.g. " + "`pip install apache-tvm-ffi`) — required for the TVM FFI bridge." + ) + tvm_ffi_root = Path(tvm_ffi_spec.submodule_search_locations[0]) + tvm_ffi_include = tvm_ffi_root / "include" + tvm_ffi_lib_dir = tvm_ffi_root / "lib" + if not tvm_ffi_include.is_dir() or not (tvm_ffi_lib_dir / "libtvm_ffi.so").exists(): + raise RuntimeError( + f"apache-tvm-ffi assets missing at {tvm_ffi_root} (need include/ and lib/libtvm_ffi.so)" + ) + include_dirs.append(tvm_ffi_include) + # Compiler flags cxx_flags = ["-O3", "-fvisibility=hidden"] if debug_build_enabled(): @@ -77,8 +97,11 @@ def setup_pytorch_extension( setup_mpi_flags(include_dirs, cxx_flags) - library_dirs = [] - libraries = [] + library_dirs = [tvm_ffi_lib_dir] + libraries = ["tvm_ffi"] + # rpath pinned to the pip install dir so the loader finds libtvm_ffi.so + # without LD_LIBRARY_PATH at runtime. + extra_link_args = [f"-Wl,-rpath,{tvm_ffi_lib_dir}"] if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): assert ( os.getenv("NVSHMEM_HOME") is not None @@ -102,6 +125,7 @@ def setup_pytorch_extension( sources=[str(src) for src in sources], include_dirs=[str(inc) for inc in include_dirs], extra_compile_args={"cxx": cxx_flags}, + extra_link_args=extra_link_args, libraries=[str(lib) for lib in libraries], library_dirs=[str(lib_dir) for lib_dir in library_dirs], ) diff --git a/pyproject.toml b/pyproject.toml index 4a8fded172..826a0e54a7 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,10 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +# apache-tvm-ffi is required at configure/compile/link time: the common C++ +# library finds it via find_package(tvm_ffi) and links libtvm_ffi.so (the +# CuTeDSL quant backend bridge). It is also a runtime dependency (see setup.py). +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1", "apache-tvm-ffi>=0.1.12"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/setup.py b/setup.py index 64ed120268..ed6fe977b4 100644 --- a/setup.py +++ b/setup.py @@ -136,6 +136,9 @@ def setup_requirements() -> Tuple[List[str], List[str]]: "importlib-metadata>=1.0", "packaging", cusolvermp_pypi_package_name(), + # The core C++ library links libtvm_ffi.so (CuTeDSL quant backend bridge), + # so apache-tvm-ffi is required at runtime by every TE install. + "apache-tvm-ffi>=0.1.12", ] test_reqs: List[str] = ["pytest>=8.2.1"] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index edb8c5e109..faebee00f4 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -106,6 +106,24 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +# tvm-ffi: the quantize dispatch layer bridges to JIT-compiled CuTeDSL kernels +# through tvm-ffi (see common/tvm_ffi_bridge.h). Locate the tvm_ffi package that +# ships with the Python install and use its exported CMake config (provides the +# tvm_ffi::shared imported target with headers + libtvm_ffi.so). +execute_process( + COMMAND ${Python_EXECUTABLE} -c "import tvm_ffi.libinfo as li; print(li.find_cmake_path())" + OUTPUT_VARIABLE TVM_FFI_CMAKE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE TVM_FFI_CMAKE_QUERY) +if(NOT TVM_FFI_CMAKE_QUERY EQUAL 0) + message(FATAL_ERROR + "Could not import the tvm_ffi Python package (with '${Python_EXECUTABLE}'), " + "which Transformer Engine requires to build the CuTeDSL quantize backend " + "bridge (common/tvm_ffi_bridge.h). Install it into this Python environment: " + "`pip install apache-tvm-ffi`.") +endif() +find_package(tvm_ffi CONFIG REQUIRED PATHS "${TVM_FFI_CMAKE_DIR}") + function(find_nccl_version OUT_VERSION OUT_INCLUDE_DIR) find_path(_nvte_nccl_include_dir NAMES nccl.h @@ -360,6 +378,22 @@ target_link_libraries(transformer_engine PUBLIC CUDA::cudart CUDNN::cudnn_all) +# CuTeDSL quantize backend bridge. PRIVATE: tvm_ffi_bridge.h is an internal +# header (not in the installed public include dir), so the symbols and headers +# are only needed to compile transformer_engine itself, not by downstream +# consumers. The INTERFACE include dirs of tvm_ffi::shared still apply to our +# own TUs, which is what fixes the not-found error. +target_link_libraries(transformer_engine PRIVATE tvm_ffi::shared) + +# libtvm_ffi.so ships inside the tvm_ffi Python package (not a system lib dir), +# so add its directory to the RPATH; otherwise the runtime loader can't satisfy +# the DT_NEEDED on libtvm_ffi.so and dlopen fails with "cannot open shared +# object file". Applied to both the build tree and the installed library. +get_target_property(TVM_FFI_SHARED_LOCATION tvm_ffi::shared IMPORTED_LOCATION) +get_filename_component(TVM_FFI_LIB_DIR "${TVM_FFI_SHARED_LOCATION}" DIRECTORY) +set_property(TARGET transformer_engine APPEND PROPERTY BUILD_RPATH "${TVM_FFI_LIB_DIR}") +set_property(TARGET transformer_engine APPEND PROPERTY INSTALL_RPATH "${TVM_FFI_LIB_DIR}") + target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE diff --git a/transformer_engine/common/CuTeDSL/__init__.py b/transformer_engine/common/CuTeDSL/__init__.py new file mode 100644 index 0000000000..5621c01e64 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""CuTeDSL kernels for Transformer Engine. + +Importing this package has a side effect: it registers the CuTeDSL kernel +entrypoints (e.g. ``get_mxfp8_quantization_function``) as TVM-FFI global +functions. The C++ dispatcher probes for those names via +``tvm::ffi::Function::GetGlobal`` — finding one means the process is running +inside a Python environment with the CuTeDSL toolchain available, so the kernel +may be compiled on demand; not finding it means a plain C++ environment, and +the dispatcher falls back to the CUDA C++ kernel. + +Importing requires the optional CuTeDSL toolchain (cutlass, tvm_ffi). Callers +that want graceful degradation should guard the import in a try/except. +""" + +from . import cast # noqa: F401 (import side effect: registers global funcs) diff --git a/transformer_engine/common/CuTeDSL/cast/__init__.py b/transformer_engine/common/CuTeDSL/cast/__init__.py new file mode 100644 index 0000000000..c4890ee489 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""CuTeDSL cast/quantization kernels. Importing pulls in each kernel module so +its TVM-FFI entrypoint is registered.""" + +from . import mxfp8 # noqa: F401 (import side effect: registers global funcs) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py new file mode 100644 index 0000000000..c42df11c01 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MXFP8 CuTeDSL kernels. Importing ``quantize_mxfp8`` runs its module body, +which registers the ``get_mxfp8_quantization_function`` TVM-FFI global func.""" + +from . import quantize_mxfp8 # noqa: F401 (import side effect: registers the global func) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py new file mode 100644 index 0000000000..0ee407f6b9 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py @@ -0,0 +1,991 @@ +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cutlass._mlir.dialects import arith as mlir_arith +from cutlass._mlir.dialects import llvm +from cutlass.base_dsl.compiler import GPUArch +from cutlass.cute.runtime import make_ptr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass.cute.arch import cvt_f32_bf16 + +from types import SimpleNamespace + +# FP8E4M3 max representable value +FP8E4M3_MAX_NORM = 448.0 +FP8E4M3_MAX_NORM_RCP = 1.0 / FP8E4M3_MAX_NORM +FP8E5M2_MAX_NORM = 57344.0 +FP8E5M2_MAX_NORM_RCP = 1.0 / FP8E5M2_MAX_NORM + +# NVFP4 (fp4e2m1) — 4-bit float, max representable value is 6.0 +FP4_E2M1_MAX = 6.0 +FP4_E2M1_MAX_RCP = 1.0 / FP4_E2M1_MAX +# Largest finite f32 — used to clamp the per-block scale inverse against +# division-by-zero (which produces +inf and then NaN downstream). +FP32_MAX = 3.4028234663852886e38 + +FP32_MANTISSA_BITS = 23 + + +@dsl_user_op +def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: + return Int32(mlir_arith.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def _bitcast_i32_to_f32(val: Int32, *, loc=None, ip=None) -> Float32: + return Float32(mlir_arith.bitcast(T.f32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def fabs_f32(val: Float32, *, loc=None, ip=None) -> Float32: + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + abs_i32 = val_i32 & Int32(0x7FFFFFFF) + return _bitcast_i32_to_f32(abs_i32, loc=loc, ip=ip) + + +@dsl_user_op +def float_to_e8m0(val: Float32, *, loc=None, ip=None) -> Int32: + """Branchless float->E8M0: add mantissa mask to round up, clamp to 254.""" + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + rounded = val_i32 + Int32(0x7FFFFF) + exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) + return Int32(mlir_arith.minsi( + exponent.ir_value(loc=loc, ip=ip), + Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: + """2^(127 - biased_exp) with special-case handling.""" + new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) + result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) + for (cmp_val, repl_bits) in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: + cond = mlir_arith.cmpi(mlir_arith.CmpIPredicate.eq, + biased_exp.ir_value(loc=loc, ip=ip), + Int32(cmp_val).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) + result = Float32(mlir_arith.select( + cond, alt.ir_value(loc=loc, ip=ip), + result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result + + +@dsl_user_op +def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e4m3fn via PTX cvt.rn.satfinite.e4m3x2.f32.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e5m2 via PTX cvt.rn.satfinite.e5m2x2.f32.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32: + """`fma.rn.f32 d, a, b, c;` — single-instruction fused multiply-add + matching nvcc's FFMA. Used for explicit `partial += a * b` patterns + where we need the same rounding as TE's compiler-fused FFMA.""" + return Float32(llvm.inline_asm( + T.f32(), + [a.ir_value(loc=loc, ip=ip), + b.ir_value(loc=loc, ip=ip), + c.ir_value(loc=loc, ip=ip)], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def tanh_approx(val: Float32, *, loc=None, ip=None) -> Float32: + """`tanh.approx.f32` — fast tanh approximation. Matches CUDA `__tanhf`.""" + return Float32(llvm.inline_asm( + T.f32(), + [val.ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: + """Pack two f32 scalars into a single 64-bit register (`floatx2` layout). + + Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` + which lowers to a single register move — no actual memory traffic. + """ + return Int64(llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def pack_i32x2(lo: Int32, hi: Int32, *, loc=None, ip=None) -> Int64: + """i32 sibling of `pack_f32x2` — concat two i32 into a single b64 register. + Used by NVFP4 to glue two `(bf16,bf16)`/`(f16,f16)` Int32 packs into the + `Int64` operand the `mul_cvt.*x4` PTX expects.""" + return Int64(llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def _trunc_i32_to_i16(val: Int32, *, loc=None, ip=None) -> Int16: + """Narrow an Int32 to Int16 by keeping the low 16 bits. + + Lives here because the existing arith-dialect narrowing pattern requires + loc/ip kwargs (see other `mlir_arith.trunci` callers); wrapping it as a + `@dsl_user_op` lets `@cute.jit` bodies use it without plumbing those in.""" + return Int16(mlir_arith.trunci( + T.i16(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: + """One fp8e4m3 byte (low 8 bits of `byte_i32`) → f32. + + PTX has no direct `cvt.f32.e4m3` for a scalar; route through the packed + `cvt.rn.f16x2.e4m3x2` and then `cvt.f32.f16`. The high byte of the .b16 + register is forced to zero so the discarded high f16 lane is well-defined.""" + asm = ( + "{\n" + ".reg .b32 masked; .reg .b16 b16; .reg .b16 b16_hi;\n\t" + ".reg .b32 f16pair; .reg .b16 lo_f16; .reg .b16 hi_f16;\n\t" + "and.b32 masked, $1, 0xFF;\n\t" + "mov.b32 {b16, b16_hi}, masked;\n\t" + "cvt.rn.f16x2.e4m3x2 f16pair, b16;\n\t" + "mov.b32 {lo_f16, hi_f16}, f16pair;\n\t" + "cvt.f32.f16 $0, lo_f16;\n\t" + "}" + ) + return Float32(llvm.inline_asm( + T.f32(), + [byte_i32.ir_value(loc=loc, ip=ip)], + asm, + "=f,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + +# --------------------------------------------------------------------------- +# 16-bit packed input PTX kit (bf16 / f16) +# +# bf16 and f16 share the same fast-path shape: packed-x2 amax via +# `max.xorsign.abs.x2`, then per-lane widen-to-f32 + `mul.f32x2` + +# `cvt.rn.satfinite.x2.f32`. Only the opcodes differ. Build one PTX kit +# per format at module load and let the kernel pick the right kit at JIT +# trace time via `cfg.DTYPE` — equivalent to a C++ template arg specialization +# on `IType`, with no runtime branch. +# --------------------------------------------------------------------------- +def _build_packed16_kit(in_fmt: str): + """Build a kit of PTX wrappers for a 16-bit input format. + + `in_fmt` is the PTX format string ('bf16' or 'f16'). Returns a namespace + with the per-format ops the rowwise/colwise inner loops need: + + abs_max_x2(Int32, Int32) -> Int32 # `max.xorsign.abs.x2` + abs_max_scalar(Int16, Int16) -> Int16 # `max.xorsign.abs.` + bits_to_f32(Int16) -> Float32 # widen one 16-bit element + x2_lo_to_f32(Int32) -> Float32 # extract+widen low half + x2_hi_to_f32(Int32) -> Float32 # extract+widen high half + mul_cvt_to_fp8x2(fp8_dtype) -> callable(Int32, Int64)->Int32 + # fused x2 * f32x2 -> fp8x2 + """ + + @dsl_user_op + def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32(llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32(llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: + return Int16(llvm.inline_asm( + T.i16(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt} $0, $1, $2;", + "=h,h,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + if in_fmt == "bf16": + # bf16 == top 16 bits of f32 — widening is a free bit-shift. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + i32 = Int32(mlir_arith.extui( + T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + return _bitcast_i32_to_f32( + (bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal + # issues. Sign bits from the arith-right shift get zeroed by the + # left shift. + return _bitcast_i32_to_f32( + (bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to bf16 precision (round-to-nearest-even), keep f32. + Matches C++'s `static_cast(static_cast(elt))`.""" + bf16_bits = Int16(llvm.inline_asm( + T.i16(), [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + i32 = Int32(mlir_arith.extui( + T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + else: + # f16 has its own bit layout; widening requires `cvt.f32.f16`. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + return Float32(llvm.inline_asm( + T.f32(), [bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + lo_i16 = Int16(mlir_arith.trunci( + T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return bits_to_f32(lo_i16, loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + hi_shifted = bits >> Int32(16) + hi_i16 = Int16(mlir_arith.trunci( + T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return bits_to_f32(hi_i16, loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to f16 precision, keep f32.""" + f16_bits = Int16(llvm.inline_asm( + T.i16(), [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16.f32 $0, $1;", + "=h,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32(llvm.inline_asm( + T.f32(), [f16_bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + def _build_mul_cvt(out_fmt: str, relu: bool = False): + """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. + + The shape is identical across (in_fmt, out_fmt) combos — only the + widening opcode (`cvt.f32.`) and the final saturating cvt + (`cvt.rn.satfinite.x2.f32`) differ. + """ + out_op = "e4m3x2" if out_fmt == "e4m3" else "e5m2x2" + asm = ( + "{\n" + ".reg.b64 vp0; .reg.b64 vp1;\n\t" + ".reg.b32 v1; .reg.b32 v2;\n\t" + ".reg.b16 vb1; .reg.b16 vb2;\n\t" + "mov.b32 {vb1, vb2}, $1;\n\t" + f"cvt.f32.{in_fmt} v1, vb1;\n\t" + f"cvt.f32.{in_fmt} v2, vb2;\n\t" + "mov.b64 vp0, {v1, v2};\n\t" + "mul.f32x2 vp1, vp0, $2;\n\t" + "mov.b64 {v2, v1}, vp1;\n\t" + f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 $0, v1, v2;\n\t" + "}" + ) + + @dsl_user_op + def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_2x.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=h,r,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return fn + + def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): + if fp8_dtype == "e5m2": + return _build_mul_cvt("e5m2", relu) + return _build_mul_cvt("e4m3", relu) + + # NVFP4 fused cast: x4 × f32x2 → fp4e2m1x4 (4 fp4 packed in 16 + # bits). Same shape as `mul_cvt_to_fp8x2` but produces 4 elements at a + # time because the `cvt.rn.satfinite.e2m1x2.f32` PTX consumes pairs and + # writes a single byte (high nibble = first input, low nibble = second). + # The shuffled `mov.b64 {v1, v0}, v01` lines after the muls undo the + # PTX's hi/lo packing so the resulting byte is naturally + # `(fp4(elt1) << 4) | fp4(elt0)` — matches TE's C++ asm. + @dsl_user_op + def mul_cvt_to_fp4x4(in_4x: Int64, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + asm = ( + "{\n" + ".reg.b64 v01; .reg.b64 v23;\n\t" + ".reg.b16 i0; .reg.b16 i1; .reg.b16 i2; .reg.b16 i3;\n\t" + ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" + ".reg.b8 f0; .reg.b8 f1;\n\t" + "mov.b64 {i0, i1, i2, i3}, $1;\n\t" + f"cvt.f32.{in_fmt} v0, i0;\n\t" + f"cvt.f32.{in_fmt} v1, i1;\n\t" + f"cvt.f32.{in_fmt} v2, i2;\n\t" + f"cvt.f32.{in_fmt} v3, i3;\n\t" + "mov.b64 v01, {v0, v1};\n\t" + "mov.b64 v23, {v2, v3};\n\t" + "mul.f32x2 v01, v01, $2;\n\t" + "mul.f32x2 v23, v23, $2;\n\t" + "mov.b64 {v1, v0}, v01;\n\t" + "mov.b64 {v3, v2}, v23;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 $0, {f0, f1, f0, f1};\n\t" + "}" + ) + return Int32(llvm.inline_asm( + T.i32(), + [in_4x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,l,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + return SimpleNamespace( + abs_max_x2=abs_max_x2, + max_x2=max_x2, + abs_max_scalar=abs_max_scalar, + bits_to_f32=bits_to_f32, + x2_lo_to_f32=x2_lo_to_f32, + x2_hi_to_f32=x2_hi_to_f32, + truncate_f32=truncate_f32, + mul_cvt_to_fp8x2=mul_cvt_to_fp8x2, + mul_cvt_to_fp4x4=mul_cvt_to_fp4x4, + ) + + +_BF16_KIT = _build_packed16_kit("bf16") +_F16_KIT = _build_packed16_kit("f16") + + +def _is_packed16(dtype) -> bool: + """True if `dtype` is one of the 16-bit packed input formats.""" + return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 + + +def _packed16_kit(dtype): + """Trace-time selector — pick a Packed16Kit for the input dtype.""" + if dtype is cutlass.Float16: + return _F16_KIT + return _BF16_KIT + + +# --------------------------------------------------------------------------- +# Forward-activation registry +# +# Each entry is a Float32 → Float32 callable applied per element before the +# MXFP8 amax + cast. Selection is by Python string at JIT trace time, so the +# const-expr machinery treats `cfg.ACTIVATION` like a C++ template argument +# — no runtime branch in the inner loop, separate kernel cached per choice. +# +# Math primitives match CUDA fast-math intrinsics so outputs are bit-exact +# with PyTorch's CUDA implementations of the same activations: +# tanh -> tanh.approx.f32 (== __tanhf) +# exp(x) -> exp2.approx.f32(x · log2(e)) (== __expf) +# --------------------------------------------------------------------------- +def _act_relu(x: Float32) -> Float32: + return cute.arch.fmax(x, Float32(0.0)) + + +def _act_gelu(x: Float32) -> Float32: + """Tanh-approximation GELU. Constants and operator grouping match TE's + `transformer_engine/common/util/math.h::gelu` exactly (factored form + `x · (0.5 + 0.5·tanh(x·(a + b·x²)))`) so quantized output is bit-exact + against the C++ fused IS_ACT path. Uses `cute.math.tanh(fastmath=False)` + rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation + kernels without `--use_fast_math` by default, so its `tanhf` is the + IEEE-precise expansion.""" + A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal + B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation + return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) + + +def _act_silu(x: Float32) -> Float32: + """SiLU/Swish: x · σ(x) = x / (1 + e^-x). + Matches TE's `silu` (`val / (1 + expf(-val))`).""" + return x / (Float32(1.0) + cute.arch.exp(-x)) + + +def _act_qgelu(x: Float32) -> Float32: + """Quick GELU: x · σ(1.702·x). Matches TE `qgelu_with_alpha(val, 1.702)` = + `cval · (1 / (1 + expf(-1.702·cval)))` (multiply by sigmoid, not a divide).""" + z = Float32(1.702) * x + return x * (Float32(1.0) / (Float32(1.0) + cute.arch.exp(-z))) + + +def _act_srelu(x: Float32) -> Float32: + """Squared ReLU: x>0 ? x·x : 0 == (max(0,x))². Matches TE `srelu`.""" + r = cute.arch.fmax(x, Float32(0.0)) + return r * r + + +SUPPORTED_ACTIVATIONS = { + "relu": _act_relu, + "gelu": _act_gelu, + "silu": _act_silu, + "qgelu": _act_qgelu, + "srelu": _act_srelu, +} + + +# --------------------------------------------------------------------------- +# Backward-activation (dact) registry +# +# Each entry is the derivative act'(x) as a Float32 → Float32 callable, matching +# the corresponding `d` in transformer_engine/common/util/math.h. The dact +# kernel computes `grad · act'(x)` per element before the MXFP8 amax + cast. +# Primitives mirror the forward registry (cute.math.tanh fastmath=False for +# gelu, cute.arch.exp for the sigmoid) so output is bit-exact with the C++ path. +# --------------------------------------------------------------------------- +@dsl_user_op +def _dact_drelu(x: Float32, *, loc=None, ip=None) -> Float32: + """drelu: x > 0 ? 1 : 0. Matches math.h `drelu` (NaN → 0 via ordered compare).""" + cond = mlir_arith.cmpf(mlir_arith.CmpFPredicate.OGT, + x.ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return Float32(mlir_arith.select(cond, + Float32(1.0).ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +def _dact_dsrelu(x: Float32) -> Float32: + """dsrelu: fmax(2x, 0). Matches math.h `dsrelu`.""" + return cute.arch.fmax(Float32(2.0) * x, Float32(0.0)) + + +def _sigmoid(x: Float32) -> Float32: + """σ(x) = 1 / (1 + e^-x), same exp intrinsic as the forward silu/qgelu.""" + return Float32(1.0) / (Float32(1.0) + cute.arch.exp(-x)) + + +def _dact_dsilu(x: Float32) -> Float32: + """dsilu: x·σ(x)·(1-σ(x)) + σ(x). Matches math.h `dsilu` + (`cval·dsigmoid + sigmoid`, dsigmoid = s·(1-s)).""" + s = _sigmoid(x) + return x * (s * (Float32(1.0) - s)) + s + + +def _dact_dqgelu(x: Float32) -> Float32: + """dqgelu (alpha=1.702): a·x·dσ(a·x) + σ(a·x). Matches math.h + `dqgelu_with_alpha(val, 1.702)`.""" + a = Float32(1.702) + ax = a * x + s = _sigmoid(ax) + return a * x * (s * (Float32(1.0) - s)) + s + + +def _dact_dgelu(x: Float32) -> Float32: + """dgelu (tanh approximation). Matches math.h `dgelu` term-for-term; + same tanh argument as the forward `_act_gelu`.""" + t = cute.math.tanh( + Float32(0.79788456) * x * (Float32(1.0) + Float32(0.044715) * x * x), + fastmath=False, + ) + return (Float32(0.5) * x + * ((Float32(1.0) - t * t) * (Float32(0.79788456) + Float32(0.1070322243) * x * x)) + + Float32(0.5) * (Float32(1.0) + t)) + + +SUPPORTED_DACTIVATIONS = { + "drelu": _dact_drelu, + "dgelu": _dact_dgelu, + "dsilu": _dact_dsilu, + "dqgelu": _dact_dqgelu, + "dsrelu": _dact_dsrelu, +} + + +@dsl_user_op +def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. + + Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). + This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. + """ + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def mul_cvt_f32x4_to_fp4x4(in01: Int64, in23: Int64, scale_2x: Int64, + *, loc=None, ip=None) -> Int32: + """f32x4 sibling of `kit.mul_cvt_to_fp4x4` — for the NVFP4 colwise path + where elements live on a strided column and we've already widened to f32 + for the amax reduction. `in01` = pack(f32_0, f32_1), `in23` similarly.""" + asm = ( + "{\n" + ".reg.b64 v01; .reg.b64 v23;\n\t" + ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" + ".reg.b8 f0; .reg.b8 f1;\n\t" + "mov.b64 {v0, v1}, $1;\n\t" + "mov.b64 {v2, v3}, $2;\n\t" + "mov.b64 v01, {v0, v1};\n\t" + "mov.b64 v23, {v2, v3};\n\t" + "mul.f32x2 v01, v01, $3;\n\t" + "mul.f32x2 v23, v23, $3;\n\t" + "mov.b64 {v1, v0}, v01;\n\t" + "mov.b64 {v3, v2}, v23;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 $0, {f0, f1, f0, f1};\n\t" + "}" + ) + return Int32(llvm.inline_asm( + T.i32(), + [in01.ir_value(loc=loc, ip=ip), + in23.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,l,l,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +def _cvt_f32_to_fp8(fp8_dtype: str): + """Const-expr dispatch: pick the f32→fp8 scalar PTX op based on output dtype. + + `fp8_dtype` is the Python string from `cfg.FP8_DTYPE`, evaluated at JIT + trace time; the unused branch is never traced. + """ + if fp8_dtype == "e5m2": + return cvt_f32_to_fp8e5m2 + return cvt_f32_to_fp8e4m3 + + +def _cvt_f32x2_to_fp8x2(fp8_dtype: str): + """Const-expr dispatch for the packed f32x2→fp8x2 cvt.""" + if fp8_dtype == "e5m2": + return cvt_f32x2_to_fp8e5m2x2 + return cvt_f32x2_to_fp8e4m3x2 + +@cute.jit +def quantize_rowwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). Same purpose. + M, N, # Int32 — full tensor extents; OOB threads skip their + # scale store. + ACTIVATION, + DTYPE, + ROWWISE, + COLWISE, + FP8_DTYPE, + TILE_Y, + SCALE_DIM, + WAVES, + THREADS_PER_WARP, + THREADS_PER_BANK, + PACK_SIZE, + WITH_ACT=False, # forward: apply activation to the element + WITH_DACT=False, # backward: out = grad · act'(act_input) + sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + DBIAS_REDUCTION=False, # rowwise-only dbias: accumulate per-column partials + dbias_acc=None, # rmem Float32[SCALE_DIM]; += this row's pre-truncate elt per column +): + tidx, _, _ = cute.arch.thread_idx() + + # Match the C++ reference's thread layout: pairs of adjacent lanes + # share a row (lanes 2k / 2k+1 both own row k), each pair covering + # the two 32-element scale blocks of that row. Express as a cute + # layout mapping `(tid_Y, tid_X) -> tidx` with stride (2, 1): + # linear(tidx) = tid_Y*2 + tid_X, so `get_flat_coord` inverts to + # `(tidx // 2, tidx % 2)` — semantically clearer than the raw + # divmod, and readily reusable if we later partition via TiledCopy. + # print(f"sX_tile: {sX_tile}") + # print(f"sO_row_tile: {sO_row_tile}") + # print(f"mS_row_stage: {mS_row_stage}") + + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, SCALE_DIM), stride=(0, 1)) + ) + # print(f"tv_layout: {tv_layout}") + # print(f"tiler: {tiler}") + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_row_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] # shape (32,) bf16 + sO_thread = sO_tv[tidx, None] # shape (32,) uint8 + + # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C1%29-%281%2C+32%29%3A%280%2C1%29 + # print(f"sX_thread: {sX_thread}") + # print(f"sO_thread: {sO_thread}") + + sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) + # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. + sO_thread_u32 = cute.make_tensor( + sO_thread_u32_ptr, + cute.make_layout((SCALE_DIM // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + ) + # print(f"sO_thread_u32: {sO_thread_u32}") + + FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") + # For this fast paht we can read in pack of 2 instead of reading individual f16 / bf16 element. + # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. + _row_fast = (_is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) + and not DBIAS_REDUCTION) + + if cutlass.const_expr(_row_fast): + # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack + # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute + kit = _packed16_kit(DTYPE) + sX_thread_rw_i32 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int32), + cute.make_layout((1, SCALE_DIM // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + ) + # print(f"sX_thread_rw_i32: {sX_thread_rw_i32}") + # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) + # In total we have 8 waves where each wave reads + in_r = [[None, None] for _ in range(WAVES)] + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 2 # Each bank group will read 2 i32 from their bank + for w in cutlass.range_constexpr(WAVES): + idx = (w * 2 + offset) % (SCALE_DIM // 2) + in_r[w][0] = sX_thread_rw_i32[0, idx] + in_r[w][1] = sX_thread_rw_i32[0, idx + 1] + + # 1. Packed-x2 amax — 2 PTX per wave, 16 total per thread. + # Accumulates `|elt|` in both lanes (with xorsign-drifted signs); + # final horizontal max reduces the two lanes to a single f32. + amax_2x = Int32(0) + # Each wave will use max.xorsign.abs.f16x2 or max.xorsign.abs.bf16x2 to compare 2 packed elements in parallel + for w in cutlass.range_constexpr(WAVES): + if cutlass.const_expr(FUSE_RELU): + # If we fuse relu then we don't want to do abs since negative value will be set to 0 and they will lose comparison automatically + amax_2x = kit.max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.max_x2(amax_2x, in_r[w][1]) + else: + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][1]) + if cutlass.const_expr(FUSE_RELU): + # Compare the 2 packed max without abs + amax_r = cute.arch.fmax( + kit.x2_lo_to_f32(amax_2x), + kit.x2_hi_to_f32(amax_2x), + ) + # For relu the max is at least 0 + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) + else: + # Compare the 2 packed abs max + amax_r = cute.arch.fmax( + fabs_f32(kit.x2_lo_to_f32(amax_2x)), + fabs_f32(kit.x2_hi_to_f32(amax_2x)), + ) + else: + # Since we need to do computation on individual f16 / bf16 elements, we can't read in pack + sX_thread_rw = cute.make_tensor( + sX_thread.iterator, + cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + ) + in_r = [[None] * PACK_SIZE for _ in range(WAVES)] + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank + + if cutlass.const_expr(WITH_DACT): + # Backward: out = grad · act'(act_input). sX is grad, sA is act_input. + dop = SUPPORTED_DACTIVATIONS[ACTIVATION] + sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] + sA_thread_rw = cute.make_tensor( + sA_thread.iterator, + cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + ) + elif cutlass.const_expr(WITH_ACT): + op = SUPPORTED_ACTIVATIONS[ACTIVATION] + + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = _packed16_kit(DTYPE) + amax_r = Float32(0.0) + for w in cutlass.range_constexpr(WAVES): + idx = (w * PACK_SIZE + offset) % SCALE_DIM + for e in cutlass.range_constexpr(PACK_SIZE): + x = Float32(sX_thread_rw[0, idx + e]) # grad + if cutlass.const_expr(WITH_DACT): + # out = grad · act'(act_input) + x = x * dop(Float32(sA_thread_rw[0, idx + e])) + # If IS_ACT, apply activation function to x in f32 + elif cutlass.const_expr(WITH_ACT): + # If it's relu, we can handle it later + if not cutlass.const_expr(FUSE_RELU): + x = op(x) + # dbias: accumulate this row's column (idx+e) value BEFORE the bf16 + # truncation (matches CUDA's `thread_dbias_rowwise[j] += elt`). idx+e + # is a multiple-of-PACK_SIZE group + e, so it stays within [0, SCALE_DIM). + if cutlass.const_expr(DBIAS_REDUCTION): + dbias_acc[idx + e] = dbias_acc[idx + e] + x + # If 16-bit input with activation, truncate to IType + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + x = kit_act.truncate_f32(x) # TODO: Why not just qunatize from f32? + in_r[w][e] = x + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically + else: + amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 + + # 2. E8M0 scale → gmem. mS_row's layout already encodes the swizzle + # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. + biased_exp_r = float_to_e8m0(amax_r * max_norm_rcp) + # mS_row_stage has logical shape (32, 2) and we have 64 threads where each is mapped to one scale factor + # The TV layout is equivalent to https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C+1%29-%281%29 + # but it's too trival so let's just index it directly without using layout + # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout that mappes the logical index to the physical offset + # Irregular shapes: skip the scale store if this thread's logical row / + # col-block lies past the input's actual extents. TMA already zero-fills + # OOB input reads and drops OOB output writes; only the direct scale-byte + # gmem store needs an explicit guard. + scale_row = tile_row_start + tidx // 2 + scale_col_first_elt = tile_col_start + (tidx % 2) * SCALE_DIM + if scale_row < M and scale_col_first_elt < N: + mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) + + # 3. scale + packed fp8 cast → smem as one u32 per wave. + inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale + # Fetch the conversion function based on the FP8 format + cvt_f32x2 = _cvt_f32x2_to_fp8x2(FP8_DTYPE) + if cutlass.const_expr(_row_fast): + kit_cast = _packed16_kit(DTYPE) + mul_cvt_x2 = kit_cast.mul_cvt_to_fp8x2(FP8_DTYPE, FUSE_RELU) + # Pack `(inv_scale_r, inv_scale_r)` as a single 64-bit f32x2 once; + # the per-wave mul_cvt consumes this directly. + scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) + + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group + offset = bank_group * 4 # Each bank group will write 4 fp8 to + for w in cutlass.range_constexpr(WAVES): + idx = (w * 4 + offset) % SCALE_DIM + idx = idx // 4 + if cutlass.const_expr(_row_fast): + # One fused PTX per x2 pair: x2 × f32x2 → fp8x2. + # Byte layout: byte[0]=fp8(lo * s), byte[1]=fp8(hi * s). + p01 = mul_cvt_x2(in_r[w][0], scale_2x) + p23 = mul_cvt_x2(in_r[w][1], scale_2x) + else: + # cvt PTX semantics: `cvt.rn.satfinite..f32 d, a, b` gives + # d[15:8]=fp8(a), d[7:0]=fp8(b). Pass (v1, v0) so the u16 low + # byte ends up as fp8(v0) and the high byte as fp8(v1). + v0 = in_r[w][0] * inv_scale_r + v1 = in_r[w][1] * inv_scale_r + v2 = in_r[w][2] * inv_scale_r + v3 = in_r[w][3] * inv_scale_r + p01 = cvt_f32x2(v1, v0, FUSE_RELU) # u16 little-endian: v0,v1 + p23 = cvt_f32x2(v3, v2, FUSE_RELU) # u16 little-endian: v2,v3 + quad = (p23 << Int32(16)) | p01 + sO_thread_u32[idx] = Uint32(quad) + + return amax_r + +@cute.jit +def quantize_colwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). + M, N, # Int32 — full tensor extents. + ACTIVATION, + DTYPE, + FP8_DTYPE, + SWIZZLE, + TILE_X, + TILE_Y, + SCALE_DIM, + WITH_ACT=False, # forward: apply activation to the element + WITH_DACT=False, # backward: out = grad · act'(act_input) + sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) +): + tidx, _, _ = cute.arch.thread_idx() + + # print(f"sX_tile: {sX_tile}") + # print(f"sO_col_tile: {sO_col_tile}") + # print(f"mS_col_stage: {mS_col_stage}") + + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), + val_layout=cute.make_layout((SCALE_DIM, 1), stride=(1, 1)) + ) + # print(f"tv_layout: {tv_layout}") + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_col_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] + sO_thread = sO_tv[tidx, None] + + # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%281%2C+64%29%3A%2864%2C+1%29-%2832%2C+1%29%3A%281%2C+1%29 + # print(f"sX_thread: {sX_thread}") # shape (32,) bf16 + # print(f"sO_thread: {sO_thread}") # shape (32,) uint8 + + # dbias needs the per-element fp32 values to sum, so it takes the f32 path + # (never the i16 fast path) — matching CUDA, whose f16 fast path requires + # `!IS_DBIAS` (quantize_mxfp8.cuh:219). + HALF_PRECISION_PATH = _is_packed16(DTYPE) and ACTIVATION is None and not WITH_DBIAS + dbias_partial = Float32(0.0) + + # 0. Load the 32-element column from smem into registers once (matches + # C++'s `in_colwise_IType[i]` cache). Amax and cast both reuse these. + if cutlass.const_expr(HALF_PRECISION_PATH): + kit = _packed16_kit(DTYPE) + # Per-thread Int16 view of the column. Same byte address as + # `sX_thread` (bf16/fp16 are 16-bit, same width as Int16); the + # element stride is TILE_X because the column elements are + # TILE_X apart in the row-major tile. + sX_thread_i16 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int16), + cute.make_layout((SCALE_DIM,), stride=(TILE_X,)), + ) + amax_bits = Int16(0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) + amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) + else: + # Materialize the column into f32 registers — widen on read so + # bf16/fp16 inputs become real fp32 values (a pointer recast to + # Float32 would not widen; it would reinterpret the 16-bit bytes + # as half of a 32-bit float). + sX_thread_f32 = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + dtype=Float32, + ) + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = Float32(sX_thread[i]) + # Apply activation (fwd) or grad·act'(act_input) (bwd dact) in f32. + if cutlass.const_expr(WITH_DACT): + dop = SUPPORTED_DACTIVATIONS[ACTIVATION] + sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = sX_thread_f32[i] * dop(Float32(sA_thread[i])) + elif cutlass.const_expr(WITH_ACT): + op = SUPPORTED_ACTIVATIONS[ACTIVATION] + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = op(sX_thread_f32[i]) + # dbias = column sum of the (post-act/dact) value, taken BEFORE the bf16 + # truncation — matches CUDA's `partial_dbias_colwise += elt`. + if cutlass.const_expr(WITH_DBIAS): + for i in cutlass.range_constexpr(SCALE_DIM): + dbias_partial += sX_thread_f32[i] + # Numerical truncation through IType so amax/cast match C++. + # Only needed when 16-bit input + activation; without activation + # the widening was already exact. + if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = _packed16_kit(DTYPE) + for i in cutlass.range_constexpr(SCALE_DIM): + sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) + amax_c = Float32(0.0) + for i in cutlass.range_constexpr(SCALE_DIM): + amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) + + # 2. E8M0 scale → gmem. mS_col's layout already encodes the swizzle + # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. + # Irregular shapes: skip when this stage's row range or this thread's + # column lies past the input extents. TILE_Y == SCALE_DIM so each stage + # is exactly one scale-row; valid iff `tile_row_start < M`. + biased_exp_c = float_to_e8m0(amax_c * max_norm_rcp) + scale_col = tile_col_start + tidx + if tile_row_start < M and scale_col < N: + if cutlass.const_expr(SWIZZLE): + mS_col_stage[(0, tidx % 32, tidx // 32)] = Uint8(biased_exp_c) + else: + mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) + + # 3. scale + FP8 cast → smem (one byte per (row, tidx)). Caller + # flushes the whole (TILE_Y, TILE_X) tile with a TMA S2G. + inv_scale_c = exp2f_rcp(biased_exp_c) + cvt_to_fp8 = _cvt_f32_to_fp8(FP8_DTYPE) + if cutlass.const_expr(HALF_PRECISION_PATH): + kit_cast = _packed16_kit(DTYPE) + for i in cutlass.range_constexpr(SCALE_DIM): + v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) + sO_thread[i] = Uint8(cvt_to_fp8(v_f32 * inv_scale_c)) + else: + for i in cutlass.range_constexpr(SCALE_DIM): + sO_thread[i] = Uint8(cvt_to_fp8(sX_thread_f32[i] * inv_scale_c)) + + return amax_c, dbias_partial diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py new file mode 100644 index 0000000000..2a57bfb4f4 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -0,0 +1,1021 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""MXFP8 quantization kernel implemented in CuTeDSL. + +Replicates the core logic of quantize_mxfp8.cuh: given a 2D tensor of BF16/FP16 +values, quantize to MXFP8 format (FP8E4M3 data + E8M0 per-block scales). + +Matches the C++ kernel's tile dimensions and thread layout: + CHUNK_DIM_Y = 64, CHUNK_DIM_X = 64, THREADS_PER_CHUNK = 64 + BUFF_DIM_Y = 32, BUFF_DIM_X = 64, STAGES = 2 + SCALE_DIM = 32 (elements per MXFP8 scaling block) + +Grid: (ceil(N / 64), ceil(M / 64)) +Each block processes a 64x64 chunk in 2 stages of 32x64 tiles loaded into +shared memory. +""" +import logging + +import transformer_engine +from transformer_engine.common.CuTeDSL.utils import str_to_cutlass_dtype +import transformer_engine_torch as tex + +from typing import Optional, Type + +import torch +import transformer_engine_torch as tex + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cuda.bindings.driver import CUstream + +import hashlib +import tvm_ffi + +from .mxfp8_utils import ( + SUPPORTED_ACTIVATIONS, + SUPPORTED_DACTIVATIONS, + FP8E4M3_MAX_NORM_RCP, + FP8E5M2_MAX_NORM_RCP, + _bitcast_f32_to_i32, + _cvt_f32_to_fp8, + _cvt_f32x2_to_fp8x2, + _is_packed16, + _packed16_kit, + exp2f_rcp, + fabs_f32, + float_to_e8m0, + quantize_colwise_mxfp8, + quantize_rowwise_mxfp8, +) + +# Per-backend logger, so a fallback warning is attributable to *this* CuTeDSL +# backend (the MXFP8 quantize backend). Other CuTeDSL backends should use their +# own `transformer_engine.cutedsl.` logger. +logger = logging.getLogger("transformer_engine.cutedsl.mxfp8") + +# MXFP8 settings +MXFP8_BLOCK_SIZE = 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor +SCALE_DIM = MXFP8_BLOCK_SIZE + +# Double-buffering for async copy + compute overlap +BUFFER_NUM = 2 + +# Vectorised access constants for bank-conflict avoidance (rowwise pass) +PACK_SIZE = 4 # Elements per vector load +WAVES = SCALE_DIM // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +THREADS_PER_WARP = 32 +TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) +THREADS_PER_BANK = TOTAL_BANKS_WIDTH // SCALE_DIM # 4 threads per bank + +# Tiling sizes +NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) +NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension +TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total +TILE_X = 64 # Each tile has 64 columns + +# CTA size +THREADS_PER_CHUNK = 64 +NUM_WARPS = THREADS_PER_CHUNK // 32 + +# --------------------------------------------------------------------------- +# Kernel configuration +# --------------------------------------------------------------------------- +class MXFP8QuantizeConfig: + + def __init__( + self, + dtype: str, + fp8_dtype: str, + rowwise: bool, + colwise: bool, + with_gemm_swizzled_scales: bool, + with_amax: bool, + with_dbias: bool = False, + with_dact: bool = False, + with_act: bool = False, + with_noop: bool = False, + activation: Optional[str] = None + ): + if dtype is None or dtype not in ("fp32", "fp16", "bf16"): + raise ValueError(f"unknown input dtype {dtype!r}; expected fp32|fp16|bf16") + self.DTYPE = str_to_cutlass_dtype(dtype) + self.DTYPE_STR = dtype # readable input-dtype token, for __str__ + if fp8_dtype not in ("e4m3", "e5m2"): + raise ValueError(f"unknown FP8 dtype {fp8_dtype!r}; expected 'e4m3' or 'e5m2'") + self.FP8_DTYPE = fp8_dtype + self.ROWWISE = rowwise + self.COLWISE = colwise + if not (rowwise or colwise): + raise ValueError("at least one of rowwise or colwise must be true") + self.WITH_GEMM_SWIZZLED_SCALES = with_gemm_swizzled_scales + self.WITH_AMAX = with_amax + if not with_dact and not with_act: + if activation == "none": + self.ACTIVATION = None + else: + raise ValueError("activation must be none when with_dact and with_act are both False") + else: + if with_dact and with_act: + raise ValueError("with_dact and with_act cannot be true at the same time since they are used for different paths (bwd vs fwd)") + elif with_dact: + if activation in SUPPORTED_DACTIVATIONS: + self.ACTIVATION = activation + else: + raise ValueError(f"unknown activation {activation!r} for with_dact=True; expected one of {sorted(SUPPORTED_DACTIVATIONS)}") + elif with_act: + if activation in SUPPORTED_ACTIVATIONS: + self.ACTIVATION = activation + else: + raise ValueError(f"unknown activation {activation!r} for with_act=True; expected one of {sorted(SUPPORTED_ACTIVATIONS)}") + self.WITH_DACT = with_dact + self.WITH_ACT = with_act + # dbias is the column reduction of the (post-act/dact) element. With colwise + # output each thread owns a full column (trivial reduction); rowwise-only + # uses a cross-thread smem reduction over THREADS_Y. Both mirror the CUDA + # kernel's COLWISE_SCALING / rowwise dbias branches. + self.WITH_DBIAS = with_dbias + self.WITH_NOOP = with_noop + self.MAX_NORM_RCP = FP8E4M3_MAX_NORM_RCP if fp8_dtype == "e4m3" else FP8E5M2_MAX_NORM_RCP + + def __str__(self): + return (f"MXFP8QuantizeConfig(dtype={self.DTYPE_STR}, fp8_dtype={self.FP8_DTYPE}, " + f"rowwise={self.ROWWISE}, colwise={self.COLWISE}, " + f"swizzled={self.WITH_GEMM_SWIZZLED_SCALES}, with_amax={self.WITH_AMAX}, " + f"with_dbias={self.WITH_DBIAS}, with_dact={self.WITH_DACT}, " + f"with_act={self.WITH_ACT}, with_noop={self.WITH_NOOP}, " + f"activation={self.ACTIVATION})") + + __repr__ = __str__ + +# --------------------------------------------------------------------------- +# Unified MXFP8 quantization kernel — shared memory tiled, single-pass +# --------------------------------------------------------------------------- +class MXFP8QuantizeSmemKernel: + """MXFP8 quantization with shared-memory tiling (rowwise, colwise, or both). + + Matches C++ kernel's BIDIMENSIONAL scaling mode: + Grid (ceil(N/64), ceil(M/64)) + Block (64) + Each block processes a 64x64 chunk in 2 stages of 32x64. + + Per stage, the tile is loaded into shared memory once. The colwise + pass reads columns from smem first, then the rowwise pass reads rows. + When both directions are enabled, global memory is read only once per + element — matching the C++ single-pass behaviour. + + Thread mappings (per stage): + Colwise: thread tidx handles column tidx, 32 rows (stride BUFF_DIM_X). + Rowwise: tid_Y = tidx // 2 -> row, tid_X = tidx % 2 -> scale-block. + """ + + def __init__(self, cfg): + self.cfg = cfg + + @cute.jit + def __call__( + self, + mX: cute.Tensor, # Input tensor to quantize + mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], # Rowwise output and scale tensors + mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Colwise output and scale tensors + mAmax: Optional[cute.Tensor], # Global amax accumulator, only used in WITH_AMAX path + mNoop: Optional[cute.Tensor], # 1-element cast_noop flag, only used in WITH_NOOP path + # Backward-only slots, present to mirror the CUDA mxfp8::quantize signature + # (act_input / dbias / workspace). NOT used yet — None on the forward path; + # WITH_DACT/WITH_DBIAS configs are rejected upstream so these never carry data. + mActInput: Optional[cute.Tensor], + mDbias: Optional[cute.Tensor], + mWorkspace: Optional[cute.Tensor], + stream: CUstream, + ): + M = mX.shape[0] + N = mX.shape[1] + cfg = self.cfg + max_norm_rcp = cfg.MAX_NORM_RCP + num_scale_cols = N // SCALE_DIM + num_scale_rows = M // SCALE_DIM + + # Rewrap mS_row / mS_col with the GEMM-swizzled layout when requested. + # Wrapper passes in a tensor with the compact (M, N/32):(N/32, 1) layout + # (built from a compact fake-ptr at compile time), and we re-view the + # underlying buffer here so the per-block scale stores below land at the + # cuBLAS-swizzled byte offsets. + # See https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout + # and swizzle_demo.svg for a visual of the byte permutation. + if cutlass.const_expr(cfg.WITH_GEMM_SWIZZLED_SCALES): + num_tiles_M = (M + 127) // 128 + num_tiles_SC = (num_scale_cols + 3) // 4 # = ceil(N / 128) + num_tiles_SR = (num_scale_rows + 3) // 4 # = ceil(M / 128) + num_tiles_N = (N + 127) // 128 + # row i = i_lo + 32 * (i_hi + 4 * tile_Y); col j = j_lo + 4 * tile_X. + # Within one 128×4 tile: byte = i_lo*16 + i_hi*4 + j_lo. + + # Tile-major outer dims add (tile_Y * num_tiles_SC + tile_X) * 512. + # For example, if M=256, N=512, then num_scale_cols = 16, num_scale_rows = 8, and num_tiles_M=2, num_tiles_SC=4, num_tiles_SR=2, num_tiles_N=4 + # The swizzled layout is ((32, 4, 2), (4, 4)):((16, 4, 2048), (1, 512)) + if cutlass.const_expr(cfg.ROWWISE): + mS_row = cute.make_tensor( + mS_row.iterator, + cute.make_layout( + ((32, 4, num_tiles_M), (4, num_tiles_SC)), + stride=((16, 4, num_tiles_SC * 512), (1, 512)), + ), + ) + # Colwise: same swizzle, axes swap roles — col axis gets the 32×4 + # inner decomp, scale-row axis gets the 4-extent dim. + if cutlass.const_expr(cfg.COLWISE): + mS_col = cute.make_tensor( + mS_col.iterator, + cute.make_layout( + ((4, num_tiles_SR), (32, 4, num_tiles_N)), + stride=((1, 512), (16, 4, num_tiles_SR * 512)), + ), + ) + + # Divide by the STAGE tile (TILE_Y, TILE_X // SCALE_DIM), not the CTA + # tile. Each CTA owns NUM_TILES consecutive row-tiles; the kernel walks + # them by indexing GRID's row dim with `bidy * NUM_TILES + stage` (cute + # auto-decomposes a flat coord onto GRID's hierarchical row modes). + # + # Critically, this is the only divide that cleanly cuts both layouts: + # - compact `(M, N/32):(N/32, 1)` → SCALE_TILE = (32, 2):(N/32, 1) + # - swizzled `((32,4,n_M),(4,n_SC)):((16,4,n_SC·512),(1,512))` + # → SCALE_TILE = (32, 2):(16, 1) + # The bigger (TILE_Y * NUM_TILES, ...) divide we used before tangles the + # swizzle's (32, 4) row hierarchy under flatten + sub-divide chain. + + # Declare TMA descriptors on the host side. + # make_tiled_tma_atom returns the UNTILED gmem tensor with basis strides. + # Tile it inside the kernel with zipped_divide so each coord selects + # one (TILE_Y, TILE_X) tile. + smem_tile_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + cta_tiler = (TILE_Y, TILE_X) + + # Input: TMA G2S (bf16/fp16 → smem). + op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + tma_atom, tma_src = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_load, mX, smem_tile_layout, cta_tiler, num_multicast=1, + ) + + # Backward (dact): the activation input is a second G2S load, identical to + # mX's. The kernel computes `grad · act'(act_input)`; here mX carries grad. + tma_atom_act = None + tma_src_act = None + if cutlass.const_expr(cfg.WITH_DACT): + tma_atom_act, tma_src_act = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_load, mActInput, smem_tile_layout, cta_tiler, num_multicast=1, + ) + + # Output: TMA S2G (uint8 smem → gmem) for both directions. Creating + # both atoms unconditionally — if a direction is disabled the kernel + # simply won't dispatch its copy, and the atom cost is negligible. + op_store = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + out_smem_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + tma_atom_out_row = None + tma_dst_out_row = None + tma_atom_out_col = None + tma_dst_out_col = None + if cutlass.const_expr(cfg.ROWWISE): + tma_atom_out_row, tma_dst_out_row = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_store, mO_row, out_smem_layout, cta_tiler, num_multicast=1, + ) + if cutlass.const_expr(cfg.COLWISE): + tma_atom_out_col, tma_dst_out_col = cute.nvgpu.cpasync.make_tiled_tma_atom( + op_store, mO_col, out_smem_layout, cta_tiler, num_multicast=1, + ) + + # Decide when to perform dbias reduction + DBIAS_REDUCTION_COLWISE: cutlass.Constexpr = False + DBIAS_REDUCTION_ROWWISE: cutlass.Constexpr = False + if cutlass.const_expr(cfg.WITH_DBIAS): + # We prefer to perform dbias reduction in the colwise pass since it doesn't require shuffle + if cutlass.const_expr(cfg.COLWISE): + DBIAS_REDUCTION_COLWISE = True + else: + DBIAS_REDUCTION_ROWWISE = True + + # CUDA launches in (0,0), (1,0), (2,0)... order, so we should make N the leading dimension for better access pattern + # So consecutive blocks will move along the N dimension first, which is the innermost dimension in memory and we can use cache better + grid = [ + cute.ceil_div(Int32(N), TILE_X), + cute.ceil_div(M, TILE_Y * NUM_TILES), + ] + block = [THREADS_PER_CHUNK,] + + self.kernel( + mX, mS_row, mS_col, mAmax, mNoop, mWorkspace, + max_norm_rcp, mX.element_type, + tma_atom, tma_src, + tma_atom_out_row, tma_dst_out_row, + tma_atom_out_col, tma_dst_out_col, + tma_atom_act, tma_src_act, + ).launch( + grid=grid, + block=block, + stream=stream, + ) + + # Device entry (launched by __call__). Reads the cast_noop flag and runs the + # work only if it is not set — matching the CUDA kernel's + # `if (noop[0]==1.0f) return;`. When WITH_NOOP is off, mNoop is None and the + # whole check is compiled out (so no flag is read). + @cute.kernel + def kernel( + self, + mX, + mS_row, + mS_col, + mAmax, + mNoop, + mWorkspace, + max_norm_rcp, + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + tma_atom, tma_src, # how to use TMA to copy the input + tma_atom_out_row, tma_dst_out_row, # how to use TMA to copy the rowwise output + tma_atom_out_col, tma_dst_out_col, # how to use TMA to copy the colwise output + tma_atom_act, tma_src_act, # dact only: how to copy the activation input + ): + cfg = self.cfg + # `not const_expr(WITH_NOOP)` is a compile-time True when noop is disabled, + # so Python short-circuits the `or` and never reads mNoop[0] (it is None). + if not cutlass.const_expr(cfg.WITH_NOOP) or mNoop[0] != Float32(1.0): + self._kernel_main( + mX, mS_row, mS_col, mAmax, mWorkspace, + max_norm_rcp, dtype, + tma_atom, tma_src, + tma_atom_out_row, tma_dst_out_row, + tma_atom_out_col, tma_dst_out_col, + tma_atom_act, tma_src_act, + ) + + # The actual quantize work. MUST be @cute.jit (not @cute.kernel): it is invoked + # from the @cute.kernel `kernel` wrapper under a runtime noop branch, and only a + # separately-traced @cute.jit callable may allocate shared memory inside such a + # branch (an inlined/undecorated method or a nested @cute.kernel would fail). + @cute.jit + def _kernel_main( + self, + mX, + mS_row, + mS_col, + mAmax, + mWorkspace, + max_norm_rcp, + dtype: cutlass.Constexpr[Type[cutlass.Numeric]], + tma_atom, tma_src, # how to use TMA to copy the input + tma_atom_out_row, tma_dst_out_row, # how to use TMA to copy the rowwise output + tma_atom_out_col, tma_dst_out_col, # how to use TMA to copy the colwise output + tma_atom_act, tma_src_act, # dact only: how to copy the activation input + ): + cfg = self.cfg + + if cutlass.const_expr(cfg.ROWWISE): + mS_row = cute.zipped_divide(mS_row, (TILE_Y, TILE_X // SCALE_DIM)) + if cutlass.const_expr(cfg.COLWISE): + mS_col = cute.zipped_divide(mS_col, (TILE_Y // SCALE_DIM, TILE_X)) + # For M=256, N=512: + # Non-swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28256%2C+16%29%3A%2816%2C+1%29-32%0A2 + # Swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28%2832%2C+4%2C+2%29%2C+%284%2C+4%29%29%3A%28%2816%2C+4%2C+2048%29%2C+%281%2C+512%29%29-32%0A2 + # print(f"mS_row after zipped_divide: {mS_row}") + + # FP8 output smem, one 32×64 tile per stage per enabled direction. + # Allocating a dead sO_col in rowwise-only (or sO_row in colwise-only) + # bumps per-CTA smem from 12 KB to 16 KB, which drops occupancy and + # regresses the single-direction path by ~8-10% at 16384^2. Match + # C++ and only allocate what the active pass actually uses. + # sAmax holds one f32 per warp for the cross-warp amax reduction — + # negligible (8 bytes for NUM_WARPS=2) and we always allocate so the + # struct doesn't fork on a 4th const-expr (cfg.WITH_AMAX) dimension. + if cutlass.const_expr(cfg.ROWWISE and cfg.COLWISE): + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_col: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE and not cfg.COLWISE): + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE): + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_row: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + else: + @cute.struct + class SharedStorage: + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + sX: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sO_col: cute.struct.Align[ + cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + + # dact: the activation-input tile lives in its own smem buffer, same + # shape/layout as sX. Allocated separately so the 4 SharedStorage variants + # above don't have to fork again on WITH_DACT. + if cutlass.const_expr(cfg.WITH_DACT): + @cute.struct + class DactStorage: + sActInput: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + dact_storage = smem.allocate(DactStorage) + sActInput = dact_storage.sActInput.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + + # Rowwise-only dbias needs a cross-thread (over THREADS_Y) smem reduction, + # since each rowwise thread owns a row, not a column. Buffer is + # [THREADS_Y][THREADS_X*(SCALE_DIM+1)] f32 — the +1 per scale-block padding + # avoids bank conflicts, matching CUDA's DBIAS_BUFF_WIDTH. + DBIAS_REDUCTION_ROWWISE = cutlass.const_expr(cfg.WITH_DBIAS and not cfg.COLWISE) + DBIAS_BUFF_WIDTH = (TILE_X // SCALE_DIM) * (SCALE_DIM + 1) + if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): + @cute.struct + class DbiasStorage: + sDbias: cute.struct.MemRange[Float32, TILE_Y * DBIAS_BUFF_WIDTH] + dbias_storage = smem.allocate(DbiasStorage) + sDbias = dbias_storage.sDbias.get_tensor( + cute.make_layout(TILE_Y * DBIAS_BUFF_WIDTH) + ) + + # Per-stage shmem tile is 2D (TILE_Y, TILE_X); stages laid out back-to-back. + # Mode 0 is hierarchical ((TILE_Y, TILE_X),) so it matches the rank/shape + # of gX_tiled[(None, (ty, tx))] produced by zipped_divide. + # sX[(None, stage)] selects one (TILE_Y, TILE_X) tile. + sX = storage.sX.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + if cutlass.const_expr(cfg.ROWWISE): + sO_row = storage.sO_row.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + if cutlass.const_expr(cfg.COLWISE): + sO_col = storage.sO_col.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # Prefetch TMA descriptor (one-time; warp-0 only). + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom) + if cutlass.const_expr(cfg.WITH_DACT): + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_act) + + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + + # Producer: `arrive_and_expect_tx` is wrapped in `elect_one`, so only + # one lane of warp 0 arrives on the full barrier per stage → arrive_count=1. + producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) + # Consumer: `consumer_release` arrives only on the `is_signalling_thread` + # (lane 0 of each warp), so arrive_count = num_warps per stage. + num_warps = THREADS_PER_CHUNK // 32 + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_warps) + + # Bytes transferred per TMA copy: one (TILE_Y, TILE_X) tile of dtype. + # dact loads two tiles (grad + act_input) under the same per-stage barrier, + # so the barrier must expect both copies' bytes. + tx_count = TILE_Y * TILE_X * dtype.width // 8 + if cutlass.const_expr(cfg.WITH_DACT): + tx_count *= 2 + + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=storage.mbar_storage.data_ptr(), + num_stages=NUM_STAGES, + producer_group=producer_group, + consumer_group=consumer_group, + tx_count=tx_count, + cta_layout_vmnk=None, # single-CTA, no cluster/multicast + ) + + prod_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, NUM_STAGES + ) + cons_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, NUM_STAGES + ) + + M = mX.shape[0] + N = mX.shape[1] + + num_tiles = cutlass.min( + NUM_TILES, + cute.ceil_div(M - bidy * TILE_Y * NUM_TILES, TILE_Y), + ) + + # Tile the TMA gmem view: ((TILE_Y, TILE_X), (M/TILE_Y, N/TILE_X)). + gX_tiled = cute.zipped_divide(tma_src, (TILE_Y, TILE_X)) + + # Partition sX/gX for the TMA atom (single-CTA, no cluster/multicast). + tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( + tma_atom, + 0, # Use the only CTA to do the TMA copy + cute.make_layout(1), # This cluster only has 1 CTAs + sX, + gX_tiled, + ) + + # dact: identical partition for the activation-input load. + if cutlass.const_expr(cfg.WITH_DACT): + gA_tiled = cute.zipped_divide(tma_src_act, (TILE_Y, TILE_X)) + tXsA, tXgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_act, + 0, + cute.make_layout(1), + sActInput, + gA_tiled, + ) + + # Same partitioning for S2G outputs: sO_row → mO_row and sO_col → mO_col. + if cutlass.const_expr(cfg.ROWWISE): + gO_row_tiled = cute.zipped_divide(tma_dst_out_row, (TILE_Y, TILE_X)) + tXsO_row, tXgO_row = cute.nvgpu.cpasync.tma_partition( + tma_atom_out_row, + 0, + cute.make_layout(1), + sO_row, + gO_row_tiled, + ) + if cutlass.const_expr(cfg.COLWISE): + gO_col_tiled = cute.zipped_divide(tma_dst_out_col, (TILE_Y, TILE_X)) + tXsO_col, tXgO_col = cute.nvgpu.cpasync.tma_partition( + tma_atom_out_col, + 0, + cute.make_layout(1), + sO_col, + gO_col_tiled, + ) + + # print(f"sX: {sX}\n") + # print(f"gX_tiled: {gX_tiled}\n") + # print(f"tXsX: {tXsX}\n") + # print(f"tXgX: {tXgX}\n") + + # Ensure barrier init is visible to all threads before the pipeline is used. + cute.arch.sync_threads() + + # ---- Producer: warp 0 issues one TMA copy per tile. ---- + if warp_idx == 0: + for stage in cutlass.range(num_tiles, unroll=1): + mainloop_pipeline.producer_acquire(prod_state) + tile_y = bidy * NUM_TILES + stage + cute.copy( + tma_atom, + tXgX[(None, (tile_y, bidx))], + tXsX[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + if cutlass.const_expr(cfg.WITH_DACT): + cute.copy( + tma_atom_act, + tXgA[(None, (tile_y, bidx))], + tXsA[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + mainloop_pipeline.producer_commit(prod_state) + prod_state.advance() + + # Per-thread amax accumulator across all stages of this CTA. Combined + # with the per-warp redux + cross-warp shmem reduce + atomic at the + # bottom to produce a global max(|x|) in mAmax. Initialised to 0 + # since amax is non-negative. + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = Float32(0.0) + + # Per-thread partial dbias: thread tidx owns column tidx of the colwise + # tile and accumulates its column sum over this CTA's rows (both stages). + # Written to workspace[bidy, col] below; reduced over row-blocks separately. + if cutlass.const_expr(cfg.WITH_DBIAS): + block_dbias = Float32(0.0) + # Rowwise-only dbias: each thread holds per-column partials for its 32-col + # block, summed across stages, then cross-thread reduced (over THREADS_Y) + # into block_dbias after the loop. + if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): + rowwise_dbias_arr = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + dtype=Float32, + ) + for c in cutlass.range_constexpr(SCALE_DIM): + rowwise_dbias_arr[c] = Float32(0.0) + + # ---- Consumer: all threads quantize each completed tile. ---- + for stage in cutlass.range(num_tiles, unroll=1): + mainloop_pipeline.consumer_wait(cons_state) + sX_tile = sX[(None, stage)] # (TILE_Y, TILE_X) bf16 (grad for dact) + sActInput_tile = None + if cutlass.const_expr(cfg.WITH_DACT): + sActInput_tile = sActInput[(None, stage)] # (TILE_Y, TILE_X) act_input + + """ + grid = [ + cute.ceil_div(Int32(N), TILE_X), + cute.ceil_div(M, TILE_Y * NUM_TILES), + ] + So to obtain the tile that belongs to this CTA. + """ + # This is just block's x axis idx + tile_idx_x = bidx + # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. + # So the tile index along Y dimension is `bidy * NUM_TILES + stage` + tile_idx_y = bidy * NUM_TILES + stage + if cutlass.const_expr(cfg.COLWISE): + # The first row that belongs to this CTA. Each CTA handles NUM_TILES of (TILE_Y, TILE_X) tiles stacked vertically, + # and each stage handles one of them. + sO_col_tile = sO_col[(None, stage)] + mS_col_stage = cute.flatten(mS_col[(None, (tile_idx_y, tile_idx_x))]) + + amax_c, dbias_c = self._process_colwise( + sX_tile, sO_col_tile, + mS_col_stage, max_norm_rcp, + tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + sActInput_tile, + ) + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = cute.arch.fmax(block_amax, amax_c) + if cutlass.const_expr(cfg.WITH_DBIAS): + block_dbias += dbias_c + if cutlass.const_expr(cfg.ROWWISE): + sO_row_tile = sO_row[(None, stage)] + # mS_row is ((SCALE_TILE), (GRID)) where SCALE_TILE = (32, 2). + # Each CTA owns NUM_TILES consecutive row-tiles of GRID. cute + # auto-decomposes the flat row coord `bidy * NUM_TILES + stage` + # onto GRID's hierarchical row modes — which is the + # (i_hi, tile_Y) tile-major order for swizzled, and the plain + # row-tile order for compact. Same source, both layouts correct. + mS_row_stage = cute.flatten(mS_row[(None, (tile_idx_y, tile_idx_x))]) + # print(f"s0_row_tile: {sO_row_tile}\n") + # print(f"sO_row: {sO_row}\n") + # print(f"mS_row: {mS_row}\n") + # print(f"mS_row_stage: {mS_row_stage}\n") + # print(f"mS_row_stage: {mS_row_stage}\n") + amax_r = self._process_rowwise( + sX_tile, sO_row_tile, + mS_row_stage, max_norm_rcp, + tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + sActInput_tile, + rowwise_dbias_arr if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE) else None, + ) + + if cutlass.const_expr(cfg.WITH_AMAX): + block_amax = cute.arch.fmax(block_amax, amax_r) + + # Make all smem stores (sO_row and/or sO_col) visible to the TMA + # async proxy, then block-sync so warp 0 sees the fences from all + # warps before issuing the bulk store(s). Matches the C++ + # reference's fence_proxy + __syncthreads pattern. + cute.arch.fence_proxy( + "async.shared", + space="cta", + ) + cute.arch.sync_threads() + + if warp_idx == 0: + tile_y = bidy * NUM_TILES + stage + if cutlass.const_expr(cfg.ROWWISE): + cute.copy( + tma_atom_out_row, + tXsO_row[(None, stage)], + tXgO_row[(None, (tile_y, bidx))], + ) + if cutlass.const_expr(cfg.COLWISE): + cute.copy( + tma_atom_out_col, + tXsO_col[(None, stage)], + tXgO_col[(None, (tile_y, bidx))], + ) + cute.arch.cp_async_bulk_commit_group() + + mainloop_pipeline.consumer_release(cons_state) + cons_state.advance() + + # Wait for in-flight TMA stores so data is visible to the host + # before the kernel returns. + cute.arch.cp_async_bulk_wait_group(0, read=False) + + # ---- rowwise-only dbias: cross-thread reduction over THREADS_Y --------- + # In the rowwise pass each thread owns a row, so its rowwise_dbias_arr holds + # per-column partials for its 32-col block. Transpose through smem so thread + # tidx ends up owning column tidx of the chunk (mirrors CUDA's + # partial_dbias_rowwise smem buffer + reduce over THREADS_Y). + if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): + THREADS_X = TILE_X // SCALE_DIM # scale-blocks per row (=2) + tid_Y = tidx // THREADS_X + tid_X = tidx % THREADS_X + for c in cutlass.range_constexpr(SCALE_DIM): + sDbias[tid_Y * DBIAS_BUFF_WIDTH + tid_X * (SCALE_DIM + 1) + c] = \ + rowwise_dbias_arr[c] + cute.arch.sync_threads() + # thread tidx owns column tidx; +block skips the per-block padding slot. + block = tidx // SCALE_DIM + block_dbias = Float32(0.0) + for i in cutlass.range_constexpr(TILE_Y): + block_dbias += sDbias[i * DBIAS_BUFF_WIDTH + tidx + block] + + # ---- dbias: write this CTA's per-column partial to the workspace ------- + # Thread tidx owns column (bidx*TILE_X + tidx). Each CTA-row-block (bidy) + # contributes one row of the (blocks_Y, N) fp32 workspace; the reduction + # over blocks_Y to the final dbias[N] is a separate step. + if cutlass.const_expr(cfg.WITH_DBIAS): + dbias_col = bidx * TILE_X + tidx + if dbias_col < N: + mWorkspace[(bidy, dbias_col)] = block_dbias + + # ---- amax block reduction + cross-CTA atomic ---------------------- + # 1) intra-warp: redux.sync.fmax.f32 (sm_80+, single instruction). + # 2) cross-warp: NUM_WARPS shmem floats + sync_threads. + # 3) cross-CTA: int-atomic-max on the f32 bit pattern. Since amax is + # always ≥ 0, IEEE-754 bit ordering on positives matches float + # magnitude ordering, so atomic_max on i32 bits gives the right + # result. (atomic_max_float32 also exists but its pointer + # normalisation is broken as of this CuTeDSL build.) + if cutlass.const_expr(cfg.WITH_AMAX): + warp_amax = cute.arch.warp_redux_sync(block_amax, kind="fmax") + sAmax = storage.sAmax.get_tensor(cute.make_layout(NUM_WARPS)) + lane_idx = tidx % 32 + if lane_idx == 0: + sAmax[warp_idx] = warp_amax + cute.arch.sync_threads() + if tidx == 0: + cta_amax = Float32(0.0) + for w in cutlass.range_constexpr(NUM_WARPS): + cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) + amax_i32 = cute.make_tensor( + cute.recast_ptr(mAmax.iterator, dtype=Int32), + cute.make_layout(1), + ) + cute.arch.atomic_max( + amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), + ) + + @cute.jit + def _process_rowwise( + self, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, N, # Int32 — full input extents, for OOB masking + sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) + dbias_acc=None, # rmem Float32[SCALE_DIM] dbias accumulator (rowwise-only dbias) + ): + """Rowwise MXFP8 pass: thread `(tid_Y, tid_X) = (tidx % 32, tidx // 32)` + owns one 32-element scale block (row `tid_Y`, columns `tid_X*32 .. +32`). + + The bank-group swizzle `((w + bank_group) * PACK_SIZE) % SCALE_DIM` + staggers each 4-thread group's starting wave, which otherwise would + collide on smem banks since all lanes in a warp read different rows + at the same column offset. + + Writes quantized bytes into `sO_row_tile` as u32s (one per wave); + caller is responsible for the TMA S2G flush. + """ + cfg = self.cfg + return quantize_rowwise_mxfp8( + sX_tile, + sO_row_tile, + mS_row_stage, + max_norm_rcp, + tile_row_start, + tile_col_start, + M, + N, + ACTIVATION=cfg.ACTIVATION, + DTYPE=cfg.DTYPE, + ROWWISE=cfg.ROWWISE, + COLWISE=cfg.COLWISE, + FP8_DTYPE=cfg.FP8_DTYPE, + TILE_Y=TILE_Y, + SCALE_DIM=SCALE_DIM, + WAVES=WAVES, + THREADS_PER_WARP=THREADS_PER_WARP, + THREADS_PER_BANK=THREADS_PER_BANK, + PACK_SIZE=PACK_SIZE, + WITH_ACT=cfg.WITH_ACT, + WITH_DACT=cfg.WITH_DACT, + sA_tile=sActInput_tile, + DBIAS_REDUCTION=cfg.WITH_DBIAS and not cfg.COLWISE, + dbias_acc=dbias_acc, + ) + + @cute.jit + def _process_colwise( + self, + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, N, # Int32 — full input extents, for OOB masking + sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) + ): + """Colwise MXFP8 pass: thread `tidx` owns column `tidx` of the (32, 64) + smem tile — 32 elements down. Writes quantized bytes into `sO_col_tile` + so the caller can flush with a TMA S2G — matches C++'s + `out_colwise_data_sh` + `cp.async.bulk.tensor.2d.shared_to_global`. + """ + cfg = self.cfg + return quantize_colwise_mxfp8( + sX_tile, + sO_col_tile, + mS_col_stage, + max_norm_rcp, + tile_row_start, + tile_col_start, + M, + N, + ACTIVATION=cfg.ACTIVATION, + DTYPE=cfg.DTYPE, + FP8_DTYPE=cfg.FP8_DTYPE, + SWIZZLE=cfg.WITH_GEMM_SWIZZLED_SCALES, + TILE_X=TILE_X, + TILE_Y=TILE_Y, + SCALE_DIM=SCALE_DIM, + WITH_ACT=cfg.WITH_ACT, + WITH_DACT=cfg.WITH_DACT, + sA_tile=sActInput_tile, + WITH_DBIAS=cfg.WITH_DBIAS, + ) + +def compile_cutedsl_function_from_cfg(cfg): + """ + Return the compiled CuTeDSL function object for the given MXFP8 quantization config. + """ + + kernel_obj = MXFP8QuantizeSmemKernel(cfg) + + # stride_order=(1, 0): row-major, dim 1 stride 1. 1D: (0,). + kw_rm16_2d = dict(stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, assumed_align=16) + kw_rm4_2d = dict(stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, assumed_align=4) + kw_rm4_1d = dict(stride_order=(0,), + memspace=cute.AddressSpace.gmem, assumed_align=4) + def fake(dtype, shape, kw): + return cute.runtime.make_fake_compact_tensor(dtype, shape, **kw) + + + # M, N must be divisible by the MXFP8 scale-block size (SCALE_DIM = 32) — the + # same alignment the CUDA C++ kernel requires. The C++ dispatcher gates on the + # matching value (kCuTeDSLMXFP8ShapeAlignment in cast/dispatch/quantize.cuh) + # and falls back to CUDA for anything not divisible by it, so tvm-ffi never + # sees a shape this kernel can't accept. + sym_M = cute.sym_int32(divisibility=SCALE_DIM) + sym_N = cute.sym_int32(divisibility=SCALE_DIM) + in_shape = out_shape = (sym_M, sym_N) + # TE allocates scale tensors at a padded shape (see + # MXFP8Quantizer::get_scale_shape in transformer_engine/pytorch/csrc): + # rowwise: (roundup(M, 128), roundup(N // 32, 4)) + # columnwise: (roundup(M // 32, 4), roundup(N, 128)) + # These padded extents are NOT M/N (and SymInt has no `//`/`+`), so give the + # scales their own fresh syms carrying the divisibility the padding + # guarantees (rowwise: 128 x 4; colwise: 4 x 128). + scale_r_shape = (cute.sym_int32(divisibility=128), cute.sym_int32(divisibility=4)) + scale_c_shape = (cute.sym_int32(divisibility=4), cute.sym_int32(divisibility=128)) + # Scale dim-1 is only 4-byte-divisible, so a 16-byte alignment promise would + # be a lie for many shapes; the per-block scale stores are byte-wise anyway, + # so 4-byte alignment loses nothing. + scale_kw = kw_rm4_2d + + in_fake = fake(cfg.DTYPE, in_shape, kw_rm16_2d) + out_row_fake = fake(cute.Uint8, out_shape, kw_rm16_2d) if cfg.ROWWISE else None + scale_row_fake = fake(cute.Uint8, scale_r_shape, scale_kw) if cfg.ROWWISE else None + out_col_fake = fake(cute.Uint8, out_shape, kw_rm16_2d) if cfg.COLWISE else None + scale_col_fake = fake(cute.Uint8, scale_c_shape, scale_kw) if cfg.COLWISE else None + amax_fake = fake(Float32, (1,), kw_rm4_1d) if cfg.WITH_AMAX else None + noop_fake = fake(Float32, (1,), kw_rm4_1d) if cfg.WITH_NOOP else None + # Backward-only slots (act_input/dbias/workspace). Always None today — + # WITH_DACT/WITH_DBIAS are rejected in the config — but kept in the compile + # signature so the tvm-ffi protocol matches the CUDA mxfp8::quantize args. + act_input_fake = fake(cfg.DTYPE, in_shape, kw_rm16_2d) if cfg.WITH_DACT else None + # dbias: the kernel never writes the dbias tensor — it writes per-row-block + # partials into the workspace (shape (blocks_Y, N) fp32, blocks_Y = ceil(M/64), + # set by the C++ worker's size query). The final reduction lives elsewhere, so + # mDbias stays None and only the workspace fake is built. + dbias_fake = None + ws_shape = (cute.sym_int32(), sym_N) # (blocks_Y, N); N ties to input N + workspace_fake = fake(Float32, ws_shape, kw_rm4_2d) if cfg.WITH_DBIAS else None + + compiled = cute.compile( + kernel_obj, + in_fake, # mX + out_row_fake, scale_row_fake, # mO_row, mS_row + out_col_fake, scale_col_fake, # mO_col, mS_col + amax_fake, # mAmax + noop_fake, # mNoop (1-element cast_noop flag) + act_input_fake, # mActInput (backward slot, unused) + dbias_fake, # mDbias (backward slot, unused) + workspace_fake, # mWorkspace(backward slot, unused) + cute.runtime.make_fake_stream(), # stream (compiled as an explicit tvm-ffi + # "handle" arg; C++ passes the CUDA stream + # as void*) + options="--enable-tvm-ffi", + ) + return compiled + +def get_mxfp8_quantization_function( + fn_name: str, + dtype: str, + fp8_dtype: str, + rowwise: bool, + colwise: bool, + with_gemm_swizzled_scales: bool, + with_amax: bool, + with_dbias: bool, + with_dact: bool, + with_act: bool, + with_noop: bool, + activation: str, +) -> bool: + """Compile the MXFP8 quantize kernel for this config and register it in the + TVM-FFI global registry under EXACTLY `fn_name` (the key the C++ dispatcher + built; Python treats it as an opaque name). Returns True if a kernel is + registered under `fn_name` (the C++ side then fetches it with + GetGlobal(fn_name)); False if the config is unsupported, so the caller caches + the negative result and falls back to the CUDA C++ kernel. + + The registry owns the compiled kernel's lifetime — important because it wraps + a Python object, and tvm-ffi releases registry entries at interpreter + shutdown (whereas a C++-held handle would be released after finalize → crash). + """ + # Already registered (e.g. by a prior call) -> supported. + if tvm_ffi.get_global_func(fn_name, allow_missing=True) is not None: + return True + + try: + cfg = MXFP8QuantizeConfig( + dtype=dtype, + fp8_dtype=fp8_dtype, + rowwise=rowwise, + colwise=colwise, + with_gemm_swizzled_scales=with_gemm_swizzled_scales, + with_amax=with_amax, + with_dbias=with_dbias, + with_dact=with_dact, + with_act=with_act, + with_noop=with_noop, + activation=activation, + ) + except ValueError as e: + # The exception message states exactly why the config is unsupported + # (unknown dtype/activation, dbias not implemented, ...). Surfacing it as a + # warning lets the C++ dispatcher's CUDA fallback be recognized as expected. + logger.warning(f"CuTeDSL MXFP8 backend does not support this config, " + f"falling back to the CUDA C++ kernel: {e}") + return False + + logger.debug(f"Compiling CuTeDSL MXFP8 quantization kernel for {cfg}") + compiled = compile_cutedsl_function_from_cfg(cfg) + tvm_ffi.register_global_func(fn_name, compiled, override=True) + + return True + +# Exposed so the C++ dispatcher can request on-demand compilation by name. +tvm_ffi.register_global_func("get_mxfp8_quantization_function", get_mxfp8_quantization_function, override=True) diff --git a/transformer_engine/common/CuTeDSL/utils.py b/transformer_engine/common/CuTeDSL/utils.py new file mode 100644 index 0000000000..258feea66c --- /dev/null +++ b/transformer_engine/common/CuTeDSL/utils.py @@ -0,0 +1,16 @@ +import cutlass + +_CUTLASS_DTYPE_FROM_STR = { + "fp32": cutlass.Float32, + "fp16": cutlass.Float16, + "bf16": cutlass.BFloat16, +} +_STR_FROM_CUTLASS_DTYPE = {v: k for k, v in _CUTLASS_DTYPE_FROM_STR.items()} + +def str_to_cutlass_dtype(dtype_str: str): + """Convert a string dtype to a cutlass dtype, or None if unknown.""" + return _CUTLASS_DTYPE_FROM_STR.get(dtype_str, None) + +def cutlass_dtype_to_str(dtype): + """Convert a cutlass dtype back to its protocol string, or None if unknown.""" + return _STR_FROM_CUTLASS_DTYPE.get(dtype, None) \ No newline at end of file diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 6c71285cd4..ca9d72a139 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -13,8 +13,13 @@ #include +#include +#include + #include "../../common.h" #include "../../transpose/cast_transpose.h" +#include "../mxfp8/quantize_mxfp8_cutedsl.cuh" +#include "../../util/cuda_runtime.h" #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" @@ -27,6 +32,7 @@ namespace transformer_engine { namespace dispatch { + template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { @@ -84,9 +90,16 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const Tensor *dummy_input_tensor = nullptr; Tensor *dummy_dbias_tensor = nullptr; Tensor *dummy_workspace_tensor = nullptr; - mxfp8::quantize( + bool quantized_with_cutedsl = + quantize::mxfp8_quantize_cutedsl( + input_tensor, dummy_input_tensor, noop_tensor, output_tensor, + dummy_dbias_tensor, dummy_workspace_tensor, stream); + if (!quantized_with_cutedsl) { + mxfp8::quantize( *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, dummy_workspace_tensor, stream); + } break; } case NVTE_NVFP4_1D_SCALING: { @@ -249,9 +262,15 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens break; } case NVTE_MXFP8_1D_SCALING: { - mxfp8::quantize( - *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); + bool quantized_with_cutedsl = + quantize::mxfp8_quantize_cutedsl( + grad_tensor, input_tensor, noop_tensor, output_tensor, + dbias_tensor, workspace_tensor, stream); + if (!quantized_with_cutedsl) { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } break; } case NVTE_NVFP4_1D_SCALING: { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh new file mode 100644 index 0000000000..6e86bb68dd --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh @@ -0,0 +1,251 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ +#define TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ + +#include +#include +#include + +#include +#include + +#include "../../common.h" +#include "../../tvm_ffi_bridge.h" +#include "../../util/math.h" +#include "../core/common.cuh" // dispatch::common::reduce_dbias + +namespace transformer_engine { +namespace tvm_ffi_bridge { + +struct MXFP8QuantConfig { + static constexpr const char *kEntrypointName = "get_mxfp8_quantization_function"; + + DType dtype; + DType fp8_dtype; + bool rowwise; + bool colwise; + bool swizzled; + bool with_amax; + bool with_dbias = false; + bool with_dact = false; + bool with_act = false; + bool with_noop = false; + Activation activation = Activation::kNone; + + std::string to_key() const { + std::string key; + key.reserve(56); + key.append("cutedsl_mxfp8_") + .append(te_dtype_to_str(dtype)).append("_") + .append(te_dtype_to_str(fp8_dtype)).append("_") + .append(rowwise ? "1" : "0").append("_") + .append(colwise ? "1" : "0").append("_") + .append(swizzled ? "1" : "0").append("_") + .append(with_amax ? "1" : "0").append("_") + .append(with_dbias ? "1" : "0").append("_") + .append(with_dact ? "1" : "0").append("_") + .append(with_act ? "1" : "0").append("_") + .append(with_noop ? "1" : "0").append("_") + .append(activation_to_str(activation)); + return key; + } + + bool retrieve_func_from_python(const std::string &fn_name) const { + auto entrypoint = tvm::ffi::Function::GetGlobal(kEntrypointName); + if (!entrypoint.has_value()) { + return false; + } + tvm::ffi::Any result = (*entrypoint)( + tvm::ffi::String(fn_name), tvm::ffi::String(te_dtype_to_str(dtype)), + tvm::ffi::String(te_dtype_to_str(fp8_dtype)), rowwise, colwise, swizzled, with_amax, + with_dbias, with_dact, with_act, with_noop, + tvm::ffi::String(activation_to_str(activation))); + return result.try_cast().value_or(false); + } +}; + +template +struct MXFP8QuantFused { + static constexpr Activation activation = Activation::kNone; + // No fused op: plain quantize, or dbias-only cast (IS_DBIAS, no activation). + static constexpr bool supported = (OP == nullptr) && !IS_DACT && !IS_ACT; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kReLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kGeLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kSiLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kQGeLU; + static constexpr bool supported = true; +}; +template <> +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kSReLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDReLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDGeLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDSiLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDQGeLU; + static constexpr bool supported = true; +}; +template +struct MXFP8QuantFused> { + static constexpr Activation activation = Activation::kDSReLU; + static constexpr bool supported = true; +}; + +} // namespace tvm_ffi_bridge + +namespace quantize { + +// Signature mirrors mxfp8::quantize (input, act_input, noop, output, dbias, +// workspace, stream). Returns false to fall back to the CUDA kernel. +inline bool mxfp8_quantize_cutedsl(const tvm_ffi_bridge::MXFP8QuantConfig &config, + const Tensor *input_tensor, const Tensor *act_input_tensor, + const Tensor *noop_tensor, Tensor *output_tensor, + Tensor *dbias_tensor, Tensor *workspace_tensor, + cudaStream_t stream) { + constexpr size_t kCuTeDSLMXFP8ShapeAlignment = 32; + const size_t flat_m = input_tensor->flat_first_dim(); + const size_t flat_n = input_tensor->flat_last_dim(); + if (flat_m % kCuTeDSLMXFP8ShapeAlignment != 0 || + flat_n % kCuTeDSLMXFP8ShapeAlignment != 0) { + return false; + } + + // dbias workspace-size query, mirroring mxfp8::quantize: the framework first + // calls with an unallocated workspace to learn its shape, allocates a buffer of + // that shape, then calls again to run. The kernel writes per-row-block partial + // dbias into this workspace; reducing it to the final dbias is a separate step. + if (config.with_dbias && workspace_tensor != nullptr && + workspace_tensor->data.dptr == nullptr) { + constexpr size_t kCuTeDSLMXFP8ChunkRows = 64; // TILE_Y * NUM_TILES (CTA row span) + const size_t dbias_rows = (flat_m + kCuTeDSLMXFP8ChunkRows - 1) / kCuTeDSLMXFP8ChunkRows; + workspace_tensor->data.shape = {dbias_rows, flat_n}; + workspace_tensor->data.dtype = DType::kFloat32; + return true; + } + + std::optional mxfp8_quant_func_opt = + tvm_ffi_bridge::TVMFFICentral::getInstance().lazyload_function(config); + if (!mxfp8_quant_func_opt.has_value()) { + return false; + } + + // Zero out swizzled scale padding when the matrix isn't a multiple of the + // 128x128 GEMM tile. The kernel writes only the meaningful scale region, so + // cuBLAS would otherwise read uninitialized padding. Mirrors the CUDA launcher + // in quantize_mxfp8.cuh (the kernel itself does not pad the scales). + // TODO: move this into the CuTeDSL host code so the padding is handled inside + // the kernel launch — this CUDA-driver memset is an implementation detail that + // doesn't belong in the dispatcher (blocked on calling the driver API there). + if (config.swizzled && (flat_m % 128 != 0 || flat_n % 128 != 0)) { + if (output_tensor->has_data()) { + NVTE_CHECK_CUDA(cudaMemsetAsync(output_tensor->scale_inv.dptr, 0, + output_tensor->scale_inv.buffer_size_bytes(), stream)); + } + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK_CUDA( + cudaMemsetAsync(output_tensor->columnwise_scale_inv.dptr, 0, + output_tensor->columnwise_scale_inv.buffer_size_bytes(), stream)); + } + } + + // Data tensors auto-flatten to 2D (DLTensorWrapper's default), matching the + // kernel's flat (rows, cols) view; scale/amax/noop are rank <= 2 and pass through. + tvm_ffi_bridge::DLTensorWrapper mX(input_tensor->data); + tvm_ffi_bridge::DLTensorWrapper mO_row(output_tensor->data); + tvm_ffi_bridge::DLTensorWrapper mS_row(output_tensor->scale_inv); + tvm_ffi_bridge::DLTensorWrapper mO_col(output_tensor->columnwise_data); + tvm_ffi_bridge::DLTensorWrapper mS_col(output_tensor->columnwise_scale_inv); + tvm_ffi_bridge::DLTensorWrapper mAmax(output_tensor->amax); + tvm_ffi_bridge::DLTensorWrapper mNoop(noop_tensor->data); + // Backward tensors: null wrapper (None) unless present, no allocation when absent. + // mDbias stays None: the kernel writes per-block partials into the workspace, and + // the final dbias is produced by a separate reduction (not by this kernel). + tvm_ffi_bridge::DLTensorWrapper mActInput, mDbias, mWorkspace; + if (act_input_tensor != nullptr) mActInput = tvm_ffi_bridge::DLTensorWrapper(act_input_tensor->data); + if (workspace_tensor != nullptr) mWorkspace = tvm_ffi_bridge::DLTensorWrapper(workspace_tensor->data); + // stream is a tvm-ffi opaque "handle"; pass the CUDA stream as void*. + (*mxfp8_quant_func_opt)(&mX, &mO_row, &mS_row, &mO_col, &mS_col, &mAmax, &mNoop, + &mActInput, &mDbias, &mWorkspace, static_cast(stream)); + + // dbias: the kernel wrote per-row-block partials into the workspace; reduce them + // over the row-blocks into the final dbias[N]. Mirrors mxfp8::quantize, which + // launches common::reduce_dbias after its quantize kernel. + if (config.with_dbias) { + const size_t blocks_Y = (flat_m + 63) / 64; // ceil(M/64) = workspace rows + const float *workspace_ptr = reinterpret_cast(workspace_tensor->data.dptr); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input_tensor->dtype(), IType, + dispatch::common::reduce_dbias(workspace_ptr, dbias_tensor, blocks_Y, flat_n, + stream);) // NOLINT(*) + } + return true; +} + +template +bool mxfp8_quantize_cutedsl(const Tensor *input_tensor, const Tensor *act_input_tensor, + const Tensor *noop_tensor, Tensor *output_tensor, + Tensor *dbias_tensor, Tensor *workspace_tensor, + cudaStream_t stream) { + using Fused = tvm_ffi_bridge::MXFP8QuantFused; + if constexpr (!Fused::supported) { + return false; + } else { + const bool with_noop = noop_tensor != nullptr && noop_tensor->data.dptr != nullptr; + const tvm_ffi_bridge::MXFP8QuantConfig config{ + /*dtype=*/input_tensor->dtype(), + /*fp8_dtype=*/output_tensor->dtype(), + /*rowwise=*/output_tensor->has_data(), + /*colwise=*/output_tensor->has_columnwise_data(), + /*swizzled=*/output_tensor->with_gemm_swizzled_scales, + /*with_amax=*/output_tensor->amax.dptr != nullptr, + /*with_dbias=*/IS_DBIAS, + /*with_dact=*/IS_DACT, + /*with_act=*/IS_ACT, + /*with_noop=*/with_noop, + /*activation=*/Fused::activation}; + return mxfp8_quantize_cutedsl(config, input_tensor, act_input_tensor, noop_tensor, + output_tensor, dbias_tensor, workspace_tensor, stream); + } +} + +} // namespace quantize +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ diff --git a/transformer_engine/common/tvm_ffi_bridge.h b/transformer_engine/common/tvm_ffi_bridge.h new file mode 100644 index 0000000000..aae48466d5 --- /dev/null +++ b/transformer_engine/common/tvm_ffi_bridge.h @@ -0,0 +1,259 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ +#define TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "transformer_engine/transformer_engine.h" +#include "util/cuda_runtime.h" +#include "util/logging.h" + +namespace transformer_engine { +namespace tvm_ffi_bridge { + +inline const char *te_dtype_to_str(DType dtype) { + switch (dtype) { + case DType::kFloat32: return "fp32"; + case DType::kFloat16: return "fp16"; + case DType::kBFloat16: return "bf16"; + case DType::kFloat8E4M3: return "e4m3"; + case DType::kFloat8E5M2: return "e5m2"; + default: return ""; + } +} + +// Fused activation token forwarded to Python. Encodes both the family and the +// forward-vs-derivative direction: "relu" is the forward activation, "drelu" its +// backward derivative (dact). This is why no separate is_act/is_dact flag is +// needed — the token carries it; only with_dbias (orthogonal) is a separate flag. +// The d-variants are slots for the not-yet-wired backward path; the forward +// tokens must match Python's SUPPORTED_ACTIVATIONS set. +enum class Activation { + kNone, + kReLU, + kGeLU, + kSiLU, + kQGeLU, + kSReLU, + kDReLU, + kDGeLU, + kDSiLU, + kDQGeLU, + kDSReLU +}; + +inline const char *activation_to_str(Activation act) { + switch (act) { + case Activation::kReLU: return "relu"; + case Activation::kGeLU: return "gelu"; + case Activation::kSiLU: return "silu"; + case Activation::kQGeLU: return "qgelu"; + case Activation::kSReLU: return "srelu"; + case Activation::kDReLU: return "drelu"; + case Activation::kDGeLU: return "dgelu"; + case Activation::kDSiLU: return "dsilu"; + case Activation::kDQGeLU: return "dqgelu"; + case Activation::kDSReLU: return "dsrelu"; + case Activation::kNone: return "none"; + } + return "none"; +} + +inline DLDataType convert_to_dltype(NVTEDType type) { + switch (type) { + case kNVTEFloat32: return DLDataType{kDLFloat, 32, 1}; + case kNVTEFloat16: return DLDataType{kDLFloat, 16, 1}; + case kNVTEBFloat16: return DLDataType{kDLBfloat, 16, 1}; + case kNVTEByte: return DLDataType{kDLUInt, 8, 1}; + case kNVTEInt32: return DLDataType{kDLInt, 32, 1}; + case kNVTEInt64: return DLDataType{kDLInt, 64, 1}; + // FP8 / E8M0 → raw 1-byte uint; the kernel interprets the bits. + case kNVTEFloat8E4M3: return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E5M2: return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E8M0: return DLDataType{kDLUInt, 8, 1}; + default: NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); + } +} + +class DLTensorWrapper : public DLTensor { + public: + // Null wrapper (data == nullptr): packs as TVM-FFI None, no allocation. + DLTensorWrapper() : DLTensor{} {} + + DLTensorWrapper(const NVTEBasicTensor &tensor, bool flatten_2D = true) { + const int32_t device_index = transformer_engine::cuda::current_device(); + const int n = static_cast(tensor.shape.ndim); + if (flatten_2D && n > 2) { + int64_t flat_first = 1; + for (int i = 0; i + 1 < n; ++i) flat_first *= static_cast(tensor.shape.data[i]); + const int64_t flat_last = static_cast(tensor.shape.data[n - 1]); + shape_buf_ = std::make_unique(2); + strides_buf_ = std::make_unique(2); + shape_buf_[0] = flat_first; shape_buf_[1] = flat_last; + strides_buf_[0] = flat_last; strides_buf_[1] = 1; + this->ndim = 2; + } else { + shape_buf_ = std::make_unique(n); + strides_buf_ = std::make_unique(n); + int64_t stride = 1; + for (int i = n - 1; i >= 0; --i) { + shape_buf_[i] = static_cast(tensor.shape.data[i]); + strides_buf_[i] = stride; + stride *= shape_buf_[i]; + } + this->ndim = n; + } + this->data = tensor.data_ptr; + this->device = DLDevice{kDLCUDA, device_index}; + this->dtype = convert_to_dltype(tensor.dtype); + this->shape = shape_buf_.get(); + this->strides = strides_buf_.get(); + this->byte_offset = 0; + } + + ~DLTensorWrapper() = default; + DLTensorWrapper(const DLTensorWrapper &) = delete; + DLTensorWrapper &operator=(const DLTensorWrapper &) = delete; + DLTensorWrapper(DLTensorWrapper &&) = default; + DLTensorWrapper &operator=(DLTensorWrapper &&) = default; + + private: + std::unique_ptr shape_buf_; + std::unique_ptr strides_buf_; +}; + +} // namespace tvm_ffi_bridge +} // namespace transformer_engine + +namespace tvm { +namespace ffi { +// Make a (borrowed) DLTensorWrapper* a first-class TVM-FFI argument, so wrappers +// can be passed straight to Function::operator()(&w, ...). Like DLTensor* it is a +// non-owning DLTensorPtr view (the wrapper must outlive the call), but a null +// pointer OR a wrapper over an absent buffer (null data) packs as TVM-FFI None — +// so a kernel's optional args need no special handling at the call site. Only +// the pack-as-argument path (CopyToAnyView) is provided; reading back is unused. +// Declared after DLTensorWrapper: the specialization needs the complete type +// (it reads src->data and static_casts to its DLTensor base). +template <> +struct TypeTraits + : public TypeTraits { + TVM_FFI_INLINE static void CopyToAnyView( + transformer_engine::tvm_ffi_bridge::DLTensorWrapper *src, TVMFFIAny *result + ) { + if (src == nullptr || src->data == nullptr) { + TypeTraits::CopyToAnyView(nullptr, result); // -> TVM-FFI None + } else { + TypeTraits::CopyToAnyView(static_cast(src), result); + } + } +}; +} // namespace ffi +} // namespace tvm + +namespace transformer_engine { +namespace tvm_ffi_bridge { + +// Compile-time check that a config provides the lazy-loadable kernel API: +// - std::string to_key() const +// - bool retrieve_func_from_python(const std::string& key) const +// (compiles + globally registers the kernel under `key`; returns whether +// a kernel is now registered / the config is supported) +// Drives the static_assert in TVMFFICentral::lazyload_function so a config that +// is missing either method fails with a clear message instead of a deref-into- +// the-template error. +namespace detail { +template +struct is_lazyloadable_config : std::false_type {}; +template +struct is_lazyloadable_config< + T, std::void_t().to_key()), + decltype(std::declval().retrieve_func_from_python( + std::declval()))>> : std::true_type {}; +} // namespace detail + + +class TVMFFICentral { + public: + static TVMFFICentral &getInstance() { + static TVMFFICentral instance; + return instance; + } + + // Resolve the compiled kernel for `cfg`. The kernel itself lives in the tvm-ffi + // global registry (registered by the Python entrypoint under cfg.to_key()), + // which releases its Python-backed entries safely at interpreter shutdown; we + // fetch it per call with GetGlobal(key). C++ caches only a bool per config + // (supported or not), so Python is asked at most once per config and we never + // hold a Python-backed handle in a static-duration object (which would crash + // at exit, when the singleton is torn down after the interpreter is finalized). + template + std::optional lazyload_function(const Config &cfg) { + static_assert(detail::is_lazyloadable_config::value, + "Config must define `std::string to_key() const` and " + "`bool retrieve_func_from_python(const std::string&) const`."); + if (!enabled_) return std::nullopt; + const std::string key = cfg.to_key(); + { + std::shared_lock read_lock(mutex_); + auto it = supported_.find(key); + if (it != supported_.end()) { + return it->second ? tvm::ffi::Function::GetGlobal(key) : std::nullopt; + } + } + // Cold miss: ask Python to compile + globally register the kernel under + // `key`; cache only the support decision (avoids re-asking Python, and + // negative-caches unsupported configs). + const bool supported = cfg.retrieve_func_from_python(key); + { + std::unique_lock write_lock(mutex_); + supported_.emplace(key, supported); + } + return supported ? tvm::ffi::Function::GetGlobal(key) : std::nullopt; + } + + private: + ~TVMFFICentral() = default; + TVMFFICentral() : enabled_(is_cutedsl_backend_enabled()) {} + TVMFFICentral(const TVMFFICentral &) = delete; + TVMFFICentral &operator=(const TVMFFICentral &) = delete; + TVMFFICentral(TVMFFICentral &&) = delete; + TVMFFICentral &operator=(TVMFFICentral &&) = delete; + + static bool is_cutedsl_backend_enabled() { + // On by default; set NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=0 to disable. + const char *flag = std::getenv("NVTE_ENABLE_CUTEDSL_QUANT_BACKEND"); + return flag == nullptr || flag[0] != '0'; + } + + const bool enabled_; + std::shared_mutex mutex_; + // Per-config support decision (cfg.to_key() -> supported). Holds NO Python- + // backed handles, so it is safe to destroy at static teardown — the kernels + // live in the tvm-ffi registry, owned and released by tvm-ffi itself. + std::unordered_map supported_; +}; + +} // namespace tvm_ffi_bridge +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 06db28ee27..334dd0eb15 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -18,6 +18,16 @@ load_framework_extension("torch") from transformer_engine.pytorch import constants from transformer_engine.pytorch.constants import DType + +# Register the CuTeDSL kernel entrypoints (TVM-FFI global funcs) so the C++ +# dispatcher can discover them via GetGlobal and compile kernels on demand. The +# CuTeDSL toolchain (cutlass, tvm_ffi) is optional; if it is unavailable the +# import is skipped and C++ simply falls back to the CUDA C++ kernels. +try: + import transformer_engine.common.CuTeDSL # noqa: F401 +except Exception: + pass + from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP From 675b47cd9c0c7229da386db9b93841c035e25135 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Mon, 22 Jun 2026 22:42:03 +0000 Subject: [PATCH 02/22] make the code look nicer --- .../common/CuTeDSL/activations.py | 93 ++++ .../common/CuTeDSL/cast/mxfp8/mxfp8_utils.py | 503 +++++------------ .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 518 ++++++------------ .../common/cast/dispatch/quantize.cuh | 4 +- .../cast/mxfp8/quantize_mxfp8_cutedsl.cuh | 54 +- transformer_engine/common/tvm_ffi_bridge.h | 4 +- 6 files changed, 433 insertions(+), 743 deletions(-) create mode 100644 transformer_engine/common/CuTeDSL/activations.py diff --git a/transformer_engine/common/CuTeDSL/activations.py b/transformer_engine/common/CuTeDSL/activations.py new file mode 100644 index 0000000000..96ffd5c6c1 --- /dev/null +++ b/transformer_engine/common/CuTeDSL/activations.py @@ -0,0 +1,93 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import cutlass +import cutlass.cute as cute +from cutlass import Float32 +from cutlass._mlir.dialects import arith as mlir_arith +from cutlass.cutlass_dsl import dsl_user_op + +def act_relu(x: Float32) -> Float32: + return cute.arch.fmax(x, Float32(0.0)) + + +def act_gelu(x: Float32) -> Float32: + """Tanh-approximation GELU. Constants and operator grouping match TE's + `transformer_engine/common/util/math.h::gelu` exactly (factored form + `x · (0.5 + 0.5·tanh(x·(a + b·x²)))`) so quantized output is bit-exact + against the C++ fused IS_ACT path. Uses `cute.math.tanh(fastmath=False)` + rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation + kernels without `--use_fast_math` by default, so its `tanhf` is the + IEEE-precise expansion.""" + A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal + B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation + return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) + + +def act_silu(x: Float32) -> Float32: + """SiLU/Swish: x · σ(x) = x / (1 + e^-x). + Matches TE's `silu` (`val / (1 + expf(-val))`).""" + return x / (Float32(1.0) + cute.arch.exp(-x)) + + +def act_qgelu(x: Float32) -> Float32: + """Quick GELU: x · σ(1.702·x). Matches TE `qgelu_with_alpha(val, 1.702)` = + `cval · (1 / (1 + expf(-1.702·cval)))` (multiply by sigmoid, not a divide).""" + z = Float32(1.702) * x + return x * (Float32(1.0) / (Float32(1.0) + cute.arch.exp(-z))) + + +def act_srelu(x: Float32) -> Float32: + """Squared ReLU: x>0 ? x·x : 0 == (max(0,x))². Matches TE `srelu`.""" + r = cute.arch.fmax(x, Float32(0.0)) + return r * r + +@dsl_user_op +def dact_drelu(x: Float32, *, loc=None, ip=None) -> Float32: + """drelu: x > 0 ? 1 : 0. Matches math.h `drelu` (NaN → 0 via ordered compare).""" + cond = mlir_arith.cmpf(mlir_arith.CmpFPredicate.OGT, + x.ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + return Float32(mlir_arith.select(cond, + Float32(1.0).ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +def dact_dsrelu(x: Float32) -> Float32: + """dsrelu: fmax(2x, 0). Matches math.h `dsrelu`.""" + return cute.arch.fmax(Float32(2.0) * x, Float32(0.0)) + + +def sigmoid(x: Float32) -> Float32: + """σ(x) = 1 / (1 + e^-x), same exp intrinsic as the forward silu/qgelu.""" + return Float32(1.0) / (Float32(1.0) + cute.arch.exp(-x)) + + +def dact_dsilu(x: Float32) -> Float32: + """dsilu: x·σ(x)·(1-σ(x)) + σ(x). Matches math.h `dsilu` + (`cval·dsigmoid + sigmoid`, dsigmoid = s·(1-s)).""" + s = sigmoid(x) + return x * (s * (Float32(1.0) - s)) + s + + +def dact_dqgelu(x: Float32) -> Float32: + """dqgelu (alpha=1.702): a·x·dσ(a·x) + σ(a·x). Matches math.h + `dqgelu_with_alpha(val, 1.702)`.""" + a = Float32(1.702) + ax = a * x + s = sigmoid(ax) + return a * x * (s * (Float32(1.0) - s)) + s + + +def dact_dgelu(x: Float32) -> Float32: + """dgelu (tanh approximation). Matches math.h `dgelu` term-for-term; + same tanh argument as the forward `_act_gelu`.""" + t = cute.math.tanh( + Float32(0.79788456) * x * (Float32(1.0) + Float32(0.044715) * x * x), + fastmath=False, + ) + return (Float32(0.5) * x + * ((Float32(1.0) - t * t) * (Float32(0.79788456) + Float32(0.1070322243) * x * x)) + + Float32(0.5) * (Float32(1.0) + t)) + diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py index 0ee407f6b9..2389962287 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py @@ -1,26 +1,32 @@ import cutlass import cutlass.cute as cute -import cutlass.utils as utils -import cutlass.pipeline as pipeline from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 from cutlass._mlir.dialects import arith as mlir_arith from cutlass._mlir.dialects import llvm -from cutlass.base_dsl.compiler import GPUArch -from cutlass.cute.runtime import make_ptr from cutlass.cutlass_dsl import T, dsl_user_op -from cutlass.cute.arch import cvt_f32_bf16 from types import SimpleNamespace +from transformer_engine.common.CuTeDSL.activations import ( + act_relu, + act_gelu, + act_silu, + act_qgelu, + act_srelu, + dact_drelu, + dact_dsrelu, + dact_dsilu, + dact_dqgelu, + dact_dgelu, +) + + # FP8E4M3 max representable value FP8E4M3_MAX_NORM = 448.0 FP8E4M3_MAX_NORM_RCP = 1.0 / FP8E4M3_MAX_NORM FP8E5M2_MAX_NORM = 57344.0 FP8E5M2_MAX_NORM_RCP = 1.0 / FP8E5M2_MAX_NORM -# NVFP4 (fp4e2m1) — 4-bit float, max representable value is 6.0 -FP4_E2M1_MAX = 6.0 -FP4_E2M1_MAX_RCP = 1.0 / FP4_E2M1_MAX # Largest finite f32 — used to clamp the per-block scale inverse against # division-by-zero (which produces +inf and then NaN downstream). FP32_MAX = 3.4028234663852886e38 @@ -143,30 +149,6 @@ def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: asm_dialect=llvm.AsmDialect.AD_ATT)) -@dsl_user_op -def pack_i32x2(lo: Int32, hi: Int32, *, loc=None, ip=None) -> Int64: - """i32 sibling of `pack_f32x2` — concat two i32 into a single b64 register. - Used by NVFP4 to glue two `(bf16,bf16)`/`(f16,f16)` Int32 packs into the - `Int64` operand the `mul_cvt.*x4` PTX expects.""" - return Int64(llvm.inline_asm( - T.i64(), - [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], - "mov.b64 $0, {$1, $2};", - "=l,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - -@dsl_user_op -def _trunc_i32_to_i16(val: Int32, *, loc=None, ip=None) -> Int16: - """Narrow an Int32 to Int16 by keeping the low 16 bits. - - Lives here because the existing arith-dialect narrowing pattern requires - loc/ip kwargs (see other `mlir_arith.trunci` callers); wrapping it as a - `@dsl_user_op` lets `@cute.jit` bodies use it without plumbing those in.""" - return Int16(mlir_arith.trunci( - T.i16(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - @dsl_user_op def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: """One fp8e4m3 byte (low 8 bits of `byte_i32`) → f32. @@ -192,6 +174,56 @@ def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: "=f,r", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT)) +@dsl_user_op +def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. + + Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). + This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. + """ + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +def _cvt_f32_to_fp8(fp8_dtype: str): + """Const-expr dispatch: pick the f32→fp8 scalar PTX op based on output dtype. + + `fp8_dtype` is the Python string from `cfg.FP8_DTYPE`, evaluated at JIT + trace time; the unused branch is never traced. + """ + if fp8_dtype == "e5m2": + return cvt_f32_to_fp8e5m2 + return cvt_f32_to_fp8e4m3 + + +def _cvt_f32x2_to_fp8x2(fp8_dtype: str): + """Const-expr dispatch for the packed f32x2→fp8x2 cvt.""" + if fp8_dtype == "e5m2": + return cvt_f32x2_to_fp8e5m2x2 + return cvt_f32x2_to_fp8e4m3x2 + + # --------------------------------------------------------------------------- # 16-bit packed input PTX kit (bf16 / f16) # @@ -355,44 +387,6 @@ def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): return _build_mul_cvt("e5m2", relu) return _build_mul_cvt("e4m3", relu) - # NVFP4 fused cast: x4 × f32x2 → fp4e2m1x4 (4 fp4 packed in 16 - # bits). Same shape as `mul_cvt_to_fp8x2` but produces 4 elements at a - # time because the `cvt.rn.satfinite.e2m1x2.f32` PTX consumes pairs and - # writes a single byte (high nibble = first input, low nibble = second). - # The shuffled `mov.b64 {v1, v0}, v01` lines after the muls undo the - # PTX's hi/lo packing so the resulting byte is naturally - # `(fp4(elt1) << 4) | fp4(elt0)` — matches TE's C++ asm. - @dsl_user_op - def mul_cvt_to_fp4x4(in_4x: Int64, scale_2x: Int64, *, loc=None, ip=None) -> Int32: - asm = ( - "{\n" - ".reg.b64 v01; .reg.b64 v23;\n\t" - ".reg.b16 i0; .reg.b16 i1; .reg.b16 i2; .reg.b16 i3;\n\t" - ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" - ".reg.b8 f0; .reg.b8 f1;\n\t" - "mov.b64 {i0, i1, i2, i3}, $1;\n\t" - f"cvt.f32.{in_fmt} v0, i0;\n\t" - f"cvt.f32.{in_fmt} v1, i1;\n\t" - f"cvt.f32.{in_fmt} v2, i2;\n\t" - f"cvt.f32.{in_fmt} v3, i3;\n\t" - "mov.b64 v01, {v0, v1};\n\t" - "mov.b64 v23, {v2, v3};\n\t" - "mul.f32x2 v01, v01, $2;\n\t" - "mul.f32x2 v23, v23, $2;\n\t" - "mov.b64 {v1, v0}, v01;\n\t" - "mov.b64 {v3, v2}, v23;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 $0, {f0, f1, f0, f1};\n\t" - "}" - ) - return Int32(llvm.inline_asm( - T.i32(), - [in_4x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=r,l,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return SimpleNamespace( abs_max_x2=abs_max_x2, max_x2=max_x2, @@ -402,7 +396,6 @@ def mul_cvt_to_fp4x4(in_4x: Int64, scale_2x: Int64, *, loc=None, ip=None) -> Int x2_hi_to_f32=x2_hi_to_f32, truncate_f32=truncate_f32, mul_cvt_to_fp8x2=mul_cvt_to_fp8x2, - mul_cvt_to_fp4x4=mul_cvt_to_fp4x4, ) @@ -414,227 +407,35 @@ def _is_packed16(dtype) -> bool: """True if `dtype` is one of the 16-bit packed input formats.""" return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 - def _packed16_kit(dtype): """Trace-time selector — pick a Packed16Kit for the input dtype.""" if dtype is cutlass.Float16: return _F16_KIT return _BF16_KIT - -# --------------------------------------------------------------------------- -# Forward-activation registry -# -# Each entry is a Float32 → Float32 callable applied per element before the -# MXFP8 amax + cast. Selection is by Python string at JIT trace time, so the -# const-expr machinery treats `cfg.ACTIVATION` like a C++ template argument -# — no runtime branch in the inner loop, separate kernel cached per choice. -# -# Math primitives match CUDA fast-math intrinsics so outputs are bit-exact -# with PyTorch's CUDA implementations of the same activations: -# tanh -> tanh.approx.f32 (== __tanhf) -# exp(x) -> exp2.approx.f32(x · log2(e)) (== __expf) -# --------------------------------------------------------------------------- -def _act_relu(x: Float32) -> Float32: - return cute.arch.fmax(x, Float32(0.0)) - - -def _act_gelu(x: Float32) -> Float32: - """Tanh-approximation GELU. Constants and operator grouping match TE's - `transformer_engine/common/util/math.h::gelu` exactly (factored form - `x · (0.5 + 0.5·tanh(x·(a + b·x²)))`) so quantized output is bit-exact - against the C++ fused IS_ACT path. Uses `cute.math.tanh(fastmath=False)` - rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation - kernels without `--use_fast_math` by default, so its `tanhf` is the - IEEE-precise expansion.""" - A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal - B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation - return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) - - -def _act_silu(x: Float32) -> Float32: - """SiLU/Swish: x · σ(x) = x / (1 + e^-x). - Matches TE's `silu` (`val / (1 + expf(-val))`).""" - return x / (Float32(1.0) + cute.arch.exp(-x)) - - -def _act_qgelu(x: Float32) -> Float32: - """Quick GELU: x · σ(1.702·x). Matches TE `qgelu_with_alpha(val, 1.702)` = - `cval · (1 / (1 + expf(-1.702·cval)))` (multiply by sigmoid, not a divide).""" - z = Float32(1.702) * x - return x * (Float32(1.0) / (Float32(1.0) + cute.arch.exp(-z))) - - -def _act_srelu(x: Float32) -> Float32: - """Squared ReLU: x>0 ? x·x : 0 == (max(0,x))². Matches TE `srelu`.""" - r = cute.arch.fmax(x, Float32(0.0)) - return r * r - - SUPPORTED_ACTIVATIONS = { - "relu": _act_relu, - "gelu": _act_gelu, - "silu": _act_silu, - "qgelu": _act_qgelu, - "srelu": _act_srelu, + "relu": act_relu, + "gelu": act_gelu, + "silu": act_silu, + "qgelu": act_qgelu, + "srelu": act_srelu, } - -# --------------------------------------------------------------------------- -# Backward-activation (dact) registry -# -# Each entry is the derivative act'(x) as a Float32 → Float32 callable, matching -# the corresponding `d` in transformer_engine/common/util/math.h. The dact -# kernel computes `grad · act'(x)` per element before the MXFP8 amax + cast. -# Primitives mirror the forward registry (cute.math.tanh fastmath=False for -# gelu, cute.arch.exp for the sigmoid) so output is bit-exact with the C++ path. -# --------------------------------------------------------------------------- -@dsl_user_op -def _dact_drelu(x: Float32, *, loc=None, ip=None) -> Float32: - """drelu: x > 0 ? 1 : 0. Matches math.h `drelu` (NaN → 0 via ordered compare).""" - cond = mlir_arith.cmpf(mlir_arith.CmpFPredicate.OGT, - x.ir_value(loc=loc, ip=ip), - Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - return Float32(mlir_arith.select(cond, - Float32(1.0).ir_value(loc=loc, ip=ip), - Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -def _dact_dsrelu(x: Float32) -> Float32: - """dsrelu: fmax(2x, 0). Matches math.h `dsrelu`.""" - return cute.arch.fmax(Float32(2.0) * x, Float32(0.0)) - - -def _sigmoid(x: Float32) -> Float32: - """σ(x) = 1 / (1 + e^-x), same exp intrinsic as the forward silu/qgelu.""" - return Float32(1.0) / (Float32(1.0) + cute.arch.exp(-x)) - - -def _dact_dsilu(x: Float32) -> Float32: - """dsilu: x·σ(x)·(1-σ(x)) + σ(x). Matches math.h `dsilu` - (`cval·dsigmoid + sigmoid`, dsigmoid = s·(1-s)).""" - s = _sigmoid(x) - return x * (s * (Float32(1.0) - s)) + s - - -def _dact_dqgelu(x: Float32) -> Float32: - """dqgelu (alpha=1.702): a·x·dσ(a·x) + σ(a·x). Matches math.h - `dqgelu_with_alpha(val, 1.702)`.""" - a = Float32(1.702) - ax = a * x - s = _sigmoid(ax) - return a * x * (s * (Float32(1.0) - s)) + s - - -def _dact_dgelu(x: Float32) -> Float32: - """dgelu (tanh approximation). Matches math.h `dgelu` term-for-term; - same tanh argument as the forward `_act_gelu`.""" - t = cute.math.tanh( - Float32(0.79788456) * x * (Float32(1.0) + Float32(0.044715) * x * x), - fastmath=False, - ) - return (Float32(0.5) * x - * ((Float32(1.0) - t * t) * (Float32(0.79788456) + Float32(0.1070322243) * x * x)) - + Float32(0.5) * (Float32(1.0) + t)) - - SUPPORTED_DACTIVATIONS = { - "drelu": _dact_drelu, - "dgelu": _dact_dgelu, - "dsilu": _dact_dsilu, - "dqgelu": _dact_dqgelu, - "dsrelu": _dact_dsrelu, + "drelu": dact_drelu, + "dgelu": dact_dgelu, + "dsilu": dact_dsilu, + "dqgelu": dact_dqgelu, + "dsrelu": dact_dsrelu, } -@dsl_user_op -def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: - """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. - - Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). - This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. - """ - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -@dsl_user_op -def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: - """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -@dsl_user_op -def mul_cvt_f32x4_to_fp4x4(in01: Int64, in23: Int64, scale_2x: Int64, - *, loc=None, ip=None) -> Int32: - """f32x4 sibling of `kit.mul_cvt_to_fp4x4` — for the NVFP4 colwise path - where elements live on a strided column and we've already widened to f32 - for the amax reduction. `in01` = pack(f32_0, f32_1), `in23` similarly.""" - asm = ( - "{\n" - ".reg.b64 v01; .reg.b64 v23;\n\t" - ".reg.b32 v0; .reg.b32 v1; .reg.b32 v2; .reg.b32 v3;\n\t" - ".reg.b8 f0; .reg.b8 f1;\n\t" - "mov.b64 {v0, v1}, $1;\n\t" - "mov.b64 {v2, v3}, $2;\n\t" - "mov.b64 v01, {v0, v1};\n\t" - "mov.b64 v23, {v2, v3};\n\t" - "mul.f32x2 v01, v01, $3;\n\t" - "mul.f32x2 v23, v23, $3;\n\t" - "mov.b64 {v1, v0}, v01;\n\t" - "mov.b64 {v3, v2}, v23;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 $0, {f0, f1, f0, f1};\n\t" - "}" - ) - return Int32(llvm.inline_asm( - T.i32(), - [in01.ir_value(loc=loc, ip=ip), - in23.ir_value(loc=loc, ip=ip), - scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=r,l,l,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - -def _cvt_f32_to_fp8(fp8_dtype: str): - """Const-expr dispatch: pick the f32→fp8 scalar PTX op based on output dtype. - - `fp8_dtype` is the Python string from `cfg.FP8_DTYPE`, evaluated at JIT - trace time; the unused branch is never traced. - """ - if fp8_dtype == "e5m2": - return cvt_f32_to_fp8e5m2 - return cvt_f32_to_fp8e4m3 - - -def _cvt_f32x2_to_fp8x2(fp8_dtype: str): - """Const-expr dispatch for the packed f32x2→fp8x2 cvt.""" - if fp8_dtype == "e5m2": - return cvt_f32x2_to_fp8e5m2x2 - return cvt_f32x2_to_fp8e4m3x2 - @cute.jit def quantize_rowwise_mxfp8( sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sA_tile, # (TILE_Y, TILE_X) activation-input smem tile (dact only) sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) - mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, tile_row_start, # Int32 — global row index of this stage's row 0 # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores @@ -645,41 +446,25 @@ def quantize_rowwise_mxfp8( # scale store. ACTIVATION, DTYPE, - ROWWISE, - COLWISE, FP8_DTYPE, TILE_Y, - SCALE_DIM, + MXFP8_BLOCK_SIZE, WAVES, THREADS_PER_WARP, THREADS_PER_BANK, PACK_SIZE, - WITH_ACT=False, # forward: apply activation to the element - WITH_DACT=False, # backward: out = grad · act'(act_input) - sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) - DBIAS_REDUCTION=False, # rowwise-only dbias: accumulate per-column partials - dbias_acc=None, # rmem Float32[SCALE_DIM]; += this row's pre-truncate elt per column + WITH_ACT=False, + WITH_DACT=False, + WITH_DBIAS=False, # rowwise-only dbias: accumulate per-column partials + dbias_acc=None, # only needed when WITH_DBIAS is True ): tidx, _, _ = cute.arch.thread_idx() - # Match the C++ reference's thread layout: pairs of adjacent lanes - # share a row (lanes 2k / 2k+1 both own row k), each pair covering - # the two 32-element scale blocks of that row. Express as a cute - # layout mapping `(tid_Y, tid_X) -> tidx` with stride (2, 1): - # linear(tidx) = tid_Y*2 + tid_X, so `get_flat_coord` inverts to - # `(tidx // 2, tidx % 2)` — semantically clearer than the raw - # divmod, and readily reusable if we later partition via TiledCopy. - # print(f"sX_tile: {sX_tile}") - # print(f"sO_row_tile: {sO_row_tile}") - # print(f"mS_row_stage: {mS_row_stage}") - tiler, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), - val_layout=cute.make_layout((1, SCALE_DIM), stride=(0, 1)) + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)) ) - # print(f"tv_layout: {tv_layout}") - # print(f"tiler: {tiler}") - + sX_tv = cute.composition(sX_tile, tv_layout) sO_tv = cute.composition(sO_row_tile, tv_layout) @@ -687,23 +472,21 @@ def quantize_rowwise_mxfp8( sX_thread = sX_tv[tidx, None] # shape (32,) bf16 sO_thread = sO_tv[tidx, None] # shape (32,) uint8 - # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C1%29-%281%2C+32%29%3A%280%2C1%29 - # print(f"sX_thread: {sX_thread}") - # print(f"sO_thread: {sO_thread}") - sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. sO_thread_u32 = cute.make_tensor( sO_thread_u32_ptr, - cute.make_layout((SCALE_DIM // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements ) - # print(f"sO_thread_u32: {sO_thread_u32}") + # PTX allows to fuse relu operation in `cvt.rn.satfinite` FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") - # For this fast paht we can read in pack of 2 instead of reading individual f16 / bf16 element. + # For this fast path we can read in pack of 2 instead of reading individual f16 / bf16 element. # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. _row_fast = (_is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) - and not DBIAS_REDUCTION) + and not WITH_DBIAS) + + amax_r = Float32(0.0) if cutlass.const_expr(_row_fast): # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack @@ -711,22 +494,18 @@ def quantize_rowwise_mxfp8( kit = _packed16_kit(DTYPE) sX_thread_rw_i32 = cute.make_tensor( cute.recast_ptr(sX_thread.iterator, dtype=Int32), - cute.make_layout((1, SCALE_DIM // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + cute.make_layout((1, MXFP8_BLOCK_SIZE // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements ) - # print(f"sX_thread_rw_i32: {sX_thread_rw_i32}") # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) # In total we have 8 waves where each wave reads in_r = [[None, None] for _ in range(WAVES)] bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group offset = bank_group * 2 # Each bank group will read 2 i32 from their bank for w in cutlass.range_constexpr(WAVES): - idx = (w * 2 + offset) % (SCALE_DIM // 2) + idx = (w * 2 + offset) % (MXFP8_BLOCK_SIZE // 2) in_r[w][0] = sX_thread_rw_i32[0, idx] in_r[w][1] = sX_thread_rw_i32[0, idx + 1] - # 1. Packed-x2 amax — 2 PTX per wave, 16 total per thread. - # Accumulates `|elt|` in both lanes (with xorsign-drifted signs); - # final horizontal max reduces the two lanes to a single f32. amax_2x = Int32(0) # Each wave will use max.xorsign.abs.f16x2 or max.xorsign.abs.bf16x2 to compare 2 packed elements in parallel for w in cutlass.range_constexpr(WAVES): @@ -744,8 +523,7 @@ def quantize_rowwise_mxfp8( kit.x2_hi_to_f32(amax_2x), ) # For relu the max is at least 0 - if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, Float32(0.0)) + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) else: # Compare the 2 packed abs max amax_r = cute.arch.fmax( @@ -756,7 +534,7 @@ def quantize_rowwise_mxfp8( # Since we need to do computation on individual f16 / bf16 elements, we can't read in pack sX_thread_rw = cute.make_tensor( sX_thread.iterator, - cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), ) in_r = [[None] * PACK_SIZE for _ in range(WAVES)] bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group @@ -768,35 +546,35 @@ def quantize_rowwise_mxfp8( sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] sA_thread_rw = cute.make_tensor( sA_thread.iterator, - cute.make_layout((1, SCALE_DIM), stride=(0, 1)), + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), ) elif cutlass.const_expr(WITH_ACT): op = SUPPORTED_ACTIVATIONS[ACTIVATION] if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): kit_act = _packed16_kit(DTYPE) - amax_r = Float32(0.0) + for w in cutlass.range_constexpr(WAVES): - idx = (w * PACK_SIZE + offset) % SCALE_DIM - for e in cutlass.range_constexpr(PACK_SIZE): - x = Float32(sX_thread_rw[0, idx + e]) # grad + start = (w * PACK_SIZE + offset) % MXFP8_BLOCK_SIZE + for i in cutlass.range_constexpr(PACK_SIZE): + x = Float32(sX_thread_rw[0, start + i]) # grad if cutlass.const_expr(WITH_DACT): # out = grad · act'(act_input) - x = x * dop(Float32(sA_thread_rw[0, idx + e])) + x = x * dop(Float32(sA_thread_rw[0, start + i])) # If IS_ACT, apply activation function to x in f32 elif cutlass.const_expr(WITH_ACT): # If it's relu, we can handle it later if not cutlass.const_expr(FUSE_RELU): x = op(x) - # dbias: accumulate this row's column (idx+e) value BEFORE the bf16 - # truncation (matches CUDA's `thread_dbias_rowwise[j] += elt`). idx+e - # is a multiple-of-PACK_SIZE group + e, so it stays within [0, SCALE_DIM). - if cutlass.const_expr(DBIAS_REDUCTION): - dbias_acc[idx + e] = dbias_acc[idx + e] + x + # dbias: accumulate this row's column (start+e) value BEFORE the bf16 + # truncation (matches CUDA's `thread_dbias_rowwise[j] += elt`). start+i + # is a multiple-of-PACK_SIZE group + i, so it stays within [0, MXFP8_BLOCK_SIZE). + if cutlass.const_expr(WITH_DBIAS): + dbias_acc[start + i] += x # If 16-bit input with activation, truncate to IType if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): - x = kit_act.truncate_f32(x) # TODO: Why not just qunatize from f32? - in_r[w][e] = x + x = kit_act.truncate_f32(x) + in_r[w][i] = x if cutlass.const_expr(FUSE_RELU): amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically else: @@ -804,23 +582,21 @@ def quantize_rowwise_mxfp8( if cutlass.const_expr(FUSE_RELU): amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 - # 2. E8M0 scale → gmem. mS_row's layout already encodes the swizzle - # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. biased_exp_r = float_to_e8m0(amax_r * max_norm_rcp) + # mS_row_stage has logical shape (32, 2) and we have 64 threads where each is mapped to one scale factor - # The TV layout is equivalent to https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%2832%2C+2%29%3A%282%2C+1%29-%281%29 + # The TV layout is equivalent to TV layout with thr_layout=(32, 2):(2, 1), val_layout=(1,) # but it's too trival so let's just index it directly without using layout - # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout that mappes the logical index to the physical offset - # Irregular shapes: skip the scale store if this thread's logical row / - # col-block lies past the input's actual extents. TMA already zero-fills - # OOB input reads and drops OOB output writes; only the direct scale-byte - # gmem store needs an explicit guard. + # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout + # that mappes the logical index to the physical offset + + # For irregular shapes, skip the scale store if this thread's logical row / col-block lies past the input's actual extents. + # TMA already zero-fills OOB input reads and drops OOB output writes; only the direct scale-byte gmem store needs an explicit guard. scale_row = tile_row_start + tidx // 2 - scale_col_first_elt = tile_col_start + (tidx % 2) * SCALE_DIM + scale_col_first_elt = tile_col_start + (tidx % 2) * MXFP8_BLOCK_SIZE if scale_row < M and scale_col_first_elt < N: mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) - # 3. scale + packed fp8 cast → smem as one u32 per wave. inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale # Fetch the conversion function based on the FP8 format cvt_f32x2 = _cvt_f32x2_to_fp8x2(FP8_DTYPE) @@ -834,7 +610,7 @@ def quantize_rowwise_mxfp8( bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group offset = bank_group * 4 # Each bank group will write 4 fp8 to for w in cutlass.range_constexpr(WAVES): - idx = (w * 4 + offset) % SCALE_DIM + idx = (w * 4 + offset) % MXFP8_BLOCK_SIZE idx = idx // 4 if cutlass.const_expr(_row_fast): # One fused PTX per x2 pair: x2 × f32x2 → fp8x2. @@ -874,7 +650,7 @@ def quantize_colwise_mxfp8( SWIZZLE, TILE_X, TILE_Y, - SCALE_DIM, + MXFP8_BLOCK_SIZE, WITH_ACT=False, # forward: apply activation to the element WITH_DACT=False, # backward: out = grad · act'(act_input) sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) @@ -882,15 +658,10 @@ def quantize_colwise_mxfp8( ): tidx, _, _ = cute.arch.thread_idx() - # print(f"sX_tile: {sX_tile}") - # print(f"sO_col_tile: {sO_col_tile}") - # print(f"mS_col_stage: {mS_col_stage}") - tiler, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), - val_layout=cute.make_layout((SCALE_DIM, 1), stride=(1, 1)) + val_layout=cute.make_layout((MXFP8_BLOCK_SIZE, 1), stride=(1, 1)) ) - # print(f"tv_layout: {tv_layout}") sX_tv = cute.composition(sX_tile, tv_layout) sO_tv = cute.composition(sO_col_tile, tv_layout) @@ -899,19 +670,15 @@ def quantize_colwise_mxfp8( sX_thread = sX_tv[tidx, None] sO_thread = sO_tv[tidx, None] - # See https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=tv-2-%281%2C+64%29%3A%2864%2C+1%29-%2832%2C+1%29%3A%281%2C+1%29 - # print(f"sX_thread: {sX_thread}") # shape (32,) bf16 - # print(f"sO_thread: {sO_thread}") # shape (32,) uint8 - # dbias needs the per-element fp32 values to sum, so it takes the f32 path # (never the i16 fast path) — matching CUDA, whose f16 fast path requires # `!IS_DBIAS` (quantize_mxfp8.cuh:219). - HALF_PRECISION_PATH = _is_packed16(DTYPE) and ACTIVATION is None and not WITH_DBIAS + USE_HALF_PRECISION = _is_packed16(DTYPE) and ACTIVATION is None and not WITH_DBIAS dbias_partial = Float32(0.0) # 0. Load the 32-element column from smem into registers once (matches # C++'s `in_colwise_IType[i]` cache). Amax and cast both reuse these. - if cutlass.const_expr(HALF_PRECISION_PATH): + if cutlass.const_expr(USE_HALF_PRECISION): kit = _packed16_kit(DTYPE) # Per-thread Int16 view of the column. Same byte address as # `sX_thread` (bf16/fp16 are 16-bit, same width as Int16); the @@ -919,10 +686,10 @@ def quantize_colwise_mxfp8( # TILE_X apart in the row-major tile. sX_thread_i16 = cute.make_tensor( cute.recast_ptr(sX_thread.iterator, dtype=Int16), - cute.make_layout((SCALE_DIM,), stride=(TILE_X,)), + cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(TILE_X,)), ) amax_bits = Int16(0) - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) else: @@ -931,41 +698,39 @@ def quantize_colwise_mxfp8( # Float32 would not widen; it would reinterpret the 16-bit bytes # as half of a 32-bit float). sX_thread_f32 = cute.make_rmem_tensor( - layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + layout_or_shape=cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(1,)), dtype=Float32, ) - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = Float32(sX_thread[i]) # Apply activation (fwd) or grad·act'(act_input) (bwd dact) in f32. if cutlass.const_expr(WITH_DACT): dop = SUPPORTED_DACTIVATIONS[ACTIVATION] sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = sX_thread_f32[i] * dop(Float32(sA_thread[i])) elif cutlass.const_expr(WITH_ACT): op = SUPPORTED_ACTIVATIONS[ACTIVATION] - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = op(sX_thread_f32[i]) # dbias = column sum of the (post-act/dact) value, taken BEFORE the bf16 # truncation — matches CUDA's `partial_dbias_colwise += elt`. if cutlass.const_expr(WITH_DBIAS): - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): dbias_partial += sX_thread_f32[i] # Numerical truncation through IType so amax/cast match C++. # Only needed when 16-bit input + activation; without activation # the widening was already exact. if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): kit_act = _packed16_kit(DTYPE) - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) amax_c = Float32(0.0) - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) - # 2. E8M0 scale → gmem. mS_col's layout already encodes the swizzle - # when cfg.WITH_GEMM_SWIZZLED_SCALES=True, so 2D access just works. # Irregular shapes: skip when this stage's row range or this thread's - # column lies past the input extents. TILE_Y == SCALE_DIM so each stage + # column lies past the input extents. TILE_Y == MXFP8_BLOCK_SIZE so each stage # is exactly one scale-row; valid iff `tile_row_start < M`. biased_exp_c = float_to_e8m0(amax_c * max_norm_rcp) scale_col = tile_col_start + tidx @@ -975,17 +740,17 @@ def quantize_colwise_mxfp8( else: mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) - # 3. scale + FP8 cast → smem (one byte per (row, tidx)). Caller - # flushes the whole (TILE_Y, TILE_X) tile with a TMA S2G. inv_scale_c = exp2f_rcp(biased_exp_c) cvt_to_fp8 = _cvt_f32_to_fp8(FP8_DTYPE) - if cutlass.const_expr(HALF_PRECISION_PATH): + if cutlass.const_expr(USE_HALF_PRECISION): kit_cast = _packed16_kit(DTYPE) - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) sO_thread[i] = Uint8(cvt_to_fp8(v_f32 * inv_scale_c)) else: - for i in cutlass.range_constexpr(SCALE_DIM): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sO_thread[i] = Uint8(cvt_to_fp8(sX_thread_f32[i] * inv_scale_c)) + # Return this stage's per-column partial alongside amax; the caller accumulates + # it across stages (a scalar can't be updated in-place through the arg). return amax_c, dbias_partial diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 2a57bfb4f4..a9839ca73d 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -8,9 +8,9 @@ values, quantize to MXFP8 format (FP8E4M3 data + E8M0 per-block scales). Matches the C++ kernel's tile dimensions and thread layout: - CHUNK_DIM_Y = 64, CHUNK_DIM_X = 64, THREADS_PER_CHUNK = 64 + CHUNK_DIM_Y = 64, CHUNK_DIM_X = 64, THREADS_PER_CTA = 64 BUFF_DIM_Y = 32, BUFF_DIM_X = 64, STAGES = 2 - SCALE_DIM = 32 (elements per MXFP8 scaling block) + MXFP8_BLOCK_SIZE = 32 (elements per MXFP8 scaling block) Grid: (ceil(N / 64), ceil(M / 64)) Each block processes a 64x64 chunk in 2 stages of 32x64 tiles loaded into @@ -18,22 +18,16 @@ """ import logging -import transformer_engine from transformer_engine.common.CuTeDSL.utils import str_to_cutlass_dtype -import transformer_engine_torch as tex from typing import Optional, Type -import torch -import transformer_engine_torch as tex - import cutlass import cutlass.cute as cute import cutlass.pipeline as pipeline -from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cutlass import Float32, Int32, Uint8 from cuda.bindings.driver import CUstream -import hashlib import tvm_ffi from .mxfp8_utils import ( @@ -42,35 +36,24 @@ FP8E4M3_MAX_NORM_RCP, FP8E5M2_MAX_NORM_RCP, _bitcast_f32_to_i32, - _cvt_f32_to_fp8, - _cvt_f32x2_to_fp8x2, - _is_packed16, - _packed16_kit, - exp2f_rcp, - fabs_f32, - float_to_e8m0, quantize_colwise_mxfp8, quantize_rowwise_mxfp8, ) -# Per-backend logger, so a fallback warning is attributable to *this* CuTeDSL -# backend (the MXFP8 quantize backend). Other CuTeDSL backends should use their -# own `transformer_engine.cutedsl.` logger. logger = logging.getLogger("transformer_engine.cutedsl.mxfp8") # MXFP8 settings MXFP8_BLOCK_SIZE = 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor -SCALE_DIM = MXFP8_BLOCK_SIZE # Double-buffering for async copy + compute overlap BUFFER_NUM = 2 # Vectorised access constants for bank-conflict avoidance (rowwise pass) PACK_SIZE = 4 # Elements per vector load -WAVES = SCALE_DIM // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +WAVES = MXFP8_BLOCK_SIZE // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total THREADS_PER_WARP = 32 TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) -THREADS_PER_BANK = TOTAL_BANKS_WIDTH // SCALE_DIM # 4 threads per bank +THREADS_PER_BANK = TOTAL_BANKS_WIDTH // MXFP8_BLOCK_SIZE # 4 threads per bank # Tiling sizes NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) @@ -79,13 +62,13 @@ TILE_X = 64 # Each tile has 64 columns # CTA size -THREADS_PER_CHUNK = 64 -NUM_WARPS = THREADS_PER_CHUNK // 32 +THREADS_PER_CTA = 64 +NUM_WARPS = THREADS_PER_CTA // 32 -# --------------------------------------------------------------------------- -# Kernel configuration -# --------------------------------------------------------------------------- class MXFP8QuantizeConfig: + """Configs for the compiled CuTeDSL kernel. These will be fixed once the kernel is compiled and + they will behave as const expressions. + """ def __init__( self, @@ -152,29 +135,18 @@ def __str__(self): __repr__ = __str__ -# --------------------------------------------------------------------------- -# Unified MXFP8 quantization kernel — shared memory tiled, single-pass -# --------------------------------------------------------------------------- class MXFP8QuantizeSmemKernel: - """MXFP8 quantization with shared-memory tiling (rowwise, colwise, or both). - - Matches C++ kernel's BIDIMENSIONAL scaling mode: - Grid (ceil(N/64), ceil(M/64)) - Block (64) - Each block processes a 64x64 chunk in 2 stages of 32x64. - - Per stage, the tile is loaded into shared memory once. The colwise - pass reads columns from smem first, then the rowwise pass reads rows. - When both directions are enabled, global memory is read only once per - element — matching the C++ single-pass behaviour. - - Thread mappings (per stage): - Colwise: thread tidx handles column tidx, 32 rows (stride BUFF_DIM_X). - Rowwise: tid_Y = tidx // 2 -> row, tid_X = tidx % 2 -> scale-block. + """The MXFP8 quantization kernel that mirrors the standard (non-specialized) MXFP8 CUDA C++ quantization kernel + with multiple fusions (activation, dbias, etc.). + `__call__` method is the entrypoint which is AOT compiled. `self` will be captured so it's fixed per compiled kernel """ def __init__(self, cfg): self.cfg = cfg + # We prefer to do dbias reduction in colwise which is easier (no cross-thread reduction needed). + # Only do rowwise reduction when we don't quantize columnwisely when WITH_DBIAS is True. + self.DBIAS_REDUCTION_COLWISE = cfg.WITH_DBIAS and cfg.COLWISE + self.DBIAS_REDUCTION_ROWWISE = cfg.WITH_DBIAS and not cfg.COLWISE @cute.jit def __call__( @@ -182,41 +154,28 @@ def __call__( mX: cute.Tensor, # Input tensor to quantize mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], # Rowwise output and scale tensors mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Colwise output and scale tensors - mAmax: Optional[cute.Tensor], # Global amax accumulator, only used in WITH_AMAX path - mNoop: Optional[cute.Tensor], # 1-element cast_noop flag, only used in WITH_NOOP path - # Backward-only slots, present to mirror the CUDA mxfp8::quantize signature - # (act_input / dbias / workspace). NOT used yet — None on the forward path; - # WITH_DACT/WITH_DBIAS configs are rejected upstream so these never carry data. - mActInput: Optional[cute.Tensor], - mDbias: Optional[cute.Tensor], - mWorkspace: Optional[cute.Tensor], + mAmax: Optional[cute.Tensor], # Global amax accumulator, only used when WITH_AMAX is True + mNoop: Optional[cute.Tensor], # 1-element cast_noop flag, only used when WITH_NOOP is True + mDActInput: Optional[cute.Tensor], # Activation input for activation derivative fusion, only used when WITH_DACT is True + mWorkspace: Optional[cute.Tensor], # Workspace for the dbias reduction, only used when WITH_DBIAS is True stream: CUstream, ): M = mX.shape[0] N = mX.shape[1] cfg = self.cfg max_norm_rcp = cfg.MAX_NORM_RCP - num_scale_cols = N // SCALE_DIM - num_scale_rows = M // SCALE_DIM + num_scale_cols = N // MXFP8_BLOCK_SIZE + num_scale_rows = M // MXFP8_BLOCK_SIZE - # Rewrap mS_row / mS_col with the GEMM-swizzled layout when requested. - # Wrapper passes in a tensor with the compact (M, N/32):(N/32, 1) layout - # (built from a compact fake-ptr at compile time), and we re-view the - # underlying buffer here so the per-block scale stores below land at the - # cuBLAS-swizzled byte offsets. - # See https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout - # and swizzle_demo.svg for a visual of the byte permutation. + # If WITH_GEMM_SWIZZLED_SCALES is enabled, the output must satisfy cublas's swizzled layout + # This is expressed as a CuTe layout applied to the output tensor so it can be transparent throughout the kernel implementation. + # See https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout for more details. if cutlass.const_expr(cfg.WITH_GEMM_SWIZZLED_SCALES): num_tiles_M = (M + 127) // 128 - num_tiles_SC = (num_scale_cols + 3) // 4 # = ceil(N / 128) - num_tiles_SR = (num_scale_rows + 3) // 4 # = ceil(M / 128) + num_tiles_SC = (num_scale_cols + 3) // 4 + num_tiles_SR = (num_scale_rows + 3) // 4 num_tiles_N = (N + 127) // 128 - # row i = i_lo + 32 * (i_hi + 4 * tile_Y); col j = j_lo + 4 * tile_X. - # Within one 128×4 tile: byte = i_lo*16 + i_hi*4 + j_lo. - # Tile-major outer dims add (tile_Y * num_tiles_SC + tile_X) * 512. - # For example, if M=256, N=512, then num_scale_cols = 16, num_scale_rows = 8, and num_tiles_M=2, num_tiles_SC=4, num_tiles_SR=2, num_tiles_N=4 - # The swizzled layout is ((32, 4, 2), (4, 4)):((16, 4, 2048), (1, 512)) if cutlass.const_expr(cfg.ROWWISE): mS_row = cute.make_tensor( mS_row.iterator, @@ -225,8 +184,6 @@ def __call__( stride=((16, 4, num_tiles_SC * 512), (1, 512)), ), ) - # Colwise: same swizzle, axes swap roles — col axis gets the 32×4 - # inner decomp, scale-row axis gets the 4-extent dim. if cutlass.const_expr(cfg.COLWISE): mS_col = cute.make_tensor( mS_col.iterator, @@ -235,44 +192,26 @@ def __call__( stride=((1, 512), (16, 4, num_tiles_SR * 512)), ), ) - - # Divide by the STAGE tile (TILE_Y, TILE_X // SCALE_DIM), not the CTA - # tile. Each CTA owns NUM_TILES consecutive row-tiles; the kernel walks - # them by indexing GRID's row dim with `bidy * NUM_TILES + stage` (cute - # auto-decomposes a flat coord onto GRID's hierarchical row modes). - # - # Critically, this is the only divide that cleanly cuts both layouts: - # - compact `(M, N/32):(N/32, 1)` → SCALE_TILE = (32, 2):(N/32, 1) - # - swizzled `((32,4,n_M),(4,n_SC)):((16,4,n_SC·512),(1,512))` - # → SCALE_TILE = (32, 2):(16, 1) - # The bigger (TILE_Y * NUM_TILES, ...) divide we used before tangles the - # swizzle's (32, 4) row hierarchy under flatten + sub-divide chain. - - # Declare TMA descriptors on the host side. - # make_tiled_tma_atom returns the UNTILED gmem tensor with basis strides. - # Tile it inside the kernel with zipped_divide so each coord selects - # one (TILE_Y, TILE_X) tile. + + # We have 2 stages in our pipeline where each stage loads / computes a (TILE_Y, TILE_X) tile smem_tile_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) cta_tiler = (TILE_Y, TILE_X) - - # Input: TMA G2S (bf16/fp16 → smem). + + # Input TMA atoms op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() tma_atom, tma_src = cute.nvgpu.cpasync.make_tiled_tma_atom( op_load, mX, smem_tile_layout, cta_tiler, num_multicast=1, ) - # Backward (dact): the activation input is a second G2S load, identical to - # mX's. The kernel computes `grad · act'(act_input)`; here mX carries grad. + # Activation input TMA atoms for activation derivative fusion tma_atom_act = None tma_src_act = None if cutlass.const_expr(cfg.WITH_DACT): tma_atom_act, tma_src_act = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_load, mActInput, smem_tile_layout, cta_tiler, num_multicast=1, + op_load, mDActInput, smem_tile_layout, cta_tiler, num_multicast=1, ) - # Output: TMA S2G (uint8 smem → gmem) for both directions. Creating - # both atoms unconditionally — if a direction is disabled the kernel - # simply won't dispatch its copy, and the atom cost is negligible. + # Output TMA atoms op_store = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() out_smem_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) tma_atom_out_row = None @@ -287,24 +226,12 @@ def __call__( tma_atom_out_col, tma_dst_out_col = cute.nvgpu.cpasync.make_tiled_tma_atom( op_store, mO_col, out_smem_layout, cta_tiler, num_multicast=1, ) - - # Decide when to perform dbias reduction - DBIAS_REDUCTION_COLWISE: cutlass.Constexpr = False - DBIAS_REDUCTION_ROWWISE: cutlass.Constexpr = False - if cutlass.const_expr(cfg.WITH_DBIAS): - # We prefer to perform dbias reduction in the colwise pass since it doesn't require shuffle - if cutlass.const_expr(cfg.COLWISE): - DBIAS_REDUCTION_COLWISE = True - else: - DBIAS_REDUCTION_ROWWISE = True - # CUDA launches in (0,0), (1,0), (2,0)... order, so we should make N the leading dimension for better access pattern - # So consecutive blocks will move along the N dimension first, which is the innermost dimension in memory and we can use cache better grid = [ cute.ceil_div(Int32(N), TILE_X), cute.ceil_div(M, TILE_Y * NUM_TILES), ] - block = [THREADS_PER_CHUNK,] + block = [THREADS_PER_CTA,] self.kernel( mX, mS_row, mS_col, mAmax, mNoop, mWorkspace, @@ -319,10 +246,6 @@ def __call__( stream=stream, ) - # Device entry (launched by __call__). Reads the cast_noop flag and runs the - # work only if it is not set — matching the CUDA kernel's - # `if (noop[0]==1.0f) return;`. When WITH_NOOP is off, mNoop is None and the - # whole check is compiled out (so no flag is read). @cute.kernel def kernel( self, @@ -334,14 +257,14 @@ def kernel( mWorkspace, max_norm_rcp, dtype: cutlass.Constexpr[Type[cutlass.Numeric]], - tma_atom, tma_src, # how to use TMA to copy the input - tma_atom_out_row, tma_dst_out_row, # how to use TMA to copy the rowwise output - tma_atom_out_col, tma_dst_out_col, # how to use TMA to copy the colwise output - tma_atom_act, tma_src_act, # dact only: how to copy the activation input + tma_atom, tma_src, # Input TMA atoms + tma_atom_out_row, tma_dst_out_row, # Rowwise output TMA atoms + tma_atom_out_col, tma_dst_out_col, # Colwise output TMA atoms + tma_atom_act, tma_src_act, # Activation derivative TMA atoms, or None if WITH_DACT is False ): cfg = self.cfg - # `not const_expr(WITH_NOOP)` is a compile-time True when noop is disabled, - # so Python short-circuits the `or` and never reads mNoop[0] (it is None). + # If the noop tensor is not passed (compile-time check), or the noop tensor is not 1.0 (run-time check) + # then we run the kernel for real. Otherwise, skip the quantization so this kernel becomes a no-op. if not cutlass.const_expr(cfg.WITH_NOOP) or mNoop[0] != Float32(1.0): self._kernel_main( mX, mS_row, mS_col, mAmax, mWorkspace, @@ -352,10 +275,6 @@ def kernel( tma_atom_act, tma_src_act, ) - # The actual quantize work. MUST be @cute.jit (not @cute.kernel): it is invoked - # from the @cute.kernel `kernel` wrapper under a runtime noop branch, and only a - # separately-traced @cute.jit callable may allocate shared memory inside such a - # branch (an inlined/undecorated method or a nested @cute.kernel would fail). @cute.jit def _kernel_main( self, @@ -366,30 +285,19 @@ def _kernel_main( mWorkspace, max_norm_rcp, dtype: cutlass.Constexpr[Type[cutlass.Numeric]], - tma_atom, tma_src, # how to use TMA to copy the input - tma_atom_out_row, tma_dst_out_row, # how to use TMA to copy the rowwise output - tma_atom_out_col, tma_dst_out_col, # how to use TMA to copy the colwise output - tma_atom_act, tma_src_act, # dact only: how to copy the activation input + tma_atom, tma_src, # Input TMA atoms + tma_atom_out_row, tma_dst_out_row, # Rowwise output TMA atoms + tma_atom_out_col, tma_dst_out_col, # Colwise output TMA atoms + tma_atom_act, tma_src_act, # Activation derivative TMA atoms, or None if WITH_DACT is False ): cfg = self.cfg if cutlass.const_expr(cfg.ROWWISE): - mS_row = cute.zipped_divide(mS_row, (TILE_Y, TILE_X // SCALE_DIM)) + mS_row = cute.zipped_divide(mS_row, (TILE_Y, TILE_X // MXFP8_BLOCK_SIZE)) if cutlass.const_expr(cfg.COLWISE): - mS_col = cute.zipped_divide(mS_col, (TILE_Y // SCALE_DIM, TILE_X)) - # For M=256, N=512: - # Non-swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28256%2C+16%29%3A%2816%2C+1%29-32%0A2 - # Swizzled: https://kainzhong.github.io/CuTe-Layout-Visualizer/?key=zipped_divide-%28%2832%2C+4%2C+2%29%2C+%284%2C+4%29%29%3A%28%2816%2C+4%2C+2048%29%2C+%281%2C+512%29%29-32%0A2 - # print(f"mS_row after zipped_divide: {mS_row}") - - # FP8 output smem, one 32×64 tile per stage per enabled direction. - # Allocating a dead sO_col in rowwise-only (or sO_row in colwise-only) - # bumps per-CTA smem from 12 KB to 16 KB, which drops occupancy and - # regresses the single-direction path by ~8-10% at 16384^2. Match - # C++ and only allocate what the active pass actually uses. - # sAmax holds one f32 per warp for the cross-warp amax reduction — - # negligible (8 bytes for NUM_WARPS=2) and we always allocate so the - # struct doesn't fork on a 4th const-expr (cfg.WITH_AMAX) dimension. + mS_col = cute.zipped_divide(mS_col, (TILE_Y // MXFP8_BLOCK_SIZE, TILE_X)) + + # Allocate shared memory for the input and rowwise / columnwise outputs if cutlass.const_expr(cfg.ROWWISE and cfg.COLWISE): @cute.struct class SharedStorage: @@ -439,43 +347,8 @@ class SharedStorage: sAmax: cute.struct.MemRange[Float32, NUM_WARPS] smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) - - # dact: the activation-input tile lives in its own smem buffer, same - # shape/layout as sX. Allocated separately so the 4 SharedStorage variants - # above don't have to fork again on WITH_DACT. - if cutlass.const_expr(cfg.WITH_DACT): - @cute.struct - class DactStorage: - sActInput: cute.struct.Align[ - cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 - ] - dact_storage = smem.allocate(DactStorage) - sActInput = dact_storage.sActInput.get_tensor( - cute.make_layout( - ((TILE_Y, TILE_X), NUM_STAGES), - stride=((TILE_X, 1), TILE_Y * TILE_X), - ) - ) - - # Rowwise-only dbias needs a cross-thread (over THREADS_Y) smem reduction, - # since each rowwise thread owns a row, not a column. Buffer is - # [THREADS_Y][THREADS_X*(SCALE_DIM+1)] f32 — the +1 per scale-block padding - # avoids bank conflicts, matching CUDA's DBIAS_BUFF_WIDTH. - DBIAS_REDUCTION_ROWWISE = cutlass.const_expr(cfg.WITH_DBIAS and not cfg.COLWISE) - DBIAS_BUFF_WIDTH = (TILE_X // SCALE_DIM) * (SCALE_DIM + 1) - if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): - @cute.struct - class DbiasStorage: - sDbias: cute.struct.MemRange[Float32, TILE_Y * DBIAS_BUFF_WIDTH] - dbias_storage = smem.allocate(DbiasStorage) - sDbias = dbias_storage.sDbias.get_tensor( - cute.make_layout(TILE_Y * DBIAS_BUFF_WIDTH) - ) - - # Per-stage shmem tile is 2D (TILE_Y, TILE_X); stages laid out back-to-back. - # Mode 0 is hierarchical ((TILE_Y, TILE_X),) so it matches the rank/shape - # of gX_tiled[(None, (ty, tx))] produced by zipped_divide. - # sX[(None, stage)] selects one (TILE_Y, TILE_X) tile. + # Apply the layout to the allocated shared memory buffers so the first rank is the tile (nested layout) + # and the second rank is the pipeline stage sX = storage.sX.get_tensor( cute.make_layout( ((TILE_Y, TILE_X), NUM_STAGES), @@ -497,10 +370,26 @@ class DbiasStorage: ) ) + # Allocate shared memory for the activation input used for the activation derivative fusion. + if cutlass.const_expr(cfg.WITH_DACT): + @cute.struct + class DactStorage: + sActInput: cute.struct.Align[ + cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + ] + dact_storage = smem.allocate(DactStorage) + # Apply the same layout as the input + sActInput = dact_storage.sActInput.get_tensor( + cute.make_layout( + ((TILE_Y, TILE_X), NUM_STAGES), + stride=((TILE_X, 1), TILE_Y * TILE_X), + ) + ) + warp_idx = cute.arch.warp_idx() warp_idx = cute.arch.make_warp_uniform(warp_idx) - # Prefetch TMA descriptor (one-time; warp-0 only). + # Prefetch TMA descriptors if warp_idx == 0: cute.nvgpu.cpasync.prefetch_descriptor(tma_atom) if cutlass.const_expr(cfg.WITH_DACT): @@ -509,18 +398,15 @@ class DbiasStorage: tidx, _, _ = cute.arch.thread_idx() bidx, bidy, _ = cute.arch.block_idx() - # Producer: `arrive_and_expect_tx` is wrapped in `elect_one`, so only - # one lane of warp 0 arrives on the full barrier per stage → arrive_count=1. + # Only warp 0 is the producer (issues TMA) producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) - # Consumer: `consumer_release` arrives only on the `is_signalling_thread` - # (lane 0 of each warp), so arrive_count = num_warps per stage. - num_warps = THREADS_PER_CHUNK // 32 - consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_warps) + # Every warp is the consumer (reads the data loaded by TMA) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, NUM_WARPS) # Bytes transferred per TMA copy: one (TILE_Y, TILE_X) tile of dtype. + tx_count = TILE_Y * TILE_X * dtype.width // 8 # dact loads two tiles (grad + act_input) under the same per-stage barrier, # so the barrier must expect both copies' bytes. - tx_count = TILE_Y * TILE_X * dtype.width // 8 if cutlass.const_expr(cfg.WITH_DACT): tx_count *= 2 @@ -560,7 +446,7 @@ class DbiasStorage: gX_tiled, ) - # dact: identical partition for the activation-input load. + # If WITH_DACT, partition the activation input for TMA as well in the same way if cutlass.const_expr(cfg.WITH_DACT): gA_tiled = cute.zipped_divide(tma_src_act, (TILE_Y, TILE_X)) tXsA, tXgA = cute.nvgpu.cpasync.tma_partition( @@ -571,7 +457,7 @@ class DbiasStorage: gA_tiled, ) - # Same partitioning for S2G outputs: sO_row → mO_row and sO_col → mO_col. + # Partitioning for rowwise / columnwise outputs if cutlass.const_expr(cfg.ROWWISE): gO_row_tiled = cute.zipped_divide(tma_dst_out_row, (TILE_Y, TILE_X)) tXsO_row, tXgO_row = cute.nvgpu.cpasync.tma_partition( @@ -591,11 +477,6 @@ class DbiasStorage: gO_col_tiled, ) - # print(f"sX: {sX}\n") - # print(f"gX_tiled: {gX_tiled}\n") - # print(f"tXsX: {tXsX}\n") - # print(f"tXgX: {tXgX}\n") - # Ensure barrier init is visible to all threads before the pipeline is used. cute.arch.sync_threads() @@ -620,45 +501,39 @@ class DbiasStorage: mainloop_pipeline.producer_commit(prod_state) prod_state.advance() - # Per-thread amax accumulator across all stages of this CTA. Combined - # with the per-warp redux + cross-warp shmem reduce + atomic at the - # bottom to produce a global max(|x|) in mAmax. Initialised to 0 - # since amax is non-negative. + # Per-thread amax accumulator if cutlass.const_expr(cfg.WITH_AMAX): - block_amax = Float32(0.0) - - # Per-thread partial dbias: thread tidx owns column tidx of the colwise - # tile and accumulates its column sum over this CTA's rows (both stages). - # Written to workspace[bidy, col] below; reduced over row-blocks separately. - if cutlass.const_expr(cfg.WITH_DBIAS): - block_dbias = Float32(0.0) - # Rowwise-only dbias: each thread holds per-column partials for its 32-col - # block, summed across stages, then cross-thread reduced (over THREADS_Y) - # into block_dbias after the loop. - if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): - rowwise_dbias_arr = cute.make_rmem_tensor( - layout_or_shape=cute.make_layout((SCALE_DIM,), stride=(1,)), + per_thread_amax = Float32(0.0) + + # Prepare thread-level register accumulators for rowwise dbias reduction. + # Each thread will process two (1, MXFP8_BLOCK_SIZE) rows in two stages, and in each stage the thread will add the + # (after dact applied) value to this register array with the same shape so it carries the the two stages' partial sum. + # Then it will be written to a SMEM buffer to let the whole CTA do the reduction separately to yield + # the final (1, TILE_X) dbias workspace output. + rowwise_dbias_acc = None + if cutlass.const_expr(self.DBIAS_REDUCTION_ROWWISE): + rowwise_dbias_acc = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(1,)), dtype=Float32, ) - for c in cutlass.range_constexpr(SCALE_DIM): - rowwise_dbias_arr[c] = Float32(0.0) + # Zero the accumulator registers. + for c in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + rowwise_dbias_acc[c] = Float32(0.0) + block_dbias = Float32(0.0) + # Prepare thread-level register for columnwise dbias reduction. + # Each thread will process two (MXFP8_BLOCK_SIZE, 1) columns in two stages, and in each stage the thread will reduce the + # (after dact applied) column to (1,) and add to this register. + # Then this partial sum scalar will be written to the GMEM workspace buffer directly. + if cutlass.const_expr(self.DBIAS_REDUCTION_COLWISE): + block_dbias = Float32(0.0) # ---- Consumer: all threads quantize each completed tile. ---- for stage in cutlass.range(num_tiles, unroll=1): mainloop_pipeline.consumer_wait(cons_state) - sX_tile = sX[(None, stage)] # (TILE_Y, TILE_X) bf16 (grad for dact) + sX_tile = sX[(None, stage)] sActInput_tile = None if cutlass.const_expr(cfg.WITH_DACT): - sActInput_tile = sActInput[(None, stage)] # (TILE_Y, TILE_X) act_input - - """ - grid = [ - cute.ceil_div(Int32(N), TILE_X), - cute.ceil_div(M, TILE_Y * NUM_TILES), - ] - So to obtain the tile that belongs to this CTA. - """ - # This is just block's x axis idx + sActInput_tile = sActInput[(None, stage)] tile_idx_x = bidx # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. # So the tile index along Y dimension is `bidy * NUM_TILES + stage` @@ -676,8 +551,8 @@ class DbiasStorage: sActInput_tile, ) if cutlass.const_expr(cfg.WITH_AMAX): - block_amax = cute.arch.fmax(block_amax, amax_c) - if cutlass.const_expr(cfg.WITH_DBIAS): + per_thread_amax = cute.arch.fmax(per_thread_amax, amax_c) + if cutlass.const_expr(self.DBIAS_REDUCTION_COLWISE): block_dbias += dbias_c if cutlass.const_expr(cfg.ROWWISE): sO_row_tile = sO_row[(None, stage)] @@ -698,22 +573,20 @@ class DbiasStorage: mS_row_stage, max_norm_rcp, tile_idx_y * TILE_Y, bidx * TILE_X, M, N, sActInput_tile, - rowwise_dbias_arr if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE) else None, + rowwise_dbias_acc, ) if cutlass.const_expr(cfg.WITH_AMAX): - block_amax = cute.arch.fmax(block_amax, amax_r) + per_thread_amax = cute.arch.fmax(per_thread_amax, amax_r) - # Make all smem stores (sO_row and/or sO_col) visible to the TMA - # async proxy, then block-sync so warp 0 sees the fences from all - # warps before issuing the bulk store(s). Matches the C++ - # reference's fence_proxy + __syncthreads pattern. + # Make the shared-memory writes visible to the TMA's async proxy before the TMA reads them. cute.arch.fence_proxy( "async.shared", space="cta", ) cute.arch.sync_threads() + # Warp 0 issues TMA copy to write the quantized output tile from shared memory to global memory and then commits if warp_idx == 0: tile_y = bidy * NUM_TILES + stage if cutlass.const_expr(cfg.ROWWISE): @@ -733,48 +606,55 @@ class DbiasStorage: mainloop_pipeline.consumer_release(cons_state) cons_state.advance() - # Wait for in-flight TMA stores so data is visible to the host - # before the kernel returns. - cute.arch.cp_async_bulk_wait_group(0, read=False) + # Complete the cross-thread dbias reduction after each thread has its own per-thread partial sum after the rowwise quantization. + if cutlass.const_expr(self.DBIAS_REDUCTION_ROWWISE): + # Allocate the SMEM buffer that all threads use to reduce the two-stage partial sum (per thread) to the + # partial sum (per block). - # ---- rowwise-only dbias: cross-thread reduction over THREADS_Y --------- - # In the rowwise pass each thread owns a row, so its rowwise_dbias_arr holds - # per-column partials for its 32-col block. Transpose through smem so thread - # tidx ends up owning column tidx of the chunk (mirrors CUDA's - # partial_dbias_rowwise smem buffer + reduce over THREADS_Y). - if cutlass.const_expr(DBIAS_REDUCTION_ROWWISE): - THREADS_X = TILE_X // SCALE_DIM # scale-blocks per row (=2) - tid_Y = tidx // THREADS_X - tid_X = tidx % THREADS_X - for c in cutlass.range_constexpr(SCALE_DIM): - sDbias[tid_Y * DBIAS_BUFF_WIDTH + tid_X * (SCALE_DIM + 1) + c] = \ - rowwise_dbias_arr[c] + # Pad the buffer to avoid bank conflicts. The logical shape is still the same. Only the stride is different. + DBIAS_BUFF_WIDTH = TILE_X // MXFP8_BLOCK_SIZE * (MXFP8_BLOCK_SIZE + 1) + @cute.struct + class DbiasStorage: + sDbias: cute.struct.MemRange[Float32, TILE_Y * DBIAS_BUFF_WIDTH] + dbias_storage = smem.allocate(DbiasStorage) + sDbias = dbias_storage.sDbias.get_tensor( + cute.make_layout((TILE_Y, TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), + ) + # Thread layout: (TILE_Y, 2); value layout: (1, MXFP8_BLOCK_SIZE) where TILE_X = 2 * MXFP8_BLOCK_SIZE + # And each thread writes the (1, MXFP8_BLOCK_SIZE) partial sum to this (TILE_Y, TILE_X) buffer + # and then each thread reads its (TILE_Y, 1) sDbias column and writes the reduced sum to the GMEM workspace. + # Since TILE_X == THREADS_PER_CTA, this column reduction yields (TILE_Y, TILE_X) -> (1, TILE_X). + _, tv_layout_dbias_write = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), + ) + sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) + # All threads write their per-thread partial sum results to the shared buffer. + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sDbias_write[(tidx, i)] = rowwise_dbias_acc[i] cute.arch.sync_threads() - # thread tidx owns column tidx; +block skips the per-block padding slot. - block = tidx // SCALE_DIM + # All threads reduce the cross-thread partial sums to the per-block partial sum. + _, tv_layout_dbias_reduce = cute.make_layout_tv( + thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), + val_layout=cute.make_layout((TILE_Y, 1), stride=(1, 1)) + ) + sDbias_reduce = cute.composition(sDbias, tv_layout_dbias_reduce) + # make_layout_tv yields a (thread, value) layout: thread=tidx -> column tidx, + # value=i -> row i. So index [tidx, i] (thread first), summing the column's rows. block_dbias = Float32(0.0) for i in cutlass.range_constexpr(TILE_Y): - block_dbias += sDbias[i * DBIAS_BUFF_WIDTH + tidx + block] + block_dbias += sDbias_reduce[tidx, i] - # ---- dbias: write this CTA's per-column partial to the workspace ------- - # Thread tidx owns column (bidx*TILE_X + tidx). Each CTA-row-block (bidy) - # contributes one row of the (blocks_Y, N) fp32 workspace; the reduction - # over blocks_Y to the final dbias[N] is a separate step. + # Write the per-tile reduced dbias to the global workspace. if cutlass.const_expr(cfg.WITH_DBIAS): dbias_col = bidx * TILE_X + tidx if dbias_col < N: mWorkspace[(bidy, dbias_col)] = block_dbias - # ---- amax block reduction + cross-CTA atomic ---------------------- - # 1) intra-warp: redux.sync.fmax.f32 (sm_80+, single instruction). - # 2) cross-warp: NUM_WARPS shmem floats + sync_threads. - # 3) cross-CTA: int-atomic-max on the f32 bit pattern. Since amax is - # always ≥ 0, IEEE-754 bit ordering on positives matches float - # magnitude ordering, so atomic_max on i32 bits gives the right - # result. (atomic_max_float32 also exists but its pointer - # normalisation is broken as of this CuTeDSL build.) if cutlass.const_expr(cfg.WITH_AMAX): - warp_amax = cute.arch.warp_redux_sync(block_amax, kind="fmax") + # Reduce and get the per-warp amax. + warp_amax = cute.arch.warp_redux_sync(per_thread_amax, kind="fmax") + # Write the per-warp amax to shared memory sAmax = storage.sAmax.get_tensor(cute.make_layout(NUM_WARPS)) lane_idx = tidx % 32 if lane_idx == 0: @@ -782,16 +662,22 @@ class DbiasStorage: cute.arch.sync_threads() if tidx == 0: cta_amax = Float32(0.0) + # The first thread reduces all the per-warp amax to the per-CTA amax for w in cutlass.range_constexpr(NUM_WARPS): cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) amax_i32 = cute.make_tensor( cute.recast_ptr(mAmax.iterator, dtype=Int32), cute.make_layout(1), ) + # The first thread updates the global amax with an atomic max on the bitcasted float value cute.arch.atomic_max( amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), ) + # Wait for in-flight TMA stores so data is visible to the host + # before the kernel returns. + cute.arch.cp_async_bulk_wait_group(0, read=False) + @cute.jit def _process_rowwise( self, @@ -803,22 +689,12 @@ def _process_rowwise( tile_col_start, # Int32 — global col of this CTA's col 0 M, N, # Int32 — full input extents, for OOB masking sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) - dbias_acc=None, # rmem Float32[SCALE_DIM] dbias accumulator (rowwise-only dbias) + dbias_acc=None, # rmem Float32[MXFP8_BLOCK_SIZE] dbias accumulator (rowwise-only dbias) ): - """Rowwise MXFP8 pass: thread `(tid_Y, tid_X) = (tidx % 32, tidx // 32)` - owns one 32-element scale block (row `tid_Y`, columns `tid_X*32 .. +32`). - - The bank-group swizzle `((w + bank_group) * PACK_SIZE) % SCALE_DIM` - staggers each 4-thread group's starting wave, which otherwise would - collide on smem banks since all lanes in a warp read different rows - at the same column offset. - - Writes quantized bytes into `sO_row_tile` as u32s (one per wave); - caller is responsible for the TMA S2G flush. - """ cfg = self.cfg return quantize_rowwise_mxfp8( sX_tile, + sActInput_tile, sO_row_tile, mS_row_stage, max_norm_rcp, @@ -828,19 +704,16 @@ def _process_rowwise( N, ACTIVATION=cfg.ACTIVATION, DTYPE=cfg.DTYPE, - ROWWISE=cfg.ROWWISE, - COLWISE=cfg.COLWISE, FP8_DTYPE=cfg.FP8_DTYPE, TILE_Y=TILE_Y, - SCALE_DIM=SCALE_DIM, + MXFP8_BLOCK_SIZE=MXFP8_BLOCK_SIZE, WAVES=WAVES, THREADS_PER_WARP=THREADS_PER_WARP, THREADS_PER_BANK=THREADS_PER_BANK, PACK_SIZE=PACK_SIZE, WITH_ACT=cfg.WITH_ACT, WITH_DACT=cfg.WITH_DACT, - sA_tile=sActInput_tile, - DBIAS_REDUCTION=cfg.WITH_DBIAS and not cfg.COLWISE, + WITH_DBIAS=self.DBIAS_REDUCTION_ROWWISE, dbias_acc=dbias_acc, ) @@ -856,11 +729,6 @@ def _process_colwise( M, N, # Int32 — full input extents, for OOB masking sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) ): - """Colwise MXFP8 pass: thread `tidx` owns column `tidx` of the (32, 64) - smem tile — 32 elements down. Writes quantized bytes into `sO_col_tile` - so the caller can flush with a TMA S2G — matches C++'s - `out_colwise_data_sh` + `cp.async.bulk.tensor.2d.shared_to_global`. - """ cfg = self.cfg return quantize_colwise_mxfp8( sX_tile, @@ -877,11 +745,11 @@ def _process_colwise( SWIZZLE=cfg.WITH_GEMM_SWIZZLED_SCALES, TILE_X=TILE_X, TILE_Y=TILE_Y, - SCALE_DIM=SCALE_DIM, + MXFP8_BLOCK_SIZE=MXFP8_BLOCK_SIZE, WITH_ACT=cfg.WITH_ACT, WITH_DACT=cfg.WITH_DACT, sA_tile=sActInput_tile, - WITH_DBIAS=cfg.WITH_DBIAS, + WITH_DBIAS=self.DBIAS_REDUCTION_COLWISE, ) def compile_cutedsl_function_from_cfg(cfg): @@ -890,25 +758,10 @@ def compile_cutedsl_function_from_cfg(cfg): """ kernel_obj = MXFP8QuantizeSmemKernel(cfg) - - # stride_order=(1, 0): row-major, dim 1 stride 1. 1D: (0,). - kw_rm16_2d = dict(stride_order=(1, 0), - memspace=cute.AddressSpace.gmem, assumed_align=16) - kw_rm4_2d = dict(stride_order=(1, 0), - memspace=cute.AddressSpace.gmem, assumed_align=4) - kw_rm4_1d = dict(stride_order=(0,), - memspace=cute.AddressSpace.gmem, assumed_align=4) - def fake(dtype, shape, kw): - return cute.runtime.make_fake_compact_tensor(dtype, shape, **kw) - - - # M, N must be divisible by the MXFP8 scale-block size (SCALE_DIM = 32) — the - # same alignment the CUDA C++ kernel requires. The C++ dispatcher gates on the - # matching value (kCuTeDSLMXFP8ShapeAlignment in cast/dispatch/quantize.cuh) - # and falls back to CUDA for anything not divisible by it, so tvm-ffi never - # sees a shape this kernel can't accept. - sym_M = cute.sym_int32(divisibility=SCALE_DIM) - sym_N = cute.sym_int32(divisibility=SCALE_DIM) + # M, N must be divisible by the MXFP8 scale-block size (MXFP8_BLOCK_SIZE = 32) — the + # same alignment the CUDA C++ kernel requires. + sym_M = cute.sym_int32(divisibility=MXFP8_BLOCK_SIZE) + sym_N = cute.sym_int32(divisibility=MXFP8_BLOCK_SIZE) in_shape = out_shape = (sym_M, sym_N) # TE allocates scale tensors at a padded shape (see # MXFP8Quantizer::get_scale_shape in transformer_engine/pytorch/csrc): @@ -917,31 +770,19 @@ def fake(dtype, shape, kw): # These padded extents are NOT M/N (and SymInt has no `//`/`+`), so give the # scales their own fresh syms carrying the divisibility the padding # guarantees (rowwise: 128 x 4; colwise: 4 x 128). - scale_r_shape = (cute.sym_int32(divisibility=128), cute.sym_int32(divisibility=4)) - scale_c_shape = (cute.sym_int32(divisibility=4), cute.sym_int32(divisibility=128)) - # Scale dim-1 is only 4-byte-divisible, so a 16-byte alignment promise would - # be a lie for many shapes; the per-block scale stores are byte-wise anyway, - # so 4-byte alignment loses nothing. - scale_kw = kw_rm4_2d - - in_fake = fake(cfg.DTYPE, in_shape, kw_rm16_2d) - out_row_fake = fake(cute.Uint8, out_shape, kw_rm16_2d) if cfg.ROWWISE else None - scale_row_fake = fake(cute.Uint8, scale_r_shape, scale_kw) if cfg.ROWWISE else None - out_col_fake = fake(cute.Uint8, out_shape, kw_rm16_2d) if cfg.COLWISE else None - scale_col_fake = fake(cute.Uint8, scale_c_shape, scale_kw) if cfg.COLWISE else None - amax_fake = fake(Float32, (1,), kw_rm4_1d) if cfg.WITH_AMAX else None - noop_fake = fake(Float32, (1,), kw_rm4_1d) if cfg.WITH_NOOP else None - # Backward-only slots (act_input/dbias/workspace). Always None today — - # WITH_DACT/WITH_DBIAS are rejected in the config — but kept in the compile - # signature so the tvm-ffi protocol matches the CUDA mxfp8::quantize args. - act_input_fake = fake(cfg.DTYPE, in_shape, kw_rm16_2d) if cfg.WITH_DACT else None - # dbias: the kernel never writes the dbias tensor — it writes per-row-block - # partials into the workspace (shape (blocks_Y, N) fp32, blocks_Y = ceil(M/64), - # set by the C++ worker's size query). The final reduction lives elsewhere, so - # mDbias stays None and only the workspace fake is built. - dbias_fake = None + scale_rowwise_shape = (cute.sym_int32(divisibility=128), cute.sym_int32(divisibility=4)) + scale_colwise_shape = (cute.sym_int32(divisibility=4), cute.sym_int32(divisibility=128)) ws_shape = (cute.sym_int32(), sym_N) # (blocks_Y, N); N ties to input N - workspace_fake = fake(Float32, ws_shape, kw_rm4_2d) if cfg.WITH_DBIAS else None + + in_fake = cute.runtime.make_fake_compact_tensor(cfg.DTYPE, in_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) + out_row_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, out_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) if cfg.ROWWISE else None + scale_row_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, scale_rowwise_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.ROWWISE else None + out_col_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, out_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) if cfg.COLWISE else None + scale_col_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, scale_colwise_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.COLWISE else None + amax_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.WITH_AMAX else None + noop_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.WITH_NOOP else None + act_input_fake = cute.runtime.make_fake_compact_tensor(cfg.DTYPE, in_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) if cfg.WITH_DACT else None + workspace_fake = cute.runtime.make_fake_compact_tensor(Float32, ws_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.WITH_DBIAS else None compiled = cute.compile( kernel_obj, @@ -950,8 +791,7 @@ def fake(dtype, shape, kw): out_col_fake, scale_col_fake, # mO_col, mS_col amax_fake, # mAmax noop_fake, # mNoop (1-element cast_noop flag) - act_input_fake, # mActInput (backward slot, unused) - dbias_fake, # mDbias (backward slot, unused) + act_input_fake, # mDActInput (backward slot, unused) workspace_fake, # mWorkspace(backward slot, unused) cute.runtime.make_fake_stream(), # stream (compiled as an explicit tvm-ffi # "handle" arg; C++ passes the CUDA stream @@ -974,16 +814,10 @@ def get_mxfp8_quantization_function( with_noop: bool, activation: str, ) -> bool: - """Compile the MXFP8 quantize kernel for this config and register it in the - TVM-FFI global registry under EXACTLY `fn_name` (the key the C++ dispatcher - built; Python treats it as an opaque name). Returns True if a kernel is - registered under `fn_name` (the C++ side then fetches it with - GetGlobal(fn_name)); False if the config is unsupported, so the caller caches - the negative result and falls back to the CUDA C++ kernel. - - The registry owns the compiled kernel's lifetime — important because it wraps - a Python object, and tvm-ffi releases registry entries at interpreter - shutdown (whereas a C++-held handle would be released after finalize → crash). + """Compile the MXFP8 quantize kernel for this config and register it in the TVM-FFI global registry + under EXACTLY `fn_name` (the key the C++ dispatcher built; Python treats it as an opaque name). + Returns True if a kernel is successfully registered under `fn_name` (the C++ side then fetches it with GetGlobal(fn_name)); + False if the config is unsupported, so the caller caches the negative result and falls back to the CUDA C++ kernel. """ # Already registered (e.g. by a prior call) -> supported. if tvm_ffi.get_global_func(fn_name, allow_missing=True) is not None: diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index ca9d72a139..dfc01646dc 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -91,7 +91,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, Tensor *dummy_dbias_tensor = nullptr; Tensor *dummy_workspace_tensor = nullptr; bool quantized_with_cutedsl = - quantize::mxfp8_quantize_cutedsl( input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, dummy_workspace_tensor, stream); @@ -263,7 +263,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens } case NVTE_MXFP8_1D_SCALING: { bool quantized_with_cutedsl = - quantize::mxfp8_quantize_cutedsl( + cutedsl_backend::mxfp8_quantize_cutedsl( grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); if (!quantized_with_cutedsl) { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh index 6e86bb68dd..16bea8305e 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh @@ -20,21 +20,25 @@ #include "../core/common.cuh" // dispatch::common::reduce_dbias namespace transformer_engine { -namespace tvm_ffi_bridge { +namespace cutedsl_backend { + +// Activation, te_dtype_to_str, activation_to_str, DLTensorWrapper, TVMFFICentral +// all live in transformer_engine::tvm_ffi_bridge (tvm_ffi_bridge.h). +using namespace tvm_ffi_bridge; struct MXFP8QuantConfig { static constexpr const char *kEntrypointName = "get_mxfp8_quantization_function"; - DType dtype; - DType fp8_dtype; - bool rowwise; - bool colwise; - bool swizzled; - bool with_amax; - bool with_dbias = false; - bool with_dact = false; - bool with_act = false; - bool with_noop = false; + DType dtype; // The input format + DType fp8_dtype; // The fp8 output format + bool rowwise; // If quantize rowwisely + bool colwise; // If quantize columnwisely + bool swizzled; // If the scale output is used for cudnn's swizzled layout + bool with_amax; // If the kernel should return the amax + bool with_dbias = false; // If the dbias is computated (via the workspace tensor) + bool with_dact = false; // If an activation derivative operation is fused + bool with_act = false; // If an activation operation is fused + bool with_noop = false; // If a non-nullptr noop tensor is passed to the kernel Activation activation = Activation::kNone; std::string to_key() const { @@ -73,7 +77,7 @@ template struct MXFP8QuantFused { static constexpr Activation activation = Activation::kNone; - // No fused op: plain quantize, or dbias-only cast (IS_DBIAS, no activation). + // No fused activation / activation derivative op: plain quantize static constexpr bool supported = (OP == nullptr) && !IS_DACT && !IS_ACT; }; template <> @@ -127,13 +131,9 @@ struct MXFP8QuantFused> { static constexpr bool supported = true; }; -} // namespace tvm_ffi_bridge - -namespace quantize { - // Signature mirrors mxfp8::quantize (input, act_input, noop, output, dbias, // workspace, stream). Returns false to fall back to the CUDA kernel. -inline bool mxfp8_quantize_cutedsl(const tvm_ffi_bridge::MXFP8QuantConfig &config, +inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, const Tensor *input_tensor, const Tensor *act_input_tensor, const Tensor *noop_tensor, Tensor *output_tensor, Tensor *dbias_tensor, Tensor *workspace_tensor, @@ -169,6 +169,7 @@ inline bool mxfp8_quantize_cutedsl(const tvm_ffi_bridge::MXFP8QuantConfig &confi // 128x128 GEMM tile. The kernel writes only the meaningful scale region, so // cuBLAS would otherwise read uninitialized padding. Mirrors the CUDA launcher // in quantize_mxfp8.cuh (the kernel itself does not pad the scales). + // TODO: move this into the CuTeDSL host code so the padding is handled inside // the kernel launch — this CUDA-driver memset is an implementation detail that // doesn't belong in the dispatcher (blocked on calling the driver API there). @@ -193,19 +194,16 @@ inline bool mxfp8_quantize_cutedsl(const tvm_ffi_bridge::MXFP8QuantConfig &confi tvm_ffi_bridge::DLTensorWrapper mS_col(output_tensor->columnwise_scale_inv); tvm_ffi_bridge::DLTensorWrapper mAmax(output_tensor->amax); tvm_ffi_bridge::DLTensorWrapper mNoop(noop_tensor->data); - // Backward tensors: null wrapper (None) unless present, no allocation when absent. - // mDbias stays None: the kernel writes per-block partials into the workspace, and - // the final dbias is produced by a separate reduction (not by this kernel). - tvm_ffi_bridge::DLTensorWrapper mActInput, mDbias, mWorkspace; + // Backward tensors: if the passed tensor pointer is nullptr, they will be empty DLTensorWrapper with null data pointer too + tvm_ffi_bridge::DLTensorWrapper mActInput, mWorkspace; + // If these tensors are not nullptr, wrap them as DLTensorWrappers with real data if (act_input_tensor != nullptr) mActInput = tvm_ffi_bridge::DLTensorWrapper(act_input_tensor->data); if (workspace_tensor != nullptr) mWorkspace = tvm_ffi_bridge::DLTensorWrapper(workspace_tensor->data); // stream is a tvm-ffi opaque "handle"; pass the CUDA stream as void*. (*mxfp8_quant_func_opt)(&mX, &mO_row, &mS_row, &mO_col, &mS_col, &mAmax, &mNoop, - &mActInput, &mDbias, &mWorkspace, static_cast(stream)); + &mActInput, &mWorkspace, static_cast(stream)); - // dbias: the kernel wrote per-row-block partials into the workspace; reduce them - // over the row-blocks into the final dbias[N]. Mirrors mxfp8::quantize, which - // launches common::reduce_dbias after its quantize kernel. + // If WITH_DBIAS, reduce the workspace partial dbias in CUDA C++ for now. if (config.with_dbias) { const size_t blocks_Y = (flat_m + 63) / 64; // ceil(M/64) = workspace rows const float *workspace_ptr = reinterpret_cast(workspace_tensor->data.dptr); @@ -223,12 +221,12 @@ bool mxfp8_quantize_cutedsl(const Tensor *input_tensor, const Tensor *act_input_ const Tensor *noop_tensor, Tensor *output_tensor, Tensor *dbias_tensor, Tensor *workspace_tensor, cudaStream_t stream) { - using Fused = tvm_ffi_bridge::MXFP8QuantFused; + using Fused = MXFP8QuantFused; if constexpr (!Fused::supported) { return false; } else { const bool with_noop = noop_tensor != nullptr && noop_tensor->data.dptr != nullptr; - const tvm_ffi_bridge::MXFP8QuantConfig config{ + const MXFP8QuantConfig config{ /*dtype=*/input_tensor->dtype(), /*fp8_dtype=*/output_tensor->dtype(), /*rowwise=*/output_tensor->has_data(), @@ -245,7 +243,7 @@ bool mxfp8_quantize_cutedsl(const Tensor *input_tensor, const Tensor *act_input_ } } -} // namespace quantize +} // namespace cutedsl_backend } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ diff --git a/transformer_engine/common/tvm_ffi_bridge.h b/transformer_engine/common/tvm_ffi_bridge.h index aae48466d5..755c4266bc 100644 --- a/transformer_engine/common/tvm_ffi_bridge.h +++ b/transformer_engine/common/tvm_ffi_bridge.h @@ -240,9 +240,9 @@ class TVMFFICentral { TVMFFICentral &operator=(TVMFFICentral &&) = delete; static bool is_cutedsl_backend_enabled() { - // On by default; set NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=0 to disable. + // Off by default; set NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=1 to enable. const char *flag = std::getenv("NVTE_ENABLE_CUTEDSL_QUANT_BACKEND"); - return flag == nullptr || flag[0] != '0'; + return flag != nullptr && flag[0] != '0'; } const bool enabled_; From ff8d6adc22ebab2a831151fd5f98b2732deba21d Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 23 Jun 2026 00:41:22 +0000 Subject: [PATCH 03/22] polish --- .../common/CuTeDSL/cast/mxfp8/mxfp8_utils.py | 77 ++++++++----------- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 35 +++++++-- 2 files changed, 63 insertions(+), 49 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py index 2389962287..d34be52707 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py @@ -224,18 +224,9 @@ def _cvt_f32x2_to_fp8x2(fp8_dtype: str): return cvt_f32x2_to_fp8e4m3x2 -# --------------------------------------------------------------------------- -# 16-bit packed input PTX kit (bf16 / f16) -# -# bf16 and f16 share the same fast-path shape: packed-x2 amax via -# `max.xorsign.abs.x2`, then per-lane widen-to-f32 + `mul.f32x2` + -# `cvt.rn.satfinite.x2.f32`. Only the opcodes differ. Build one PTX kit -# per format at module load and let the kernel pick the right kit at JIT -# trace time via `cfg.DTYPE` — equivalent to a C++ template arg specialization -# on `IType`, with no runtime branch. -# --------------------------------------------------------------------------- def _build_packed16_kit(in_fmt: str): - """Build a kit of PTX wrappers for a 16-bit input format. + """Build a kit of PTX wrappers for a 16-bit input format so we don't have to repeat + the same inline asm boilerplate code for FP16 and BF16 dtypes. `in_fmt` is the PTX format string ('bf16' or 'f16'). Returns a namespace with the per-format ops the rowwise/colwise inner loops need: @@ -460,7 +451,7 @@ def quantize_rowwise_mxfp8( ): tidx, _, _ = cute.arch.thread_idx() - tiler, tv_layout = cute.make_layout_tv( + _, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)) ) @@ -479,7 +470,7 @@ def quantize_rowwise_mxfp8( cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements ) - # PTX allows to fuse relu operation in `cvt.rn.satfinite` + # PTX allows to fuse relu activation in `cvt.rn.satfinite` FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") # For this fast path we can read in pack of 2 instead of reading individual f16 / bf16 element. # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. @@ -497,7 +488,7 @@ def quantize_rowwise_mxfp8( cute.make_layout((1, MXFP8_BLOCK_SIZE // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements ) # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) - # In total we have 8 waves where each wave reads + # In total we have 8 waves where each wave reads 4 elements, so we read 32 elements in total. in_r = [[None, None] for _ in range(WAVES)] bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group offset = bank_group * 2 # Each bank group will read 2 i32 from their bank @@ -536,9 +527,6 @@ def quantize_rowwise_mxfp8( sX_thread.iterator, cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), ) - in_r = [[None] * PACK_SIZE for _ in range(WAVES)] - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group - offset = bank_group * 4 # Each bank group will read 4 f16 from their bank if cutlass.const_expr(WITH_DACT): # Backward: out = grad · act'(act_input). sX is grad, sA is act_input. @@ -554,10 +542,17 @@ def quantize_rowwise_mxfp8( if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): kit_act = _packed16_kit(DTYPE) + # Each wave we read PACK_SIZE elements, and we have WAVES waves, so we read WAVES * PACK_SIZE (= MXFP8_BLOCK_SIZE) elements in total. + in_r = [[None] * PACK_SIZE for _ in range(WAVES)] + # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks + # to avoid bank conflict. + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK + # The offset this thread should start reading from based on what's its first bank to access. + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank for w in cutlass.range_constexpr(WAVES): start = (w * PACK_SIZE + offset) % MXFP8_BLOCK_SIZE for i in cutlass.range_constexpr(PACK_SIZE): - x = Float32(sX_thread_rw[0, start + i]) # grad + x = Float32(sX_thread_rw[0, start + i]) if cutlass.const_expr(WITH_DACT): # out = grad · act'(act_input) x = x * dop(Float32(sA_thread_rw[0, start + i])) @@ -566,9 +561,7 @@ def quantize_rowwise_mxfp8( # If it's relu, we can handle it later if not cutlass.const_expr(FUSE_RELU): x = op(x) - # dbias: accumulate this row's column (start+e) value BEFORE the bf16 - # truncation (matches CUDA's `thread_dbias_rowwise[j] += elt`). start+i - # is a multiple-of-PACK_SIZE group + i, so it stays within [0, MXFP8_BLOCK_SIZE). + # Accumulate to the per-thread dbias register buffer for this tile if WITH_DBIAS if cutlass.const_expr(WITH_DBIAS): dbias_acc[start + i] += x # If 16-bit input with activation, truncate to IType @@ -607,8 +600,6 @@ def quantize_rowwise_mxfp8( # the per-wave mul_cvt consumes this directly. scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group - offset = bank_group * 4 # Each bank group will write 4 fp8 to for w in cutlass.range_constexpr(WAVES): idx = (w * 4 + offset) % MXFP8_BLOCK_SIZE idx = idx // 4 @@ -655,10 +646,13 @@ def quantize_colwise_mxfp8( WITH_DACT=False, # backward: out = grad · act'(act_input) sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) + CACHE_ACTIVATION=False, # overwrite sX_tile in place with the post-activation + # (IType-truncated) values, so the rowwise pass can read + # them instead of recomputing op ): tidx, _, _ = cute.arch.thread_idx() - tiler, tv_layout = cute.make_layout_tv( + _, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), val_layout=cute.make_layout((MXFP8_BLOCK_SIZE, 1), stride=(1, 1)) ) @@ -673,30 +667,25 @@ def quantize_colwise_mxfp8( # dbias needs the per-element fp32 values to sum, so it takes the f32 path # (never the i16 fast path) — matching CUDA, whose f16 fast path requires # `!IS_DBIAS` (quantize_mxfp8.cuh:219). - USE_HALF_PRECISION = _is_packed16(DTYPE) and ACTIVATION is None and not WITH_DBIAS + USE_HALF_PRECISION = _is_packed16(DTYPE) and ACTIVATION is None dbias_partial = Float32(0.0) - # 0. Load the 32-element column from smem into registers once (matches - # C++'s `in_colwise_IType[i]` cache). Amax and cast both reuse these. if cutlass.const_expr(USE_HALF_PRECISION): kit = _packed16_kit(DTYPE) - # Per-thread Int16 view of the column. Same byte address as - # `sX_thread` (bf16/fp16 are 16-bit, same width as Int16); the - # element stride is TILE_X because the column elements are - # TILE_X apart in the row-major tile. + # If we can use the half precision format, then use the input tile directly since there is no need to upcast sX_thread_i16 = cute.make_tensor( cute.recast_ptr(sX_thread.iterator, dtype=Int16), cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(TILE_X,)), ) + if cutlass.const_expr(WITH_DBIAS): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + dbias_partial += kit.bits_to_f32(sX_thread_i16[i]) amax_bits = Int16(0) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) else: - # Materialize the column into f32 registers — widen on read so - # bf16/fp16 inputs become real fp32 values (a pointer recast to - # Float32 would not widen; it would reinterpret the 16-bit bytes - # as half of a 32-bit float). + # Otherwise we need to case input values to fp32. Allocate the register tensor and load from SMEM input tiles. sX_thread_f32 = cute.make_rmem_tensor( layout_or_shape=cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(1,)), dtype=Float32, @@ -713,18 +702,20 @@ def quantize_colwise_mxfp8( op = SUPPORTED_ACTIVATIONS[ACTIVATION] for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = op(sX_thread_f32[i]) - # dbias = column sum of the (post-act/dact) value, taken BEFORE the bf16 - # truncation — matches CUDA's `partial_dbias_colwise += elt`. + # Accumulate the per-thread column partial for dbias if WITH_DBIAS. if cutlass.const_expr(WITH_DBIAS): for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): dbias_partial += sX_thread_f32[i] - # Numerical truncation through IType so amax/cast match C++. - # Only needed when 16-bit input + activation; without activation - # the widening was already exact. + # Truncate the activation (after we apply op) back to the half precision type if input is also half precision. if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): kit_act = _packed16_kit(DTYPE) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) + # Columnwise is the preferred direction so it runs first. If it needs to cache the activation in the input tile + # to let the rowwise pass read it, we need to cast and overwrite the input data in-place here + if cutlass.const_expr(CACHE_ACTIVATION): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sX_thread[i] = DTYPE(sX_thread_f32[i]) amax_c = Float32(0.0) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) @@ -741,15 +732,15 @@ def quantize_colwise_mxfp8( mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) inv_scale_c = exp2f_rcp(biased_exp_c) - cvt_to_fp8 = _cvt_f32_to_fp8(FP8_DTYPE) + cvt_to_fp8_func = _cvt_f32_to_fp8(FP8_DTYPE) if cutlass.const_expr(USE_HALF_PRECISION): kit_cast = _packed16_kit(DTYPE) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) - sO_thread[i] = Uint8(cvt_to_fp8(v_f32 * inv_scale_c)) + sO_thread[i] = Uint8(cvt_to_fp8_func(v_f32 * inv_scale_c)) else: for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sO_thread[i] = Uint8(cvt_to_fp8(sX_thread_f32[i] * inv_scale_c)) + sO_thread[i] = Uint8(cvt_to_fp8_func(sX_thread_f32[i] * inv_scale_c)) # Return this stage's per-column partial alongside amax; the caller accumulates # it across stages (a scalar can't be updated in-place through the arg). diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index a9839ca73d..e884a9c0e2 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -147,6 +147,23 @@ def __init__(self, cfg): # Only do rowwise reduction when we don't quantize columnwisely when WITH_DBIAS is True. self.DBIAS_REDUCTION_COLWISE = cfg.WITH_DBIAS and cfg.COLWISE self.DBIAS_REDUCTION_ROWWISE = cfg.WITH_DBIAS and not cfg.COLWISE + # Cache activation in-place in the SMEM input tile when we process both rowwise and colwise passes + # so the activation is only computed once in the direction we favor (columnwise) and the other direction (rowwise) + # reads the cached value instead of recomputing it. + # Note: if activation is relu, there is no standalong relu applied because it's already fused into `cvt.rn.satfinite` + # so it should be treated as "no activation" + self.CACHE_ACTIVATION = ( + (cfg.WITH_ACT or cfg.WITH_DACT) + and cfg.ROWWISE and cfg.COLWISE + and cfg.ACTIVATION != "relu" + ) + # The global tensor amax (mAmax) is the max over ALL elements. Each direction's + # per-block amaxes already span every element, so when both passes run we only + # fold the global amax from one of them — favor colwise (matches the flags + # above). The per-block *scale* amax is still computed in each pass for its own + # scale; this only skips the redundant global comparison in the other pass. + self.AMAX_FROM_COLWISE = cfg.WITH_AMAX and cfg.COLWISE + self.AMAX_FROM_ROWWISE = cfg.WITH_AMAX and not cfg.COLWISE @cute.jit def __call__( @@ -538,6 +555,7 @@ class DactStorage: # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. # So the tile index along Y dimension is `bidy * NUM_TILES + stage` tile_idx_y = bidy * NUM_TILES + stage + # Process rowwise and colwise quantization separately if cutlass.const_expr(cfg.COLWISE): # The first row that belongs to this CTA. Each CTA handles NUM_TILES of (TILE_Y, TILE_X) tiles stacked vertically, # and each stage handles one of them. @@ -550,10 +568,14 @@ class DactStorage: tile_idx_y * TILE_Y, bidx * TILE_X, M, N, sActInput_tile, ) - if cutlass.const_expr(cfg.WITH_AMAX): + if cutlass.const_expr(self.AMAX_FROM_COLWISE): per_thread_amax = cute.arch.fmax(per_thread_amax, amax_c) if cutlass.const_expr(self.DBIAS_REDUCTION_COLWISE): block_dbias += dbias_c + # If we cache the activation in shared memory, we need to ensure that all threads have finished writing to the shared memory + # from the columnwise pass before any thread reads from it in the rowwise pass. + if cutlass.const_expr(self.CACHE_ACTIVATION): + cute.arch.sync_threads() if cutlass.const_expr(cfg.ROWWISE): sO_row_tile = sO_row[(None, stage)] # mS_row is ((SCALE_TILE), (GRID)) where SCALE_TILE = (32, 2). @@ -576,7 +598,7 @@ class DactStorage: rowwise_dbias_acc, ) - if cutlass.const_expr(cfg.WITH_AMAX): + if cutlass.const_expr(self.AMAX_FROM_ROWWISE): per_thread_amax = cute.arch.fmax(per_thread_amax, amax_r) # Make the shared-memory writes visible to the TMA's async proxy before the TMA reads them. @@ -694,7 +716,7 @@ def _process_rowwise( cfg = self.cfg return quantize_rowwise_mxfp8( sX_tile, - sActInput_tile, + None if self.CACHE_ACTIVATION else sActInput_tile, sO_row_tile, mS_row_stage, max_norm_rcp, @@ -702,7 +724,7 @@ def _process_rowwise( tile_col_start, M, N, - ACTIVATION=cfg.ACTIVATION, + ACTIVATION=None if self.CACHE_ACTIVATION else cfg.ACTIVATION, DTYPE=cfg.DTYPE, FP8_DTYPE=cfg.FP8_DTYPE, TILE_Y=TILE_Y, @@ -711,8 +733,8 @@ def _process_rowwise( THREADS_PER_WARP=THREADS_PER_WARP, THREADS_PER_BANK=THREADS_PER_BANK, PACK_SIZE=PACK_SIZE, - WITH_ACT=cfg.WITH_ACT, - WITH_DACT=cfg.WITH_DACT, + WITH_ACT=cfg.WITH_ACT and not self.CACHE_ACTIVATION, + WITH_DACT=cfg.WITH_DACT and not self.CACHE_ACTIVATION, WITH_DBIAS=self.DBIAS_REDUCTION_ROWWISE, dbias_acc=dbias_acc, ) @@ -750,6 +772,7 @@ def _process_colwise( WITH_DACT=cfg.WITH_DACT, sA_tile=sActInput_tile, WITH_DBIAS=self.DBIAS_REDUCTION_COLWISE, + CACHE_ACTIVATION=self.CACHE_ACTIVATION, ) def compile_cutedsl_function_from_cfg(cfg): From fbcbba21172e45c5f9ea2c23eaa2156f648e38b4 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 23 Jun 2026 17:35:48 +0000 Subject: [PATCH 04/22] fix --- .../common/CuTeDSL/cast/mxfp8/mxfp8_utils.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py index d34be52707..080f74f5eb 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py @@ -34,6 +34,7 @@ FP32_MANTISSA_BITS = 23 +# TODO: move these to util @dsl_user_op def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: return Int32(mlir_arith.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @@ -479,6 +480,11 @@ def quantize_rowwise_mxfp8( amax_r = Float32(0.0) + # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks + # to avoid bank conflict. + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK + # The offset this thread should start reading from based on what's its first bank to access. + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank if cutlass.const_expr(_row_fast): # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute @@ -490,10 +496,8 @@ def quantize_rowwise_mxfp8( # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) # In total we have 8 waves where each wave reads 4 elements, so we read 32 elements in total. in_r = [[None, None] for _ in range(WAVES)] - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # Each 4 threads share the same bank, which forms a bank group - offset = bank_group * 2 # Each bank group will read 2 i32 from their bank for w in cutlass.range_constexpr(WAVES): - idx = (w * 2 + offset) % (MXFP8_BLOCK_SIZE // 2) + idx = (w * 2 + offset // 2) % (MXFP8_BLOCK_SIZE // 2) in_r[w][0] = sX_thread_rw_i32[0, idx] in_r[w][1] = sX_thread_rw_i32[0, idx + 1] @@ -544,11 +548,6 @@ def quantize_rowwise_mxfp8( # Each wave we read PACK_SIZE elements, and we have WAVES waves, so we read WAVES * PACK_SIZE (= MXFP8_BLOCK_SIZE) elements in total. in_r = [[None] * PACK_SIZE for _ in range(WAVES)] - # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks - # to avoid bank conflict. - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK - # The offset this thread should start reading from based on what's its first bank to access. - offset = bank_group * 4 # Each bank group will read 4 f16 from their bank for w in cutlass.range_constexpr(WAVES): start = (w * PACK_SIZE + offset) % MXFP8_BLOCK_SIZE for i in cutlass.range_constexpr(PACK_SIZE): From 2a157d00a9fe0e7a74df91f8931d0af4b038d634 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 23 Jun 2026 18:53:25 +0000 Subject: [PATCH 05/22] refactor --- setup.py | 6 +- transformer_engine/common/CuTeDSL/__init__.py | 6 +- .../common/CuTeDSL/activations.py | 16 +- .../common/CuTeDSL/cast/mxfp8/__init__.py | 6 +- .../common/CuTeDSL/cast/mxfp8/mxfp8_utils.py | 746 ------------------ .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 384 ++++++++- transformer_engine/common/CuTeDSL/utils.py | 253 +++++- .../common/CuTeDSL/utils_fp8.py | 97 +++ transformer_engine/pytorch/__init__.py | 25 +- 9 files changed, 761 insertions(+), 778 deletions(-) delete mode 100644 transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py create mode 100644 transformer_engine/common/CuTeDSL/utils_fp8.py diff --git a/setup.py b/setup.py index ed6fe977b4..61c2a4586d 100644 --- a/setup.py +++ b/setup.py @@ -366,13 +366,17 @@ def git_check_submodules() -> None: "core_cu13": [f"transformer_engine_cu13=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], + "cutedsl": ["nvidia-cutlass-dsl>=4.2.0"], } else: install_requires, test_requires = setup_requirements() ext_modules = [setup_common_extension()] package_data = {"": ["VERSION.txt"]} include_package_data = True - extras_require = {"test": test_requires} + extras_require = { + "test": test_requires, + "cutedsl": ["nvidia-cutlass-dsl>=4.2.0"], + } if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if "pytorch" in frameworks: diff --git a/transformer_engine/common/CuTeDSL/__init__.py b/transformer_engine/common/CuTeDSL/__init__.py index 5621c01e64..993e292339 100644 --- a/transformer_engine/common/CuTeDSL/__init__.py +++ b/transformer_engine/common/CuTeDSL/__init__.py @@ -11,9 +11,7 @@ inside a Python environment with the CuTeDSL toolchain available, so the kernel may be compiled on demand; not finding it means a plain C++ environment, and the dispatcher falls back to the CUDA C++ kernel. - -Importing requires the optional CuTeDSL toolchain (cutlass, tvm_ffi). Callers -that want graceful degradation should guard the import in a try/except. """ -from . import cast # noqa: F401 (import side effect: registers global funcs) +# Trigger the casting CuTeDSL entrypoints registration via TVM-FFI. +from . import cast # noqa: F401 diff --git a/transformer_engine/common/CuTeDSL/activations.py b/transformer_engine/common/CuTeDSL/activations.py index 96ffd5c6c1..0389310690 100644 --- a/transformer_engine/common/CuTeDSL/activations.py +++ b/transformer_engine/common/CuTeDSL/activations.py @@ -2,11 +2,13 @@ # # See LICENSE for license information. -import cutlass import cutlass.cute as cute from cutlass import Float32 from cutlass._mlir.dialects import arith as mlir_arith -from cutlass.cutlass_dsl import dsl_user_op +from cutlass.cutlass_dsl import T, dsl_user_op + +from transformer_engine.common.CuTeDSL.utils import fma_f32 + def act_relu(x: Float32) -> Float32: return cute.arch.fmax(x, Float32(0.0)) @@ -28,14 +30,14 @@ def act_gelu(x: Float32) -> Float32: def act_silu(x: Float32) -> Float32: """SiLU/Swish: x · σ(x) = x / (1 + e^-x). Matches TE's `silu` (`val / (1 + expf(-val))`).""" - return x / (Float32(1.0) + cute.arch.exp(-x)) + return x / (Float32(1.0) + cute.math.exp(-x, fastmath=True)) def act_qgelu(x: Float32) -> Float32: """Quick GELU: x · σ(1.702·x). Matches TE `qgelu_with_alpha(val, 1.702)` = `cval · (1 / (1 + expf(-1.702·cval)))` (multiply by sigmoid, not a divide).""" z = Float32(1.702) * x - return x * (Float32(1.0) / (Float32(1.0) + cute.arch.exp(-z))) + return x * (Float32(1.0) / (Float32(1.0) + cute.math.exp(-z, fastmath=True))) def act_srelu(x: Float32) -> Float32: @@ -61,14 +63,16 @@ def dact_dsrelu(x: Float32) -> Float32: def sigmoid(x: Float32) -> Float32: """σ(x) = 1 / (1 + e^-x), same exp intrinsic as the forward silu/qgelu.""" - return Float32(1.0) / (Float32(1.0) + cute.arch.exp(-x)) + return Float32(1.0) / (Float32(1.0) + cute.math.exp(-x, fastmath=True)) def dact_dsilu(x: Float32) -> Float32: """dsilu: x·σ(x)·(1-σ(x)) + σ(x). Matches math.h `dsilu` (`cval·dsigmoid + sigmoid`, dsigmoid = s·(1-s)).""" s = sigmoid(x) - return x * (s * (Float32(1.0) - s)) + s + # cval·dsigmoid + sigmoid as one FFMA — matches nvcc's contraction of + # math.h `dsilu` (`cval * dsigmoid + sigmoid`) so dbias is bit-exact. + return fma_f32(x, s * (Float32(1.0) - s), s) def dact_dqgelu(x: Float32) -> Float32: diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py index c42df11c01..b5e8e55e36 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/__init__.py @@ -2,7 +2,5 @@ # # See LICENSE for license information. -"""MXFP8 CuTeDSL kernels. Importing ``quantize_mxfp8`` runs its module body, -which registers the ``get_mxfp8_quantization_function`` TVM-FFI global func.""" - -from . import quantize_mxfp8 # noqa: F401 (import side effect: registers the global func) +# Trigger the MXFP8 quantization CuTeDSL entrypoints registration via TVM-FFI. +from . import quantize_mxfp8 # noqa: F401 diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py deleted file mode 100644 index 080f74f5eb..0000000000 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/mxfp8_utils.py +++ /dev/null @@ -1,746 +0,0 @@ -import cutlass -import cutlass.cute as cute -from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 -from cutlass._mlir.dialects import arith as mlir_arith -from cutlass._mlir.dialects import llvm -from cutlass.cutlass_dsl import T, dsl_user_op - -from types import SimpleNamespace - -from transformer_engine.common.CuTeDSL.activations import ( - act_relu, - act_gelu, - act_silu, - act_qgelu, - act_srelu, - dact_drelu, - dact_dsrelu, - dact_dsilu, - dact_dqgelu, - dact_dgelu, -) - - -# FP8E4M3 max representable value -FP8E4M3_MAX_NORM = 448.0 -FP8E4M3_MAX_NORM_RCP = 1.0 / FP8E4M3_MAX_NORM -FP8E5M2_MAX_NORM = 57344.0 -FP8E5M2_MAX_NORM_RCP = 1.0 / FP8E5M2_MAX_NORM - -# Largest finite f32 — used to clamp the per-block scale inverse against -# division-by-zero (which produces +inf and then NaN downstream). -FP32_MAX = 3.4028234663852886e38 - -FP32_MANTISSA_BITS = 23 - - -# TODO: move these to util -@dsl_user_op -def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: - return Int32(mlir_arith.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -@dsl_user_op -def _bitcast_i32_to_f32(val: Int32, *, loc=None, ip=None) -> Float32: - return Float32(mlir_arith.bitcast(T.f32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -@dsl_user_op -def fabs_f32(val: Float32, *, loc=None, ip=None) -> Float32: - val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) - abs_i32 = val_i32 & Int32(0x7FFFFFFF) - return _bitcast_i32_to_f32(abs_i32, loc=loc, ip=ip) - - -@dsl_user_op -def float_to_e8m0(val: Float32, *, loc=None, ip=None) -> Int32: - """Branchless float->E8M0: add mantissa mask to round up, clamp to 254.""" - val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) - rounded = val_i32 + Int32(0x7FFFFF) - exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) - return Int32(mlir_arith.minsi( - exponent.ir_value(loc=loc, ip=ip), - Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -@dsl_user_op -def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: - """2^(127 - biased_exp) with special-case handling.""" - new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) - result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) - for (cmp_val, repl_bits) in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: - cond = mlir_arith.cmpi(mlir_arith.CmpIPredicate.eq, - biased_exp.ir_value(loc=loc, ip=ip), - Int32(cmp_val).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) - result = Float32(mlir_arith.select( - cond, alt.ir_value(loc=loc, ip=ip), - result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return result - - -@dsl_user_op -def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: - """float32 -> fp8e4m3fn via PTX cvt.rn.satfinite.e4m3x2.f32.""" - zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return result_i32 & Int32(0xFF) - - -@dsl_user_op -def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: - """float32 -> fp8e5m2 via PTX cvt.rn.satfinite.e5m2x2.f32.""" - zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return result_i32 & Int32(0xFF) - - -@dsl_user_op -def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32: - """`fma.rn.f32 d, a, b, c;` — single-instruction fused multiply-add - matching nvcc's FFMA. Used for explicit `partial += a * b` patterns - where we need the same rounding as TE's compiler-fused FFMA.""" - return Float32(llvm.inline_asm( - T.f32(), - [a.ir_value(loc=loc, ip=ip), - b.ir_value(loc=loc, ip=ip), - c.ir_value(loc=loc, ip=ip)], - "fma.rn.f32 $0, $1, $2, $3;", - "=f,f,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - -@dsl_user_op -def tanh_approx(val: Float32, *, loc=None, ip=None) -> Float32: - """`tanh.approx.f32` — fast tanh approximation. Matches CUDA `__tanhf`.""" - return Float32(llvm.inline_asm( - T.f32(), - [val.ir_value(loc=loc, ip=ip)], - "tanh.approx.f32 $0, $1;", - "=f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - -@dsl_user_op -def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: - """Pack two f32 scalars into a single 64-bit register (`floatx2` layout). - - Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` - which lowers to a single register move — no actual memory traffic. - """ - return Int64(llvm.inline_asm( - T.i64(), - [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], - "mov.b64 $0, {$1, $2};", - "=l,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - -@dsl_user_op -def cvt_fp8e4m3_to_f32(byte_i32: Int32, *, loc=None, ip=None) -> Float32: - """One fp8e4m3 byte (low 8 bits of `byte_i32`) → f32. - - PTX has no direct `cvt.f32.e4m3` for a scalar; route through the packed - `cvt.rn.f16x2.e4m3x2` and then `cvt.f32.f16`. The high byte of the .b16 - register is forced to zero so the discarded high f16 lane is well-defined.""" - asm = ( - "{\n" - ".reg .b32 masked; .reg .b16 b16; .reg .b16 b16_hi;\n\t" - ".reg .b32 f16pair; .reg .b16 lo_f16; .reg .b16 hi_f16;\n\t" - "and.b32 masked, $1, 0xFF;\n\t" - "mov.b32 {b16, b16_hi}, masked;\n\t" - "cvt.rn.f16x2.e4m3x2 f16pair, b16;\n\t" - "mov.b32 {lo_f16, hi_f16}, f16pair;\n\t" - "cvt.f32.f16 $0, lo_f16;\n\t" - "}" - ) - return Float32(llvm.inline_asm( - T.f32(), - [byte_i32.ir_value(loc=loc, ip=ip)], - asm, - "=f,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - -@dsl_user_op -def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: - """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. - - Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). - This mirrors ptx::mul_cvt_2x which converts 2 values in one instruction. - """ - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -@dsl_user_op -def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: - """e5m2 sibling of `cvt_f32x2_to_fp8e4m3x2`.""" - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - - -def _cvt_f32_to_fp8(fp8_dtype: str): - """Const-expr dispatch: pick the f32→fp8 scalar PTX op based on output dtype. - - `fp8_dtype` is the Python string from `cfg.FP8_DTYPE`, evaluated at JIT - trace time; the unused branch is never traced. - """ - if fp8_dtype == "e5m2": - return cvt_f32_to_fp8e5m2 - return cvt_f32_to_fp8e4m3 - - -def _cvt_f32x2_to_fp8x2(fp8_dtype: str): - """Const-expr dispatch for the packed f32x2→fp8x2 cvt.""" - if fp8_dtype == "e5m2": - return cvt_f32x2_to_fp8e5m2x2 - return cvt_f32x2_to_fp8e4m3x2 - - -def _build_packed16_kit(in_fmt: str): - """Build a kit of PTX wrappers for a 16-bit input format so we don't have to repeat - the same inline asm boilerplate code for FP16 and BF16 dtypes. - - `in_fmt` is the PTX format string ('bf16' or 'f16'). Returns a namespace - with the per-format ops the rowwise/colwise inner loops need: - - abs_max_x2(Int32, Int32) -> Int32 # `max.xorsign.abs.x2` - abs_max_scalar(Int16, Int16) -> Int16 # `max.xorsign.abs.` - bits_to_f32(Int16) -> Float32 # widen one 16-bit element - x2_lo_to_f32(Int32) -> Float32 # extract+widen low half - x2_hi_to_f32(Int32) -> Float32 # extract+widen high half - mul_cvt_to_fp8x2(fp8_dtype) -> callable(Int32, Int64)->Int32 - # fused x2 * f32x2 -> fp8x2 - """ - - @dsl_user_op - def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: - return Int32(llvm.inline_asm( - T.i32(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", - "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - @dsl_user_op - def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: - return Int32(llvm.inline_asm( - T.i32(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.{in_fmt}x2 $0, $1, $2;", - "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - @dsl_user_op - def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: - return Int16(llvm.inline_asm( - T.i16(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.xorsign.abs.{in_fmt} $0, $1, $2;", - "=h,h,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - if in_fmt == "bf16": - # bf16 == top 16 bits of f32 — widening is a free bit-shift. - @dsl_user_op - def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: - i32 = Int32(mlir_arith.extui( - T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) - - @dsl_user_op - def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - return _bitcast_i32_to_f32( - (bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) - - @dsl_user_op - def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal - # issues. Sign bits from the arith-right shift get zeroed by the - # left shift. - return _bitcast_i32_to_f32( - (bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) - - @dsl_user_op - def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: - """Round f32 to bf16 precision (round-to-nearest-even), keep f32. - Matches C++'s `static_cast(static_cast(elt))`.""" - bf16_bits = Int16(llvm.inline_asm( - T.i16(), [val.ir_value(loc=loc, ip=ip)], - "cvt.rn.bf16.f32 $0, $1;", - "=h,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - i32 = Int32(mlir_arith.extui( - T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) - else: - # f16 has its own bit layout; widening requires `cvt.f32.f16`. - @dsl_user_op - def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: - return Float32(llvm.inline_asm( - T.f32(), [bits.ir_value(loc=loc, ip=ip)], - "cvt.f32.f16 $0, $1;", - "=f,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - @dsl_user_op - def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - lo_i16 = Int16(mlir_arith.trunci( - T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return bits_to_f32(lo_i16, loc=loc, ip=ip) - - @dsl_user_op - def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - hi_shifted = bits >> Int32(16) - hi_i16 = Int16(mlir_arith.trunci( - T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return bits_to_f32(hi_i16, loc=loc, ip=ip) - - @dsl_user_op - def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: - """Round f32 to f16 precision, keep f32.""" - f16_bits = Int16(llvm.inline_asm( - T.i16(), [val.ir_value(loc=loc, ip=ip)], - "cvt.rn.f16.f32 $0, $1;", - "=h,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Float32(llvm.inline_asm( - T.f32(), [f16_bits.ir_value(loc=loc, ip=ip)], - "cvt.f32.f16 $0, $1;", - "=f,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - - def _build_mul_cvt(out_fmt: str, relu: bool = False): - """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. - - The shape is identical across (in_fmt, out_fmt) combos — only the - widening opcode (`cvt.f32.`) and the final saturating cvt - (`cvt.rn.satfinite.x2.f32`) differ. - """ - out_op = "e4m3x2" if out_fmt == "e4m3" else "e5m2x2" - asm = ( - "{\n" - ".reg.b64 vp0; .reg.b64 vp1;\n\t" - ".reg.b32 v1; .reg.b32 v2;\n\t" - ".reg.b16 vb1; .reg.b16 vb2;\n\t" - "mov.b32 {vb1, vb2}, $1;\n\t" - f"cvt.f32.{in_fmt} v1, vb1;\n\t" - f"cvt.f32.{in_fmt} v2, vb2;\n\t" - "mov.b64 vp0, {v1, v2};\n\t" - "mul.f32x2 vp1, vp0, $2;\n\t" - "mov.b64 {v2, v1}, vp1;\n\t" - f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 $0, v1, v2;\n\t" - "}" - ) - - @dsl_user_op - def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_2x.ir_value(loc=loc, ip=ip), - scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=h,r,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) - return fn - - def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): - if fp8_dtype == "e5m2": - return _build_mul_cvt("e5m2", relu) - return _build_mul_cvt("e4m3", relu) - - return SimpleNamespace( - abs_max_x2=abs_max_x2, - max_x2=max_x2, - abs_max_scalar=abs_max_scalar, - bits_to_f32=bits_to_f32, - x2_lo_to_f32=x2_lo_to_f32, - x2_hi_to_f32=x2_hi_to_f32, - truncate_f32=truncate_f32, - mul_cvt_to_fp8x2=mul_cvt_to_fp8x2, - ) - - -_BF16_KIT = _build_packed16_kit("bf16") -_F16_KIT = _build_packed16_kit("f16") - - -def _is_packed16(dtype) -> bool: - """True if `dtype` is one of the 16-bit packed input formats.""" - return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 - -def _packed16_kit(dtype): - """Trace-time selector — pick a Packed16Kit for the input dtype.""" - if dtype is cutlass.Float16: - return _F16_KIT - return _BF16_KIT - -SUPPORTED_ACTIVATIONS = { - "relu": act_relu, - "gelu": act_gelu, - "silu": act_silu, - "qgelu": act_qgelu, - "srelu": act_srelu, -} - -SUPPORTED_DACTIVATIONS = { - "drelu": dact_drelu, - "dgelu": dact_dgelu, - "dsilu": dact_dsilu, - "dqgelu": dact_dqgelu, - "dsrelu": dact_dsrelu, -} - - -@cute.jit -def quantize_rowwise_mxfp8( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sA_tile, # (TILE_Y, TILE_X) activation-input smem tile (dact only) - sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) - mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) - max_norm_rcp, - tile_row_start, # Int32 — global row index of this stage's row 0 - # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores - # for irregular shapes. - tile_col_start, # Int32 — global col index of this CTA's col 0 - # (= bidx * TILE_X). Same purpose. - M, N, # Int32 — full tensor extents; OOB threads skip their - # scale store. - ACTIVATION, - DTYPE, - FP8_DTYPE, - TILE_Y, - MXFP8_BLOCK_SIZE, - WAVES, - THREADS_PER_WARP, - THREADS_PER_BANK, - PACK_SIZE, - WITH_ACT=False, - WITH_DACT=False, - WITH_DBIAS=False, # rowwise-only dbias: accumulate per-column partials - dbias_acc=None, # only needed when WITH_DBIAS is True -): - tidx, _, _ = cute.arch.thread_idx() - - _, tv_layout = cute.make_layout_tv( - thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), - val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)) - ) - - sX_tv = cute.composition(sX_tile, tv_layout) - sO_tv = cute.composition(sO_row_tile, tv_layout) - - # I/O Elements that belong to this thread - sX_thread = sX_tv[tidx, None] # shape (32,) bf16 - sO_thread = sO_tv[tidx, None] # shape (32,) uint8 - - sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) - # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. - sO_thread_u32 = cute.make_tensor( - sO_thread_u32_ptr, - cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements - ) - - # PTX allows to fuse relu activation in `cvt.rn.satfinite` - FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") - # For this fast path we can read in pack of 2 instead of reading individual f16 / bf16 element. - # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. - _row_fast = (_is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) - and not WITH_DBIAS) - - amax_r = Float32(0.0) - - # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks - # to avoid bank conflict. - bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK - # The offset this thread should start reading from based on what's its first bank to access. - offset = bank_group * 4 # Each bank group will read 4 f16 from their bank - if cutlass.const_expr(_row_fast): - # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack - # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute - kit = _packed16_kit(DTYPE) - sX_thread_rw_i32 = cute.make_tensor( - cute.recast_ptr(sX_thread.iterator, dtype=Int32), - cute.make_layout((1, MXFP8_BLOCK_SIZE // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements - ) - # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) - # In total we have 8 waves where each wave reads 4 elements, so we read 32 elements in total. - in_r = [[None, None] for _ in range(WAVES)] - for w in cutlass.range_constexpr(WAVES): - idx = (w * 2 + offset // 2) % (MXFP8_BLOCK_SIZE // 2) - in_r[w][0] = sX_thread_rw_i32[0, idx] - in_r[w][1] = sX_thread_rw_i32[0, idx + 1] - - amax_2x = Int32(0) - # Each wave will use max.xorsign.abs.f16x2 or max.xorsign.abs.bf16x2 to compare 2 packed elements in parallel - for w in cutlass.range_constexpr(WAVES): - if cutlass.const_expr(FUSE_RELU): - # If we fuse relu then we don't want to do abs since negative value will be set to 0 and they will lose comparison automatically - amax_2x = kit.max_x2(amax_2x, in_r[w][0]) - amax_2x = kit.max_x2(amax_2x, in_r[w][1]) - else: - amax_2x = kit.abs_max_x2(amax_2x, in_r[w][0]) - amax_2x = kit.abs_max_x2(amax_2x, in_r[w][1]) - if cutlass.const_expr(FUSE_RELU): - # Compare the 2 packed max without abs - amax_r = cute.arch.fmax( - kit.x2_lo_to_f32(amax_2x), - kit.x2_hi_to_f32(amax_2x), - ) - # For relu the max is at least 0 - amax_r = cute.arch.fmax(amax_r, Float32(0.0)) - else: - # Compare the 2 packed abs max - amax_r = cute.arch.fmax( - fabs_f32(kit.x2_lo_to_f32(amax_2x)), - fabs_f32(kit.x2_hi_to_f32(amax_2x)), - ) - else: - # Since we need to do computation on individual f16 / bf16 elements, we can't read in pack - sX_thread_rw = cute.make_tensor( - sX_thread.iterator, - cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), - ) - - if cutlass.const_expr(WITH_DACT): - # Backward: out = grad · act'(act_input). sX is grad, sA is act_input. - dop = SUPPORTED_DACTIVATIONS[ACTIVATION] - sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] - sA_thread_rw = cute.make_tensor( - sA_thread.iterator, - cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), - ) - elif cutlass.const_expr(WITH_ACT): - op = SUPPORTED_ACTIVATIONS[ACTIVATION] - - if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): - kit_act = _packed16_kit(DTYPE) - - # Each wave we read PACK_SIZE elements, and we have WAVES waves, so we read WAVES * PACK_SIZE (= MXFP8_BLOCK_SIZE) elements in total. - in_r = [[None] * PACK_SIZE for _ in range(WAVES)] - for w in cutlass.range_constexpr(WAVES): - start = (w * PACK_SIZE + offset) % MXFP8_BLOCK_SIZE - for i in cutlass.range_constexpr(PACK_SIZE): - x = Float32(sX_thread_rw[0, start + i]) - if cutlass.const_expr(WITH_DACT): - # out = grad · act'(act_input) - x = x * dop(Float32(sA_thread_rw[0, start + i])) - # If IS_ACT, apply activation function to x in f32 - elif cutlass.const_expr(WITH_ACT): - # If it's relu, we can handle it later - if not cutlass.const_expr(FUSE_RELU): - x = op(x) - # Accumulate to the per-thread dbias register buffer for this tile if WITH_DBIAS - if cutlass.const_expr(WITH_DBIAS): - dbias_acc[start + i] += x - # If 16-bit input with activation, truncate to IType - if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): - x = kit_act.truncate_f32(x) - in_r[w][i] = x - if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically - else: - amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) - if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 - - biased_exp_r = float_to_e8m0(amax_r * max_norm_rcp) - - # mS_row_stage has logical shape (32, 2) and we have 64 threads where each is mapped to one scale factor - # The TV layout is equivalent to TV layout with thr_layout=(32, 2):(2, 1), val_layout=(1,) - # but it's too trival so let's just index it directly without using layout - # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout - # that mappes the logical index to the physical offset - - # For irregular shapes, skip the scale store if this thread's logical row / col-block lies past the input's actual extents. - # TMA already zero-fills OOB input reads and drops OOB output writes; only the direct scale-byte gmem store needs an explicit guard. - scale_row = tile_row_start + tidx // 2 - scale_col_first_elt = tile_col_start + (tidx % 2) * MXFP8_BLOCK_SIZE - if scale_row < M and scale_col_first_elt < N: - mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) - - inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale - # Fetch the conversion function based on the FP8 format - cvt_f32x2 = _cvt_f32x2_to_fp8x2(FP8_DTYPE) - if cutlass.const_expr(_row_fast): - kit_cast = _packed16_kit(DTYPE) - mul_cvt_x2 = kit_cast.mul_cvt_to_fp8x2(FP8_DTYPE, FUSE_RELU) - # Pack `(inv_scale_r, inv_scale_r)` as a single 64-bit f32x2 once; - # the per-wave mul_cvt consumes this directly. - scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) - - for w in cutlass.range_constexpr(WAVES): - idx = (w * 4 + offset) % MXFP8_BLOCK_SIZE - idx = idx // 4 - if cutlass.const_expr(_row_fast): - # One fused PTX per x2 pair: x2 × f32x2 → fp8x2. - # Byte layout: byte[0]=fp8(lo * s), byte[1]=fp8(hi * s). - p01 = mul_cvt_x2(in_r[w][0], scale_2x) - p23 = mul_cvt_x2(in_r[w][1], scale_2x) - else: - # cvt PTX semantics: `cvt.rn.satfinite..f32 d, a, b` gives - # d[15:8]=fp8(a), d[7:0]=fp8(b). Pass (v1, v0) so the u16 low - # byte ends up as fp8(v0) and the high byte as fp8(v1). - v0 = in_r[w][0] * inv_scale_r - v1 = in_r[w][1] * inv_scale_r - v2 = in_r[w][2] * inv_scale_r - v3 = in_r[w][3] * inv_scale_r - p01 = cvt_f32x2(v1, v0, FUSE_RELU) # u16 little-endian: v0,v1 - p23 = cvt_f32x2(v3, v2, FUSE_RELU) # u16 little-endian: v2,v3 - quad = (p23 << Int32(16)) | p01 - sO_thread_u32[idx] = Uint32(quad) - - return amax_r - -@cute.jit -def quantize_colwise_mxfp8( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) - mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) - max_norm_rcp, - tile_row_start, # Int32 — global row index of this stage's row 0 - # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores - # for irregular shapes. - tile_col_start, # Int32 — global col index of this CTA's col 0 - # (= bidx * TILE_X). - M, N, # Int32 — full tensor extents. - ACTIVATION, - DTYPE, - FP8_DTYPE, - SWIZZLE, - TILE_X, - TILE_Y, - MXFP8_BLOCK_SIZE, - WITH_ACT=False, # forward: apply activation to the element - WITH_DACT=False, # backward: out = grad · act'(act_input) - sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) - WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) - CACHE_ACTIVATION=False, # overwrite sX_tile in place with the post-activation - # (IType-truncated) values, so the rowwise pass can read - # them instead of recomputing op -): - tidx, _, _ = cute.arch.thread_idx() - - _, tv_layout = cute.make_layout_tv( - thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), - val_layout=cute.make_layout((MXFP8_BLOCK_SIZE, 1), stride=(1, 1)) - ) - - sX_tv = cute.composition(sX_tile, tv_layout) - sO_tv = cute.composition(sO_col_tile, tv_layout) - - # I/O Elements that belong to this thread - sX_thread = sX_tv[tidx, None] - sO_thread = sO_tv[tidx, None] - - # dbias needs the per-element fp32 values to sum, so it takes the f32 path - # (never the i16 fast path) — matching CUDA, whose f16 fast path requires - # `!IS_DBIAS` (quantize_mxfp8.cuh:219). - USE_HALF_PRECISION = _is_packed16(DTYPE) and ACTIVATION is None - dbias_partial = Float32(0.0) - - if cutlass.const_expr(USE_HALF_PRECISION): - kit = _packed16_kit(DTYPE) - # If we can use the half precision format, then use the input tile directly since there is no need to upcast - sX_thread_i16 = cute.make_tensor( - cute.recast_ptr(sX_thread.iterator, dtype=Int16), - cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(TILE_X,)), - ) - if cutlass.const_expr(WITH_DBIAS): - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - dbias_partial += kit.bits_to_f32(sX_thread_i16[i]) - amax_bits = Int16(0) - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) - amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) - else: - # Otherwise we need to case input values to fp32. Allocate the register tensor and load from SMEM input tiles. - sX_thread_f32 = cute.make_rmem_tensor( - layout_or_shape=cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(1,)), - dtype=Float32, - ) - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sX_thread_f32[i] = Float32(sX_thread[i]) - # Apply activation (fwd) or grad·act'(act_input) (bwd dact) in f32. - if cutlass.const_expr(WITH_DACT): - dop = SUPPORTED_DACTIVATIONS[ACTIVATION] - sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sX_thread_f32[i] = sX_thread_f32[i] * dop(Float32(sA_thread[i])) - elif cutlass.const_expr(WITH_ACT): - op = SUPPORTED_ACTIVATIONS[ACTIVATION] - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sX_thread_f32[i] = op(sX_thread_f32[i]) - # Accumulate the per-thread column partial for dbias if WITH_DBIAS. - if cutlass.const_expr(WITH_DBIAS): - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - dbias_partial += sX_thread_f32[i] - # Truncate the activation (after we apply op) back to the half precision type if input is also half precision. - if cutlass.const_expr(_is_packed16(DTYPE) and ACTIVATION is not None): - kit_act = _packed16_kit(DTYPE) - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) - # Columnwise is the preferred direction so it runs first. If it needs to cache the activation in the input tile - # to let the rowwise pass read it, we need to cast and overwrite the input data in-place here - if cutlass.const_expr(CACHE_ACTIVATION): - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sX_thread[i] = DTYPE(sX_thread_f32[i]) - amax_c = Float32(0.0) - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) - - # Irregular shapes: skip when this stage's row range or this thread's - # column lies past the input extents. TILE_Y == MXFP8_BLOCK_SIZE so each stage - # is exactly one scale-row; valid iff `tile_row_start < M`. - biased_exp_c = float_to_e8m0(amax_c * max_norm_rcp) - scale_col = tile_col_start + tidx - if tile_row_start < M and scale_col < N: - if cutlass.const_expr(SWIZZLE): - mS_col_stage[(0, tidx % 32, tidx // 32)] = Uint8(biased_exp_c) - else: - mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) - - inv_scale_c = exp2f_rcp(biased_exp_c) - cvt_to_fp8_func = _cvt_f32_to_fp8(FP8_DTYPE) - if cutlass.const_expr(USE_HALF_PRECISION): - kit_cast = _packed16_kit(DTYPE) - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) - sO_thread[i] = Uint8(cvt_to_fp8_func(v_f32 * inv_scale_c)) - else: - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sO_thread[i] = Uint8(cvt_to_fp8_func(sX_thread_f32[i] * inv_scale_c)) - - # Return this stage's per-column partial alongside amax; the caller accumulates - # it across stages (a scalar can't be updated in-place through the arg). - return amax_c, dbias_partial diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index e884a9c0e2..de26a60e43 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -18,28 +18,45 @@ """ import logging -from transformer_engine.common.CuTeDSL.utils import str_to_cutlass_dtype +from transformer_engine.common.CuTeDSL.utils import _bitcast_f32_to_i32, str_to_cutlass_dtype from typing import Optional, Type import cutlass import cutlass.cute as cute import cutlass.pipeline as pipeline -from cutlass import Float32, Int32, Uint8 +from cutlass import Float32, Int16, Int32, Uint32, Uint8 from cuda.bindings.driver import CUstream import tvm_ffi -from .mxfp8_utils import ( - SUPPORTED_ACTIVATIONS, - SUPPORTED_DACTIVATIONS, - FP8E4M3_MAX_NORM_RCP, - FP8E5M2_MAX_NORM_RCP, - _bitcast_f32_to_i32, - quantize_colwise_mxfp8, - quantize_rowwise_mxfp8, +from transformer_engine.common.CuTeDSL.activations import ( + act_relu, + act_gelu, + act_silu, + act_qgelu, + act_srelu, + dact_drelu, + dact_dsrelu, + dact_dsilu, + dact_dqgelu, + dact_dgelu, ) +from transformer_engine.common.CuTeDSL.utils import ( + is_packed16, + packed16_kit, + fabs_f32, + exp2f_rcp, + pack_f32x2, +) +from transformer_engine.common.CuTeDSL.utils_fp8 import ( + cvt_f32_to_fp8, + cvt_f32x2_to_fp8x2, + cvt_f32_to_fp8e8m0 +) + + logger = logging.getLogger("transformer_engine.cutedsl.mxfp8") # MXFP8 settings @@ -65,6 +82,353 @@ THREADS_PER_CTA = 64 NUM_WARPS = THREADS_PER_CTA // 32 +# FP8E4M3 max representable value +FP8E4M3_MAX_NORM = 448.0 +FP8E4M3_MAX_NORM_RCP = 1.0 / FP8E4M3_MAX_NORM +FP8E5M2_MAX_NORM = 57344.0 +FP8E5M2_MAX_NORM_RCP = 1.0 / FP8E5M2_MAX_NORM + + +SUPPORTED_ACTIVATIONS = { + "relu": act_relu, + "gelu": act_gelu, + "silu": act_silu, + "qgelu": act_qgelu, + "srelu": act_srelu, +} + +SUPPORTED_DACTIVATIONS = { + "drelu": dact_drelu, + "dgelu": dact_dgelu, + "dsilu": dact_dsilu, + "dqgelu": dact_dqgelu, + "dsrelu": dact_dsrelu, +} + + +@cute.jit +def quantize_rowwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sA_tile, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). Same purpose. + M, N, # Int32 — full tensor extents; OOB threads skip their + # scale store. + ACTIVATION, + DTYPE, + FP8_DTYPE, + TILE_Y, + MXFP8_BLOCK_SIZE, + WAVES, + THREADS_PER_WARP, + THREADS_PER_BANK, + PACK_SIZE, + WITH_ACT=False, + WITH_DACT=False, + WITH_DBIAS=False, # rowwise-only dbias: accumulate per-column partials + dbias_acc=None, # only needed when WITH_DBIAS is True +): + tidx, _, _ = cute.arch.thread_idx() + + _, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)) + ) + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_row_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] # shape (32,) bf16 + sO_thread = sO_tv[tidx, None] # shape (32,) uint8 + + sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) + # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. + sO_thread_u32 = cute.make_tensor( + sO_thread_u32_ptr, + cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + ) + + # PTX allows to fuse relu activation in `cvt.rn.satfinite` + FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") + # For this fast path we can read in pack of 2 instead of reading individual f16 / bf16 element. + # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. + _row_fast = (is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) + and not WITH_DBIAS) + + amax_r = Float32(0.0) + + # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks + # to avoid bank conflict. + bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK + # The offset this thread should start reading from based on what's its first bank to access. + offset = bank_group * 4 # Each bank group will read 4 f16 from their bank + if cutlass.const_expr(_row_fast): + # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack + # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute + kit = packed16_kit(DTYPE) + sX_thread_rw_i32 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int32), + cute.make_layout((1, MXFP8_BLOCK_SIZE // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + ) + # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) + # In total we have 8 waves where each wave reads 4 elements, so we read 32 elements in total. + in_r = [[None, None] for _ in range(WAVES)] + for w in cutlass.range_constexpr(WAVES): + idx = (w * 2 + offset // 2) % (MXFP8_BLOCK_SIZE // 2) + in_r[w][0] = sX_thread_rw_i32[0, idx] + in_r[w][1] = sX_thread_rw_i32[0, idx + 1] + + amax_2x = Int32(0) + # Each wave will use max.xorsign.abs.f16x2 or max.xorsign.abs.bf16x2 to compare 2 packed elements in parallel + for w in cutlass.range_constexpr(WAVES): + if cutlass.const_expr(FUSE_RELU): + # If we fuse relu then we don't want to do abs since negative value will be set to 0 and they will lose comparison automatically + amax_2x = kit.max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.max_x2(amax_2x, in_r[w][1]) + else: + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][0]) + amax_2x = kit.abs_max_x2(amax_2x, in_r[w][1]) + if cutlass.const_expr(FUSE_RELU): + # Compare the 2 packed max without abs + amax_r = cute.arch.fmax( + kit.x2_lo_to_f32(amax_2x), + kit.x2_hi_to_f32(amax_2x), + ) + # For relu the max is at least 0 + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) + else: + # Compare the 2 packed abs max + amax_r = cute.arch.fmax( + fabs_f32(kit.x2_lo_to_f32(amax_2x)), + fabs_f32(kit.x2_hi_to_f32(amax_2x)), + ) + else: + # Since we need to do computation on individual f16 / bf16 elements, we can't read in pack + sX_thread_rw = cute.make_tensor( + sX_thread.iterator, + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), + ) + + if cutlass.const_expr(WITH_DACT): + # Backward: out = grad · act'(act_input). sX is grad, sA is act_input. + dop = SUPPORTED_DACTIVATIONS[ACTIVATION] + sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] + sA_thread_rw = cute.make_tensor( + sA_thread.iterator, + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), + ) + elif cutlass.const_expr(WITH_ACT): + op = SUPPORTED_ACTIVATIONS[ACTIVATION] + + if cutlass.const_expr(is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = packed16_kit(DTYPE) + + # Each wave we read PACK_SIZE elements, and we have WAVES waves, so we read WAVES * PACK_SIZE (= MXFP8_BLOCK_SIZE) elements in total. + in_r = [[None] * PACK_SIZE for _ in range(WAVES)] + for w in cutlass.range_constexpr(WAVES): + start = (w * PACK_SIZE + offset) % MXFP8_BLOCK_SIZE + for i in cutlass.range_constexpr(PACK_SIZE): + x = Float32(sX_thread_rw[0, start + i]) + if cutlass.const_expr(WITH_DACT): + # out = grad · act'(act_input) + x = x * dop(Float32(sA_thread_rw[0, start + i])) + # If IS_ACT, apply activation function to x in f32 + elif cutlass.const_expr(WITH_ACT): + # If it's relu, we can handle it later + if not cutlass.const_expr(FUSE_RELU): + x = op(x) + # Accumulate to the per-thread dbias register buffer for this tile if WITH_DBIAS + if cutlass.const_expr(WITH_DBIAS): + dbias_acc[start + i] += x + # If 16-bit input with activation, truncate to IType + if cutlass.const_expr(is_packed16(DTYPE) and ACTIVATION is not None): + x = kit_act.truncate_f32(x) + in_r[w][i] = x + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically + else: + amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) + if cutlass.const_expr(FUSE_RELU): + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 + + biased_exp_r = cvt_f32_to_fp8e8m0(amax_r * max_norm_rcp) + + # mS_row_stage has logical shape (32, 2) and we have 64 threads where each is mapped to one scale factor + # The TV layout is equivalent to TV layout with thr_layout=(32, 2):(2, 1), val_layout=(1,) + # but it's too trival so let's just index it directly without using layout + # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout + # that mappes the logical index to the physical offset + + # For irregular shapes, skip the scale store if this thread's logical row / col-block lies past the input's actual extents. + # TMA already zero-fills OOB input reads and drops OOB output writes; only the direct scale-byte gmem store needs an explicit guard. + scale_row = tile_row_start + tidx // 2 + scale_col_first_elt = tile_col_start + (tidx % 2) * MXFP8_BLOCK_SIZE + if scale_row < M and scale_col_first_elt < N: + mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) + + inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale + # Fetch the conversion function based on the FP8 format + cvt_f32x2 = cvt_f32x2_to_fp8x2(FP8_DTYPE) + if cutlass.const_expr(_row_fast): + kit_cast = packed16_kit(DTYPE) + mul_cvt_x2 = kit_cast.mul_cvt_to_fp8x2(FP8_DTYPE, FUSE_RELU) + # Pack `(inv_scale_r, inv_scale_r)` as a single 64-bit f32x2 once; + # the per-wave mul_cvt consumes this directly. + scale_2x = pack_f32x2(inv_scale_r, inv_scale_r) + + for w in cutlass.range_constexpr(WAVES): + idx = (w * 4 + offset) % MXFP8_BLOCK_SIZE + idx = idx // 4 + if cutlass.const_expr(_row_fast): + # One fused PTX per x2 pair: x2 × f32x2 → fp8x2. + # Byte layout: byte[0]=fp8(lo * s), byte[1]=fp8(hi * s). + p01 = mul_cvt_x2(in_r[w][0], scale_2x) + p23 = mul_cvt_x2(in_r[w][1], scale_2x) + else: + # cvt PTX semantics: `cvt.rn.satfinite..f32 d, a, b` gives + # d[15:8]=fp8(a), d[7:0]=fp8(b). Pass (v1, v0) so the u16 low + # byte ends up as fp8(v0) and the high byte as fp8(v1). + v0 = in_r[w][0] * inv_scale_r + v1 = in_r[w][1] * inv_scale_r + v2 = in_r[w][2] * inv_scale_r + v3 = in_r[w][3] * inv_scale_r + p01 = cvt_f32x2(v1, v0, FUSE_RELU) # u16 little-endian: v0,v1 + p23 = cvt_f32x2(v3, v2, FUSE_RELU) # u16 little-endian: v2,v3 + quad = (p23 << Int32(16)) | p01 + sO_thread_u32[idx] = Uint32(quad) + + return amax_r + +@cute.jit +def quantize_colwise_mxfp8( + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + max_norm_rcp, + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). + M, N, # Int32 — full tensor extents. + ACTIVATION, + DTYPE, + FP8_DTYPE, + SWIZZLE, + TILE_X, + TILE_Y, + MXFP8_BLOCK_SIZE, + WITH_ACT=False, # forward: apply activation to the element + WITH_DACT=False, # backward: out = grad · act'(act_input) + sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) + CACHE_ACTIVATION=False, # overwrite sX_tile in place with the post-activation + # (IType-truncated) values, so the rowwise pass can read + # them instead of recomputing op +): + tidx, _, _ = cute.arch.thread_idx() + + _, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), + val_layout=cute.make_layout((MXFP8_BLOCK_SIZE, 1), stride=(1, 1)) + ) + + sX_tv = cute.composition(sX_tile, tv_layout) + sO_tv = cute.composition(sO_col_tile, tv_layout) + + # I/O Elements that belong to this thread + sX_thread = sX_tv[tidx, None] + sO_thread = sO_tv[tidx, None] + + # dbias needs the per-element fp32 values to sum, so it takes the f32 path + # (never the i16 fast path) — matching CUDA, whose f16 fast path requires + # `!IS_DBIAS` (quantize_mxfp8.cuh:219). + USE_HALF_PRECISION = is_packed16(DTYPE) and ACTIVATION is None + dbias_partial = Float32(0.0) + + if cutlass.const_expr(USE_HALF_PRECISION): + kit = packed16_kit(DTYPE) + # If we can use the half precision format, then use the input tile directly since there is no need to upcast + sX_thread_i16 = cute.make_tensor( + cute.recast_ptr(sX_thread.iterator, dtype=Int16), + cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(TILE_X,)), + ) + if cutlass.const_expr(WITH_DBIAS): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + dbias_partial += kit.bits_to_f32(sX_thread_i16[i]) + amax_bits = Int16(0) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) + amax_c = fabs_f32(kit.bits_to_f32(amax_bits)) + else: + # Otherwise we need to case input values to fp32. Allocate the register tensor and load from SMEM input tiles. + sX_thread_f32 = cute.make_rmem_tensor( + layout_or_shape=cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(1,)), + dtype=Float32, + ) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sX_thread_f32[i] = Float32(sX_thread[i]) + # Apply activation (fwd) or grad·act'(act_input) (bwd dact) in f32. + if cutlass.const_expr(WITH_DACT): + dop = SUPPORTED_DACTIVATIONS[ACTIVATION] + sA_thread = cute.composition(sA_tile, tv_layout)[tidx, None] + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sX_thread_f32[i] = sX_thread_f32[i] * dop(Float32(sA_thread[i])) + elif cutlass.const_expr(WITH_ACT): + op = SUPPORTED_ACTIVATIONS[ACTIVATION] + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sX_thread_f32[i] = op(sX_thread_f32[i]) + # Accumulate the per-thread column partial for dbias if WITH_DBIAS. + if cutlass.const_expr(WITH_DBIAS): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + dbias_partial += sX_thread_f32[i] + # Truncate the activation (after we apply op) back to the half precision type if input is also half precision. + if cutlass.const_expr(is_packed16(DTYPE) and ACTIVATION is not None): + kit_act = packed16_kit(DTYPE) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sX_thread_f32[i] = kit_act.truncate_f32(sX_thread_f32[i]) + # Columnwise is the preferred direction so it runs first. If it needs to cache the activation in the input tile + # to let the rowwise pass read it, we need to cast and overwrite the input data in-place here + if cutlass.const_expr(CACHE_ACTIVATION): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sX_thread[i] = DTYPE(sX_thread_f32[i]) + amax_c = Float32(0.0) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + amax_c = cute.arch.fmax(amax_c, fabs_f32(sX_thread_f32[i])) + + # Irregular shapes: skip when this stage's row range or this thread's + # column lies past the input extents. TILE_Y == MXFP8_BLOCK_SIZE so each stage + # is exactly one scale-row; valid iff `tile_row_start < M`. + biased_exp_c = cvt_f32_to_fp8e8m0(amax_c * max_norm_rcp) + scale_col = tile_col_start + tidx + if tile_row_start < M and scale_col < N: + if cutlass.const_expr(SWIZZLE): + mS_col_stage[(0, tidx % 32, tidx // 32)] = Uint8(biased_exp_c) + else: + mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) + + inv_scale_c = exp2f_rcp(biased_exp_c) + cvt_to_fp8_func = cvt_f32_to_fp8(FP8_DTYPE) + if cutlass.const_expr(USE_HALF_PRECISION): + kit_cast = packed16_kit(DTYPE) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) + sO_thread[i] = Uint8(cvt_to_fp8_func(v_f32 * inv_scale_c)) + else: + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + sO_thread[i] = Uint8(cvt_to_fp8_func(sX_thread_f32[i] * inv_scale_c)) + + # Return this stage's per-column partial alongside amax; the caller accumulates + # it across stages (a scalar can't be updated in-place through the arg). + return amax_c, dbias_partial + class MXFP8QuantizeConfig: """Configs for the compiled CuTeDSL kernel. These will be fixed once the kernel is compiled and they will behave as const expressions. diff --git a/transformer_engine/common/CuTeDSL/utils.py b/transformer_engine/common/CuTeDSL/utils.py index 258feea66c..a798eba864 100644 --- a/transformer_engine/common/CuTeDSL/utils.py +++ b/transformer_engine/common/CuTeDSL/utils.py @@ -1,4 +1,10 @@ import cutlass +from cutlass import Float32, Int64, Int32, Int16 +from cutlass._mlir.dialects import arith as mlir_arith +from cutlass._mlir.dialects import llvm +from cutlass.cutlass_dsl import T, dsl_user_op + +from types import SimpleNamespace _CUTLASS_DTYPE_FROM_STR = { "fp32": cutlass.Float32, @@ -13,4 +19,249 @@ def str_to_cutlass_dtype(dtype_str: str): def cutlass_dtype_to_str(dtype): """Convert a cutlass dtype back to its protocol string, or None if unknown.""" - return _STR_FROM_CUTLASS_DTYPE.get(dtype, None) \ No newline at end of file + return _STR_FROM_CUTLASS_DTYPE.get(dtype, None) + +FP32_MANTISSA_BITS = 23 + +@dsl_user_op +def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: + """Bitcast a float32 value to int32 without changing the bit pattern.""" + return Int32(mlir_arith.bitcast(T.i32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def _bitcast_i32_to_f32(val: Int32, *, loc=None, ip=None) -> Float32: + """Bitcast an int32 value to float32 without changing the bit pattern.""" + return Float32(mlir_arith.bitcast(T.f32(), val.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def fabs_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Compute the absolute value of a float32.""" + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + abs_i32 = val_i32 & Int32(0x7FFFFFFF) + return _bitcast_i32_to_f32(abs_i32, loc=loc, ip=ip) + + +@dsl_user_op +def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32: + """Compute the fused multiply-add of three float32 values: a * b + c.""" + return Float32(llvm.inline_asm( + T.f32(), + [a.ir_value(loc=loc, ip=ip), + b.ir_value(loc=loc, ip=ip), + c.ir_value(loc=loc, ip=ip)], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + +@dsl_user_op +def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: + """2^(127 - biased_exp) with special-case handling.""" + new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) + result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) + for (cmp_val, repl_bits) in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: + cond = mlir_arith.cmpi(mlir_arith.CmpIPredicate.eq, + biased_exp.ir_value(loc=loc, ip=ip), + Int32(cmp_val).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) + result = Float32(mlir_arith.select( + cond, alt.ir_value(loc=loc, ip=ip), + result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result + + +@dsl_user_op +def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: + """Pack two f32 scalars into a single 64-bit register (`floatx2` layout). + + Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` + which lowers to a single register move — no actual memory traffic. + """ + return Int64(llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + +def _build_packed16_kit(in_fmt: str): + """Build a kit of PTX wrappers for a 16-bit input format so we don't have to repeat + the same inline asm boilerplate code for FP16 and BF16 dtypes. + + `in_fmt` is the PTX format string ('bf16' or 'f16'). Returns a namespace + with the per-format ops the rowwise/colwise inner loops need: + + abs_max_x2(Int32, Int32) -> Int32 # `max.xorsign.abs.x2` + abs_max_scalar(Int16, Int16) -> Int16 # `max.xorsign.abs.` + bits_to_f32(Int16) -> Float32 # widen one 16-bit element + x2_lo_to_f32(Int32) -> Float32 # extract+widen low half + x2_hi_to_f32(Int32) -> Float32 # extract+widen high half + mul_cvt_to_fp8x2(fp8_dtype) -> callable(Int32, Int64)->Int32 + # fused x2 * f32x2 -> fp8x2 + """ + + @dsl_user_op + def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32(llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: + return Int32(llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: + return Int16(llvm.inline_asm( + T.i16(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt} $0, $1, $2;", + "=h,h,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + if in_fmt == "bf16": + # bf16 == top 16 bits of f32 — widening is a free bit-shift. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + i32 = Int32(mlir_arith.extui( + T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + return _bitcast_i32_to_f32( + (bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal + # issues. Sign bits from the arith-right shift get zeroed by the + # left shift. + return _bitcast_i32_to_f32( + (bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to bf16 precision (round-to-nearest-even), keep f32. + Matches C++'s `static_cast(static_cast(elt))`.""" + bf16_bits = Int16(llvm.inline_asm( + T.i16(), [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + i32 = Int32(mlir_arith.extui( + T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + else: + # f16 has its own bit layout; widening requires `cvt.f32.f16`. + @dsl_user_op + def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: + return Float32(llvm.inline_asm( + T.f32(), [bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + @dsl_user_op + def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + lo_i16 = Int16(mlir_arith.trunci( + T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return bits_to_f32(lo_i16, loc=loc, ip=ip) + + @dsl_user_op + def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: + hi_shifted = bits >> Int32(16) + hi_i16 = Int16(mlir_arith.trunci( + T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return bits_to_f32(hi_i16, loc=loc, ip=ip) + + @dsl_user_op + def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: + """Round f32 to f16 precision, keep f32.""" + f16_bits = Int16(llvm.inline_asm( + T.i16(), [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16.f32 $0, $1;", + "=h,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32(llvm.inline_asm( + T.f32(), [f16_bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + + def _build_mul_cvt(out_fmt: str, relu: bool = False): + """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. + + The shape is identical across (in_fmt, out_fmt) combos — only the + widening opcode (`cvt.f32.`) and the final saturating cvt + (`cvt.rn.satfinite.x2.f32`) differ. + """ + out_op = "e4m3x2" if out_fmt == "e4m3" else "e5m2x2" + asm = ( + "{\n" + ".reg.b64 vp0; .reg.b64 vp1;\n\t" + ".reg.b32 v1; .reg.b32 v2;\n\t" + ".reg.b16 vb1; .reg.b16 vb2;\n\t" + "mov.b32 {vb1, vb2}, $1;\n\t" + f"cvt.f32.{in_fmt} v1, vb1;\n\t" + f"cvt.f32.{in_fmt} v2, vb2;\n\t" + "mov.b64 vp0, {v1, v2};\n\t" + "mul.f32x2 vp1, vp0, $2;\n\t" + "mov.b64 {v2, v1}, vp1;\n\t" + f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 $0, v1, v2;\n\t" + "}" + ) + + @dsl_user_op + def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_2x.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=h,r,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return fn + + def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): + if fp8_dtype == "e5m2": + return _build_mul_cvt("e5m2", relu) + return _build_mul_cvt("e4m3", relu) + + return SimpleNamespace( + max_x2=max_x2, + abs_max_x2=abs_max_x2, + abs_max_scalar=abs_max_scalar, + bits_to_f32=bits_to_f32, + x2_lo_to_f32=x2_lo_to_f32, + x2_hi_to_f32=x2_hi_to_f32, + truncate_f32=truncate_f32, + mul_cvt_to_fp8x2=mul_cvt_to_fp8x2, + ) + + +_BF16_KIT = _build_packed16_kit("bf16") +_F16_KIT = _build_packed16_kit("f16") + + +def is_packed16(dtype) -> bool: + """True if `dtype` is one of the 16-bit packed input formats.""" + return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 + +def packed16_kit(dtype): + """Trace-time selector — pick a Packed16Kit for the input dtype.""" + if dtype is cutlass.Float16: + return _F16_KIT + return _BF16_KIT diff --git a/transformer_engine/common/CuTeDSL/utils_fp8.py b/transformer_engine/common/CuTeDSL/utils_fp8.py new file mode 100644 index 0000000000..0c2e901afc --- /dev/null +++ b/transformer_engine/common/CuTeDSL/utils_fp8.py @@ -0,0 +1,97 @@ +import cutlass +import cutlass.cute as cute +from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 +from cutlass._mlir.dialects import arith as mlir_arith +from cutlass._mlir.dialects import llvm +from cutlass.cutlass_dsl import T, dsl_user_op + +from transformer_engine.common.CuTeDSL.utils import FP32_MANTISSA_BITS, _bitcast_f32_to_i32 + +@dsl_user_op +def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e4m3 conversion.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e5m2 conversion.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +@dsl_user_op +def cvt_f32_to_fp8e8m0(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e8m0 conversion.""" + val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) + rounded = val_i32 + Int32(0x7FFFFF) + exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) + return Int32(mlir_arith.minsi( + exponent.ir_value(loc=loc, ip=ip), + Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. + + Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). + """ + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +@dsl_user_op +def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, + *, loc=None, ip=None) -> Int32: + """Convert two float32 values to two packed fp8e5m2 bytes in one instruction. + + Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). + """ + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + + +def cvt_f32_to_fp8(fp8_dtype: str): + """Returns the float32 -> float8 conversion function for the given FP8 format.""" + if fp8_dtype == "e5m2": + return cvt_f32_to_fp8e5m2 + return cvt_f32_to_fp8e4m3 + + +def cvt_f32x2_to_fp8x2(fp8_dtype: str): + """Returns the float32x2 -> float8x2 conversion function for the given FP8 format.""" + if fp8_dtype == "e5m2": + return cvt_f32x2_to_fp8e5m2x2 + return cvt_f32x2_to_fp8e4m3x2 + diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 334dd0eb15..21c1f5a1e2 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -19,14 +19,27 @@ from transformer_engine.pytorch import constants from transformer_engine.pytorch.constants import DType -# Register the CuTeDSL kernel entrypoints (TVM-FFI global funcs) so the C++ -# dispatcher can discover them via GetGlobal and compile kernels on demand. The -# CuTeDSL toolchain (cutlass, tvm_ffi) is optional; if it is unavailable the -# import is skipped and C++ simply falls back to the CUDA C++ kernels. +# Import the CuTeDSL module to register its quantize entrypoints with TVM-FFI. +# This backend is optional (install via `pip install transformer-engine[cutedsl]`); +# without it, TE simply uses the built-in CUDA kernels. try: import transformer_engine.common.CuTeDSL # noqa: F401 -except Exception: - pass +except ModuleNotFoundError as e: + # The optional CuTeDSL toolchain (e.g. nvidia-cutlass-dsl) isn't installed — + # expected for a default install; quietly fall back to the CUDA kernels. + import logging + logging.getLogger(__name__).debug( + "CuTeDSL quantize backend not available (%s); using the CUDA kernels. " + "Install `transformer-engine[cutedsl]` to enable it.", + e, + ) +except Exception as e: + # Something is really broken if it's not an import error + import logging + logging.getLogger(__name__).warning( + "CuTeDSL quantize backend failed to import (%s); using the CUDA kernels.", + e, + ) from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear From 0e525322f415fed148f0e182cbe604d5d787cf5b Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 23 Jun 2026 19:09:52 +0000 Subject: [PATCH 06/22] add debug logging --- .../common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index de26a60e43..83ea3bec8d 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -17,6 +17,7 @@ shared memory. """ import logging +import os from transformer_engine.common.CuTeDSL.utils import _bitcast_f32_to_i32, str_to_cutlass_dtype @@ -56,6 +57,7 @@ cvt_f32_to_fp8e8m0 ) +CUTEDSL_DEBUG_LOGGING = os.environ.get("CUTEDSL_DEBUG_LOGGING", "0") == "1" logger = logging.getLogger("transformer_engine.cutedsl.mxfp8") @@ -528,6 +530,7 @@ def __init__(self, cfg): # scale; this only skips the redundant global comparison in the other pass. self.AMAX_FROM_COLWISE = cfg.WITH_AMAX and cfg.COLWISE self.AMAX_FROM_ROWWISE = cfg.WITH_AMAX and not cfg.COLWISE + # Only enable logging if CUTEDSL_DEBUG_LOGGING=1 @cute.jit def __call__( @@ -541,6 +544,9 @@ def __call__( mWorkspace: Optional[cute.Tensor], # Workspace for the dbias reduction, only used when WITH_DBIAS is True stream: CUstream, ): + if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): + cute.printf(f"[CuTeDSL] MXFP8QuantizeSmemKernel.__call__() with config: {self.cfg}\n") + M = mX.shape[0] N = mX.shape[1] cfg = self.cfg From 5f8186ea6cf5ec03abb9bb117b270d1cff67420a Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 23 Jun 2026 19:23:30 +0000 Subject: [PATCH 07/22] nit --- .../common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 83ea3bec8d..362ff17a42 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -68,8 +68,8 @@ BUFFER_NUM = 2 # Vectorised access constants for bank-conflict avoidance (rowwise pass) -PACK_SIZE = 4 # Elements per vector load -WAVES = MXFP8_BLOCK_SIZE // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +PACK_SIZE = 4 # Elements per vector load +WAVES = MXFP8_BLOCK_SIZE // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total THREADS_PER_WARP = 32 TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) THREADS_PER_BANK = TOTAL_BANKS_WIDTH // MXFP8_BLOCK_SIZE # 4 threads per bank From 1c694c84c754ed519a6bf0fb2b05d71e873cc9d9 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Tue, 23 Jun 2026 19:23:44 +0000 Subject: [PATCH 08/22] make CuTeDSL work for JAX too --- transformer_engine/common/__init__.py | 32 ++++++++++++++++++++++++++ transformer_engine/jax/__init__.py | 6 ++++- transformer_engine/pytorch/__init__.py | 26 ++++----------------- 3 files changed, 41 insertions(+), 23 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 42b458bfc5..b8b06b91c9 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -192,6 +192,38 @@ def load_framework_extension(framework: str) -> None: spec.loader.exec_module(solib) +def register_cutedsl_quant_backend() -> None: + """Import the CuTeDSL module so it registers its quantize entrypoints with + TVM-FFI, making the (framework-agnostic) C++ dispatcher able to discover and + JIT-compile them. Call this from each framework's __init__ after the extension + is loaded — the registration itself is framework-agnostic. + + The backend is optional (`pip install transformer-engine[cutedsl]`); on a + default install the import is skipped and TE uses the built-in CUDA kernels. + """ + try: + import transformer_engine.common.CuTeDSL # noqa: F401 + except ModuleNotFoundError as e: + # The optional CuTeDSL toolchain (e.g. nvidia-cutlass-dsl) isn't installed — + # expected for a default install; quietly fall back to the CUDA kernels. + import logging + + logging.getLogger(__name__).debug( + "CuTeDSL quantize backend not available (%s); using the CUDA kernels. " + "Install `transformer-engine[cutedsl]` to enable it.", + e, + ) + except Exception as e: + # Toolchain is present but the backend failed to import — unexpected and + # worth surfacing; TE still falls back to the CUDA kernels. + import logging + + logging.getLogger(__name__).warning( + "CuTeDSL quantize backend failed to import (%s); using the CUDA kernels.", + e, + ) + + def sanity_checks_for_pypi_installation() -> None: """Ensure that package is installed correctly if using PyPI.""" diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index d0afc1ff25..c31abd5f54 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -27,10 +27,14 @@ # extensions are not available. import jax -from transformer_engine.common import load_framework_extension +from transformer_engine.common import load_framework_extension, register_cutedsl_quant_backend load_framework_extension("jax") +# Register the CuTeDSL quantize backend entrypoints (TVM-FFI). Optional; falls +# back to the CUDA kernels if the CuTeDSL toolchain isn't installed. +register_cutedsl_quant_backend() + from . import flax from . import quantize diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 21c1f5a1e2..61a8ee0edf 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -10,7 +10,7 @@ import torch -from transformer_engine.common import load_framework_extension +from transformer_engine.common import load_framework_extension, register_cutedsl_quant_backend from transformer_engine.pytorch.torch_version import torch_version assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." @@ -19,27 +19,9 @@ from transformer_engine.pytorch import constants from transformer_engine.pytorch.constants import DType -# Import the CuTeDSL module to register its quantize entrypoints with TVM-FFI. -# This backend is optional (install via `pip install transformer-engine[cutedsl]`); -# without it, TE simply uses the built-in CUDA kernels. -try: - import transformer_engine.common.CuTeDSL # noqa: F401 -except ModuleNotFoundError as e: - # The optional CuTeDSL toolchain (e.g. nvidia-cutlass-dsl) isn't installed — - # expected for a default install; quietly fall back to the CUDA kernels. - import logging - logging.getLogger(__name__).debug( - "CuTeDSL quantize backend not available (%s); using the CUDA kernels. " - "Install `transformer-engine[cutedsl]` to enable it.", - e, - ) -except Exception as e: - # Something is really broken if it's not an import error - import logging - logging.getLogger(__name__).warning( - "CuTeDSL quantize backend failed to import (%s); using the CUDA kernels.", - e, - ) +# Register the CuTeDSL quantize backend entrypoints (TVM-FFI). Optional; falls +# back to the CUDA kernels if the CuTeDSL toolchain isn't installed. +register_cutedsl_quant_backend() from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear From 0fe5ab9d5d2cf7d1a2182613448790110e20f97a Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 24 Jun 2026 19:00:38 +0000 Subject: [PATCH 09/22] refactor, add specialized kernel --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 420 +++++++++++++----- transformer_engine/common/CuTeDSL/utils.py | 52 ++- .../common/CuTeDSL/utils_fp8.py | 69 ++- 3 files changed, 435 insertions(+), 106 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 362ff17a42..7572740f94 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -50,39 +50,23 @@ fabs_f32, exp2f_rcp, pack_f32x2, + mul_cvt_f32x4_to_fp8x4, ) from transformer_engine.common.CuTeDSL.utils_fp8 import ( - cvt_f32_to_fp8, - cvt_f32x2_to_fp8x2, + get_cvt_f32_to_fp8_func, + get_cvt_f32x2_to_fp8x2_func, cvt_f32_to_fp8e8m0 ) +from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE CUTEDSL_DEBUG_LOGGING = os.environ.get("CUTEDSL_DEBUG_LOGGING", "0") == "1" logger = logging.getLogger("transformer_engine.cutedsl.mxfp8") -# MXFP8 settings -MXFP8_BLOCK_SIZE = 32 # Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor - -# Double-buffering for async copy + compute overlap -BUFFER_NUM = 2 - -# Vectorised access constants for bank-conflict avoidance (rowwise pass) -PACK_SIZE = 4 # Elements per vector load -WAVES = MXFP8_BLOCK_SIZE // PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total +# Number of elements per MXFP8 scale block. They will share the same E8M0 scale factor +MXFP8_BLOCK_SIZE = 32 +# How many threads are in one warp THREADS_PER_WARP = 32 -TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) -THREADS_PER_BANK = TOTAL_BANKS_WIDTH // MXFP8_BLOCK_SIZE # 4 threads per bank - -# Tiling sizes -NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) -NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension -TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total -TILE_X = 64 # Each tile has 64 columns - -# CTA size -THREADS_PER_CTA = 64 -NUM_WARPS = THREADS_PER_CTA // 32 # FP8E4M3 max representable value FP8E4M3_MAX_NORM = 448.0 @@ -126,9 +110,7 @@ def quantize_rowwise_mxfp8( DTYPE, FP8_DTYPE, TILE_Y, - MXFP8_BLOCK_SIZE, WAVES, - THREADS_PER_WARP, THREADS_PER_BANK, PACK_SIZE, WITH_ACT=False, @@ -277,7 +259,7 @@ def quantize_rowwise_mxfp8( inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale # Fetch the conversion function based on the FP8 format - cvt_f32x2 = cvt_f32x2_to_fp8x2(FP8_DTYPE) + cvt_f32x2 = get_cvt_f32x2_to_fp8x2_func(FP8_DTYPE) if cutlass.const_expr(_row_fast): kit_cast = packed16_kit(DTYPE) mul_cvt_x2 = kit_cast.mul_cvt_to_fp8x2(FP8_DTYPE, FUSE_RELU) @@ -326,7 +308,6 @@ def quantize_colwise_mxfp8( SWIZZLE, TILE_X, TILE_Y, - MXFP8_BLOCK_SIZE, WITH_ACT=False, # forward: apply activation to the element WITH_DACT=False, # backward: out = grad · act'(act_input) sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) @@ -417,7 +398,7 @@ def quantize_colwise_mxfp8( mS_col_stage[(0, tidx)] = Uint8(biased_exp_c) inv_scale_c = exp2f_rcp(biased_exp_c) - cvt_to_fp8_func = cvt_f32_to_fp8(FP8_DTYPE) + cvt_to_fp8_func = get_cvt_f32_to_fp8_func(FP8_DTYPE) if cutlass.const_expr(USE_HALF_PRECISION): kit_cast = packed16_kit(DTYPE) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): @@ -501,12 +482,28 @@ def __str__(self): __repr__ = __str__ -class MXFP8QuantizeSmemKernel: +class MXFP8QuantizeKernel: """The MXFP8 quantization kernel that mirrors the standard (non-specialized) MXFP8 CUDA C++ quantization kernel with multiple fusions (activation, dbias, etc.). `__call__` method is the entrypoint which is AOT compiled. `self` will be captured so it's fixed per compiled kernel """ + # Vectorised access constants for bank-conflict avoidance (rowwise pass) + _PACK_SIZE = 4 # Elements per vector load + _WAVES = MXFP8_BLOCK_SIZE // _PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total + _TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) + _THREADS_PER_BANK = _TOTAL_BANKS_WIDTH // MXFP8_BLOCK_SIZE # 4 threads per bank + + # Tiling sizes + _NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) + _NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension + _TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total + _TILE_X = 64 # Each tile has 64 columns + + # CTA size + _THREADS_PER_CTA = 64 + _NUM_WARPS = _THREADS_PER_CTA // 32 + def __init__(self, cfg): self.cfg = cfg # We prefer to do dbias reduction in colwise which is easier (no cross-thread reduction needed). @@ -530,7 +527,6 @@ def __init__(self, cfg): # scale; this only skips the redundant global comparison in the other pass. self.AMAX_FROM_COLWISE = cfg.WITH_AMAX and cfg.COLWISE self.AMAX_FROM_ROWWISE = cfg.WITH_AMAX and not cfg.COLWISE - # Only enable logging if CUTEDSL_DEBUG_LOGGING=1 @cute.jit def __call__( @@ -545,7 +541,7 @@ def __call__( stream: CUstream, ): if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): - cute.printf(f"[CuTeDSL] MXFP8QuantizeSmemKernel.__call__() with config: {self.cfg}\n") + cute.printf(f"[CuTeDSL] MXFP8QuantizeKernel.__call__() with config: {self.cfg}\n") M = mX.shape[0] N = mX.shape[1] @@ -581,8 +577,8 @@ def __call__( ) # We have 2 stages in our pipeline where each stage loads / computes a (TILE_Y, TILE_X) tile - smem_tile_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) - cta_tiler = (TILE_Y, TILE_X) + smem_tile_layout = cute.make_ordered_layout((self._TILE_Y, self._TILE_X), order=(1, 0)) + cta_tiler = (self._TILE_Y, self._TILE_X) # Input TMA atoms op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() @@ -600,7 +596,7 @@ def __call__( # Output TMA atoms op_store = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() - out_smem_layout = cute.make_ordered_layout((TILE_Y, TILE_X), order=(1, 0)) + out_smem_layout = cute.make_ordered_layout((self._TILE_Y, self._TILE_X), order=(1, 0)) tma_atom_out_row = None tma_dst_out_row = None tma_atom_out_col = None @@ -615,10 +611,10 @@ def __call__( ) grid = [ - cute.ceil_div(Int32(N), TILE_X), - cute.ceil_div(M, TILE_Y * NUM_TILES), + cute.ceil_div(Int32(N), self._TILE_X), + cute.ceil_div(M, self._TILE_Y * self._NUM_TILES), ] - block = [THREADS_PER_CTA,] + block = [self._THREADS_PER_CTA,] self.kernel( mX, mS_row, mS_col, mAmax, mNoop, mWorkspace, @@ -680,80 +676,80 @@ def _kernel_main( cfg = self.cfg if cutlass.const_expr(cfg.ROWWISE): - mS_row = cute.zipped_divide(mS_row, (TILE_Y, TILE_X // MXFP8_BLOCK_SIZE)) + mS_row = cute.zipped_divide(mS_row, (self._TILE_Y, self._TILE_X // self._MXFP8_BLOCK_SIZE)) if cutlass.const_expr(cfg.COLWISE): - mS_col = cute.zipped_divide(mS_col, (TILE_Y // MXFP8_BLOCK_SIZE, TILE_X)) + mS_col = cute.zipped_divide(mS_col, (self._TILE_Y // self._MXFP8_BLOCK_SIZE, self._TILE_X)) # Allocate shared memory for the input and rowwise / columnwise outputs if cutlass.const_expr(cfg.ROWWISE and cfg.COLWISE): @cute.struct class SharedStorage: - mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] sX: cute.struct.Align[ - cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[dtype, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sO_row: cute.struct.Align[ - cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sO_col: cute.struct.Align[ - cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] - sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] elif cutlass.const_expr(cfg.ROWWISE and not cfg.COLWISE): @cute.struct class SharedStorage: - mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] sX: cute.struct.Align[ - cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[dtype, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sO_row: cute.struct.Align[ - cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] - sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] elif cutlass.const_expr(cfg.ROWWISE): @cute.struct class SharedStorage: - mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] sX: cute.struct.Align[ - cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[dtype, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sO_row: cute.struct.Align[ - cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] - sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] else: @cute.struct class SharedStorage: - mbar_storage: cute.struct.MemRange[cute.Int64, 2 * NUM_STAGES] + mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] sX: cute.struct.Align[ - cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[dtype, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sO_col: cute.struct.Align[ - cute.struct.MemRange[Uint8, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] - sAmax: cute.struct.MemRange[Float32, NUM_WARPS] + sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) # Apply the layout to the allocated shared memory buffers so the first rank is the tile (nested layout) # and the second rank is the pipeline stage sX = storage.sX.get_tensor( cute.make_layout( - ((TILE_Y, TILE_X), NUM_STAGES), - stride=((TILE_X, 1), TILE_Y * TILE_X), + ((self._TILE_Y, self._TILE_X), self._NUM_STAGES), + stride=((self._TILE_X, 1), self._TILE_Y * self._TILE_X), ) ) if cutlass.const_expr(cfg.ROWWISE): sO_row = storage.sO_row.get_tensor( cute.make_layout( - ((TILE_Y, TILE_X), NUM_STAGES), - stride=((TILE_X, 1), TILE_Y * TILE_X), + ((self._TILE_Y, self._TILE_X), self._NUM_STAGES), + stride=((self._TILE_X, 1), self._TILE_Y * self._TILE_X), ) ) if cutlass.const_expr(cfg.COLWISE): sO_col = storage.sO_col.get_tensor( cute.make_layout( - ((TILE_Y, TILE_X), NUM_STAGES), - stride=((TILE_X, 1), TILE_Y * TILE_X), + ((self._TILE_Y, self._TILE_X), self._NUM_STAGES), + stride=((self._TILE_X, 1), self._TILE_Y * self._TILE_X), ) ) @@ -762,14 +758,14 @@ class SharedStorage: @cute.struct class DactStorage: sActInput: cute.struct.Align[ - cute.struct.MemRange[dtype, TILE_Y * TILE_X * NUM_STAGES], 128 + cute.struct.MemRange[dtype, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] dact_storage = smem.allocate(DactStorage) # Apply the same layout as the input sActInput = dact_storage.sActInput.get_tensor( cute.make_layout( - ((TILE_Y, TILE_X), NUM_STAGES), - stride=((TILE_X, 1), TILE_Y * TILE_X), + ((self._TILE_Y, self._TILE_X), self._NUM_STAGES), + stride=((self._TILE_X, 1), self._TILE_Y * self._TILE_X), ) ) @@ -788,10 +784,10 @@ class DactStorage: # Only warp 0 is the producer (issues TMA) producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, 1) # Every warp is the consumer (reads the data loaded by TMA) - consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, NUM_WARPS) + consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self._NUM_WARPS) # Bytes transferred per TMA copy: one (TILE_Y, TILE_X) tile of dtype. - tx_count = TILE_Y * TILE_X * dtype.width // 8 + tx_count = self._TILE_Y * self._TILE_X * dtype.width // 8 # dact loads two tiles (grad + act_input) under the same per-stage barrier, # so the barrier must expect both copies' bytes. if cutlass.const_expr(cfg.WITH_DACT): @@ -799,7 +795,7 @@ class DactStorage: mainloop_pipeline = pipeline.PipelineTmaAsync.create( barrier_storage=storage.mbar_storage.data_ptr(), - num_stages=NUM_STAGES, + num_stages=self._NUM_STAGES, producer_group=producer_group, consumer_group=consumer_group, tx_count=tx_count, @@ -807,22 +803,22 @@ class DactStorage: ) prod_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Producer, NUM_STAGES + pipeline.PipelineUserType.Producer, self._NUM_STAGES ) cons_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, NUM_STAGES + pipeline.PipelineUserType.Consumer, self._NUM_STAGES ) M = mX.shape[0] N = mX.shape[1] num_tiles = cutlass.min( - NUM_TILES, - cute.ceil_div(M - bidy * TILE_Y * NUM_TILES, TILE_Y), + self._NUM_TILES, + cute.ceil_div(M - bidy * self._TILE_Y * self._NUM_TILES, self._TILE_Y), ) # Tile the TMA gmem view: ((TILE_Y, TILE_X), (M/TILE_Y, N/TILE_X)). - gX_tiled = cute.zipped_divide(tma_src, (TILE_Y, TILE_X)) + gX_tiled = cute.zipped_divide(tma_src, (self._TILE_Y, self._TILE_X)) # Partition sX/gX for the TMA atom (single-CTA, no cluster/multicast). tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( @@ -835,7 +831,7 @@ class DactStorage: # If WITH_DACT, partition the activation input for TMA as well in the same way if cutlass.const_expr(cfg.WITH_DACT): - gA_tiled = cute.zipped_divide(tma_src_act, (TILE_Y, TILE_X)) + gA_tiled = cute.zipped_divide(tma_src_act, (self._TILE_Y, self._TILE_X)) tXsA, tXgA = cute.nvgpu.cpasync.tma_partition( tma_atom_act, 0, @@ -846,7 +842,7 @@ class DactStorage: # Partitioning for rowwise / columnwise outputs if cutlass.const_expr(cfg.ROWWISE): - gO_row_tiled = cute.zipped_divide(tma_dst_out_row, (TILE_Y, TILE_X)) + gO_row_tiled = cute.zipped_divide(tma_dst_out_row, (self._TILE_Y, self._TILE_X)) tXsO_row, tXgO_row = cute.nvgpu.cpasync.tma_partition( tma_atom_out_row, 0, @@ -855,7 +851,7 @@ class DactStorage: gO_row_tiled, ) if cutlass.const_expr(cfg.COLWISE): - gO_col_tiled = cute.zipped_divide(tma_dst_out_col, (TILE_Y, TILE_X)) + gO_col_tiled = cute.zipped_divide(tma_dst_out_col, (self._TILE_Y, self._TILE_X)) tXsO_col, tXgO_col = cute.nvgpu.cpasync.tma_partition( tma_atom_out_col, 0, @@ -871,7 +867,7 @@ class DactStorage: if warp_idx == 0: for stage in cutlass.range(num_tiles, unroll=1): mainloop_pipeline.producer_acquire(prod_state) - tile_y = bidy * NUM_TILES + stage + tile_y = bidy * self._NUM_TILES + stage cute.copy( tma_atom, tXgX[(None, (tile_y, bidx))], @@ -924,7 +920,7 @@ class DactStorage: tile_idx_x = bidx # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. # So the tile index along Y dimension is `bidy * NUM_TILES + stage` - tile_idx_y = bidy * NUM_TILES + stage + tile_idx_y = bidy * self._NUM_TILES + stage # Process rowwise and colwise quantization separately if cutlass.const_expr(cfg.COLWISE): # The first row that belongs to this CTA. Each CTA handles NUM_TILES of (TILE_Y, TILE_X) tiles stacked vertically, @@ -935,7 +931,7 @@ class DactStorage: amax_c, dbias_c = self._process_colwise( sX_tile, sO_col_tile, mS_col_stage, max_norm_rcp, - tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + tile_idx_y * self._TILE_Y, bidx * self._TILE_X, M, N, sActInput_tile, ) if cutlass.const_expr(self.AMAX_FROM_COLWISE): @@ -963,7 +959,7 @@ class DactStorage: amax_r = self._process_rowwise( sX_tile, sO_row_tile, mS_row_stage, max_norm_rcp, - tile_idx_y * TILE_Y, bidx * TILE_X, M, N, + tile_idx_y * self._TILE_Y, bidx * self._TILE_X, M, N, sActInput_tile, rowwise_dbias_acc, ) @@ -980,7 +976,7 @@ class DactStorage: # Warp 0 issues TMA copy to write the quantized output tile from shared memory to global memory and then commits if warp_idx == 0: - tile_y = bidy * NUM_TILES + stage + tile_y = bidy * self._NUM_TILES + stage if cutlass.const_expr(cfg.ROWWISE): cute.copy( tma_atom_out_row, @@ -1004,20 +1000,20 @@ class DactStorage: # partial sum (per block). # Pad the buffer to avoid bank conflicts. The logical shape is still the same. Only the stride is different. - DBIAS_BUFF_WIDTH = TILE_X // MXFP8_BLOCK_SIZE * (MXFP8_BLOCK_SIZE + 1) + DBIAS_BUFF_WIDTH = self._TILE_X // MXFP8_BLOCK_SIZE * (MXFP8_BLOCK_SIZE + 1) @cute.struct class DbiasStorage: - sDbias: cute.struct.MemRange[Float32, TILE_Y * DBIAS_BUFF_WIDTH] + sDbias: cute.struct.MemRange[Float32, self._TILE_Y * DBIAS_BUFF_WIDTH] dbias_storage = smem.allocate(DbiasStorage) sDbias = dbias_storage.sDbias.get_tensor( - cute.make_layout((TILE_Y, TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), + cute.make_layout((self._TILE_Y, self._TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), ) # Thread layout: (TILE_Y, 2); value layout: (1, MXFP8_BLOCK_SIZE) where TILE_X = 2 * MXFP8_BLOCK_SIZE # And each thread writes the (1, MXFP8_BLOCK_SIZE) partial sum to this (TILE_Y, TILE_X) buffer # and then each thread reads its (TILE_Y, 1) sDbias column and writes the reduced sum to the GMEM workspace. # Since TILE_X == THREADS_PER_CTA, this column reduction yields (TILE_Y, TILE_X) -> (1, TILE_X). _, tv_layout_dbias_write = cute.make_layout_tv( - thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + thr_layout=cute.make_layout((self._TILE_Y, 2), stride=(2, 1)), val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), ) sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) @@ -1027,19 +1023,19 @@ class DbiasStorage: cute.arch.sync_threads() # All threads reduce the cross-thread partial sums to the per-block partial sum. _, tv_layout_dbias_reduce = cute.make_layout_tv( - thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), - val_layout=cute.make_layout((TILE_Y, 1), stride=(1, 1)) + thr_layout=cute.make_layout((1, self._TILE_X), stride=(self._TILE_X, 1)), + val_layout=cute.make_layout((self._TILE_Y, 1), stride=(1, 1)) ) sDbias_reduce = cute.composition(sDbias, tv_layout_dbias_reduce) # make_layout_tv yields a (thread, value) layout: thread=tidx -> column tidx, # value=i -> row i. So index [tidx, i] (thread first), summing the column's rows. block_dbias = Float32(0.0) - for i in cutlass.range_constexpr(TILE_Y): + for i in cutlass.range_constexpr(self._TILE_Y): block_dbias += sDbias_reduce[tidx, i] # Write the per-tile reduced dbias to the global workspace. if cutlass.const_expr(cfg.WITH_DBIAS): - dbias_col = bidx * TILE_X + tidx + dbias_col = bidx * self._TILE_X + tidx if dbias_col < N: mWorkspace[(bidy, dbias_col)] = block_dbias @@ -1047,7 +1043,7 @@ class DbiasStorage: # Reduce and get the per-warp amax. warp_amax = cute.arch.warp_redux_sync(per_thread_amax, kind="fmax") # Write the per-warp amax to shared memory - sAmax = storage.sAmax.get_tensor(cute.make_layout(NUM_WARPS)) + sAmax = storage.sAmax.get_tensor(cute.make_layout(self._NUM_WARPS)) lane_idx = tidx % 32 if lane_idx == 0: sAmax[warp_idx] = warp_amax @@ -1055,7 +1051,7 @@ class DbiasStorage: if tidx == 0: cta_amax = Float32(0.0) # The first thread reduces all the per-warp amax to the per-CTA amax - for w in cutlass.range_constexpr(NUM_WARPS): + for w in cutlass.range_constexpr(self._NUM_WARPS): cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) amax_i32 = cute.make_tensor( cute.recast_ptr(mAmax.iterator, dtype=Int32), @@ -1097,12 +1093,10 @@ def _process_rowwise( ACTIVATION=None if self.CACHE_ACTIVATION else cfg.ACTIVATION, DTYPE=cfg.DTYPE, FP8_DTYPE=cfg.FP8_DTYPE, - TILE_Y=TILE_Y, - MXFP8_BLOCK_SIZE=MXFP8_BLOCK_SIZE, - WAVES=WAVES, - THREADS_PER_WARP=THREADS_PER_WARP, - THREADS_PER_BANK=THREADS_PER_BANK, - PACK_SIZE=PACK_SIZE, + TILE_Y=self._TILE_Y, + WAVES=self._WAVES, + THREADS_PER_BANK=self._THREADS_PER_BANK, + PACK_SIZE=self._PACK_SIZE, WITH_ACT=cfg.WITH_ACT and not self.CACHE_ACTIVATION, WITH_DACT=cfg.WITH_DACT and not self.CACHE_ACTIVATION, WITH_DBIAS=self.DBIAS_REDUCTION_ROWWISE, @@ -1135,9 +1129,8 @@ def _process_colwise( DTYPE=cfg.DTYPE, FP8_DTYPE=cfg.FP8_DTYPE, SWIZZLE=cfg.WITH_GEMM_SWIZZLED_SCALES, - TILE_X=TILE_X, - TILE_Y=TILE_Y, - MXFP8_BLOCK_SIZE=MXFP8_BLOCK_SIZE, + TILE_X=self._TILE_X, + TILE_Y=self._TILE_Y, WITH_ACT=cfg.WITH_ACT, WITH_DACT=cfg.WITH_DACT, sA_tile=sActInput_tile, @@ -1145,12 +1138,237 @@ def _process_colwise( CACHE_ACTIVATION=self.CACHE_ACTIVATION, ) +class MXFP8QuantizeSpecializedRowwiseKernel: + """Specialized cast-only ROWWISE-only MXFP8 kernel. Requires N % 128 == 0 (full vectorizable column chunks). + + Plain rowwise-only quantize. Each thread owns one 32-element MXFP8 chunk and + uses vectorized global loads/stores (no TMA used).""" + + _TILE_Y = 4 + _TILE_X = 1024 + _THREADS_PER_CTA = 128 + + def __init__(self, cfg): + self.cfg = cfg + self.STASH_SCALE_TO_SMEM = True + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], + mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Unused, kept for API compatibility + mAmax: Optional[cute.Tensor], # Unused, kept for API compatibility + mNoop: Optional[cute.Tensor], # Unused, kept for API compatibility + mDActInput: Optional[cute.Tensor], # Unused, kept for API compatibility + mWorkspace: Optional[cute.Tensor], # Unused, kept for API compatibility + stream: CUstream, + ): + if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): + cute.printf(f"[CuTeDSL] MXFP8QuantizeSpecializedRowwiseKernel.__call__() with config: {self.cfg}\n") + + M = mX.shape[0] + N = mX.shape[1] + + grid = [ + cute.ceil_div(Int32(N), self._TILE_X), + cute.ceil_div(M, self._TILE_Y), + ] + block = [self._THREADS_PER_CTA] + + self.kernel( + mX, mO_row, mS_row, self.cfg.MAX_NORM_RCP, mX.element_type, + ).launch(grid=grid, block=block, stream=stream) + + @cute.kernel + def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + M = mX.shape[0] + N = mX.shape[1] + + # Each thread handles one 32-element MXFP8 chunk (= one scale block). + # The 128 threads in the CTA are grouped as (4, 32), so they cover a + # (4, 1024) input tile and the matching (4, 32) scale tile. + CTA_Y = self._TILE_Y + CTA_X = self._TILE_X // MXFP8_BLOCK_SCALING_SIZE + tiler, tv_layout = cute.make_layout_tv( + thr_layout=cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1)), + val_layout=cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), + stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + ) + tiler_scale, tv_layout_scale = cute.make_layout_tv( + thr_layout=cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1)), + val_layout=cute.make_layout((1, 1), stride=(1, 1)), + ) + + # Select the tile that belongs to this CTA, then the fragment per thread. + mX_tile = cute.local_tile(mX, tiler, (bidy, bidx)) + mO_tile = cute.local_tile(mO_row, tiler, (bidy, bidx)) + mS_tile = cute.local_tile(mS_row, tiler_scale, (bidy, bidx)) + mX_thread = cute.composition(mX_tile, tv_layout)[tidx, None] + mO_thread = cute.composition(mO_tile, tv_layout)[tidx, None] + mS_thread = cute.composition(mS_tile, tv_layout_scale)[tidx, None] + + rX_thread = cute.make_rmem_tensor( + cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + dtype=DTYPE, + ) + # Inputs widened to f32 once (reused by amax and the fused cvt). The FP8 + # output stays a uint8 fragment; we write it through a uint32 view so the + # 4-wide mul_cvt drops one packed word per call (see the cvt loop). + rX_f32 = cute.make_rmem_tensor( + cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + dtype=Float32, + ) + rO_thread = cute.make_rmem_tensor( + cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + dtype=Uint8, + ) + rO_u32 = cute.make_tensor( + cute.recast_ptr(rO_thread.iterator, dtype=Uint32), + cute.make_layout((MXFP8_BLOCK_SCALING_SIZE // 4,), stride=(1,)), + ) + + # Each thread owns only one scale byte, so a direct RF->GMEM scale write + # can't vectorize (128 scattered 1-byte stores). If STASH_SCALE_TO_SMEM, + # stage the CTA's (CTA_Y, CTA_X) scale tile in SMEM, then flush it to gmem + # with wide (uint32) stores. Compile-time gated. + sS_slot = None + if cutlass.const_expr(self.STASH_SCALE_TO_SMEM): + @cute.struct + class SharedStorage: + buf: cute.struct.Align[cute.struct.MemRange[Uint8, CTA_Y * CTA_X], 16] + storage = cutlass.utils.SmemAllocator().allocate(SharedStorage) + sScale = storage.buf.get_tensor(cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1))) + sS_slot = cute.composition(sScale, tv_layout_scale)[tidx, None] + # Zero first so padding columns (cols past N/32 in the padded scale + # matrix) flush as 0 and we never read uninitialized smem. + sS_slot[0] = Uint8(0) + cute.arch.sync_threads() + + row = bidy * self._TILE_Y + tidx // CTA_X + col = bidx * self._TILE_X + (tidx % CTA_X) * MXFP8_BLOCK_SCALING_SIZE + if row < M and col < N: + cute.autovec_copy(mX_thread, rX_thread) + + # Widen once and reduce. bf16/fp16 -> f32 widening is exact, so the + # amax matches the CUDA 16-bit abs_max path bit-for-bit. + amax = Float32(0.0) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SCALING_SIZE): + rX_f32[0, i] = Float32(rX_thread[0, i]) + amax = cute.arch.fmax(amax, fabs_f32(rX_f32[0, i])) + + biased_exp = cvt_f32_to_fp8e8m0(amax * max_norm_rcp) + if cutlass.const_expr(self.STASH_SCALE_TO_SMEM): + sS_slot[0] = Uint8(biased_exp) + else: + mS_thread[0] = Uint8(biased_exp) + + # Rescale + FP8 cast, 4 elements per fused mul_cvt (one uint32 out), + # then a vectorized store. Mirrors CUDA's _use_cvt_4x path. + inv_scale = exp2f_rcp(biased_exp) + scale_2x = pack_f32x2(inv_scale, inv_scale) + mul_cvt4 = mul_cvt_f32x4_to_fp8x4(self.cfg.FP8_DTYPE) + for w in cutlass.range_constexpr(MXFP8_BLOCK_SCALING_SIZE // 4): + b = 4 * w + rO_u32[w] = mul_cvt4(rX_f32[0, b], rX_f32[0, b + 1], + rX_f32[0, b + 2], rX_f32[0, b + 3], scale_2x) + cute.autovec_copy(rO_thread, mO_thread) + + # Cooperative wide flush of the staged scales: the first CTA_Y*(CTA_X/G) + # threads each store one G-wide group, the rest idle. Pick the widest store + # the runtime row pitch allows (mirrors CUDA's PreferredDataType + runtime + # check): uint4 (16 scale bytes) when padded_cols % 16 == 0, else uint32 + # (4 bytes, always safe since padded_cols is a multiple of 4). A group never + # straddles the allocation boundary (base and padded_cols are both multiples + # of G), so flush a group iff its first column is in-allocation. Zeroed + # padding columns flush as 0. + if cutlass.const_expr(self.STASH_SCALE_TO_SMEM): + cute.arch.sync_threads() + padded_cols = mS_row.shape[1] + if padded_cols % 16 == 0: + self._flush_scales(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 16) + else: + self._flush_scales(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 4) + + @cute.jit + def _flush_scales(self, sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, G): + """Flush the staged (CTA_Y, CTA_X) scale tile to gmem with G-wide stores.""" + CTA_Y = self._TILE_Y + CTA_X = self._TILE_X // MXFP8_BLOCK_SCALING_SIZE + GROUPS = CTA_X // G + _, tv_flush = cute.make_layout_tv( + thr_layout=cute.make_layout((CTA_Y, GROUPS), stride=(GROUPS, 1)), + val_layout=cute.make_layout((1, G), stride=(G, 1)), + ) + if tidx < CTA_Y * GROUPS: + frow = tidx // GROUPS + fgroup = tidx % GROUPS + if bidy * CTA_Y + frow < M and bidx * CTA_X + fgroup * G < padded_cols: + cute.autovec_copy( + cute.composition(sScale, tv_flush)[tidx, None], + cute.composition(mS_tile, tv_flush)[tidx, None], + ) + +class MXFP8QuantizeSpecializedBidimensionalKernel: + """Specialized cast-only BIDIMENSIONAL (both-direction) MXFP8 kernel — the + CuTeDSL counterpart of specialized/quantize_mxfp8.cuh:: + quantize_mxfp8_kernel_cast_only<…,true,true> (non-warp-specialized). + + Plain both-direction quantize. TMA-based 32x32 warp tiles producing both the + rowwise and colwise scales/outputs from one staged tile; handles any N % 32. + + (Kernel body — TMA pipeline + dual-direction scale/cast — is implemented in a + later round; this is a dispatch stub that only logs the routing for now.)""" + + def __init__(self, cfg): + self.cfg = cfg + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], + mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], + mAmax: Optional[cute.Tensor], + mNoop: Optional[cute.Tensor], + mDActInput: Optional[cute.Tensor], + mWorkspace: Optional[cute.Tensor], + stream: CUstream, + ): + if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): + cute.printf(f"[CuTeDSL] MXFP8QuantizeSpecializedBidimensionalKernel.__call__() with config: {self.cfg}\n") + # TODO(next round): TMA-based 32x32-tile bidimensional cast-only kernel — + # grid/launch + dual-direction (rowwise+colwise) scale/cast. No output is + # produced yet; this stub exists so dispatch routing can be wired now. + + +def get_specialized_kernel_class(cfg): + """If no fusion is involved and the kernel only quantizes, dispatch to the specialized kernel for better performance.""" + plain_cast_only = ( + not cfg.WITH_GEMM_SWIZZLED_SCALES + and not cfg.WITH_AMAX and not cfg.WITH_DBIAS + and not cfg.WITH_DACT and not cfg.WITH_ACT and not cfg.WITH_NOOP + ) + if plain_cast_only: + if cfg.ROWWISE and not cfg.COLWISE: + return MXFP8QuantizeSpecializedRowwiseKernel + if cfg.ROWWISE and cfg.COLWISE: + # TODO: dispatch to the bidimensional specialized kernel once implemented; for now, use the general kernel. + return MXFP8QuantizeKernel + return MXFP8QuantizeKernel + + def compile_cutedsl_function_from_cfg(cfg): """ Return the compiled CuTeDSL function object for the given MXFP8 quantization config. """ - kernel_obj = MXFP8QuantizeSmemKernel(cfg) + # Route plain cast-only configs to the matching specialized kernel (mirrors the + # CUDA dispatcher); everything else uses the general standard kernel. + kernel_class = get_specialized_kernel_class(cfg) + kernel_obj = kernel_class(cfg) # M, N must be divisible by the MXFP8 scale-block size (MXFP8_BLOCK_SIZE = 32) — the # same alignment the CUDA C++ kernel requires. sym_M = cute.sym_int32(divisibility=MXFP8_BLOCK_SIZE) diff --git a/transformer_engine/common/CuTeDSL/utils.py b/transformer_engine/common/CuTeDSL/utils.py index a798eba864..bc5b09cd7d 100644 --- a/transformer_engine/common/CuTeDSL/utils.py +++ b/transformer_engine/common/CuTeDSL/utils.py @@ -1,5 +1,5 @@ import cutlass -from cutlass import Float32, Int64, Int32, Int16 +from cutlass import Float32, Int64, Int32, Int16, Uint32 from cutlass._mlir.dialects import arith as mlir_arith from cutlass._mlir.dialects import llvm from cutlass.cutlass_dsl import T, dsl_user_op @@ -86,6 +86,56 @@ def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: "=l,f,f", has_side_effects=False, is_align_stack=False, asm_dialect=llvm.AsmDialect.AD_ATT)) + +def _build_mul_cvt_f32x4(out_fmt: str, relu: bool = False): + """Build a fused 4-wide `f32x4 * f32x2 -> fp8x4` PTX wrapper. + + Multiplies four f32 inputs by a broadcast inverse scale (passed as an + f32x2 pack of (s, s)) and converts to FP8, packing the four bytes into one + uint32: byte i = fp8(v_i * s). Two `mul.f32x2` + two `cvt...x2.f32` — the + 4-wide analogue of the kit's `mul_cvt_to_fp8x2` (CUDA ptx::mul_cvt_4x). + """ + out_op = "e4m3x2" if out_fmt == "e4m3" else "e5m2x2" + asm = ( + "{\n" + ".reg.b64 vp0; .reg.b64 vp1; .reg.b64 vp2; .reg.b64 vp3;\n\t" + ".reg.b32 vs0; .reg.b32 vs1; .reg.b32 vs2; .reg.b32 vs3;\n\t" + ".reg.b16 vo0; .reg.b16 vo1;\n\t" + "mov.b64 vp0, {$1, $2};\n\t" + "mov.b64 vp2, {$3, $4};\n\t" + "mul.f32x2 vp1, vp0, $5;\n\t" + "mul.f32x2 vp3, vp2, $5;\n\t" + "mov.b64 {vs0, vs1}, vp1;\n\t" + "mov.b64 {vs2, vs3}, vp3;\n\t" + # cvt d, a, b => d[15:8]=fp8(a), d[7:0]=fp8(b); feed (hi, lo) so the low + # byte holds the earlier element. + f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 vo0, vs1, vs0;\n\t" + f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 vo1, vs3, vs2;\n\t" + "mov.b32 $0, {vo0, vo1};\n\t" + "}" + ) + + @dsl_user_op + def fn(v0: Float32, v1: Float32, v2: Float32, v3: Float32, scale_2x: Int64, + *, loc=None, ip=None) -> Uint32: + return Uint32(llvm.inline_asm( + T.i32(), + [v0.ir_value(loc=loc, ip=ip), v1.ir_value(loc=loc, ip=ip), + v2.ir_value(loc=loc, ip=ip), v3.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=r,f,f,f,f,l", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + return fn + + +def mul_cvt_f32x4_to_fp8x4(fp8_dtype: str, relu: bool = False): + """Return the fused 4-wide f32->FP8 multiply+cast op for the given FP8 format. + + The op takes (v0, v1, v2, v3, scale_2x) and returns a uint32 of four packed + fp8 bytes, byte i = fp8(v_i * scale). `scale_2x` is pack_f32x2(s, s).""" + return _build_mul_cvt_f32x4("e5m2" if fp8_dtype == "e5m2" else "e4m3", relu) + def _build_packed16_kit(in_fmt: str): """Build a kit of PTX wrappers for a 16-bit input format so we don't have to repeat the same inline asm boilerplate code for FP16 and BF16 dtypes. diff --git a/transformer_engine/common/CuTeDSL/utils_fp8.py b/transformer_engine/common/CuTeDSL/utils_fp8.py index 0c2e901afc..4d09f2434c 100644 --- a/transformer_engine/common/CuTeDSL/utils_fp8.py +++ b/transformer_engine/common/CuTeDSL/utils_fp8.py @@ -1,3 +1,7 @@ +import logging +import os +import re + import cutlass import cutlass.cute as cute from cutlass import Float32, Int64, Int32, Int16, Uint8, Uint32 @@ -7,6 +11,8 @@ from transformer_engine.common.CuTeDSL.utils import FP32_MANTISSA_BITS, _bitcast_f32_to_i32 +logger = logging.getLogger("transformer_engine.cutedsl.utils_fp8") + @dsl_user_op def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: """float32 -> fp8e4m3 conversion.""" @@ -38,8 +44,11 @@ def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: @dsl_user_op -def cvt_f32_to_fp8e8m0(val: Float32, *, loc=None, ip=None) -> Int32: - """float32 -> fp8e8m0 conversion.""" +def cvt_f32_to_fp8e8m0_non_blackwell(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e8m0 conversion (generic, pre-Blackwell). + + Software round-up of the biased exponent, mirroring ptx::float_to_e8m0's + non-Blackwell branch (transformer_engine/common/util/ptx.cuh).""" val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) rounded = val_i32 + Int32(0x7FFFFF) exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) @@ -48,6 +57,58 @@ def cvt_f32_to_fp8e8m0(val: Float32, *, loc=None, ip=None) -> Int32: Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) +@dsl_user_op +def cvt_f32_to_fp8e8m0_blackwell(val: Float32, *, loc=None, ip=None) -> Int32: + """float32 -> fp8e8m0 conversion (Blackwell, SM >= 100). + + Uses the hardware cvt.rp.satfinite.ue8m0x2.f32 instruction, mirroring + ptx::float_to_e8m0's Blackwell branch. The x2 form packs two e8m0 bytes; + we feed (0.0, val) so the low byte is e8m0(val) and mask it out.""" + zero = Float32(0.0) + result_i16 = Int16(llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rp.satfinite.ue8m0x2.f32 $0, $1, $2;", + "=h,f,f", has_side_effects=False, is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT)) + result_i32 = Int32(mlir_arith.extui( + T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return result_i32 & Int32(0xFF) + + +def _target_arch_is_blackwell() -> bool: + """Return True for the Blackwell family (SM 10.0 / 11.0 / 12.0), which has the + cvt.*.ue8m0x2.f32 hardware instruction. This mirrors the CUDA reference's + ARCH_BLACKWELL_FAMILY gate (FamilySpecific<100/110/120> in + transformer_engine/common/util/ptx.cuh) -- a family check, since the + instruction is available across the family (verified on sm_120a) even though + e.g. tcgen05 is not. + + The gate is the *compile target*, not the physical device, since that is what + decides whether the instruction codegens: CUTE_DSL_ARCH if set (what + cute.compile uses), else the current device's compute capability. Falls back + to the non-Blackwell software path if the arch can't be determined.""" + try: + arch = os.getenv("CUTE_DSL_ARCH") # e.g. "sm_120a", the explicit compile target + if arch: + major_minor = re.search(r"(\d+)", arch).group(1) # "120" + else: + from cuda.core import Device + major_minor = Device().arch # compute capability as digits, e.g. "120" + # Trailing digit is the minor version; the rest is the major version. + return int(major_minor[:-1]) in (10, 11, 12) + except Exception as e: # pragma: no cover - detection is best-effort + logger.debug("e8m0 arch detection failed (%s); using software path", e) + return False + + +# Pick the appropriate float32 -> fp8e8m0 conversion function based on the target architecture. +# Blackwell (SM >= 100) has a hardware instruction for this, while older architectures require a software implementation. +cvt_f32_to_fp8e8m0 = ( + cvt_f32_to_fp8e8m0_blackwell if _target_arch_is_blackwell() else cvt_f32_to_fp8e8m0_non_blackwell +) + + @dsl_user_op def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None) -> Int32: @@ -82,14 +143,14 @@ def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) -def cvt_f32_to_fp8(fp8_dtype: str): +def get_cvt_f32_to_fp8_func(fp8_dtype: str): """Returns the float32 -> float8 conversion function for the given FP8 format.""" if fp8_dtype == "e5m2": return cvt_f32_to_fp8e5m2 return cvt_f32_to_fp8e4m3 -def cvt_f32x2_to_fp8x2(fp8_dtype: str): +def get_cvt_f32x2_to_fp8x2_func(fp8_dtype: str): """Returns the float32x2 -> float8x2 conversion function for the given FP8 format.""" if fp8_dtype == "e5m2": return cvt_f32x2_to_fp8e5m2x2 From 2d58e73b1d306801356329e1e3763f99af5abb65 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Wed, 24 Jun 2026 19:52:08 +0000 Subject: [PATCH 10/22] nit --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 7 +++++-- transformer_engine/common/tvm_ffi_bridge.h | 21 +++++++++++++++---- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 7572740f94..e29377defa 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -676,9 +676,9 @@ def _kernel_main( cfg = self.cfg if cutlass.const_expr(cfg.ROWWISE): - mS_row = cute.zipped_divide(mS_row, (self._TILE_Y, self._TILE_X // self._MXFP8_BLOCK_SIZE)) + mS_row = cute.zipped_divide(mS_row, (self._TILE_Y, self._TILE_X // MXFP8_BLOCK_SIZE)) if cutlass.const_expr(cfg.COLWISE): - mS_col = cute.zipped_divide(mS_col, (self._TILE_Y // self._MXFP8_BLOCK_SIZE, self._TILE_X)) + mS_col = cute.zipped_divide(mS_col, (self._TILE_Y // MXFP8_BLOCK_SIZE, self._TILE_X)) # Allocate shared memory for the input and rowwise / columnwise outputs if cutlass.const_expr(cfg.ROWWISE and cfg.COLWISE): @@ -1230,6 +1230,9 @@ def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): cute.make_layout((MXFP8_BLOCK_SCALING_SIZE // 4,), stride=(1,)), ) + # TODO: review this kernel myself. + + # Each thread owns only one scale byte, so a direct RF->GMEM scale write # can't vectorize (128 scattered 1-byte stores). If STASH_SCALE_TO_SMEM, # stage the CTA's (CTA_Y, CTA_X) scale tile in SMEM, then flush it to gmem diff --git a/transformer_engine/common/tvm_ffi_bridge.h b/transformer_engine/common/tvm_ffi_bridge.h index 755c4266bc..5f53235ed3 100644 --- a/transformer_engine/common/tvm_ffi_bridge.h +++ b/transformer_engine/common/tvm_ffi_bridge.h @@ -211,7 +211,7 @@ class TVMFFICentral { static_assert(detail::is_lazyloadable_config::value, "Config must define `std::string to_key() const` and " "`bool retrieve_func_from_python(const std::string&) const`."); - if (!enabled_) return std::nullopt; + if (!cutedsl_backend_enabled_) return std::nullopt; const std::string key = cfg.to_key(); { std::shared_lock read_lock(mutex_); @@ -228,12 +228,19 @@ class TVMFFICentral { std::unique_lock write_lock(mutex_); supported_.emplace(key, supported); } - return supported ? tvm::ffi::Function::GetGlobal(key) : std::nullopt; + if (supported) { + return tvm::ffi::Function::GetGlobal(key); + } + if (warn_unsupported_kernels_) { + NVTE_WARN("TVM-FFI kernel for config `", key, "` is not supported."); + } + return std::nullopt; } private: ~TVMFFICentral() = default; - TVMFFICentral() : enabled_(is_cutedsl_backend_enabled()) {} + TVMFFICentral() : cutedsl_backend_enabled_(is_cutedsl_backend_enabled()), + warn_unsupported_kernels_(warn_if_cutedsl_backend_unsupported()) {} TVMFFICentral(const TVMFFICentral &) = delete; TVMFFICentral &operator=(const TVMFFICentral &) = delete; TVMFFICentral(TVMFFICentral &&) = delete; @@ -244,8 +251,14 @@ class TVMFFICentral { const char *flag = std::getenv("NVTE_ENABLE_CUTEDSL_QUANT_BACKEND"); return flag != nullptr && flag[0] != '0'; } + + static bool warn_if_cutedsl_backend_unsupported() { + const char *flag = std::getenv("NVTE_WARN_IF_CUTEDSL_BACKEND_UNSUPPORTED"); + return flag != nullptr && flag[0] != '0'; + } - const bool enabled_; + const bool cutedsl_backend_enabled_; + const bool warn_unsupported_kernels_; std::shared_mutex mutex_; // Per-config support decision (cfg.to_key() -> supported). Holds NO Python- // backed handles, so it is safe to destroy at static teardown — the kernels From 840009c0078397ca01aa17746edd3ac5866f7df0 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Thu, 25 Jun 2026 02:59:11 +0000 Subject: [PATCH 11/22] fix register spill --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index e29377defa..82f2c3e251 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -152,7 +152,7 @@ def quantize_rowwise_mxfp8( # to avoid bank conflict. bank_group = (tidx % THREADS_PER_WARP) // THREADS_PER_BANK # The offset this thread should start reading from based on what's its first bank to access. - offset = bank_group * 4 # Each bank group will read 4 f16 from their bank + offset = bank_group * PACK_SIZE if cutlass.const_expr(_row_fast): # If no activation, f16 / bf16 and rowwise quantization, we can read 2 f16 / bf16 at once in a pack # and use max.xorsign.abs.f16x2 / max.xorsign.abs.bf16x2 to compute @@ -230,7 +230,8 @@ def quantize_rowwise_mxfp8( x = op(x) # Accumulate to the per-thread dbias register buffer for this tile if WITH_DBIAS if cutlass.const_expr(WITH_DBIAS): - dbias_acc[start + i] += x + # dbias_acc is register buffer so we can just write without bank conflict + dbias_acc[w * PACK_SIZE + i] += x # If 16-bit input with activation, truncate to IType if cutlass.const_expr(is_packed16(DTYPE) and ACTIVATION is not None): x = kit_act.truncate_f32(x) @@ -1017,9 +1018,16 @@ class DbiasStorage: val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), ) sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) - # All threads write their per-thread partial sum results to the shared buffer. - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - sDbias_write[(tidx, i)] = rowwise_dbias_acc[i] + # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks + # to avoid bank conflict. + bank_group = (tidx % THREADS_PER_WARP) // self._THREADS_PER_BANK + # The offset this thread should start reading from based on what's its first bank to access. + offset = bank_group * self._PACK_SIZE + for w in cutlass.range_constexpr(self._WAVES): # Each thread starts from this offset when writing into SMEM to avoid bank conflict + start = (w * self._PACK_SIZE + offset) % MXFP8_BLOCK_SIZE + for i in cutlass.range_constexpr(self._PACK_SIZE): + # All threads write their per-thread partial sum results to the shared buffer. + sDbias_write[(tidx, start + i)] = rowwise_dbias_acc[w * self._PACK_SIZE + i] cute.arch.sync_threads() # All threads reduce the cross-thread partial sums to the per-block partial sum. _, tv_layout_dbias_reduce = cute.make_layout_tv( From 869c97d0838ef9b6fdd18e508b1d55b50567c631 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Fri, 26 Jun 2026 19:26:40 +0000 Subject: [PATCH 12/22] fix registeration --- transformer_engine/__init__.py | 6 ++++++ transformer_engine/jax/__init__.py | 6 +----- transformer_engine/pytorch/__init__.py | 6 +----- 3 files changed, 8 insertions(+), 10 deletions(-) diff --git a/transformer_engine/__init__.py b/transformer_engine/__init__.py index 480a2e9a06..0e9bfa701b 100644 --- a/transformer_engine/__init__.py +++ b/transformer_engine/__init__.py @@ -13,6 +13,12 @@ from typing import Optional, Tuple import transformer_engine.common +# Register the CuTeDSL quantize backend (TVM-FFI entrypoints). Framework-agnostic, +# so do it once here rather than per framework — it applies no matter whether +# pytorch or jax is used. Optional: a no-op fallback to the CUDA kernels if the +# CuTeDSL toolchain isn't installed. +transformer_engine.common.register_cutedsl_quant_backend() + # Minimum NCCL version for the statically-linked NCCL EP backend. _NCCL_EP_MIN_VERSION = (2, 30, 4) diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index c31abd5f54..d0afc1ff25 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -27,14 +27,10 @@ # extensions are not available. import jax -from transformer_engine.common import load_framework_extension, register_cutedsl_quant_backend +from transformer_engine.common import load_framework_extension load_framework_extension("jax") -# Register the CuTeDSL quantize backend entrypoints (TVM-FFI). Optional; falls -# back to the CUDA kernels if the CuTeDSL toolchain isn't installed. -register_cutedsl_quant_backend() - from . import flax from . import quantize diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 61a8ee0edf..5847f5caa1 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -10,7 +10,7 @@ import torch -from transformer_engine.common import load_framework_extension, register_cutedsl_quant_backend +from transformer_engine.common import load_framework_extension from transformer_engine.pytorch.torch_version import torch_version assert torch_version() >= (2, 1), f"Minimum torch version 2.1 required. Found {torch_version()}." @@ -19,10 +19,6 @@ from transformer_engine.pytorch import constants from transformer_engine.pytorch.constants import DType -# Register the CuTeDSL quantize backend entrypoints (TVM-FFI). Optional; falls -# back to the CUDA kernels if the CuTeDSL toolchain isn't installed. -register_cutedsl_quant_backend() - from transformer_engine.pytorch.module import LayerNormLinear from transformer_engine.pytorch.module import Linear from transformer_engine.pytorch.module import LayerNormMLP From d1cb5e8c0bb62fcbd7f30c5767be6b29b306e92f Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Fri, 26 Jun 2026 19:29:31 +0000 Subject: [PATCH 13/22] fix --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 39 ++++++++----------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 82f2c3e251..fc5f40f553 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -57,7 +57,6 @@ get_cvt_f32x2_to_fp8x2_func, cvt_f32_to_fp8e8m0 ) -from transformer_engine.pytorch.constants import MXFP8_BLOCK_SCALING_SIZE CUTEDSL_DEBUG_LOGGING = os.environ.get("CUTEDSL_DEBUG_LOGGING", "0") == "1" @@ -331,9 +330,6 @@ def quantize_colwise_mxfp8( sX_thread = sX_tv[tidx, None] sO_thread = sO_tv[tidx, None] - # dbias needs the per-element fp32 values to sum, so it takes the f32 path - # (never the i16 fast path) — matching CUDA, whose f16 fast path requires - # `!IS_DBIAS` (quantize_mxfp8.cuh:219). USE_HALF_PRECISION = is_packed16(DTYPE) and ACTIVATION is None dbias_partial = Float32(0.0) @@ -344,9 +340,6 @@ def quantize_colwise_mxfp8( cute.recast_ptr(sX_thread.iterator, dtype=Int16), cute.make_layout((MXFP8_BLOCK_SIZE,), stride=(TILE_X,)), ) - if cutlass.const_expr(WITH_DBIAS): - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - dbias_partial += kit.bits_to_f32(sX_thread_i16[i]) amax_bits = Int16(0) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): amax_bits = kit.abs_max_scalar(amax_bits, sX_thread_i16[i]) @@ -369,10 +362,6 @@ def quantize_colwise_mxfp8( op = SUPPORTED_ACTIVATIONS[ACTIVATION] for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): sX_thread_f32[i] = op(sX_thread_f32[i]) - # Accumulate the per-thread column partial for dbias if WITH_DBIAS. - if cutlass.const_expr(WITH_DBIAS): - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - dbias_partial += sX_thread_f32[i] # Truncate the activation (after we apply op) back to the half precision type if input is also half precision. if cutlass.const_expr(is_packed16(DTYPE) and ACTIVATION is not None): kit_act = packed16_kit(DTYPE) @@ -404,9 +393,15 @@ def quantize_colwise_mxfp8( kit_cast = packed16_kit(DTYPE) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): v_f32 = kit_cast.bits_to_f32(sX_thread_i16[i]) + if cutlass.const_expr(WITH_DBIAS): + dbias_partial += v_f32 sO_thread[i] = Uint8(cvt_to_fp8_func(v_f32 * inv_scale_c)) else: for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + # Accumulate the per-thread column partial for dbias if WITH_DBIAS. + if cutlass.const_expr(WITH_DBIAS): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): + dbias_partial += sX_thread_f32[i] sO_thread[i] = Uint8(cvt_to_fp8_func(sX_thread_f32[i] * inv_scale_c)) # Return this stage's per-column partial alongside amax; the caller accumulates @@ -1199,11 +1194,11 @@ def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): # The 128 threads in the CTA are grouped as (4, 32), so they cover a # (4, 1024) input tile and the matching (4, 32) scale tile. CTA_Y = self._TILE_Y - CTA_X = self._TILE_X // MXFP8_BLOCK_SCALING_SIZE + CTA_X = self._TILE_X // MXFP8_BLOCK_SIZE tiler, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1)), - val_layout=cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), - stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), + stride=(MXFP8_BLOCK_SIZE, 1)), ) tiler_scale, tv_layout_scale = cute.make_layout_tv( thr_layout=cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1)), @@ -1219,23 +1214,23 @@ def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): mS_thread = cute.composition(mS_tile, tv_layout_scale)[tidx, None] rX_thread = cute.make_rmem_tensor( - cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), dtype=DTYPE, ) # Inputs widened to f32 once (reused by amax and the fused cvt). The FP8 # output stays a uint8 fragment; we write it through a uint32 view so the # 4-wide mul_cvt drops one packed word per call (see the cvt loop). rX_f32 = cute.make_rmem_tensor( - cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), dtype=Float32, ) rO_thread = cute.make_rmem_tensor( - cute.make_layout((1, MXFP8_BLOCK_SCALING_SIZE), stride=(MXFP8_BLOCK_SCALING_SIZE, 1)), + cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), dtype=Uint8, ) rO_u32 = cute.make_tensor( cute.recast_ptr(rO_thread.iterator, dtype=Uint32), - cute.make_layout((MXFP8_BLOCK_SCALING_SIZE // 4,), stride=(1,)), + cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), ) # TODO: review this kernel myself. @@ -1259,14 +1254,14 @@ class SharedStorage: cute.arch.sync_threads() row = bidy * self._TILE_Y + tidx // CTA_X - col = bidx * self._TILE_X + (tidx % CTA_X) * MXFP8_BLOCK_SCALING_SIZE + col = bidx * self._TILE_X + (tidx % CTA_X) * MXFP8_BLOCK_SIZE if row < M and col < N: cute.autovec_copy(mX_thread, rX_thread) # Widen once and reduce. bf16/fp16 -> f32 widening is exact, so the # amax matches the CUDA 16-bit abs_max path bit-for-bit. amax = Float32(0.0) - for i in cutlass.range_constexpr(MXFP8_BLOCK_SCALING_SIZE): + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): rX_f32[0, i] = Float32(rX_thread[0, i]) amax = cute.arch.fmax(amax, fabs_f32(rX_f32[0, i])) @@ -1281,7 +1276,7 @@ class SharedStorage: inv_scale = exp2f_rcp(biased_exp) scale_2x = pack_f32x2(inv_scale, inv_scale) mul_cvt4 = mul_cvt_f32x4_to_fp8x4(self.cfg.FP8_DTYPE) - for w in cutlass.range_constexpr(MXFP8_BLOCK_SCALING_SIZE // 4): + for w in cutlass.range_constexpr(MXFP8_BLOCK_SIZE // 4): b = 4 * w rO_u32[w] = mul_cvt4(rX_f32[0, b], rX_f32[0, b + 1], rX_f32[0, b + 2], rX_f32[0, b + 3], scale_2x) @@ -1307,7 +1302,7 @@ class SharedStorage: def _flush_scales(self, sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, G): """Flush the staged (CTA_Y, CTA_X) scale tile to gmem with G-wide stores.""" CTA_Y = self._TILE_Y - CTA_X = self._TILE_X // MXFP8_BLOCK_SCALING_SIZE + CTA_X = self._TILE_X // MXFP8_BLOCK_SIZE GROUPS = CTA_X // G _, tv_flush = cute.make_layout_tv( thr_layout=cute.make_layout((CTA_Y, GROUPS), stride=(GROUPS, 1)), From 6cf1d6f1fac59229edacc1731a27330ae80e378a Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Fri, 26 Jun 2026 21:12:32 +0000 Subject: [PATCH 14/22] fix for cast_dbias_only --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 42 ++++++++++++------- .../cast/mxfp8/quantize_mxfp8_cutedsl.cuh | 22 +++++----- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index fc5f40f553..773ba56233 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -108,6 +108,7 @@ def quantize_rowwise_mxfp8( ACTIVATION, DTYPE, FP8_DTYPE, + TILE_X, TILE_Y, WAVES, THREADS_PER_BANK, @@ -119,8 +120,11 @@ def quantize_rowwise_mxfp8( ): tidx, _, _ = cute.arch.thread_idx() + CTA_THREADS_Y = TILE_Y # threads per column (rows per tile) + CTA_THREADS_X = TILE_X // MXFP8_BLOCK_SIZE # threads per row (chunks per row) + _, tv_layout = cute.make_layout_tv( - thr_layout=cute.make_layout((TILE_Y, 2), stride=(2, 1)), + thr_layout=cute.make_layout((CTA_THREADS_Y, CTA_THREADS_X), stride=(CTA_THREADS_X, 1)), val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)) ) @@ -252,10 +256,10 @@ def quantize_rowwise_mxfp8( # For irregular shapes, skip the scale store if this thread's logical row / col-block lies past the input's actual extents. # TMA already zero-fills OOB input reads and drops OOB output writes; only the direct scale-byte gmem store needs an explicit guard. - scale_row = tile_row_start + tidx // 2 - scale_col_first_elt = tile_col_start + (tidx % 2) * MXFP8_BLOCK_SIZE + scale_row = tile_row_start + tidx // CTA_THREADS_X + scale_col_first_elt = tile_col_start + (tidx % CTA_THREADS_X) * MXFP8_BLOCK_SIZE if scale_row < M and scale_col_first_elt < N: - mS_row_stage[(tidx // 2, tidx % 2)] = Uint8(biased_exp_r) + mS_row_stage[(tidx // CTA_THREADS_X, tidx % CTA_THREADS_X)] = Uint8(biased_exp_r) inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale # Fetch the conversion function based on the FP8 format @@ -492,16 +496,24 @@ class MXFP8QuantizeKernel: # Tiling sizes _NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) - _NUM_TILES = 2 # Each CTA process 2 tiles along the Y (row, slowest-changing) dimension - _TILE_Y = 32 # Each tile has 32 rows, so each CTA handles 32 * 2 rows in total - _TILE_X = 64 # Each tile has 64 columns - - # CTA size - _THREADS_PER_CTA = 64 - _NUM_WARPS = _THREADS_PER_CTA // 32 def __init__(self, cfg): self.cfg = cfg + # Cast + dbias with no activation gets the larger tile (CUDA CAST_DBIAS_ONLY). + cast_dbias_only = cfg.WITH_DBIAS and not cfg.WITH_DACT and not cfg.WITH_ACT + # Use a different tile size for dbias only config + # No matter what tile size we use, each thread always handles a (1, MXFP8_BLOCK_SIZE) chunk + if cast_dbias_only: + self._NUM_TILES = 4 + self._THREADS_PER_CTA = 128 + self._TILE_X = 128 + self._TILE_Y = 32 + else: + self._NUM_TILES = 2 + self._THREADS_PER_CTA = 64 + self._TILE_X = 64 + self._TILE_Y = 32 + self._NUM_WARPS = self._THREADS_PER_CTA // 32 # We prefer to do dbias reduction in colwise which is easier (no cross-thread reduction needed). # Only do rowwise reduction when we don't quantize columnwisely when WITH_DBIAS is True. self.DBIAS_REDUCTION_COLWISE = cfg.WITH_DBIAS and cfg.COLWISE @@ -1004,12 +1016,9 @@ class DbiasStorage: sDbias = dbias_storage.sDbias.get_tensor( cute.make_layout((self._TILE_Y, self._TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), ) - # Thread layout: (TILE_Y, 2); value layout: (1, MXFP8_BLOCK_SIZE) where TILE_X = 2 * MXFP8_BLOCK_SIZE - # And each thread writes the (1, MXFP8_BLOCK_SIZE) partial sum to this (TILE_Y, TILE_X) buffer - # and then each thread reads its (TILE_Y, 1) sDbias column and writes the reduced sum to the GMEM workspace. - # Since TILE_X == THREADS_PER_CTA, this column reduction yields (TILE_Y, TILE_X) -> (1, TILE_X). _, tv_layout_dbias_write = cute.make_layout_tv( - thr_layout=cute.make_layout((self._TILE_Y, 2), stride=(2, 1)), + thr_layout=cute.make_layout((self._TILE_Y, self._TILE_X // MXFP8_BLOCK_SIZE), + stride=(self._TILE_X // MXFP8_BLOCK_SIZE, 1)), val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), ) sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) @@ -1096,6 +1105,7 @@ def _process_rowwise( ACTIVATION=None if self.CACHE_ACTIVATION else cfg.ACTIVATION, DTYPE=cfg.DTYPE, FP8_DTYPE=cfg.FP8_DTYPE, + TILE_X=self._TILE_X, TILE_Y=self._TILE_Y, WAVES=self._WAVES, THREADS_PER_BANK=self._THREADS_PER_BANK, diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh index 16bea8305e..0ebef35c05 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh @@ -146,15 +146,20 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, return false; } + // When only WITH_DBIAS is true, we use a larger tile size (align with CUDA C++ implementation) + const bool cast_dbias_only = config.with_dbias && !config.with_dact && !config.with_act; + const size_t chunk_rows = cast_dbias_only ? 128 : 64; // input rows reduced per CTA + // Each CTA writes one partial-dbias row, so the workspace (and the cross-CTA + // reduction below) has ceil(M / chunk_rows) rows. + const size_t workspace_rows = (flat_m + chunk_rows - 1) / chunk_rows; + // dbias workspace-size query, mirroring mxfp8::quantize: the framework first // calls with an unallocated workspace to learn its shape, allocates a buffer of // that shape, then calls again to run. The kernel writes per-row-block partial // dbias into this workspace; reducing it to the final dbias is a separate step. if (config.with_dbias && workspace_tensor != nullptr && workspace_tensor->data.dptr == nullptr) { - constexpr size_t kCuTeDSLMXFP8ChunkRows = 64; // TILE_Y * NUM_TILES (CTA row span) - const size_t dbias_rows = (flat_m + kCuTeDSLMXFP8ChunkRows - 1) / kCuTeDSLMXFP8ChunkRows; - workspace_tensor->data.shape = {dbias_rows, flat_n}; + workspace_tensor->data.shape = {workspace_rows, flat_n}; workspace_tensor->data.dtype = DType::kFloat32; return true; } @@ -167,12 +172,10 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, // Zero out swizzled scale padding when the matrix isn't a multiple of the // 128x128 GEMM tile. The kernel writes only the meaningful scale region, so - // cuBLAS would otherwise read uninitialized padding. Mirrors the CUDA launcher - // in quantize_mxfp8.cuh (the kernel itself does not pad the scales). + // cuBLAS would otherwise read uninitialized padding. - // TODO: move this into the CuTeDSL host code so the padding is handled inside - // the kernel launch — this CUDA-driver memset is an implementation detail that - // doesn't belong in the dispatcher (blocked on calling the driver API there). + // TODO: see if it's possible to move this into the CuTeDSL host code so the padding is handled inside + // the kernel launch so it's more flexible if (config.swizzled && (flat_m % 128 != 0 || flat_n % 128 != 0)) { if (output_tensor->has_data()) { NVTE_CHECK_CUDA(cudaMemsetAsync(output_tensor->scale_inv.dptr, 0, @@ -205,11 +208,10 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, // If WITH_DBIAS, reduce the workspace partial dbias in CUDA C++ for now. if (config.with_dbias) { - const size_t blocks_Y = (flat_m + 63) / 64; // ceil(M/64) = workspace rows const float *workspace_ptr = reinterpret_cast(workspace_tensor->data.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input_tensor->dtype(), IType, - dispatch::common::reduce_dbias(workspace_ptr, dbias_tensor, blocks_Y, flat_n, + dispatch::common::reduce_dbias(workspace_ptr, dbias_tensor, workspace_rows, flat_n, stream);) // NOLINT(*) } return true; From 69b7240ee9583340f15bc4051837f1d3d064c8a0 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Sat, 27 Jun 2026 06:22:36 +0000 Subject: [PATCH 15/22] fix deadlock --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 234 ++++++++++-------- 1 file changed, 135 insertions(+), 99 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 773ba56233..246c9dac5a 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -404,8 +404,7 @@ def quantize_colwise_mxfp8( for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): # Accumulate the per-thread column partial for dbias if WITH_DBIAS. if cutlass.const_expr(WITH_DBIAS): - for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE): - dbias_partial += sX_thread_f32[i] + dbias_partial += sX_thread_f32[i] sO_thread[i] = Uint8(cvt_to_fp8_func(sX_thread_f32[i] * inv_scale_c)) # Return this stage's per-column partial alongside amax; the caller accumulates @@ -493,9 +492,7 @@ class MXFP8QuantizeKernel: _WAVES = MXFP8_BLOCK_SIZE // _PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total _TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) _THREADS_PER_BANK = _TOTAL_BANKS_WIDTH // MXFP8_BLOCK_SIZE # 4 threads per bank - - # Tiling sizes - _NUM_STAGES = 2 # Pipeline depth of the producer/consumer ring buffer for the TMA-G2S input loads (PipelineTmaAsync stage count) + _NUM_STAGES = 2 # The pipeline depth is always 2 def __init__(self, cfg): self.cfg = cfg @@ -871,26 +868,27 @@ class DactStorage: # Ensure barrier init is visible to all threads before the pipeline is used. cute.arch.sync_threads() - # ---- Producer: warp 0 issues one TMA copy per tile. ---- + # Prologue: warp 0 prefetches up to NUM_STAGES tiles to fully fill the pipeline if warp_idx == 0: - for stage in cutlass.range(num_tiles, unroll=1): - mainloop_pipeline.producer_acquire(prod_state) - tile_y = bidy * self._NUM_TILES + stage - cute.copy( - tma_atom, - tXgX[(None, (tile_y, bidx))], - tXsX[(None, prod_state.index)], - tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), - ) - if cutlass.const_expr(cfg.WITH_DACT): + for s in cutlass.range_constexpr(self._NUM_STAGES): + if s < num_tiles: + mainloop_pipeline.producer_acquire(prod_state) + tile_y = bidy * self._NUM_TILES + s cute.copy( - tma_atom_act, - tXgA[(None, (tile_y, bidx))], - tXsA[(None, prod_state.index)], + tma_atom, + tXgX[(None, (tile_y, bidx))], + tXsX[(None, prod_state.index)], tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), ) - mainloop_pipeline.producer_commit(prod_state) - prod_state.advance() + if cutlass.const_expr(cfg.WITH_DACT): + cute.copy( + tma_atom_act, + tXgA[(None, (tile_y, bidx))], + tXsA[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + mainloop_pipeline.producer_commit(prod_state) + prod_state.advance() # Per-thread amax accumulator if cutlass.const_expr(cfg.WITH_AMAX): @@ -918,22 +916,31 @@ class DactStorage: if cutlass.const_expr(self.DBIAS_REDUCTION_COLWISE): block_dbias = Float32(0.0) - # ---- Consumer: all threads quantize each completed tile. ---- - for stage in cutlass.range(num_tiles, unroll=1): + # Consumer: all threads fetch from the pipeline, and + for tile_idx in cutlass.range(num_tiles, unroll=1): mainloop_pipeline.consumer_wait(cons_state) - sX_tile = sX[(None, stage)] + # Only allow at most _NUM_STAGES-1 stages to be in-flight, because this iteration will reuse the ring buffer + # that is read _NUM_STAGES iterations ago + if warp_idx == 0: + cute.arch.cp_async_bulk_wait_group(self._NUM_STAGES - 1, read=True) + cute.arch.sync_threads() + # The current pipeline stage index, which is the tile index modulo the number of stages. + # This is used to index into the shared memory ring buffers that are wrapped around the number of stages. + stage_idx = cons_state.index + sX_tile = sX[(None, stage_idx)] + # Also fetch the activation input if WITH_DACT sActInput_tile = None if cutlass.const_expr(cfg.WITH_DACT): - sActInput_tile = sActInput[(None, stage)] + sActInput_tile = sActInput[(None, stage_idx)] + # Each CTA handles `NUM_TILES` tiles stacked vertically, so tile_idx_x is just the block index along X dimension + # and tile_idx_y is the tile that this stage handles out of the `NUM_TILES` tiles tile_idx_x = bidx - # Each CTA has `NUM_TILES` tiles. Each stage we need to obtain the tile for that specific stage. - # So the tile index along Y dimension is `bidy * NUM_TILES + stage` - tile_idx_y = bidy * self._NUM_TILES + stage + tile_idx_y = bidy * self._NUM_TILES + tile_idx # Process rowwise and colwise quantization separately if cutlass.const_expr(cfg.COLWISE): # The first row that belongs to this CTA. Each CTA handles NUM_TILES of (TILE_Y, TILE_X) tiles stacked vertically, # and each stage handles one of them. - sO_col_tile = sO_col[(None, stage)] + sO_col_tile = sO_col[(None, stage_idx)] mS_col_stage = cute.flatten(mS_col[(None, (tile_idx_y, tile_idx_x))]) amax_c, dbias_c = self._process_colwise( @@ -951,19 +958,14 @@ class DactStorage: if cutlass.const_expr(self.CACHE_ACTIVATION): cute.arch.sync_threads() if cutlass.const_expr(cfg.ROWWISE): - sO_row_tile = sO_row[(None, stage)] + sO_row_tile = sO_row[(None, stage_idx)] # mS_row is ((SCALE_TILE), (GRID)) where SCALE_TILE = (32, 2). # Each CTA owns NUM_TILES consecutive row-tiles of GRID. cute - # auto-decomposes the flat row coord `bidy * NUM_TILES + stage` + # auto-decomposes the flat row coord `bidy * NUM_TILES + tile_idx` # onto GRID's hierarchical row modes — which is the # (i_hi, tile_Y) tile-major order for swizzled, and the plain # row-tile order for compact. Same source, both layouts correct. mS_row_stage = cute.flatten(mS_row[(None, (tile_idx_y, tile_idx_x))]) - # print(f"s0_row_tile: {sO_row_tile}\n") - # print(f"sO_row: {sO_row}\n") - # print(f"mS_row: {mS_row}\n") - # print(f"mS_row_stage: {mS_row_stage}\n") - # print(f"mS_row_stage: {mS_row_stage}\n") amax_r = self._process_rowwise( sX_tile, sO_row_tile, mS_row_stage, max_norm_rcp, @@ -984,17 +986,17 @@ class DactStorage: # Warp 0 issues TMA copy to write the quantized output tile from shared memory to global memory and then commits if warp_idx == 0: - tile_y = bidy * self._NUM_TILES + stage + tile_y = bidy * self._NUM_TILES + tile_idx if cutlass.const_expr(cfg.ROWWISE): cute.copy( tma_atom_out_row, - tXsO_row[(None, stage)], + tXsO_row[(None, stage_idx)], tXgO_row[(None, (tile_y, bidx))], ) if cutlass.const_expr(cfg.COLWISE): cute.copy( tma_atom_out_col, - tXsO_col[(None, stage)], + tXsO_col[(None, stage_idx)], tXgO_col[(None, (tile_y, bidx))], ) cute.arch.cp_async_bulk_commit_group() @@ -1002,48 +1004,35 @@ class DactStorage: mainloop_pipeline.consumer_release(cons_state) cons_state.advance() + # The pipeline is no longer fully filled after we consume this tile, so we fetch a new tile to fill the pipeline. + # The next _NUM_STAGES-1 tiles are already in-flight, so the next tile to fetch is after _NUM_STAGES tiles. + if warp_idx == 0: + next_tile_idx = tile_idx + self._NUM_STAGES + if next_tile_idx < num_tiles: + mainloop_pipeline.producer_acquire(prod_state) + tile_y = bidy * self._NUM_TILES + next_tile_idx + cute.copy( + tma_atom, + tXgX[(None, (tile_y, bidx))], + tXsX[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + if cutlass.const_expr(cfg.WITH_DACT): + cute.copy( + tma_atom_act, + tXgA[(None, (tile_y, bidx))], + tXsA[(None, prod_state.index)], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(prod_state), + ) + mainloop_pipeline.producer_commit(prod_state) + prod_state.advance() + # End of the main pipeline loop + # Complete the cross-thread dbias reduction after each thread has its own per-thread partial sum after the rowwise quantization. if cutlass.const_expr(self.DBIAS_REDUCTION_ROWWISE): - # Allocate the SMEM buffer that all threads use to reduce the two-stage partial sum (per thread) to the - # partial sum (per block). - - # Pad the buffer to avoid bank conflicts. The logical shape is still the same. Only the stride is different. - DBIAS_BUFF_WIDTH = self._TILE_X // MXFP8_BLOCK_SIZE * (MXFP8_BLOCK_SIZE + 1) - @cute.struct - class DbiasStorage: - sDbias: cute.struct.MemRange[Float32, self._TILE_Y * DBIAS_BUFF_WIDTH] - dbias_storage = smem.allocate(DbiasStorage) - sDbias = dbias_storage.sDbias.get_tensor( - cute.make_layout((self._TILE_Y, self._TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), - ) - _, tv_layout_dbias_write = cute.make_layout_tv( - thr_layout=cute.make_layout((self._TILE_Y, self._TILE_X // MXFP8_BLOCK_SIZE), - stride=(self._TILE_X // MXFP8_BLOCK_SIZE, 1)), - val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), - ) - sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) - # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks - # to avoid bank conflict. - bank_group = (tidx % THREADS_PER_WARP) // self._THREADS_PER_BANK - # The offset this thread should start reading from based on what's its first bank to access. - offset = bank_group * self._PACK_SIZE - for w in cutlass.range_constexpr(self._WAVES): # Each thread starts from this offset when writing into SMEM to avoid bank conflict - start = (w * self._PACK_SIZE + offset) % MXFP8_BLOCK_SIZE - for i in cutlass.range_constexpr(self._PACK_SIZE): - # All threads write their per-thread partial sum results to the shared buffer. - sDbias_write[(tidx, start + i)] = rowwise_dbias_acc[w * self._PACK_SIZE + i] - cute.arch.sync_threads() - # All threads reduce the cross-thread partial sums to the per-block partial sum. - _, tv_layout_dbias_reduce = cute.make_layout_tv( - thr_layout=cute.make_layout((1, self._TILE_X), stride=(self._TILE_X, 1)), - val_layout=cute.make_layout((self._TILE_Y, 1), stride=(1, 1)) - ) - sDbias_reduce = cute.composition(sDbias, tv_layout_dbias_reduce) - # make_layout_tv yields a (thread, value) layout: thread=tidx -> column tidx, - # value=i -> row i. So index [tidx, i] (thread first), summing the column's rows. - block_dbias = Float32(0.0) - for i in cutlass.range_constexpr(self._TILE_Y): - block_dbias += sDbias_reduce[tidx, i] + # If we do the dbias reduction in the rowwise pass, each thread will have a (1, MXFP8_BLOCK_SIZE) partial sum + # and we need to write these to a SMEM buffer and let each thread reduce it in the columnwise direction + block_dbias = self._dbias_reduction_rowwise_epilouge(smem, tidx, rowwise_dbias_acc) # Write the per-tile reduced dbias to the global workspace. if cutlass.const_expr(cfg.WITH_DBIAS): @@ -1052,32 +1041,79 @@ class DbiasStorage: mWorkspace[(bidy, dbias_col)] = block_dbias if cutlass.const_expr(cfg.WITH_AMAX): - # Reduce and get the per-warp amax. - warp_amax = cute.arch.warp_redux_sync(per_thread_amax, kind="fmax") - # Write the per-warp amax to shared memory sAmax = storage.sAmax.get_tensor(cute.make_layout(self._NUM_WARPS)) - lane_idx = tidx % 32 - if lane_idx == 0: - sAmax[warp_idx] = warp_amax - cute.arch.sync_threads() - if tidx == 0: - cta_amax = Float32(0.0) - # The first thread reduces all the per-warp amax to the per-CTA amax - for w in cutlass.range_constexpr(self._NUM_WARPS): - cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) - amax_i32 = cute.make_tensor( - cute.recast_ptr(mAmax.iterator, dtype=Int32), - cute.make_layout(1), - ) - # The first thread updates the global amax with an atomic max on the bitcasted float value - cute.arch.atomic_max( - amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), - ) + self._amax_epilogue(sAmax, mAmax, tidx, warp_idx, per_thread_amax) # Wait for in-flight TMA stores so data is visible to the host # before the kernel returns. cute.arch.cp_async_bulk_wait_group(0, read=False) + @cute.jit + def _dbias_reduction_rowwise_epilouge(self, smem, tidx, rowwise_dbias_acc): + # Pad the buffer to avoid bank conflicts. The logical shape is still the same. Only the stride is different. + DBIAS_BUFF_WIDTH = self._TILE_X // MXFP8_BLOCK_SIZE * (MXFP8_BLOCK_SIZE + 1) + # Allocate the SMEM buffer that all threads use to reduce the two-stage partial sum (per thread) to the + # partial sum (per block). + @cute.struct + class DbiasStorage: + sDbias: cute.struct.MemRange[Float32, self._TILE_Y * DBIAS_BUFF_WIDTH] + dbias_storage = smem.allocate(DbiasStorage) + sDbias = dbias_storage.sDbias.get_tensor( + cute.make_layout((self._TILE_Y, self._TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), + ) + _, tv_layout_dbias_write = cute.make_layout_tv( + thr_layout=cute.make_layout((self._TILE_Y, self._TILE_X // MXFP8_BLOCK_SIZE), + stride=(self._TILE_X // MXFP8_BLOCK_SIZE, 1)), + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), + ) + sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) + # Each thread start reading from the specfic bank based on its thread ID so they can do their best to access different banks + # to avoid bank conflict. + bank_group = (tidx % THREADS_PER_WARP) // self._THREADS_PER_BANK + # The offset this thread should start reading from based on what's its first bank to access. + offset = bank_group * self._PACK_SIZE + for w in cutlass.range_constexpr(self._WAVES): # Each thread starts from this offset when writing into SMEM to avoid bank conflict + start = (w * self._PACK_SIZE + offset) % MXFP8_BLOCK_SIZE + for i in cutlass.range_constexpr(self._PACK_SIZE): + # All threads write their per-thread partial sum results to the shared buffer. + sDbias_write[(tidx, start + i)] = rowwise_dbias_acc[w * self._PACK_SIZE + i] + cute.arch.sync_threads() + # All threads reduce the cross-thread partial sums to the per-block partial sum. + _, tv_layout_dbias_reduce = cute.make_layout_tv( + thr_layout=cute.make_layout((1, self._TILE_X), stride=(self._TILE_X, 1)), + val_layout=cute.make_layout((self._TILE_Y, 1), stride=(1, 1)) + ) + sDbias_reduce = cute.composition(sDbias, tv_layout_dbias_reduce) + # make_layout_tv yields a (thread, value) layout: thread=tidx -> column tidx, + # value=i -> row i. So index [tidx, i] (thread first), summing the column's rows. + block_dbias = Float32(0.0) + for i in cutlass.range_constexpr(self._TILE_Y): + block_dbias += sDbias_reduce[tidx, i] + return block_dbias + + @cute.jit + def _amax_epilogue(self, sAmax, mAmax, tidx, warp_idx, per_thread_amax): + # Reduce and get the per-warp amax. + warp_amax = cute.arch.warp_redux_sync(per_thread_amax, kind="fmax") + # Write the per-warp amax to shared memory + lane_idx = tidx % 32 + if lane_idx == 0: + sAmax[warp_idx] = warp_amax + cute.arch.sync_threads() + if tidx == 0: + cta_amax = Float32(0.0) + # The first thread reduces all the per-warp amax to the per-CTA amax + for w in cutlass.range_constexpr(self._NUM_WARPS): + cta_amax = cute.arch.fmax(cta_amax, sAmax[w]) + amax_i32 = cute.make_tensor( + cute.recast_ptr(mAmax.iterator, dtype=Int32), + cute.make_layout(1), + ) + # The first thread updates the global amax with an atomic max on the bitcasted float value + cute.arch.atomic_max( + amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), + ) + @cute.jit def _process_rowwise( self, From 319156d65174bb3e32d71ee6b4fd520bf2a07f3d Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Sat, 27 Jun 2026 07:45:47 +0000 Subject: [PATCH 16/22] fix warning --- transformer_engine/common/tvm_ffi_bridge.h | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/tvm_ffi_bridge.h b/transformer_engine/common/tvm_ffi_bridge.h index 5f53235ed3..866975eaac 100644 --- a/transformer_engine/common/tvm_ffi_bridge.h +++ b/transformer_engine/common/tvm_ffi_bridge.h @@ -211,7 +211,14 @@ class TVMFFICentral { static_assert(detail::is_lazyloadable_config::value, "Config must define `std::string to_key() const` and " "`bool retrieve_func_from_python(const std::string&) const`."); - if (!cutedsl_backend_enabled_) return std::nullopt; + if (!cutedsl_backend_enabled_) { + if (warn_cutedsl_backend_not_chosen_) { + NVTE_WARN("TVM-FFI kernel for config `", cfg.to_key(), + "` is not supported because the CuTeDSL backend is disabled. " + "Set NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=1 to enable it."); + } + return std::nullopt; + } const std::string key = cfg.to_key(); { std::shared_lock read_lock(mutex_); @@ -231,7 +238,7 @@ class TVMFFICentral { if (supported) { return tvm::ffi::Function::GetGlobal(key); } - if (warn_unsupported_kernels_) { + if (warn_cutedsl_backend_not_chosen_) { NVTE_WARN("TVM-FFI kernel for config `", key, "` is not supported."); } return std::nullopt; @@ -240,7 +247,7 @@ class TVMFFICentral { private: ~TVMFFICentral() = default; TVMFFICentral() : cutedsl_backend_enabled_(is_cutedsl_backend_enabled()), - warn_unsupported_kernels_(warn_if_cutedsl_backend_unsupported()) {} + warn_cutedsl_backend_not_chosen_(warn_if_cutedsl_backend_not_chosen()) {} TVMFFICentral(const TVMFFICentral &) = delete; TVMFFICentral &operator=(const TVMFFICentral &) = delete; TVMFFICentral(TVMFFICentral &&) = delete; @@ -252,13 +259,13 @@ class TVMFFICentral { return flag != nullptr && flag[0] != '0'; } - static bool warn_if_cutedsl_backend_unsupported() { - const char *flag = std::getenv("NVTE_WARN_IF_CUTEDSL_BACKEND_UNSUPPORTED"); + static bool warn_if_cutedsl_backend_not_chosen() { + const char *flag = std::getenv("NVTE_WARN_IF_CUTEDSL_BACKEND_NOT_CHOSEN"); return flag != nullptr && flag[0] != '0'; } const bool cutedsl_backend_enabled_; - const bool warn_unsupported_kernels_; + const bool warn_cutedsl_backend_not_chosen_; std::shared_mutex mutex_; // Per-config support decision (cfg.to_key() -> supported). Holds NO Python- // backed handles, so it is safe to destroy at static teardown — the kernels From 67317a0a80f2a05ba74ff369f6efaccea665bbdb Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Sat, 27 Jun 2026 08:23:42 +0000 Subject: [PATCH 17/22] fix rowwise specialized --- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 72 ++++++++++--------- 1 file changed, 38 insertions(+), 34 deletions(-) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 246c9dac5a..a50f733fb2 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -1199,7 +1199,9 @@ class MXFP8QuantizeSpecializedRowwiseKernel: def __init__(self, cfg): self.cfg = cfg - self.STASH_SCALE_TO_SMEM = True + # If True, then this kernel will first write each thread's scale byte to a shared memory buffer, + # then utilize vectorized store to flush the buffer to global memory. + self._STASH_SCALE_TO_SMEM = True # Hardcode to true for now @cute.jit def __call__( @@ -1276,27 +1278,22 @@ def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): ) rO_u32 = cute.make_tensor( cute.recast_ptr(rO_thread.iterator, dtype=Uint32), - cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), + cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # Unit is Uint32, divide by 4 here ) - # TODO: review this kernel myself. - - - # Each thread owns only one scale byte, so a direct RF->GMEM scale write - # can't vectorize (128 scattered 1-byte stores). If STASH_SCALE_TO_SMEM, - # stage the CTA's (CTA_Y, CTA_X) scale tile in SMEM, then flush it to gmem - # with wide (uint32) stores. Compile-time gated. - sS_slot = None - if cutlass.const_expr(self.STASH_SCALE_TO_SMEM): + sS_thread = None + if cutlass.const_expr(self._STASH_SCALE_TO_SMEM): @cute.struct class SharedStorage: buf: cute.struct.Align[cute.struct.MemRange[Uint8, CTA_Y * CTA_X], 16] storage = cutlass.utils.SmemAllocator().allocate(SharedStorage) sScale = storage.buf.get_tensor(cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1))) - sS_slot = cute.composition(sScale, tv_layout_scale)[tidx, None] + # sScale is (CTA_Y, CTA_X):(CTA_X, 1), which is the same layout as tv_layout_scale + # so sS_thread is really just an 1 Uint8 buffer for this thread's scale byte. + sS_thread = cute.composition(sScale, tv_layout_scale)[tidx, None] # Zero first so padding columns (cols past N/32 in the padded scale # matrix) flush as 0 and we never read uninitialized smem. - sS_slot[0] = Uint8(0) + sS_thread[0] = Uint8(0) cute.arch.sync_threads() row = bidy * self._TILE_Y + tidx // CTA_X @@ -1312,8 +1309,8 @@ class SharedStorage: amax = cute.arch.fmax(amax, fabs_f32(rX_f32[0, i])) biased_exp = cvt_f32_to_fp8e8m0(amax * max_norm_rcp) - if cutlass.const_expr(self.STASH_SCALE_TO_SMEM): - sS_slot[0] = Uint8(biased_exp) + if cutlass.const_expr(self._STASH_SCALE_TO_SMEM): + sS_thread[0] = Uint8(biased_exp) else: mS_thread[0] = Uint8(biased_exp) @@ -1322,10 +1319,10 @@ class SharedStorage: inv_scale = exp2f_rcp(biased_exp) scale_2x = pack_f32x2(inv_scale, inv_scale) mul_cvt4 = mul_cvt_f32x4_to_fp8x4(self.cfg.FP8_DTYPE) - for w in cutlass.range_constexpr(MXFP8_BLOCK_SIZE // 4): - b = 4 * w - rO_u32[w] = mul_cvt4(rX_f32[0, b], rX_f32[0, b + 1], - rX_f32[0, b + 2], rX_f32[0, b + 3], scale_2x) + for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE // 4): + offset = 4 * i + rO_u32[i] = mul_cvt4(rX_f32[0, offset], rX_f32[0, offset + 1], + rX_f32[0, offset + 2], rX_f32[0, offset + 3], scale_2x) cute.autovec_copy(rO_thread, mO_thread) # Cooperative wide flush of the staged scales: the first CTA_Y*(CTA_X/G) @@ -1336,28 +1333,35 @@ class SharedStorage: # straddles the allocation boundary (base and padded_cols are both multiples # of G), so flush a group iff its first column is in-allocation. Zeroed # padding columns flush as 0. - if cutlass.const_expr(self.STASH_SCALE_TO_SMEM): + if cutlass.const_expr(self._STASH_SCALE_TO_SMEM): cute.arch.sync_threads() padded_cols = mS_row.shape[1] if padded_cols % 16 == 0: - self._flush_scales(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 16) + # If columns is divisible by 16, use 16 bytes as the vectorized store width + self._flush_scales_to_gmem(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 16) else: - self._flush_scales(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 4) - + # Otherwise use 4 bytes as the vectorized store width. + # Note our fake tensor requires 4 divisibility so this is enforced as long as you can get here + self._flush_scales_to_gmem(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 4) @cute.jit - def _flush_scales(self, sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, G): - """Flush the staged (CTA_Y, CTA_X) scale tile to gmem with G-wide stores.""" + def _flush_scales_to_gmem(self, sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, width): + """Flush the staged (CTA_Y, CTA_X) scale tile to gmem with vectorized stores.""" CTA_Y = self._TILE_Y CTA_X = self._TILE_X // MXFP8_BLOCK_SIZE - GROUPS = CTA_X // G + # Previously each threads has 1 byte, but now we are doing vectorized store, + # which means only a subset of threads will need to issue the store while other threads are not used. + active_threads = CTA_X // width _, tv_flush = cute.make_layout_tv( - thr_layout=cute.make_layout((CTA_Y, GROUPS), stride=(GROUPS, 1)), - val_layout=cute.make_layout((1, G), stride=(G, 1)), + thr_layout=cute.make_layout((CTA_Y, active_threads), stride=(active_threads, 1)), + val_layout=cute.make_layout((1, width), stride=(width, 1)), ) - if tidx < CTA_Y * GROUPS: - frow = tidx // GROUPS - fgroup = tidx % GROUPS - if bidy * CTA_Y + frow < M and bidx * CTA_X + fgroup * G < padded_cols: + # We only need to use a subset of threads with shape (CTA_Y, active_threads) to write + # so if the thread is outside of this subset, it will remain inactive + if tidx < CTA_Y * active_threads: + # Absolute position of the scale vector to write in the GMEM buffer + thread_y = bidy * CTA_Y + tidx // active_threads + thread_x = bidx * CTA_X + (tidx % active_threads) * width + if thread_y < M and thread_x < padded_cols: cute.autovec_copy( cute.composition(sScale, tv_flush)[tidx, None], cute.composition(mS_tile, tv_flush)[tidx, None], @@ -1396,7 +1400,7 @@ def __call__( # produced yet; this stub exists so dispatch routing can be wired now. -def get_specialized_kernel_class(cfg): +def get_kernel_class(cfg): """If no fusion is involved and the kernel only quantizes, dispatch to the specialized kernel for better performance.""" plain_cast_only = ( not cfg.WITH_GEMM_SWIZZLED_SCALES @@ -1419,7 +1423,7 @@ def compile_cutedsl_function_from_cfg(cfg): # Route plain cast-only configs to the matching specialized kernel (mirrors the # CUDA dispatcher); everything else uses the general standard kernel. - kernel_class = get_specialized_kernel_class(cfg) + kernel_class = get_kernel_class(cfg) kernel_obj = kernel_class(cfg) # M, N must be divisible by the MXFP8 scale-block size (MXFP8_BLOCK_SIZE = 32) — the # same alignment the CUDA C++ kernel requires. From f8ae14e446b0f1575f9e1e4c7e236f52d912d0cc Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Sat, 27 Jun 2026 08:25:01 +0000 Subject: [PATCH 18/22] nit --- transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index a50f733fb2..8adf28eafc 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -1343,6 +1343,7 @@ class SharedStorage: # Otherwise use 4 bytes as the vectorized store width. # Note our fake tensor requires 4 divisibility so this is enforced as long as you can get here self._flush_scales_to_gmem(sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, 4) + @cute.jit def _flush_scales_to_gmem(self, sScale, mS_tile, tidx, bidx, bidy, M, padded_cols, width): """Flush the staged (CTA_Y, CTA_X) scale tile to gmem with vectorized stores.""" From 218cd246e1d42169364ddcdc7e1ee402e899796c Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Sat, 27 Jun 2026 08:26:29 +0000 Subject: [PATCH 19/22] todo --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 61c2a4586d..2189f170b2 100644 --- a/setup.py +++ b/setup.py @@ -366,7 +366,7 @@ def git_check_submodules() -> None: "core_cu13": [f"transformer_engine_cu13=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], - "cutedsl": ["nvidia-cutlass-dsl>=4.2.0"], + "cutedsl": ["nvidia-cutlass-dsl>=4.2.0"], # TODO: explain this in the docs when shipping this: `pip3 install --no-build-isolation '.[cutedsl]' ` } else: install_requires, test_requires = setup_requirements() From f357b6940356401d71cfd5eddec3cf5ec8bc9714 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 08:28:36 +0000 Subject: [PATCH 20/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- setup.py | 4 +- .../common/CuTeDSL/activations.py | 34 +- .../CuTeDSL/cast/mxfp8/quantize_mxfp8.py | 553 ++++++++++++------ transformer_engine/common/CuTeDSL/utils.py | 261 ++++++--- .../common/CuTeDSL/utils_fp8.py | 131 +++-- .../common/cast/dispatch/quantize.cuh | 17 +- .../cast/mxfp8/quantize_mxfp8_cutedsl.cuh | 125 ++-- transformer_engine/common/tvm_ffi_bridge.h | 126 ++-- 8 files changed, 808 insertions(+), 443 deletions(-) diff --git a/setup.py b/setup.py index 2189f170b2..16d3d4812a 100644 --- a/setup.py +++ b/setup.py @@ -366,7 +366,9 @@ def git_check_submodules() -> None: "core_cu13": [f"transformer_engine_cu13=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], - "cutedsl": ["nvidia-cutlass-dsl>=4.2.0"], # TODO: explain this in the docs when shipping this: `pip3 install --no-build-isolation '.[cutedsl]' ` + "cutedsl": [ + "nvidia-cutlass-dsl>=4.2.0" + ], # TODO: explain this in the docs when shipping this: `pip3 install --no-build-isolation '.[cutedsl]' ` } else: install_requires, test_requires = setup_requirements() diff --git a/transformer_engine/common/CuTeDSL/activations.py b/transformer_engine/common/CuTeDSL/activations.py index 0389310690..dd43e75cad 100644 --- a/transformer_engine/common/CuTeDSL/activations.py +++ b/transformer_engine/common/CuTeDSL/activations.py @@ -22,8 +22,8 @@ def act_gelu(x: Float32) -> Float32: rather than the `tanh.approx.f32` PTX intrinsic — TE compiles activation kernels without `--use_fast_math` by default, so its `tanhf` is the IEEE-precise expansion.""" - A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal - B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation + A = Float32(0.79788456) # sqrt(2/π) truncated to TE's 8-digit literal + B = Float32(0.03567741) # = sqrt(2/π) · 0.044715, same truncation return x * (Float32(0.5) + Float32(0.5) * cute.math.tanh(x * (A + B * x * x))) @@ -45,15 +45,26 @@ def act_srelu(x: Float32) -> Float32: r = cute.arch.fmax(x, Float32(0.0)) return r * r + @dsl_user_op def dact_drelu(x: Float32, *, loc=None, ip=None) -> Float32: """drelu: x > 0 ? 1 : 0. Matches math.h `drelu` (NaN → 0 via ordered compare).""" - cond = mlir_arith.cmpf(mlir_arith.CmpFPredicate.OGT, - x.ir_value(loc=loc, ip=ip), - Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) - return Float32(mlir_arith.select(cond, - Float32(1.0).ir_value(loc=loc, ip=ip), - Float32(0.0).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + cond = mlir_arith.cmpf( + mlir_arith.CmpFPredicate.OGT, + x.ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + return Float32( + mlir_arith.select( + cond, + Float32(1.0).ir_value(loc=loc, ip=ip), + Float32(0.0).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) def dact_dsrelu(x: Float32) -> Float32: @@ -91,7 +102,6 @@ def dact_dgelu(x: Float32) -> Float32: Float32(0.79788456) * x * (Float32(1.0) + Float32(0.044715) * x * x), fastmath=False, ) - return (Float32(0.5) * x - * ((Float32(1.0) - t * t) * (Float32(0.79788456) + Float32(0.1070322243) * x * x)) - + Float32(0.5) * (Float32(1.0) + t)) - + return Float32(0.5) * x * ( + (Float32(1.0) - t * t) * (Float32(0.79788456) + Float32(0.1070322243) * x * x) + ) + Float32(0.5) * (Float32(1.0) + t) diff --git a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py index 8adf28eafc..2f40384a03 100644 --- a/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py +++ b/transformer_engine/common/CuTeDSL/cast/mxfp8/quantize_mxfp8.py @@ -53,9 +53,9 @@ mul_cvt_f32x4_to_fp8x4, ) from transformer_engine.common.CuTeDSL.utils_fp8 import ( - get_cvt_f32_to_fp8_func, + get_cvt_f32_to_fp8_func, get_cvt_f32x2_to_fp8x2_func, - cvt_f32_to_fp8e8m0 + cvt_f32_to_fp8e8m0, ) CUTEDSL_DEBUG_LOGGING = os.environ.get("CUTEDSL_DEBUG_LOGGING", "0") == "1" @@ -93,18 +93,19 @@ @cute.jit def quantize_rowwise_mxfp8( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sA_tile, # (TILE_Y, TILE_X) activation-input smem tile (dact only) - sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) - mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sA_tile, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, - tile_row_start, # Int32 — global row index of this stage's row 0 - # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores - # for irregular shapes. - tile_col_start, # Int32 — global col index of this CTA's col 0 - # (= bidx * TILE_X). Same purpose. - M, N, # Int32 — full tensor extents; OOB threads skip their - # scale store. + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). Same purpose. + M, + N, # Int32 — full tensor extents; OOB threads skip their + # scale store. ACTIVATION, DTYPE, FP8_DTYPE, @@ -116,38 +117,37 @@ def quantize_rowwise_mxfp8( WITH_ACT=False, WITH_DACT=False, WITH_DBIAS=False, # rowwise-only dbias: accumulate per-column partials - dbias_acc=None, # only needed when WITH_DBIAS is True + dbias_acc=None, # only needed when WITH_DBIAS is True ): tidx, _, _ = cute.arch.thread_idx() - CTA_THREADS_Y = TILE_Y # threads per column (rows per tile) + CTA_THREADS_Y = TILE_Y # threads per column (rows per tile) CTA_THREADS_X = TILE_X // MXFP8_BLOCK_SIZE # threads per row (chunks per row) _, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((CTA_THREADS_Y, CTA_THREADS_X), stride=(CTA_THREADS_X, 1)), - val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)) + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(0, 1)), ) sX_tv = cute.composition(sX_tile, tv_layout) sO_tv = cute.composition(sO_row_tile, tv_layout) # I/O Elements that belong to this thread - sX_thread = sX_tv[tidx, None] # shape (32,) bf16 - sO_thread = sO_tv[tidx, None] # shape (32,) uint8 + sX_thread = sX_tv[tidx, None] # shape (32,) bf16 + sO_thread = sO_tv[tidx, None] # shape (32,) uint8 sO_thread_u32_ptr = cute.recast_ptr(sO_thread.iterator, dtype=Uint32) # Each wave it writes 32 bytes = 8 uint32s, so in 4 waves we write all 32 quantized elements. sO_thread_u32 = cute.make_tensor( sO_thread_u32_ptr, - cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements + cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # 1 uint32 is 4 fp8 elements ) # PTX allows to fuse relu activation in `cvt.rn.satfinite` FUSE_RELU = cutlass.const_expr(ACTIVATION == "relu") # For this fast path we can read in pack of 2 instead of reading individual f16 / bf16 element. # dbias needs the per-element fp32 values to accumulate, so it forces the slow path. - _row_fast = (is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) - and not WITH_DBIAS) + _row_fast = is_packed16(DTYPE) and (ACTIVATION is None or FUSE_RELU) and not WITH_DBIAS amax_r = Float32(0.0) @@ -162,7 +162,9 @@ def quantize_rowwise_mxfp8( kit = packed16_kit(DTYPE) sX_thread_rw_i32 = cute.make_tensor( cute.recast_ptr(sX_thread.iterator, dtype=Int32), - cute.make_layout((1, MXFP8_BLOCK_SIZE // 2), stride=(0, 1)), # 1 int32 is 2 fp16/bf16 elements + cute.make_layout( + (1, MXFP8_BLOCK_SIZE // 2), stride=(0, 1) + ), # 1 int32 is 2 fp16/bf16 elements ) # Each wave we read 2 packed i32, which is 4 fp16/bf16 elements (PACK_SIZE) # In total we have 8 waves where each wave reads 4 elements, so we read 32 elements in total. @@ -240,11 +242,13 @@ def quantize_rowwise_mxfp8( x = kit_act.truncate_f32(x) in_r[w][i] = x if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, x) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically + amax_r = cute.arch.fmax( + amax_r, x + ) # For relu cases, we don't need abs since negative values will be 0 so they lose comparison automatically else: amax_r = cute.arch.fmax(amax_r, fabs_f32(x)) if cutlass.const_expr(FUSE_RELU): - amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 + amax_r = cute.arch.fmax(amax_r, Float32(0.0)) # If relu, the amax is at least 0 biased_exp_r = cvt_f32_to_fp8e8m0(amax_r * max_norm_rcp) @@ -254,14 +258,14 @@ def quantize_rowwise_mxfp8( # Note this is the logical layout, which is on top of the swizzled / non-swizzled scale factor layout # that mappes the logical index to the physical offset - # For irregular shapes, skip the scale store if this thread's logical row / col-block lies past the input's actual extents. + # For irregular shapes, skip the scale store if this thread's logical row / col-block lies past the input's actual extents. # TMA already zero-fills OOB input reads and drops OOB output writes; only the direct scale-byte gmem store needs an explicit guard. scale_row = tile_row_start + tidx // CTA_THREADS_X scale_col_first_elt = tile_col_start + (tidx % CTA_THREADS_X) * MXFP8_BLOCK_SIZE if scale_row < M and scale_col_first_elt < N: mS_row_stage[(tidx // CTA_THREADS_X, tidx % CTA_THREADS_X)] = Uint8(biased_exp_r) - inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale + inv_scale_r = exp2f_rcp(biased_exp_r) # f32 reciprocal of the scale # Fetch the conversion function based on the FP8 format cvt_f32x2 = get_cvt_f32x2_to_fp8x2_func(FP8_DTYPE) if cutlass.const_expr(_row_fast): @@ -294,37 +298,39 @@ def quantize_rowwise_mxfp8( return amax_r + @cute.jit def quantize_colwise_mxfp8( - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) - mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, - tile_row_start, # Int32 — global row index of this stage's row 0 - # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores - # for irregular shapes. - tile_col_start, # Int32 — global col index of this CTA's col 0 - # (= bidx * TILE_X). - M, N, # Int32 — full tensor extents. + tile_row_start, # Int32 — global row index of this stage's row 0 + # (= tile_idx_y * TILE_Y). Used to mask OOB scale stores + # for irregular shapes. + tile_col_start, # Int32 — global col index of this CTA's col 0 + # (= bidx * TILE_X). + M, + N, # Int32 — full tensor extents. ACTIVATION, DTYPE, FP8_DTYPE, SWIZZLE, TILE_X, TILE_Y, - WITH_ACT=False, # forward: apply activation to the element - WITH_DACT=False, # backward: out = grad · act'(act_input) - sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) - WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) + WITH_ACT=False, # forward: apply activation to the element + WITH_DACT=False, # backward: out = grad · act'(act_input) + sA_tile=None, # (TILE_Y, TILE_X) activation-input smem tile (dact only) + WITH_DBIAS=False, # also return this thread's column sum (pre-truncate) CACHE_ACTIVATION=False, # overwrite sX_tile in place with the post-activation - # (IType-truncated) values, so the rowwise pass can read - # them instead of recomputing op + # (IType-truncated) values, so the rowwise pass can read + # them instead of recomputing op ): tidx, _, _ = cute.arch.thread_idx() _, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((1, TILE_X), stride=(TILE_X, 1)), - val_layout=cute.make_layout((MXFP8_BLOCK_SIZE, 1), stride=(1, 1)) + val_layout=cute.make_layout((MXFP8_BLOCK_SIZE, 1), stride=(1, 1)), ) sX_tv = cute.composition(sX_tile, tv_layout) @@ -411,6 +417,7 @@ def quantize_colwise_mxfp8( # it across stages (a scalar can't be updated in-place through the arg). return amax_c, dbias_partial + class MXFP8QuantizeConfig: """Configs for the compiled CuTeDSL kernel. These will be fixed once the kernel is compiled and they will behave as const expressions. @@ -428,7 +435,7 @@ def __init__( with_dact: bool = False, with_act: bool = False, with_noop: bool = False, - activation: Optional[str] = None + activation: Optional[str] = None, ): if dtype is None or dtype not in ("fp32", "fp16", "bf16"): raise ValueError(f"unknown input dtype {dtype!r}; expected fp32|fp16|bf16") @@ -447,20 +454,31 @@ def __init__( if activation == "none": self.ACTIVATION = None else: - raise ValueError("activation must be none when with_dact and with_act are both False") + raise ValueError( + "activation must be none when with_dact and with_act are both False" + ) else: if with_dact and with_act: - raise ValueError("with_dact and with_act cannot be true at the same time since they are used for different paths (bwd vs fwd)") + raise ValueError( + "with_dact and with_act cannot be true at the same time since they are used for" + " different paths (bwd vs fwd)" + ) elif with_dact: if activation in SUPPORTED_DACTIVATIONS: self.ACTIVATION = activation else: - raise ValueError(f"unknown activation {activation!r} for with_dact=True; expected one of {sorted(SUPPORTED_DACTIVATIONS)}") + raise ValueError( + f"unknown activation {activation!r} for with_dact=True; expected one of" + f" {sorted(SUPPORTED_DACTIVATIONS)}" + ) elif with_act: if activation in SUPPORTED_ACTIVATIONS: self.ACTIVATION = activation else: - raise ValueError(f"unknown activation {activation!r} for with_act=True; expected one of {sorted(SUPPORTED_ACTIVATIONS)}") + raise ValueError( + f"unknown activation {activation!r} for with_act=True; expected one of" + f" {sorted(SUPPORTED_ACTIVATIONS)}" + ) self.WITH_DACT = with_dact self.WITH_ACT = with_act # dbias is the column reduction of the (post-act/dact) element. With colwise @@ -472,15 +490,18 @@ def __init__( self.MAX_NORM_RCP = FP8E4M3_MAX_NORM_RCP if fp8_dtype == "e4m3" else FP8E5M2_MAX_NORM_RCP def __str__(self): - return (f"MXFP8QuantizeConfig(dtype={self.DTYPE_STR}, fp8_dtype={self.FP8_DTYPE}, " - f"rowwise={self.ROWWISE}, colwise={self.COLWISE}, " - f"swizzled={self.WITH_GEMM_SWIZZLED_SCALES}, with_amax={self.WITH_AMAX}, " - f"with_dbias={self.WITH_DBIAS}, with_dact={self.WITH_DACT}, " - f"with_act={self.WITH_ACT}, with_noop={self.WITH_NOOP}, " - f"activation={self.ACTIVATION})") + return ( + f"MXFP8QuantizeConfig(dtype={self.DTYPE_STR}, fp8_dtype={self.FP8_DTYPE}, " + f"rowwise={self.ROWWISE}, colwise={self.COLWISE}, " + f"swizzled={self.WITH_GEMM_SWIZZLED_SCALES}, with_amax={self.WITH_AMAX}, " + f"with_dbias={self.WITH_DBIAS}, with_dact={self.WITH_DACT}, " + f"with_act={self.WITH_ACT}, with_noop={self.WITH_NOOP}, " + f"activation={self.ACTIVATION})" + ) __repr__ = __str__ + class MXFP8QuantizeKernel: """The MXFP8 quantization kernel that mirrors the standard (non-specialized) MXFP8 CUDA C++ quantization kernel with multiple fusions (activation, dbias, etc.). @@ -488,11 +509,13 @@ class MXFP8QuantizeKernel: """ # Vectorised access constants for bank-conflict avoidance (rowwise pass) - _PACK_SIZE = 4 # Elements per vector load - _WAVES = MXFP8_BLOCK_SIZE // _PACK_SIZE # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total + _PACK_SIZE = 4 # Elements per vector load + _WAVES = ( + MXFP8_BLOCK_SIZE // _PACK_SIZE + ) # Each thread reads 8 waves with each wave reads 4 packed bf16, so it reads a whole MXFP8 block in total _TOTAL_BANKS_WIDTH = (32 * 4) // 1 # 32 banks × 4 bytes, in bytes (uint8 stride) _THREADS_PER_BANK = _TOTAL_BANKS_WIDTH // MXFP8_BLOCK_SIZE # 4 threads per bank - _NUM_STAGES = 2 # The pipeline depth is always 2 + _NUM_STAGES = 2 # The pipeline depth is always 2 def __init__(self, cfg): self.cfg = cfg @@ -503,7 +526,7 @@ def __init__(self, cfg): if cast_dbias_only: self._NUM_TILES = 4 self._THREADS_PER_CTA = 128 - self._TILE_X = 128 + self._TILE_X = 128 self._TILE_Y = 32 else: self._NUM_TILES = 2 @@ -522,7 +545,8 @@ def __init__(self, cfg): # so it should be treated as "no activation" self.CACHE_ACTIVATION = ( (cfg.WITH_ACT or cfg.WITH_DACT) - and cfg.ROWWISE and cfg.COLWISE + and cfg.ROWWISE + and cfg.COLWISE and cfg.ACTIVATION != "relu" ) # The global tensor amax (mAmax) is the max over ALL elements. Each direction's @@ -536,13 +560,19 @@ def __init__(self, cfg): @cute.jit def __call__( self, - mX: cute.Tensor, # Input tensor to quantize - mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], # Rowwise output and scale tensors - mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Colwise output and scale tensors - mAmax: Optional[cute.Tensor], # Global amax accumulator, only used when WITH_AMAX is True - mNoop: Optional[cute.Tensor], # 1-element cast_noop flag, only used when WITH_NOOP is True - mDActInput: Optional[cute.Tensor], # Activation input for activation derivative fusion, only used when WITH_DACT is True - mWorkspace: Optional[cute.Tensor], # Workspace for the dbias reduction, only used when WITH_DBIAS is True + mX: cute.Tensor, # Input tensor to quantize + mO_row: Optional[cute.Tensor], + mS_row: Optional[cute.Tensor], # Rowwise output and scale tensors + mO_col: Optional[cute.Tensor], + mS_col: Optional[cute.Tensor], # Colwise output and scale tensors + mAmax: Optional[cute.Tensor], # Global amax accumulator, only used when WITH_AMAX is True + mNoop: Optional[cute.Tensor], # 1-element cast_noop flag, only used when WITH_NOOP is True + mDActInput: Optional[ + cute.Tensor + ], # Activation input for activation derivative fusion, only used when WITH_DACT is True + mWorkspace: Optional[ + cute.Tensor + ], # Workspace for the dbias reduction, only used when WITH_DBIAS is True stream: CUstream, ): if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): @@ -554,7 +584,7 @@ def __call__( max_norm_rcp = cfg.MAX_NORM_RCP num_scale_cols = N // MXFP8_BLOCK_SIZE num_scale_rows = M // MXFP8_BLOCK_SIZE - + # If WITH_GEMM_SWIZZLED_SCALES is enabled, the output must satisfy cublas's swizzled layout # This is expressed as a CuTe layout applied to the output tensor so it can be transparent throughout the kernel implementation. # See https://docs.nvidia.com/cuda/cublas/#d-block-scaling-factors-layout for more details. @@ -563,7 +593,7 @@ def __call__( num_tiles_SC = (num_scale_cols + 3) // 4 num_tiles_SR = (num_scale_rows + 3) // 4 num_tiles_N = (N + 127) // 128 - + if cutlass.const_expr(cfg.ROWWISE): mS_row = cute.make_tensor( mS_row.iterator, @@ -588,7 +618,11 @@ def __call__( # Input TMA atoms op_load = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() tma_atom, tma_src = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_load, mX, smem_tile_layout, cta_tiler, num_multicast=1, + op_load, + mX, + smem_tile_layout, + cta_tiler, + num_multicast=1, ) # Activation input TMA atoms for activation derivative fusion @@ -596,7 +630,11 @@ def __call__( tma_src_act = None if cutlass.const_expr(cfg.WITH_DACT): tma_atom_act, tma_src_act = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_load, mDActInput, smem_tile_layout, cta_tiler, num_multicast=1, + op_load, + mDActInput, + smem_tile_layout, + cta_tiler, + num_multicast=1, ) # Output TMA atoms @@ -608,26 +646,46 @@ def __call__( tma_dst_out_col = None if cutlass.const_expr(cfg.ROWWISE): tma_atom_out_row, tma_dst_out_row = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_store, mO_row, out_smem_layout, cta_tiler, num_multicast=1, + op_store, + mO_row, + out_smem_layout, + cta_tiler, + num_multicast=1, ) if cutlass.const_expr(cfg.COLWISE): tma_atom_out_col, tma_dst_out_col = cute.nvgpu.cpasync.make_tiled_tma_atom( - op_store, mO_col, out_smem_layout, cta_tiler, num_multicast=1, + op_store, + mO_col, + out_smem_layout, + cta_tiler, + num_multicast=1, ) grid = [ cute.ceil_div(Int32(N), self._TILE_X), cute.ceil_div(M, self._TILE_Y * self._NUM_TILES), ] - block = [self._THREADS_PER_CTA,] - + block = [ + self._THREADS_PER_CTA, + ] + self.kernel( - mX, mS_row, mS_col, mAmax, mNoop, mWorkspace, - max_norm_rcp, mX.element_type, - tma_atom, tma_src, - tma_atom_out_row, tma_dst_out_row, - tma_atom_out_col, tma_dst_out_col, - tma_atom_act, tma_src_act, + mX, + mS_row, + mS_col, + mAmax, + mNoop, + mWorkspace, + max_norm_rcp, + mX.element_type, + tma_atom, + tma_src, + tma_atom_out_row, + tma_dst_out_row, + tma_atom_out_col, + tma_dst_out_col, + tma_atom_act, + tma_src_act, ).launch( grid=grid, block=block, @@ -645,22 +703,35 @@ def kernel( mWorkspace, max_norm_rcp, dtype: cutlass.Constexpr[Type[cutlass.Numeric]], - tma_atom, tma_src, # Input TMA atoms - tma_atom_out_row, tma_dst_out_row, # Rowwise output TMA atoms - tma_atom_out_col, tma_dst_out_col, # Colwise output TMA atoms - tma_atom_act, tma_src_act, # Activation derivative TMA atoms, or None if WITH_DACT is False + tma_atom, + tma_src, # Input TMA atoms + tma_atom_out_row, + tma_dst_out_row, # Rowwise output TMA atoms + tma_atom_out_col, + tma_dst_out_col, # Colwise output TMA atoms + tma_atom_act, + tma_src_act, # Activation derivative TMA atoms, or None if WITH_DACT is False ): cfg = self.cfg # If the noop tensor is not passed (compile-time check), or the noop tensor is not 1.0 (run-time check) # then we run the kernel for real. Otherwise, skip the quantization so this kernel becomes a no-op. if not cutlass.const_expr(cfg.WITH_NOOP) or mNoop[0] != Float32(1.0): self._kernel_main( - mX, mS_row, mS_col, mAmax, mWorkspace, - max_norm_rcp, dtype, - tma_atom, tma_src, - tma_atom_out_row, tma_dst_out_row, - tma_atom_out_col, tma_dst_out_col, - tma_atom_act, tma_src_act, + mX, + mS_row, + mS_col, + mAmax, + mWorkspace, + max_norm_rcp, + dtype, + tma_atom, + tma_src, + tma_atom_out_row, + tma_dst_out_row, + tma_atom_out_col, + tma_dst_out_col, + tma_atom_act, + tma_src_act, ) @cute.jit @@ -673,10 +744,14 @@ def _kernel_main( mWorkspace, max_norm_rcp, dtype: cutlass.Constexpr[Type[cutlass.Numeric]], - tma_atom, tma_src, # Input TMA atoms - tma_atom_out_row, tma_dst_out_row, # Rowwise output TMA atoms - tma_atom_out_col, tma_dst_out_col, # Colwise output TMA atoms - tma_atom_act, tma_src_act, # Activation derivative TMA atoms, or None if WITH_DACT is False + tma_atom, + tma_src, # Input TMA atoms + tma_atom_out_row, + tma_dst_out_row, # Rowwise output TMA atoms + tma_atom_out_col, + tma_dst_out_col, # Colwise output TMA atoms + tma_atom_act, + tma_src_act, # Activation derivative TMA atoms, or None if WITH_DACT is False ): cfg = self.cfg @@ -687,6 +762,7 @@ def _kernel_main( # Allocate shared memory for the input and rowwise / columnwise outputs if cutlass.const_expr(cfg.ROWWISE and cfg.COLWISE): + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] @@ -700,7 +776,9 @@ class SharedStorage: cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE and not cfg.COLWISE): + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] @@ -711,7 +789,9 @@ class SharedStorage: cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] + elif cutlass.const_expr(cfg.ROWWISE): + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] @@ -722,7 +802,9 @@ class SharedStorage: cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] + else: + @cute.struct class SharedStorage: mbar_storage: cute.struct.MemRange[cute.Int64, 2 * self._NUM_STAGES] @@ -733,9 +815,10 @@ class SharedStorage: cute.struct.MemRange[Uint8, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] sAmax: cute.struct.MemRange[Float32, self._NUM_WARPS] + smem = cutlass.utils.SmemAllocator() storage = smem.allocate(SharedStorage) - # Apply the layout to the allocated shared memory buffers so the first rank is the tile (nested layout) + # Apply the layout to the allocated shared memory buffers so the first rank is the tile (nested layout) # and the second rank is the pipeline stage sX = storage.sX.get_tensor( cute.make_layout( @@ -760,11 +843,13 @@ class SharedStorage: # Allocate shared memory for the activation input used for the activation derivative fusion. if cutlass.const_expr(cfg.WITH_DACT): + @cute.struct class DactStorage: sActInput: cute.struct.Align[ cute.struct.MemRange[dtype, self._TILE_Y * self._TILE_X * self._NUM_STAGES], 128 ] + dact_storage = smem.allocate(DactStorage) # Apply the same layout as the input sActInput = dact_storage.sActInput.get_tensor( @@ -804,7 +889,7 @@ class DactStorage: producer_group=producer_group, consumer_group=consumer_group, tx_count=tx_count, - cta_layout_vmnk=None, # single-CTA, no cluster/multicast + cta_layout_vmnk=None, # single-CTA, no cluster/multicast ) prod_state = pipeline.make_pipeline_state( @@ -828,8 +913,8 @@ class DactStorage: # Partition sX/gX for the TMA atom (single-CTA, no cluster/multicast). tXsX, tXgX = cute.nvgpu.cpasync.tma_partition( tma_atom, - 0, # Use the only CTA to do the TMA copy - cute.make_layout(1), # This cluster only has 1 CTAs + 0, # Use the only CTA to do the TMA copy + cute.make_layout(1), # This cluster only has 1 CTAs sX, gX_tiled, ) @@ -916,7 +1001,7 @@ class DactStorage: if cutlass.const_expr(self.DBIAS_REDUCTION_COLWISE): block_dbias = Float32(0.0) - # Consumer: all threads fetch from the pipeline, and + # Consumer: all threads fetch from the pipeline, and for tile_idx in cutlass.range(num_tiles, unroll=1): mainloop_pipeline.consumer_wait(cons_state) # Only allow at most _NUM_STAGES-1 stages to be in-flight, because this iteration will reuse the ring buffer @@ -924,7 +1009,7 @@ class DactStorage: if warp_idx == 0: cute.arch.cp_async_bulk_wait_group(self._NUM_STAGES - 1, read=True) cute.arch.sync_threads() - # The current pipeline stage index, which is the tile index modulo the number of stages. + # The current pipeline stage index, which is the tile index modulo the number of stages. # This is used to index into the shared memory ring buffers that are wrapped around the number of stages. stage_idx = cons_state.index sX_tile = sX[(None, stage_idx)] @@ -944,16 +1029,21 @@ class DactStorage: mS_col_stage = cute.flatten(mS_col[(None, (tile_idx_y, tile_idx_x))]) amax_c, dbias_c = self._process_colwise( - sX_tile, sO_col_tile, - mS_col_stage, max_norm_rcp, - tile_idx_y * self._TILE_Y, bidx * self._TILE_X, M, N, + sX_tile, + sO_col_tile, + mS_col_stage, + max_norm_rcp, + tile_idx_y * self._TILE_Y, + bidx * self._TILE_X, + M, + N, sActInput_tile, ) if cutlass.const_expr(self.AMAX_FROM_COLWISE): per_thread_amax = cute.arch.fmax(per_thread_amax, amax_c) if cutlass.const_expr(self.DBIAS_REDUCTION_COLWISE): block_dbias += dbias_c - # If we cache the activation in shared memory, we need to ensure that all threads have finished writing to the shared memory + # If we cache the activation in shared memory, we need to ensure that all threads have finished writing to the shared memory # from the columnwise pass before any thread reads from it in the rowwise pass. if cutlass.const_expr(self.CACHE_ACTIVATION): cute.arch.sync_threads() @@ -967,9 +1057,14 @@ class DactStorage: # row-tile order for compact. Same source, both layouts correct. mS_row_stage = cute.flatten(mS_row[(None, (tile_idx_y, tile_idx_x))]) amax_r = self._process_rowwise( - sX_tile, sO_row_tile, - mS_row_stage, max_norm_rcp, - tile_idx_y * self._TILE_Y, bidx * self._TILE_X, M, N, + sX_tile, + sO_row_tile, + mS_row_stage, + max_norm_rcp, + tile_idx_y * self._TILE_Y, + bidx * self._TILE_X, + M, + N, sActInput_tile, rowwise_dbias_acc, ) @@ -1034,7 +1129,7 @@ class DactStorage: # and we need to write these to a SMEM buffer and let each thread reduce it in the columnwise direction block_dbias = self._dbias_reduction_rowwise_epilouge(smem, tidx, rowwise_dbias_acc) - # Write the per-tile reduced dbias to the global workspace. + # Write the per-tile reduced dbias to the global workspace. if cutlass.const_expr(cfg.WITH_DBIAS): dbias_col = bidx * self._TILE_X + tidx if dbias_col < N: @@ -1052,18 +1147,22 @@ class DactStorage: def _dbias_reduction_rowwise_epilouge(self, smem, tidx, rowwise_dbias_acc): # Pad the buffer to avoid bank conflicts. The logical shape is still the same. Only the stride is different. DBIAS_BUFF_WIDTH = self._TILE_X // MXFP8_BLOCK_SIZE * (MXFP8_BLOCK_SIZE + 1) - # Allocate the SMEM buffer that all threads use to reduce the two-stage partial sum (per thread) to the + + # Allocate the SMEM buffer that all threads use to reduce the two-stage partial sum (per thread) to the # partial sum (per block). @cute.struct class DbiasStorage: sDbias: cute.struct.MemRange[Float32, self._TILE_Y * DBIAS_BUFF_WIDTH] + dbias_storage = smem.allocate(DbiasStorage) sDbias = dbias_storage.sDbias.get_tensor( cute.make_layout((self._TILE_Y, self._TILE_X), stride=(DBIAS_BUFF_WIDTH, 1)), ) _, tv_layout_dbias_write = cute.make_layout_tv( - thr_layout=cute.make_layout((self._TILE_Y, self._TILE_X // MXFP8_BLOCK_SIZE), - stride=(self._TILE_X // MXFP8_BLOCK_SIZE, 1)), + thr_layout=cute.make_layout( + (self._TILE_Y, self._TILE_X // MXFP8_BLOCK_SIZE), + stride=(self._TILE_X // MXFP8_BLOCK_SIZE, 1), + ), val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), ) sDbias_write = cute.composition(sDbias, tv_layout_dbias_write) @@ -1072,7 +1171,9 @@ class DbiasStorage: bank_group = (tidx % THREADS_PER_WARP) // self._THREADS_PER_BANK # The offset this thread should start reading from based on what's its first bank to access. offset = bank_group * self._PACK_SIZE - for w in cutlass.range_constexpr(self._WAVES): # Each thread starts from this offset when writing into SMEM to avoid bank conflict + for w in cutlass.range_constexpr( + self._WAVES + ): # Each thread starts from this offset when writing into SMEM to avoid bank conflict start = (w * self._PACK_SIZE + offset) % MXFP8_BLOCK_SIZE for i in cutlass.range_constexpr(self._PACK_SIZE): # All threads write their per-thread partial sum results to the shared buffer. @@ -1081,7 +1182,7 @@ class DbiasStorage: # All threads reduce the cross-thread partial sums to the per-block partial sum. _, tv_layout_dbias_reduce = cute.make_layout_tv( thr_layout=cute.make_layout((1, self._TILE_X), stride=(self._TILE_X, 1)), - val_layout=cute.make_layout((self._TILE_Y, 1), stride=(1, 1)) + val_layout=cute.make_layout((self._TILE_Y, 1), stride=(1, 1)), ) sDbias_reduce = cute.composition(sDbias, tv_layout_dbias_reduce) # make_layout_tv yields a (thread, value) layout: thread=tidx -> column tidx, @@ -1111,21 +1212,23 @@ def _amax_epilogue(self, sAmax, mAmax, tidx, warp_idx, per_thread_amax): ) # The first thread updates the global amax with an atomic max on the bitcasted float value cute.arch.atomic_max( - amax_i32.iterator, _bitcast_f32_to_i32(cta_amax), + amax_i32.iterator, + _bitcast_f32_to_i32(cta_amax), ) @cute.jit def _process_rowwise( self, - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) - mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_row_tile, # (TILE_Y, TILE_X) uint8 smem view (rowwise FP8 output) + mS_row_stage, # rowwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, - tile_row_start, # Int32 — global row of this stage's row 0 - tile_col_start, # Int32 — global col of this CTA's col 0 - M, N, # Int32 — full input extents, for OOB masking + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, + N, # Int32 — full input extents, for OOB masking sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) - dbias_acc=None, # rmem Float32[MXFP8_BLOCK_SIZE] dbias accumulator (rowwise-only dbias) + dbias_acc=None, # rmem Float32[MXFP8_BLOCK_SIZE] dbias accumulator (rowwise-only dbias) ): cfg = self.cfg return quantize_rowwise_mxfp8( @@ -1155,13 +1258,14 @@ def _process_rowwise( @cute.jit def _process_colwise( self, - sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA - sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) - mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) + sX_tile, # (TILE_Y, TILE_X) bf16/fp16 smem view, post-TMA + sO_col_tile, # (TILE_Y, TILE_X) uint8 smem view (colwise FP8 output) + mS_col_stage, # colwise scale tensor (1D swizzled, or 2D linear) max_norm_rcp, - tile_row_start, # Int32 — global row of this stage's row 0 - tile_col_start, # Int32 — global col of this CTA's col 0 - M, N, # Int32 — full input extents, for OOB masking + tile_row_start, # Int32 — global row of this stage's row 0 + tile_col_start, # Int32 — global col of this CTA's col 0 + M, + N, # Int32 — full input extents, for OOB masking sActInput_tile=None, # (TILE_Y, TILE_X) act_input tile (dact only) ): cfg = self.cfg @@ -1187,6 +1291,7 @@ def _process_colwise( CACHE_ACTIVATION=self.CACHE_ACTIVATION, ) + class MXFP8QuantizeSpecializedRowwiseKernel: """Specialized cast-only ROWWISE-only MXFP8 kernel. Requires N % 128 == 0 (full vectorizable column chunks). @@ -1201,22 +1306,27 @@ def __init__(self, cfg): self.cfg = cfg # If True, then this kernel will first write each thread's scale byte to a shared memory buffer, # then utilize vectorized store to flush the buffer to global memory. - self._STASH_SCALE_TO_SMEM = True # Hardcode to true for now + self._STASH_SCALE_TO_SMEM = True # Hardcode to true for now @cute.jit def __call__( self, mX: cute.Tensor, - mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], - mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], # Unused, kept for API compatibility - mAmax: Optional[cute.Tensor], # Unused, kept for API compatibility - mNoop: Optional[cute.Tensor], # Unused, kept for API compatibility - mDActInput: Optional[cute.Tensor], # Unused, kept for API compatibility - mWorkspace: Optional[cute.Tensor], # Unused, kept for API compatibility + mO_row: Optional[cute.Tensor], + mS_row: Optional[cute.Tensor], + mO_col: Optional[cute.Tensor], + mS_col: Optional[cute.Tensor], # Unused, kept for API compatibility + mAmax: Optional[cute.Tensor], # Unused, kept for API compatibility + mNoop: Optional[cute.Tensor], # Unused, kept for API compatibility + mDActInput: Optional[cute.Tensor], # Unused, kept for API compatibility + mWorkspace: Optional[cute.Tensor], # Unused, kept for API compatibility stream: CUstream, ): if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): - cute.printf(f"[CuTeDSL] MXFP8QuantizeSpecializedRowwiseKernel.__call__() with config: {self.cfg}\n") + cute.printf( + "[CuTeDSL] MXFP8QuantizeSpecializedRowwiseKernel.__call__() with config:" + f" {self.cfg}\n" + ) M = mX.shape[0] N = mX.shape[1] @@ -1228,7 +1338,11 @@ def __call__( block = [self._THREADS_PER_CTA] self.kernel( - mX, mO_row, mS_row, self.cfg.MAX_NORM_RCP, mX.element_type, + mX, + mO_row, + mS_row, + self.cfg.MAX_NORM_RCP, + mX.element_type, ).launch(grid=grid, block=block, stream=stream) @cute.kernel @@ -1245,8 +1359,7 @@ def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): CTA_X = self._TILE_X // MXFP8_BLOCK_SIZE tiler, tv_layout = cute.make_layout_tv( thr_layout=cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1)), - val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), - stride=(MXFP8_BLOCK_SIZE, 1)), + val_layout=cute.make_layout((1, MXFP8_BLOCK_SIZE), stride=(MXFP8_BLOCK_SIZE, 1)), ) tiler_scale, tv_layout_scale = cute.make_layout_tv( thr_layout=cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1)), @@ -1278,14 +1391,18 @@ def kernel(self, mX, mO_row, mS_row, max_norm_rcp, DTYPE): ) rO_u32 = cute.make_tensor( cute.recast_ptr(rO_thread.iterator, dtype=Uint32), - cute.make_layout((MXFP8_BLOCK_SIZE // 4,), stride=(1,)), # Unit is Uint32, divide by 4 here + cute.make_layout( + (MXFP8_BLOCK_SIZE // 4,), stride=(1,) + ), # Unit is Uint32, divide by 4 here ) sS_thread = None if cutlass.const_expr(self._STASH_SCALE_TO_SMEM): + @cute.struct class SharedStorage: buf: cute.struct.Align[cute.struct.MemRange[Uint8, CTA_Y * CTA_X], 16] + storage = cutlass.utils.SmemAllocator().allocate(SharedStorage) sScale = storage.buf.get_tensor(cute.make_layout((CTA_Y, CTA_X), stride=(CTA_X, 1))) # sScale is (CTA_Y, CTA_X):(CTA_X, 1), which is the same layout as tv_layout_scale @@ -1321,8 +1438,13 @@ class SharedStorage: mul_cvt4 = mul_cvt_f32x4_to_fp8x4(self.cfg.FP8_DTYPE) for i in cutlass.range_constexpr(MXFP8_BLOCK_SIZE // 4): offset = 4 * i - rO_u32[i] = mul_cvt4(rX_f32[0, offset], rX_f32[0, offset + 1], - rX_f32[0, offset + 2], rX_f32[0, offset + 3], scale_2x) + rO_u32[i] = mul_cvt4( + rX_f32[0, offset], + rX_f32[0, offset + 1], + rX_f32[0, offset + 2], + rX_f32[0, offset + 3], + scale_2x, + ) cute.autovec_copy(rO_thread, mO_thread) # Cooperative wide flush of the staged scales: the first CTA_Y*(CTA_X/G) @@ -1368,6 +1490,7 @@ def _flush_scales_to_gmem(self, sScale, mS_tile, tidx, bidx, bidy, M, padded_col cute.composition(mS_tile, tv_flush)[tidx, None], ) + class MXFP8QuantizeSpecializedBidimensionalKernel: """Specialized cast-only BIDIMENSIONAL (both-direction) MXFP8 kernel — the CuTeDSL counterpart of specialized/quantize_mxfp8.cuh:: @@ -1386,8 +1509,10 @@ def __init__(self, cfg): def __call__( self, mX: cute.Tensor, - mO_row: Optional[cute.Tensor], mS_row: Optional[cute.Tensor], - mO_col: Optional[cute.Tensor], mS_col: Optional[cute.Tensor], + mO_row: Optional[cute.Tensor], + mS_row: Optional[cute.Tensor], + mO_col: Optional[cute.Tensor], + mS_col: Optional[cute.Tensor], mAmax: Optional[cute.Tensor], mNoop: Optional[cute.Tensor], mDActInput: Optional[cute.Tensor], @@ -1395,7 +1520,10 @@ def __call__( stream: CUstream, ): if cutlass.const_expr(CUTEDSL_DEBUG_LOGGING): - cute.printf(f"[CuTeDSL] MXFP8QuantizeSpecializedBidimensionalKernel.__call__() with config: {self.cfg}\n") + cute.printf( + "[CuTeDSL] MXFP8QuantizeSpecializedBidimensionalKernel.__call__() with config:" + f" {self.cfg}\n" + ) # TODO(next round): TMA-based 32x32-tile bidimensional cast-only kernel — # grid/launch + dual-direction (rowwise+colwise) scale/cast. No output is # produced yet; this stub exists so dispatch routing can be wired now. @@ -1405,8 +1533,11 @@ def get_kernel_class(cfg): """If no fusion is involved and the kernel only quantizes, dispatch to the specialized kernel for better performance.""" plain_cast_only = ( not cfg.WITH_GEMM_SWIZZLED_SCALES - and not cfg.WITH_AMAX and not cfg.WITH_DBIAS - and not cfg.WITH_DACT and not cfg.WITH_ACT and not cfg.WITH_NOOP + and not cfg.WITH_AMAX + and not cfg.WITH_DBIAS + and not cfg.WITH_DACT + and not cfg.WITH_ACT + and not cfg.WITH_NOOP ) if plain_cast_only: if cfg.ROWWISE and not cfg.COLWISE: @@ -1439,35 +1570,108 @@ def compile_cutedsl_function_from_cfg(cfg): # scales their own fresh syms carrying the divisibility the padding # guarantees (rowwise: 128 x 4; colwise: 4 x 128). scale_rowwise_shape = (cute.sym_int32(divisibility=128), cute.sym_int32(divisibility=4)) - scale_colwise_shape = (cute.sym_int32(divisibility=4), cute.sym_int32(divisibility=128)) - ws_shape = (cute.sym_int32(), sym_N) # (blocks_Y, N); N ties to input N - - in_fake = cute.runtime.make_fake_compact_tensor(cfg.DTYPE, in_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) - out_row_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, out_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) if cfg.ROWWISE else None - scale_row_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, scale_rowwise_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.ROWWISE else None - out_col_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, out_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) if cfg.COLWISE else None - scale_col_fake = cute.runtime.make_fake_compact_tensor(cute.Uint8, scale_colwise_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.COLWISE else None - amax_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.WITH_AMAX else None - noop_fake = cute.runtime.make_fake_compact_tensor(Float32, (1,), stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.WITH_NOOP else None - act_input_fake = cute.runtime.make_fake_compact_tensor(cfg.DTYPE, in_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16) if cfg.WITH_DACT else None - workspace_fake = cute.runtime.make_fake_compact_tensor(Float32, ws_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4) if cfg.WITH_DBIAS else None + scale_colwise_shape = (cute.sym_int32(divisibility=4), cute.sym_int32(divisibility=128)) + ws_shape = (cute.sym_int32(), sym_N) # (blocks_Y, N); N ties to input N + + in_fake = cute.runtime.make_fake_compact_tensor( + cfg.DTYPE, in_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=16 + ) + out_row_fake = ( + cute.runtime.make_fake_compact_tensor( + cute.Uint8, + out_shape, + stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + if cfg.ROWWISE + else None + ) + scale_row_fake = ( + cute.runtime.make_fake_compact_tensor( + cute.Uint8, + scale_rowwise_shape, + stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, + assumed_align=4, + ) + if cfg.ROWWISE + else None + ) + out_col_fake = ( + cute.runtime.make_fake_compact_tensor( + cute.Uint8, + out_shape, + stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + if cfg.COLWISE + else None + ) + scale_col_fake = ( + cute.runtime.make_fake_compact_tensor( + cute.Uint8, + scale_colwise_shape, + stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, + assumed_align=4, + ) + if cfg.COLWISE + else None + ) + amax_fake = ( + cute.runtime.make_fake_compact_tensor( + Float32, (1,), stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4 + ) + if cfg.WITH_AMAX + else None + ) + noop_fake = ( + cute.runtime.make_fake_compact_tensor( + Float32, (1,), stride_order=(0,), memspace=cute.AddressSpace.gmem, assumed_align=4 + ) + if cfg.WITH_NOOP + else None + ) + act_input_fake = ( + cute.runtime.make_fake_compact_tensor( + cfg.DTYPE, + in_shape, + stride_order=(1, 0), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + if cfg.WITH_DACT + else None + ) + workspace_fake = ( + cute.runtime.make_fake_compact_tensor( + Float32, ws_shape, stride_order=(1, 0), memspace=cute.AddressSpace.gmem, assumed_align=4 + ) + if cfg.WITH_DBIAS + else None + ) compiled = cute.compile( kernel_obj, - in_fake, # mX - out_row_fake, scale_row_fake, # mO_row, mS_row - out_col_fake, scale_col_fake, # mO_col, mS_col - amax_fake, # mAmax - noop_fake, # mNoop (1-element cast_noop flag) - act_input_fake, # mDActInput (backward slot, unused) - workspace_fake, # mWorkspace(backward slot, unused) - cute.runtime.make_fake_stream(), # stream (compiled as an explicit tvm-ffi - # "handle" arg; C++ passes the CUDA stream - # as void*) + in_fake, # mX + out_row_fake, + scale_row_fake, # mO_row, mS_row + out_col_fake, + scale_col_fake, # mO_col, mS_col + amax_fake, # mAmax + noop_fake, # mNoop (1-element cast_noop flag) + act_input_fake, # mDActInput (backward slot, unused) + workspace_fake, # mWorkspace(backward slot, unused) + cute.runtime.make_fake_stream(), # stream (compiled as an explicit tvm-ffi + # "handle" arg; C++ passes the CUDA stream + # as void*) options="--enable-tvm-ffi", ) return compiled + def get_mxfp8_quantization_function( fn_name: str, dtype: str, @@ -1482,8 +1686,8 @@ def get_mxfp8_quantization_function( with_noop: bool, activation: str, ) -> bool: - """Compile the MXFP8 quantize kernel for this config and register it in the TVM-FFI global registry - under EXACTLY `fn_name` (the key the C++ dispatcher built; Python treats it as an opaque name). + """Compile the MXFP8 quantize kernel for this config and register it in the TVM-FFI global registry + under EXACTLY `fn_name` (the key the C++ dispatcher built; Python treats it as an opaque name). Returns True if a kernel is successfully registered under `fn_name` (the C++ side then fetches it with GetGlobal(fn_name)); False if the config is unsupported, so the caller caches the negative result and falls back to the CUDA C++ kernel. """ @@ -1509,8 +1713,10 @@ def get_mxfp8_quantization_function( # The exception message states exactly why the config is unsupported # (unknown dtype/activation, dbias not implemented, ...). Surfacing it as a # warning lets the C++ dispatcher's CUDA fallback be recognized as expected. - logger.warning(f"CuTeDSL MXFP8 backend does not support this config, " - f"falling back to the CUDA C++ kernel: {e}") + logger.warning( + "CuTeDSL MXFP8 backend does not support this config, " + f"falling back to the CUDA C++ kernel: {e}" + ) return False logger.debug(f"Compiling CuTeDSL MXFP8 quantization kernel for {cfg}") @@ -1519,5 +1725,8 @@ def get_mxfp8_quantization_function( return True + # Exposed so the C++ dispatcher can request on-demand compilation by name. -tvm_ffi.register_global_func("get_mxfp8_quantization_function", get_mxfp8_quantization_function, override=True) +tvm_ffi.register_global_func( + "get_mxfp8_quantization_function", get_mxfp8_quantization_function, override=True +) diff --git a/transformer_engine/common/CuTeDSL/utils.py b/transformer_engine/common/CuTeDSL/utils.py index bc5b09cd7d..ddf7ed56ce 100644 --- a/transformer_engine/common/CuTeDSL/utils.py +++ b/transformer_engine/common/CuTeDSL/utils.py @@ -13,16 +13,20 @@ } _STR_FROM_CUTLASS_DTYPE = {v: k for k, v in _CUTLASS_DTYPE_FROM_STR.items()} + def str_to_cutlass_dtype(dtype_str: str): """Convert a string dtype to a cutlass dtype, or None if unknown.""" return _CUTLASS_DTYPE_FROM_STR.get(dtype_str, None) + def cutlass_dtype_to_str(dtype): """Convert a cutlass dtype back to its protocol string, or None if unknown.""" return _STR_FROM_CUTLASS_DTYPE.get(dtype, None) + FP32_MANTISSA_BITS = 23 + @dsl_user_op def _bitcast_f32_to_i32(val: Float32, *, loc=None, ip=None) -> Int32: """Bitcast a float32 value to int32 without changing the bit pattern.""" @@ -46,14 +50,17 @@ def fabs_f32(val: Float32, *, loc=None, ip=None) -> Float32: @dsl_user_op def fma_f32(a: Float32, b: Float32, c: Float32, *, loc=None, ip=None) -> Float32: """Compute the fused multiply-add of three float32 values: a * b + c.""" - return Float32(llvm.inline_asm( - T.f32(), - [a.ir_value(loc=loc, ip=ip), - b.ir_value(loc=loc, ip=ip), - c.ir_value(loc=loc, ip=ip)], - "fma.rn.f32 $0, $1, $2, $3;", - "=f,f,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32( + llvm.inline_asm( + T.f32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip), c.ir_value(loc=loc, ip=ip)], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op @@ -61,14 +68,20 @@ def exp2f_rcp(biased_exp: Int32, *, loc=None, ip=None) -> Float32: """2^(127 - biased_exp) with special-case handling.""" new_exp = (Int32(254) - biased_exp) << Int32(FP32_MANTISSA_BITS) result = _bitcast_i32_to_f32(new_exp, loc=loc, ip=ip) - for (cmp_val, repl_bits) in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: - cond = mlir_arith.cmpi(mlir_arith.CmpIPredicate.eq, - biased_exp.ir_value(loc=loc, ip=ip), - Int32(cmp_val).ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + for cmp_val, repl_bits in [(255, 0x7FFFFFFF), (254, 0x00400000), (0, 0x7F000000)]: + cond = mlir_arith.cmpi( + mlir_arith.CmpIPredicate.eq, + biased_exp.ir_value(loc=loc, ip=ip), + Int32(cmp_val).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) alt = _bitcast_i32_to_f32(Int32(repl_bits), loc=loc, ip=ip) - result = Float32(mlir_arith.select( - cond, alt.ir_value(loc=loc, ip=ip), - result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result = Float32( + mlir_arith.select( + cond, alt.ir_value(loc=loc, ip=ip), result.ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) return result @@ -79,12 +92,17 @@ def pack_f32x2(lo: Float32, hi: Float32, *, loc=None, ip=None) -> Int64: Low 32 bits = `lo`, high 32 bits = `hi`. Uses `mov.b64 %dst, {%lo, %hi};` which lowers to a single register move — no actual memory traffic. """ - return Int64(llvm.inline_asm( - T.i64(), - [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], - "mov.b64 $0, {$1, $2};", - "=l,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int64( + llvm.inline_asm( + T.i64(), + [lo.ir_value(loc=loc, ip=ip), hi.ir_value(loc=loc, ip=ip)], + "mov.b64 $0, {$1, $2};", + "=l,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) def _build_mul_cvt_f32x4(out_fmt: str, relu: bool = False): @@ -116,16 +134,27 @@ def _build_mul_cvt_f32x4(out_fmt: str, relu: bool = False): ) @dsl_user_op - def fn(v0: Float32, v1: Float32, v2: Float32, v3: Float32, scale_2x: Int64, - *, loc=None, ip=None) -> Uint32: - return Uint32(llvm.inline_asm( - T.i32(), - [v0.ir_value(loc=loc, ip=ip), v1.ir_value(loc=loc, ip=ip), - v2.ir_value(loc=loc, ip=ip), v3.ir_value(loc=loc, ip=ip), - scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=r,f,f,f,f,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + def fn( + v0: Float32, v1: Float32, v2: Float32, v3: Float32, scale_2x: Int64, *, loc=None, ip=None + ) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [ + v0.ir_value(loc=loc, ip=ip), + v1.ir_value(loc=loc, ip=ip), + v2.ir_value(loc=loc, ip=ip), + v3.ir_value(loc=loc, ip=ip), + scale_2x.ir_value(loc=loc, ip=ip), + ], + asm, + "=r,f,f,f,f,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return fn @@ -136,6 +165,7 @@ def mul_cvt_f32x4_to_fp8x4(fp8_dtype: str, relu: bool = False): fp8 bytes, byte i = fp8(v_i * scale). `scale_2x` is pack_f32x2(s, s).""" return _build_mul_cvt_f32x4("e5m2" if fp8_dtype == "e5m2" else "e4m3", relu) + def _build_packed16_kit(in_fmt: str): """Build a kit of PTX wrappers for a 16-bit input format so we don't have to repeat the same inline asm boilerplate code for FP16 and BF16 dtypes. @@ -154,100 +184,140 @@ def _build_packed16_kit(in_fmt: str): @dsl_user_op def abs_max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: - return Int32(llvm.inline_asm( - T.i32(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", - "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - + return Int32( + llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + @dsl_user_op def max_x2(a: Int32, b: Int32, *, loc=None, ip=None) -> Int32: - return Int32(llvm.inline_asm( - T.i32(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.{in_fmt}x2 $0, $1, $2;", - "=r,r,r", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int32( + llvm.inline_asm( + T.i32(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.{in_fmt}x2 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op def abs_max_scalar(a: Int16, b: Int16, *, loc=None, ip=None) -> Int16: - return Int16(llvm.inline_asm( - T.i16(), - [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], - f"max.xorsign.abs.{in_fmt} $0, $1, $2;", - "=h,h,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Int16( + llvm.inline_asm( + T.i16(), + [a.ir_value(loc=loc, ip=ip), b.ir_value(loc=loc, ip=ip)], + f"max.xorsign.abs.{in_fmt} $0, $1, $2;", + "=h,h,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) if in_fmt == "bf16": # bf16 == top 16 bits of f32 — widening is a free bit-shift. @dsl_user_op def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: - i32 = Int32(mlir_arith.extui( - T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + i32 = Int32(mlir_arith.extui(T.i32(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) @dsl_user_op def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - return _bitcast_i32_to_f32( - (bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) + return _bitcast_i32_to_f32((bits & Int32(0xFFFF)) << Int32(16), loc=loc, ip=ip) @dsl_user_op def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: # `(x >> 16) << 16` ≡ `x & 0xFFFF0000`, sidestepping signed-literal # issues. Sign bits from the arith-right shift get zeroed by the # left shift. - return _bitcast_i32_to_f32( - (bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) + return _bitcast_i32_to_f32((bits >> Int32(16)) << Int32(16), loc=loc, ip=ip) @dsl_user_op def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: """Round f32 to bf16 precision (round-to-nearest-even), keep f32. Matches C++'s `static_cast(static_cast(elt))`.""" - bf16_bits = Int16(llvm.inline_asm( - T.i16(), [val.ir_value(loc=loc, ip=ip)], - "cvt.rn.bf16.f32 $0, $1;", - "=h,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - i32 = Int32(mlir_arith.extui( - T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + bf16_bits = Int16( + llvm.inline_asm( + T.i16(), + [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.bf16.f32 $0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + i32 = Int32( + mlir_arith.extui(T.i32(), bf16_bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return _bitcast_i32_to_f32(i32 << Int32(16), loc=loc, ip=ip) + else: # f16 has its own bit layout; widening requires `cvt.f32.f16`. @dsl_user_op def bits_to_f32(bits: Int16, *, loc=None, ip=None) -> Float32: - return Float32(llvm.inline_asm( - T.f32(), [bits.ir_value(loc=loc, ip=ip)], - "cvt.f32.f16 $0, $1;", - "=f,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + return Float32( + llvm.inline_asm( + T.f32(), + [bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) @dsl_user_op def x2_lo_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: - lo_i16 = Int16(mlir_arith.trunci( - T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + lo_i16 = Int16( + mlir_arith.trunci(T.i16(), bits.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return bits_to_f32(lo_i16, loc=loc, ip=ip) @dsl_user_op def x2_hi_to_f32(bits: Int32, *, loc=None, ip=None) -> Float32: hi_shifted = bits >> Int32(16) - hi_i16 = Int16(mlir_arith.trunci( - T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + hi_i16 = Int16( + mlir_arith.trunci(T.i16(), hi_shifted.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return bits_to_f32(hi_i16, loc=loc, ip=ip) @dsl_user_op def truncate_f32(val: Float32, *, loc=None, ip=None) -> Float32: """Round f32 to f16 precision, keep f32.""" - f16_bits = Int16(llvm.inline_asm( - T.i16(), [val.ir_value(loc=loc, ip=ip)], - "cvt.rn.f16.f32 $0, $1;", - "=h,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Float32(llvm.inline_asm( - T.f32(), [f16_bits.ir_value(loc=loc, ip=ip)], - "cvt.f32.f16 $0, $1;", - "=f,h", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) + f16_bits = Int16( + llvm.inline_asm( + T.i16(), + [val.ir_value(loc=loc, ip=ip)], + "cvt.rn.f16.f32 $0, $1;", + "=h,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Float32( + llvm.inline_asm( + T.f32(), + [f16_bits.ir_value(loc=loc, ip=ip)], + "cvt.f32.f16 $0, $1;", + "=f,h", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) def _build_mul_cvt(out_fmt: str, relu: bool = False): """Build a fused `x2 * f32x2 → fp8x2` PTX wrapper. @@ -268,21 +338,27 @@ def _build_mul_cvt(out_fmt: str, relu: bool = False): "mov.b64 vp0, {v1, v2};\n\t" "mul.f32x2 vp1, vp0, $2;\n\t" "mov.b64 {v2, v1}, vp1;\n\t" - f"cvt.rn.satfinite{".relu" if relu else ""}.{out_op}.f32 $0, v1, v2;\n\t" + f"cvt.rn.satfinite{'.relu' if relu else ''}.{out_op}.f32 $0, v1, v2;\n\t" "}" ) @dsl_user_op def fn(val_2x: Int32, scale_2x: Int64, *, loc=None, ip=None) -> Int32: - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_2x.ir_value(loc=loc, ip=ip), - scale_2x.ir_value(loc=loc, ip=ip)], - asm, - "=h,r,l", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_2x.ir_value(loc=loc, ip=ip), scale_2x.ir_value(loc=loc, ip=ip)], + asm, + "=h,r,l", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) + return fn def mul_cvt_to_fp8x2(fp8_dtype: str, relu: bool = False): @@ -310,6 +386,7 @@ def is_packed16(dtype) -> bool: """True if `dtype` is one of the 16-bit packed input formats.""" return dtype is cutlass.BFloat16 or dtype is cutlass.Float16 + def packed16_kit(dtype): """Trace-time selector — pick a Packed16Kit for the input dtype.""" if dtype is cutlass.Float16: diff --git a/transformer_engine/common/CuTeDSL/utils_fp8.py b/transformer_engine/common/CuTeDSL/utils_fp8.py index 4d09f2434c..51fdfe8ecf 100644 --- a/transformer_engine/common/CuTeDSL/utils_fp8.py +++ b/transformer_engine/common/CuTeDSL/utils_fp8.py @@ -13,18 +13,25 @@ logger = logging.getLogger("transformer_engine.cutedsl.utils_fp8") + @dsl_user_op def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: """float32 -> fp8e4m3 conversion.""" zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return result_i32 & Int32(0xFF) @@ -32,14 +39,20 @@ def cvt_f32_to_fp8e4m3(val: Float32, *, loc=None, ip=None) -> Int32: def cvt_f32_to_fp8e5m2(val: Float32, *, loc=None, ip=None) -> Int32: """float32 -> fp8e5m2 conversion.""" zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rn.satfinite.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return result_i32 & Int32(0xFF) @@ -52,9 +65,11 @@ def cvt_f32_to_fp8e8m0_non_blackwell(val: Float32, *, loc=None, ip=None) -> Int3 val_i32 = _bitcast_f32_to_i32(val, loc=loc, ip=ip) rounded = val_i32 + Int32(0x7FFFFF) exponent = (rounded >> Int32(FP32_MANTISSA_BITS)) & Int32(0xFF) - return Int32(mlir_arith.minsi( - exponent.ir_value(loc=loc, ip=ip), - Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + return Int32( + mlir_arith.minsi( + exponent.ir_value(loc=loc, ip=ip), Int32(254).ir_value(loc=loc, ip=ip), loc=loc, ip=ip + ) + ) @dsl_user_op @@ -65,14 +80,20 @@ def cvt_f32_to_fp8e8m0_blackwell(val: Float32, *, loc=None, ip=None) -> Int32: ptx::float_to_e8m0's Blackwell branch. The x2 form packs two e8m0 bytes; we feed (0.0, val) so the low byte is e8m0(val) and mask it out.""" zero = Float32(0.0) - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], - "cvt.rp.satfinite.ue8m0x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - result_i32 = Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [zero.ir_value(loc=loc, ip=ip), val.ir_value(loc=loc, ip=ip)], + "cvt.rp.satfinite.ue8m0x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + result_i32 = Int32( + mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip) + ) return result_i32 & Int32(0xFF) @@ -94,6 +115,7 @@ def _target_arch_is_blackwell() -> bool: major_minor = re.search(r"(\d+)", arch).group(1) # "120" else: from cuda.core import Device + major_minor = Device().arch # compute capability as digits, e.g. "120" # Trailing digit is the minor version; the rest is the major version. return int(major_minor[:-1]) in (10, 11, 12) @@ -105,42 +127,54 @@ def _target_arch_is_blackwell() -> bool: # Pick the appropriate float32 -> fp8e8m0 conversion function based on the target architecture. # Blackwell (SM >= 100) has a hardware instruction for this, while older architectures require a software implementation. cvt_f32_to_fp8e8m0 = ( - cvt_f32_to_fp8e8m0_blackwell if _target_arch_is_blackwell() else cvt_f32_to_fp8e8m0_non_blackwell + cvt_f32_to_fp8e8m0_blackwell + if _target_arch_is_blackwell() + else cvt_f32_to_fp8e8m0_non_blackwell ) @dsl_user_op -def cvt_f32x2_to_fp8e4m3x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: +def cvt_f32x2_to_fp8e4m3x2( + val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None +) -> Int32: """Convert two float32 values to two packed fp8e4m3fn bytes in one instruction. Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). """ - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e4m3x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32(mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) @dsl_user_op -def cvt_f32x2_to_fp8e5m2x2(val_hi: Float32, val_lo: Float32, relu: bool = False, - *, loc=None, ip=None) -> Int32: +def cvt_f32x2_to_fp8e5m2x2( + val_hi: Float32, val_lo: Float32, relu: bool = False, *, loc=None, ip=None +) -> Int32: """Convert two float32 values to two packed fp8e5m2 bytes in one instruction. Returns an int32 where bits [7:0] = fp8(val_lo), bits [15:8] = fp8(val_hi). """ - result_i16 = Int16(llvm.inline_asm( - T.i16(), - [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], - f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", - "=h,f,f", has_side_effects=False, is_align_stack=False, - asm_dialect=llvm.AsmDialect.AD_ATT)) - return Int32(mlir_arith.extui( - T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) + result_i16 = Int16( + llvm.inline_asm( + T.i16(), + [val_hi.ir_value(loc=loc, ip=ip), val_lo.ir_value(loc=loc, ip=ip)], + f"cvt.rn.satfinite{".relu" if relu else ""}.e5m2x2.f32 $0, $1, $2;", + "=h,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + return Int32(mlir_arith.extui(T.i32(), result_i16.ir_value(loc=loc, ip=ip), loc=loc, ip=ip)) def get_cvt_f32_to_fp8_func(fp8_dtype: str): @@ -155,4 +189,3 @@ def get_cvt_f32x2_to_fp8x2_func(fp8_dtype: str): if fp8_dtype == "e5m2": return cvt_f32x2_to_fp8e5m2x2 return cvt_f32x2_to_fp8e4m3x2 - diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index dfc01646dc..5051ce32ce 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -18,13 +18,13 @@ #include "../../common.h" #include "../../transpose/cast_transpose.h" -#include "../mxfp8/quantize_mxfp8_cutedsl.cuh" #include "../../util/cuda_runtime.h" #include "../../util/vectorized_pointwise.h" #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#include "../mxfp8/quantize_mxfp8_cutedsl.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_4over6_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" @@ -32,7 +32,6 @@ namespace transformer_engine { namespace dispatch { - template void quantize_fwd_helper(const NVTETensor input, NVTETensor output, const NVTEQuantizationConfig quant_config, cudaStream_t stream) { @@ -92,13 +91,13 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, Tensor *dummy_workspace_tensor = nullptr; bool quantized_with_cutedsl = cutedsl_backend::mxfp8_quantize_cutedsl( - input_tensor, dummy_input_tensor, noop_tensor, output_tensor, - dummy_dbias_tensor, dummy_workspace_tensor, stream); + ParamOP, OP>( + input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); if (!quantized_with_cutedsl) { mxfp8::quantize( - *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, - dummy_workspace_tensor, stream); + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); } break; } @@ -264,8 +263,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens case NVTE_MXFP8_1D_SCALING: { bool quantized_with_cutedsl = cutedsl_backend::mxfp8_quantize_cutedsl( - grad_tensor, input_tensor, noop_tensor, output_tensor, - dbias_tensor, workspace_tensor, stream); + grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); if (!quantized_with_cutedsl) { mxfp8::quantize( *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh index 0ebef35c05..934adec59c 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8_cutedsl.cuh @@ -7,13 +7,13 @@ #ifndef TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ #define TRANSFORMER_ENGINE_COMMON_CAST_MXFP8_QUANTIZE_MXFP8_CUTEDSL_CUH_ +#include +#include + #include #include #include -#include -#include - #include "../../common.h" #include "../../tvm_ffi_bridge.h" #include "../../util/math.h" @@ -29,32 +29,42 @@ using namespace tvm_ffi_bridge; struct MXFP8QuantConfig { static constexpr const char *kEntrypointName = "get_mxfp8_quantization_function"; - DType dtype; // The input format - DType fp8_dtype; // The fp8 output format - bool rowwise; // If quantize rowwisely - bool colwise; // If quantize columnwisely - bool swizzled; // If the scale output is used for cudnn's swizzled layout - bool with_amax; // If the kernel should return the amax - bool with_dbias = false; // If the dbias is computated (via the workspace tensor) - bool with_dact = false; // If an activation derivative operation is fused - bool with_act = false; // If an activation operation is fused - bool with_noop = false; // If a non-nullptr noop tensor is passed to the kernel + DType dtype; // The input format + DType fp8_dtype; // The fp8 output format + bool rowwise; // If quantize rowwisely + bool colwise; // If quantize columnwisely + bool swizzled; // If the scale output is used for cudnn's swizzled layout + bool with_amax; // If the kernel should return the amax + bool with_dbias = false; // If the dbias is computated (via the workspace tensor) + bool with_dact = false; // If an activation derivative operation is fused + bool with_act = false; // If an activation operation is fused + bool with_noop = false; // If a non-nullptr noop tensor is passed to the kernel Activation activation = Activation::kNone; std::string to_key() const { std::string key; key.reserve(56); key.append("cutedsl_mxfp8_") - .append(te_dtype_to_str(dtype)).append("_") - .append(te_dtype_to_str(fp8_dtype)).append("_") - .append(rowwise ? "1" : "0").append("_") - .append(colwise ? "1" : "0").append("_") - .append(swizzled ? "1" : "0").append("_") - .append(with_amax ? "1" : "0").append("_") - .append(with_dbias ? "1" : "0").append("_") - .append(with_dact ? "1" : "0").append("_") - .append(with_act ? "1" : "0").append("_") - .append(with_noop ? "1" : "0").append("_") + .append(te_dtype_to_str(dtype)) + .append("_") + .append(te_dtype_to_str(fp8_dtype)) + .append("_") + .append(rowwise ? "1" : "0") + .append("_") + .append(colwise ? "1" : "0") + .append("_") + .append(swizzled ? "1" : "0") + .append("_") + .append(with_amax ? "1" : "0") + .append("_") + .append(with_dbias ? "1" : "0") + .append("_") + .append(with_dact ? "1" : "0") + .append("_") + .append(with_act ? "1" : "0") + .append("_") + .append(with_noop ? "1" : "0") + .append("_") .append(activation_to_str(activation)); return key; } @@ -64,11 +74,11 @@ struct MXFP8QuantConfig { if (!entrypoint.has_value()) { return false; } - tvm::ffi::Any result = (*entrypoint)( - tvm::ffi::String(fn_name), tvm::ffi::String(te_dtype_to_str(dtype)), - tvm::ffi::String(te_dtype_to_str(fp8_dtype)), rowwise, colwise, swizzled, with_amax, - with_dbias, with_dact, with_act, with_noop, - tvm::ffi::String(activation_to_str(activation))); + tvm::ffi::Any result = + (*entrypoint)(tvm::ffi::String(fn_name), tvm::ffi::String(te_dtype_to_str(dtype)), + tvm::ffi::String(te_dtype_to_str(fp8_dtype)), rowwise, colwise, swizzled, + with_amax, with_dbias, with_dact, with_act, with_noop, + tvm::ffi::String(activation_to_str(activation))); return result.try_cast().value_or(false); } }; @@ -133,16 +143,14 @@ struct MXFP8QuantFused> { // Signature mirrors mxfp8::quantize (input, act_input, noop, output, dbias, // workspace, stream). Returns false to fall back to the CUDA kernel. -inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, - const Tensor *input_tensor, const Tensor *act_input_tensor, - const Tensor *noop_tensor, Tensor *output_tensor, - Tensor *dbias_tensor, Tensor *workspace_tensor, - cudaStream_t stream) { +inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, const Tensor *input_tensor, + const Tensor *act_input_tensor, const Tensor *noop_tensor, + Tensor *output_tensor, Tensor *dbias_tensor, + Tensor *workspace_tensor, cudaStream_t stream) { constexpr size_t kCuTeDSLMXFP8ShapeAlignment = 32; const size_t flat_m = input_tensor->flat_first_dim(); const size_t flat_n = input_tensor->flat_last_dim(); - if (flat_m % kCuTeDSLMXFP8ShapeAlignment != 0 || - flat_n % kCuTeDSLMXFP8ShapeAlignment != 0) { + if (flat_m % kCuTeDSLMXFP8ShapeAlignment != 0 || flat_n % kCuTeDSLMXFP8ShapeAlignment != 0) { return false; } @@ -157,8 +165,7 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, // calls with an unallocated workspace to learn its shape, allocates a buffer of // that shape, then calls again to run. The kernel writes per-row-block partial // dbias into this workspace; reducing it to the final dbias is a separate step. - if (config.with_dbias && workspace_tensor != nullptr && - workspace_tensor->data.dptr == nullptr) { + if (config.with_dbias && workspace_tensor != nullptr && workspace_tensor->data.dptr == nullptr) { workspace_tensor->data.shape = {workspace_rows, flat_n}; workspace_tensor->data.dtype = DType::kFloat32; return true; @@ -182,9 +189,9 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, output_tensor->scale_inv.buffer_size_bytes(), stream)); } if (output_tensor->has_columnwise_data()) { - NVTE_CHECK_CUDA( - cudaMemsetAsync(output_tensor->columnwise_scale_inv.dptr, 0, - output_tensor->columnwise_scale_inv.buffer_size_bytes(), stream)); + NVTE_CHECK_CUDA(cudaMemsetAsync(output_tensor->columnwise_scale_inv.dptr, 0, + output_tensor->columnwise_scale_inv.buffer_size_bytes(), + stream)); } } @@ -200,11 +207,13 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, // Backward tensors: if the passed tensor pointer is nullptr, they will be empty DLTensorWrapper with null data pointer too tvm_ffi_bridge::DLTensorWrapper mActInput, mWorkspace; // If these tensors are not nullptr, wrap them as DLTensorWrappers with real data - if (act_input_tensor != nullptr) mActInput = tvm_ffi_bridge::DLTensorWrapper(act_input_tensor->data); - if (workspace_tensor != nullptr) mWorkspace = tvm_ffi_bridge::DLTensorWrapper(workspace_tensor->data); + if (act_input_tensor != nullptr) + mActInput = tvm_ffi_bridge::DLTensorWrapper(act_input_tensor->data); + if (workspace_tensor != nullptr) + mWorkspace = tvm_ffi_bridge::DLTensorWrapper(workspace_tensor->data); // stream is a tvm-ffi opaque "handle"; pass the CUDA stream as void*. - (*mxfp8_quant_func_opt)(&mX, &mO_row, &mS_row, &mO_col, &mS_col, &mAmax, &mNoop, - &mActInput, &mWorkspace, static_cast(stream)); + (*mxfp8_quant_func_opt)(&mX, &mO_row, &mS_row, &mO_col, &mS_col, &mAmax, &mNoop, &mActInput, + &mWorkspace, static_cast(stream)); // If WITH_DBIAS, reduce the workspace partial dbias in CUDA C++ for now. if (config.with_dbias) { @@ -220,26 +229,24 @@ inline bool mxfp8_quantize_cutedsl(const MXFP8QuantConfig &config, template bool mxfp8_quantize_cutedsl(const Tensor *input_tensor, const Tensor *act_input_tensor, - const Tensor *noop_tensor, Tensor *output_tensor, - Tensor *dbias_tensor, Tensor *workspace_tensor, - cudaStream_t stream) { + const Tensor *noop_tensor, Tensor *output_tensor, Tensor *dbias_tensor, + Tensor *workspace_tensor, cudaStream_t stream) { using Fused = MXFP8QuantFused; if constexpr (!Fused::supported) { return false; } else { const bool with_noop = noop_tensor != nullptr && noop_tensor->data.dptr != nullptr; - const MXFP8QuantConfig config{ - /*dtype=*/input_tensor->dtype(), - /*fp8_dtype=*/output_tensor->dtype(), - /*rowwise=*/output_tensor->has_data(), - /*colwise=*/output_tensor->has_columnwise_data(), - /*swizzled=*/output_tensor->with_gemm_swizzled_scales, - /*with_amax=*/output_tensor->amax.dptr != nullptr, - /*with_dbias=*/IS_DBIAS, - /*with_dact=*/IS_DACT, - /*with_act=*/IS_ACT, - /*with_noop=*/with_noop, - /*activation=*/Fused::activation}; + const MXFP8QuantConfig config{/*dtype=*/input_tensor->dtype(), + /*fp8_dtype=*/output_tensor->dtype(), + /*rowwise=*/output_tensor->has_data(), + /*colwise=*/output_tensor->has_columnwise_data(), + /*swizzled=*/output_tensor->with_gemm_swizzled_scales, + /*with_amax=*/output_tensor->amax.dptr != nullptr, + /*with_dbias=*/IS_DBIAS, + /*with_dact=*/IS_DACT, + /*with_act=*/IS_ACT, + /*with_noop=*/with_noop, + /*activation=*/Fused::activation}; return mxfp8_quantize_cutedsl(config, input_tensor, act_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); } diff --git a/transformer_engine/common/tvm_ffi_bridge.h b/transformer_engine/common/tvm_ffi_bridge.h index 866975eaac..446f045a81 100644 --- a/transformer_engine/common/tvm_ffi_bridge.h +++ b/transformer_engine/common/tvm_ffi_bridge.h @@ -7,6 +7,11 @@ #ifndef TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ #define TRANSFORMER_ENGINE_COMMON_TVM_FFI_BRIDGE_H_ +#include +#include +#include +#include + #include #include #include @@ -18,11 +23,6 @@ #include #include -#include -#include -#include -#include - #include "transformer_engine/transformer_engine.h" #include "util/cuda_runtime.h" #include "util/logging.h" @@ -32,12 +32,18 @@ namespace tvm_ffi_bridge { inline const char *te_dtype_to_str(DType dtype) { switch (dtype) { - case DType::kFloat32: return "fp32"; - case DType::kFloat16: return "fp16"; - case DType::kBFloat16: return "bf16"; - case DType::kFloat8E4M3: return "e4m3"; - case DType::kFloat8E5M2: return "e5m2"; - default: return ""; + case DType::kFloat32: + return "fp32"; + case DType::kFloat16: + return "fp16"; + case DType::kBFloat16: + return "bf16"; + case DType::kFloat8E4M3: + return "e4m3"; + case DType::kFloat8E5M2: + return "e5m2"; + default: + return ""; } } @@ -63,34 +69,55 @@ enum class Activation { inline const char *activation_to_str(Activation act) { switch (act) { - case Activation::kReLU: return "relu"; - case Activation::kGeLU: return "gelu"; - case Activation::kSiLU: return "silu"; - case Activation::kQGeLU: return "qgelu"; - case Activation::kSReLU: return "srelu"; - case Activation::kDReLU: return "drelu"; - case Activation::kDGeLU: return "dgelu"; - case Activation::kDSiLU: return "dsilu"; - case Activation::kDQGeLU: return "dqgelu"; - case Activation::kDSReLU: return "dsrelu"; - case Activation::kNone: return "none"; + case Activation::kReLU: + return "relu"; + case Activation::kGeLU: + return "gelu"; + case Activation::kSiLU: + return "silu"; + case Activation::kQGeLU: + return "qgelu"; + case Activation::kSReLU: + return "srelu"; + case Activation::kDReLU: + return "drelu"; + case Activation::kDGeLU: + return "dgelu"; + case Activation::kDSiLU: + return "dsilu"; + case Activation::kDQGeLU: + return "dqgelu"; + case Activation::kDSReLU: + return "dsrelu"; + case Activation::kNone: + return "none"; } return "none"; } inline DLDataType convert_to_dltype(NVTEDType type) { switch (type) { - case kNVTEFloat32: return DLDataType{kDLFloat, 32, 1}; - case kNVTEFloat16: return DLDataType{kDLFloat, 16, 1}; - case kNVTEBFloat16: return DLDataType{kDLBfloat, 16, 1}; - case kNVTEByte: return DLDataType{kDLUInt, 8, 1}; - case kNVTEInt32: return DLDataType{kDLInt, 32, 1}; - case kNVTEInt64: return DLDataType{kDLInt, 64, 1}; + case kNVTEFloat32: + return DLDataType{kDLFloat, 32, 1}; + case kNVTEFloat16: + return DLDataType{kDLFloat, 16, 1}; + case kNVTEBFloat16: + return DLDataType{kDLBfloat, 16, 1}; + case kNVTEByte: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEInt32: + return DLDataType{kDLInt, 32, 1}; + case kNVTEInt64: + return DLDataType{kDLInt, 64, 1}; // FP8 / E8M0 → raw 1-byte uint; the kernel interprets the bits. - case kNVTEFloat8E4M3: return DLDataType{kDLUInt, 8, 1}; - case kNVTEFloat8E5M2: return DLDataType{kDLUInt, 8, 1}; - case kNVTEFloat8E8M0: return DLDataType{kDLUInt, 8, 1}; - default: NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); + case kNVTEFloat8E4M3: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E5M2: + return DLDataType{kDLUInt, 8, 1}; + case kNVTEFloat8E8M0: + return DLDataType{kDLUInt, 8, 1}; + default: + NVTE_ERROR("unsupported NVTEDType: ", static_cast(type)); } } @@ -106,27 +133,29 @@ class DLTensorWrapper : public DLTensor { int64_t flat_first = 1; for (int i = 0; i + 1 < n; ++i) flat_first *= static_cast(tensor.shape.data[i]); const int64_t flat_last = static_cast(tensor.shape.data[n - 1]); - shape_buf_ = std::make_unique(2); + shape_buf_ = std::make_unique(2); strides_buf_ = std::make_unique(2); - shape_buf_[0] = flat_first; shape_buf_[1] = flat_last; - strides_buf_[0] = flat_last; strides_buf_[1] = 1; + shape_buf_[0] = flat_first; + shape_buf_[1] = flat_last; + strides_buf_[0] = flat_last; + strides_buf_[1] = 1; this->ndim = 2; } else { - shape_buf_ = std::make_unique(n); + shape_buf_ = std::make_unique(n); strides_buf_ = std::make_unique(n); int64_t stride = 1; for (int i = n - 1; i >= 0; --i) { - shape_buf_[i] = static_cast(tensor.shape.data[i]); + shape_buf_[i] = static_cast(tensor.shape.data[i]); strides_buf_[i] = stride; stride *= shape_buf_[i]; } this->ndim = n; } - this->data = tensor.data_ptr; - this->device = DLDevice{kDLCUDA, device_index}; - this->dtype = convert_to_dltype(tensor.dtype); - this->shape = shape_buf_.get(); - this->strides = strides_buf_.get(); + this->data = tensor.data_ptr; + this->device = DLDevice{kDLCUDA, device_index}; + this->dtype = convert_to_dltype(tensor.dtype); + this->shape = shape_buf_.get(); + this->strides = strides_buf_.get(); this->byte_offset = 0; } @@ -157,9 +186,8 @@ namespace ffi { template <> struct TypeTraits : public TypeTraits { - TVM_FFI_INLINE static void CopyToAnyView( - transformer_engine::tvm_ffi_bridge::DLTensorWrapper *src, TVMFFIAny *result - ) { + TVM_FFI_INLINE static void CopyToAnyView(transformer_engine::tvm_ffi_bridge::DLTensorWrapper *src, + TVMFFIAny *result) { if (src == nullptr || src->data == nullptr) { TypeTraits::CopyToAnyView(nullptr, result); // -> TVM-FFI None } else { @@ -191,7 +219,6 @@ struct is_lazyloadable_config< std::declval()))>> : std::true_type {}; } // namespace detail - class TVMFFICentral { public: static TVMFFICentral &getInstance() { @@ -246,8 +273,9 @@ class TVMFFICentral { private: ~TVMFFICentral() = default; - TVMFFICentral() : cutedsl_backend_enabled_(is_cutedsl_backend_enabled()), - warn_cutedsl_backend_not_chosen_(warn_if_cutedsl_backend_not_chosen()) {} + TVMFFICentral() + : cutedsl_backend_enabled_(is_cutedsl_backend_enabled()), + warn_cutedsl_backend_not_chosen_(warn_if_cutedsl_backend_not_chosen()) {} TVMFFICentral(const TVMFFICentral &) = delete; TVMFFICentral &operator=(const TVMFFICentral &) = delete; TVMFFICentral(TVMFFICentral &&) = delete; @@ -258,7 +286,7 @@ class TVMFFICentral { const char *flag = std::getenv("NVTE_ENABLE_CUTEDSL_QUANT_BACKEND"); return flag != nullptr && flag[0] != '0'; } - + static bool warn_if_cutedsl_backend_not_chosen() { const char *flag = std::getenv("NVTE_WARN_IF_CUTEDSL_BACKEND_NOT_CHOSEN"); return flag != nullptr && flag[0] != '0'; From e00e802280335250774366f237595e942bb88da9 Mon Sep 17 00:00:00 2001 From: Kaining Zhong Date: Sat, 27 Jun 2026 08:31:32 +0000 Subject: [PATCH 21/22] add test --- .../mxfp8/test_mxfp8_cutedsl_backend.py | 402 ++++++++++++++++++ 1 file changed, 402 insertions(+) create mode 100644 tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py diff --git a/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py b/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py new file mode 100644 index 0000000000..ff6862333f --- /dev/null +++ b/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py @@ -0,0 +1,402 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Recompile-robustness tests for the CuTeDSL MXFP8 quantize backend. + +The CuTeDSL backend JIT-compiles one kernel per distinct config (input dtype × +fp8 format × direction × activation × dbias × swizzle), registers it in the +TVM-FFI global registry under a config key, and fetches it per call. These tests +stress that compile/cache machinery rather than numerics: + + * many distinct configs each compile and produce finite, correct output; + * interleaving configs never clobbers a cached kernel (the right kernel is + served for each key, regardless of what else was compiled); + * a single symbolic-shape kernel handles many (M, N) shapes from one compile; + * repeated calls are bit-for-bit deterministic. + +This is backend-specific, so it only runs when the CuTeDSL MXFP8 backend is +actually active in the process. Run it with:: + + NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=1 CUTE_DSL_ARCH=sm_100a \\ + python -m pytest tests/pytorch/mxfp8/test_mxfp8_cutedsl_recompile.py + +otherwise every test skips (the env var is read once, at the first quantize). +""" +# TODO: review this file + +import pytest +import torch +import torch.nn.functional as F + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import MXFP8Quantizer + +recipe_available, reason_for_no_recipe = te.is_mxfp8_available(return_reason=True) + +_FP8 = {"e4m3": tex.DType.kFloat8E4M3, "e5m2": tex.DType.kFloat8E5M2} +_DT = {"bf16": torch.bfloat16, "fp16": torch.float16} +_FWD = {"plain", "gelu", "silu", "relu", "qgelu", "srelu"} +_FWD_FN = {"gelu": tex.gelu, "silu": tex.silu, "relu": tex.relu, + "qgelu": tex.qgelu, "srelu": tex.srelu} +_DACT_FN = {"dgelu": tex.dgelu, "dsilu": tex.dsilu, "drelu": tex.drelu, + "dqgelu": tex.dqgelu, "dsrelu": tex.dsrelu} +_DBIAS_DACT_FN = {f"dbias_{k}": getattr(tex, f"dbias_{k}") + for k in ("dgelu", "dsilu", "drelu", "dqgelu", "dsrelu")} + +# A diverse set of configs to interleave/repeat: mixed dtypes, fp8 formats, +# directions, and the plain / forward-act / dact / dbias / dbias+dact families. +_CONFIGS = [ + # (combo, rowwise, columnwise, in_dtype, fp8) + ("plain", True, True, "bf16", "e4m3"), + ("plain", True, False, "bf16", "e4m3"), + ("plain", False, True, "bf16", "e4m3"), + ("plain", True, True, "bf16", "e5m2"), + ("plain", True, True, "fp16", "e4m3"), + ("gelu", True, True, "bf16", "e4m3"), + ("relu", True, True, "bf16", "e4m3"), + ("silu", True, True, "bf16", "e4m3"), + ("dgelu", True, True, "bf16", "e4m3"), + ("dbias", True, True, "bf16", "e4m3"), + ("dbias_dsilu", True, True, "bf16", "e4m3"), + ("dbias_dqgelu", True, False, "bf16", "e4m3"), +] + + +def _inputs(M, N, in_dtype, seed=0): + g = torch.Generator(device="cuda").manual_seed(seed) + dt = _DT[in_dtype] + x = torch.empty(M, N, dtype=dt, device="cuda").uniform_(-4.0, 4.0, generator=g) + ain = torch.empty(M, N, dtype=dt, device="cuda").uniform_(-3.0, 3.0, generator=g) + return x, ain + + +def _run(combo, x, ain, rowwise, columnwise, fp8, swizzle=False): + """Quantize via the public dispatch; returns (mxfp8_tensor, dbias_or_None). + + swizzle=True requests cuBLAS-swizzled scale layout (optimize_for_gemm).""" + q = MXFP8Quantizer(fp8_dtype=_FP8[fp8], rowwise=rowwise, columnwise=columnwise) + if swizzle: + q.optimize_for_gemm = True + if combo == "plain": + return q(x), None + if combo in _FWD_FN: + return _FWD_FN[combo](x, q), None + if combo in _DACT_FN: + return _DACT_FN[combo](x, ain, q), None + if combo == "dbias": + db, out = tex.bgrad_quantize(x, q) + return out, db + if combo in _DBIAS_DACT_FN: + db, out = _DBIAS_DACT_FN[combo](x, ain, q) + return out, db + raise ValueError(f"unknown combo {combo!r}") + + +def _signature(out, db, rowwise, columnwise): + """Bit-level fingerprint of a quantized result, for golden comparison. + + The scale tensors are allocated at a 128-padded shape; only the meaningful + region is written by the kernel, so we slice to it (M, ceil(N/32)) rowwise / + (ceil(M/32), N) columnwise). Comparing the padding would spuriously fail — + it's uninitialized and reflects whatever was in the (dirty) allocator pool. + M, N are read from the data tensor, which is exactly (M, N), unpadded.""" + parts = [] + if rowwise: + data = out._rowwise_data.view(torch.uint8) + M, N = data.shape + parts += [data.clone(), + out._rowwise_scale_inv[:M, :(N + 31) // 32].clone()] + if columnwise: + data = out._columnwise_data.view(torch.uint8) + M, N = data.shape + parts += [data.clone(), + out._columnwise_scale_inv[:(M + 31) // 32, :N].clone()] + if db is not None: + parts.append(db.clone()) + return parts + + +def _sig_equal(a, b): + return len(a) == len(b) and all(torch.equal(p, q) for p, q in zip(a, b)) + + +def _ref_fwd(combo, xf): + if combo == "plain": + return xf + if combo == "gelu": + return F.gelu(xf, approximate="tanh") + if combo == "silu": + return F.silu(xf) + if combo == "relu": + return F.relu(xf) + if combo == "qgelu": + return xf * torch.sigmoid(1.702 * xf) + if combo == "srelu": + return F.relu(xf) ** 2 + raise ValueError(combo) + + +@pytest.fixture(scope="module", autouse=True) +def _require_active_cutedsl_backend(): + """Skip unless the CuTeDSL backend is actually active (it registers its kernel + under a config key in the TVM-FFI registry on first use).""" + if not recipe_available: + pytest.skip(reason_for_no_recipe) + # Trigger one quantize, then confirm the CuTeDSL kernel registered itself. + x = torch.randn(64, 64, dtype=torch.bfloat16, device="cuda") + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True)(x) + active = False + try: + import tvm_ffi + active = tvm_ffi.get_global_func( + "cutedsl_mxfp8_bf16_e4m3_1_1_0_0_0_0_0_0_none", allow_missing=True + ) is not None + except Exception: + active = False + if not active: + pytest.skip( + "CuTeDSL MXFP8 backend not active in this process; run with " + "NVTE_ENABLE_CUTEDSL_QUANT_BACKEND=1 set before the first quantize." + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +def test_interleaved_configs_do_not_clobber_each_other(): + """Compile + capture a golden output for every config, then re-run them all in + reverse order. Each must reproduce its golden bit-for-bit — compiling/running + other configs must never corrupt a cached kernel or serve the wrong one.""" + M, N = 256, 512 + golden = {} + for combo, rw, cw, dt, fp8 in _CONFIGS: + x, ain = _inputs(M, N, dt) + out, db = _run(combo, x, ain, rw, cw, fp8) + golden[(combo, rw, cw, dt, fp8)] = _signature(out, db, rw, cw) + + for cfg in reversed(_CONFIGS): + combo, rw, cw, dt, fp8 = cfg + x, ain = _inputs(M, N, dt) + out, db = _run(combo, x, ain, rw, cw, fp8) + assert _sig_equal(_signature(out, db, rw, cw), golden[cfg]), ( + f"config {cfg} produced different output after other configs were " + f"(re)compiled — cached kernel was clobbered or mis-keyed" + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +def test_cached_kernel_stable_while_new_configs_compile(): + """A fixed probe config run once for golden, then re-run after compiling each + other config. The probe output must never change — a newly compiled kernel + must not evict or overwrite the probe's cached kernel.""" + M, N = 320, 640 + p_combo, p_rw, p_cw, p_dt, p_fp8 = ("gelu", True, True, "bf16", "e4m3") + px, pain = _inputs(M, N, p_dt) + out, db = _run(p_combo, px, pain, p_rw, p_cw, p_fp8) + golden = _signature(out, db, p_rw, p_cw) + + for combo, rw, cw, dt, fp8 in _CONFIGS: + x, ain = _inputs(M, N, dt) + _run(combo, x, ain, rw, cw, fp8) # (re)compile / run another config + out, db = _run(p_combo, px, pain, p_rw, p_cw, p_fp8) + assert _sig_equal(_signature(out, db, p_rw, p_cw), golden), ( + f"probe ({p_combo}) output changed after running config {combo!r}" + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("combo", ["plain", "gelu", "dbias_dsilu"]) +def test_one_symbolic_kernel_handles_many_shapes(combo): + """The kernel is compiled once with symbolic (M, N) (divisible by 32). Feeding + many shapes through that single compile must all give finite output (and, for + forward combos, output close to the reference).""" + rw = cw = True + fp8, dt = "e4m3", "bf16" + shapes = [(32, 32), (64, 64), (32, 2048), (2048, 32), + (256, 512), (1024, 1536), (2048, 2048)] + for M, N in shapes: + x, ain = _inputs(M, N, dt) + out, _ = _run(combo, x, ain, rw, cw, fp8) + deq = out.dequantize(dtype=torch.float32) + assert torch.isfinite(deq).all(), f"{combo} {M}x{N}: non-finite output" + if combo in _FWD: + ref = _ref_fwd(combo, x.float()) + rel = (deq - ref).norm() / ref.norm().clamp_min(1e-6) + assert rel < 0.12, f"{combo} {M}x{N}: rel_err={rel:.4f}" + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +def test_repeated_calls_are_deterministic(): + """The same config + same input, called repeatedly, must be bit-for-bit + identical (the cached kernel is stable across reuse).""" + M, N = 256, 512 + for combo, rw, cw, dt, fp8 in _CONFIGS: + x, ain = _inputs(M, N, dt) + sigs = [_signature(*_run(combo, x, ain, rw, cw, fp8), rw, cw) + for _ in range(4)] + for i in range(1, len(sigs)): + assert _sig_equal(sigs[i], sigs[0]), ( + f"config ({combo},{dt},{fp8}) call {i} differs from call 0" + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("direction", ["both", "row"]) +@pytest.mark.parametrize("fp8", ["e4m3", "e5m2"]) +@pytest.mark.parametrize("combo", ["plain", "gelu", "relu"]) +def test_distinct_configs_compile_and_are_correct(combo, fp8, direction): + """Each distinct (combo, fp8, direction) is its own compile. Verify it produces + finite output close to the reference — i.e. a freshly compiled kernel is not + garbage (catches e.g. a write-index regression in a recompiled kernel).""" + rw = direction in ("both", "row") + cw = direction in ("both", "col") + M, N = 256, 512 + x, _ = _inputs(M, N, "bf16") + out, _ = _run(combo, x, None, rw, cw, fp8) + deq = out.dequantize(dtype=torch.float32) + assert torch.isfinite(deq).all() + ref = _ref_fwd(combo, x.float()) + rel = (deq - ref).norm() / ref.norm().clamp_min(1e-6) + tol = 0.12 if fp8 == "e4m3" else 0.30 + assert rel < tol, f"{combo}/{fp8}/{direction}: rel_err={rel:.4f}" + +# --------------------------------------------------------------------------- +# Numerical parity vs an fp32 reference, mirroring tests/cpp/operator/ +# test_cast_mxfp8.cu. The C++ gtests never exercise the CuTeDSL backend (it is +# registered from Python), so this re-runs the C++ methodology with the backend +# active. Same case selection as the C++ test: +# * ops: GeLU family only (the C++ test has SiLU/ReLU/QGeLU/SReLU commented +# out) -> CAST_ONLY=plain, CAST_DBIAS=dbias, CAST_ACT=gelu, CAST_DACT=dgelu, +# CAST_DBIAS_DACT=dbias_dgelu +# * direction = the C++ block_size: {1,32}=row, {32,1}=col, {32,32}=both +# * three orthogonal sweeps (the C++ INSTANTIATE_TEST_SUITE_P blocks) instead +# of one giant cross product +# * no swizzle (the C++ cast test doesn't cover it; see +# test_mxfp8_quantize_swizzle_fusion for the swizzled layout) +# Comparison also mirrors the C++ test: e8m0 scales bit-exact (zero tolerance, +# valid where the reference value matches the kernel input exactly, i.e. the +# no-activation ops), FP8 data within fp8 atol/rtol, dbias relaxed. Activation +# ops use a relative-error bound instead of bit-exact scales because the torch +# reference activation isn't bit-identical to TE's device activation (the C++ +# test gets bit-exactness only by reusing TE's own host activation). +_PT_DT = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32} +_FP8_MAX_RCP = {"e4m3": 1.0 / 448.0, "e5m2": 1.0 / 57344.0} +_FP8_T = {"e4m3": torch.float8_e4m3fn, "e5m2": torch.float8_e5m2} + +# Shapes: %32 (so the CuTeDSL backend handles them; non-%32 falls back to CUDA), +# mixing %128 and %32-not-%128 (the kernels' partial-tile / OOB edge). +_PARITY_SHAPES = [(128, 128), (256, 1024), (512, 512), (160, 160), (128, 1056), (256, 384)] +# (op for _run, is_activation): the GeLU-family ProcessingMethods. +_CPP_OPS = [("plain", False), ("dbias", False), + ("gelu", True), ("dgelu", True), ("dbias_dgelu", True)] +_CPP_DIRECTIONS = ["row", "col", "both"] + + +def _parity_inputs(M, N, in_dtype, seed=0): + g = torch.Generator(device="cuda").manual_seed(seed) + dt = _PT_DT[in_dtype] + x = torch.empty(M, N, dtype=dt, device="cuda").uniform_(-4.0, 4.0, generator=g) + ain = torch.empty(M, N, dtype=dt, device="cuda").uniform_(-3.0, 3.0, generator=g) + return x, ain + + +def _ref_value(op, x, ain): + """fp32 reference of the (pre-quantization) tensor the kernel quantizes.""" + xf = x.float() + if op in ("plain", "dbias"): + return xf + if op == "gelu": + return _ref_fwd("gelu", xf) + # dgelu / dbias_dgelu: grad (x) * d(gelu)/d(input), via autograd of the + # matching forward so it tracks TE's exact tanh-gelu derivative. + av = ain.float().detach().requires_grad_(True) + _ref_fwd("gelu", av).backward(xf) + return av.grad + + +def _ref_e8m0(amax_rcp): + """fp32 (amax * max_reciprocal) -> e8m0 scale byte. Round-up of the biased + exponent; bit-identical to the Blackwell cvt.rp.ue8m0x2 the kernel uses.""" + bits = amax_rcp.contiguous().view(torch.int32) + exp = ((bits + 0x7FFFFF) >> 23) & 0xFF + return exp.clamp(max=254).to(torch.uint8) + + +def _kernel_scale_data(out, d, M, N, fp8): + """(e8m0 scales, dequantized output) for the meaningful region, direction d.""" + if d == "row": + sc = out._rowwise_scale_inv[:M, : (N + 31) // 32] + data = out._rowwise_data.view(_FP8_T[fp8]).float() + deq = data * torch.exp2(sc.float() - 127.0).repeat_interleave(32, dim=1)[:, :N] + else: + sc = out._columnwise_scale_inv[: (M + 31) // 32, :N] + data = out._columnwise_data.view(_FP8_T[fp8]).float() + deq = data * torch.exp2(sc.float() - 127.0).repeat_interleave(32, dim=0)[:M, :] + return sc, deq + + +def _ref_scales(v, d, fp8): + """Reference e8m0 scales for value v (fp32), direction d.""" + M, N = v.shape + if d == "row": + amax = v.reshape(M, N // 32, 32).abs().amax(-1) # (M, N//32) + else: + amax = v.reshape(M // 32, 32, N).abs().amax(1) # (M//32, N) + return _ref_e8m0(amax * _FP8_MAX_RCP[fp8]) + + +def _check_parity(op, is_act, direction, M, N, in_dtype, fp8): + rw = direction in ("row", "both") + cw = direction in ("col", "both") + x, ain = _parity_inputs(M, N, in_dtype) + v = _ref_value(op, x, ain) + out, db = _run(op, x, ain, rw, cw, fp8) + + tol = 0.12 if fp8 == "e4m3" else 0.30 + for d in (["row"] if rw else []) + (["col"] if cw else []): + sc, deq = _kernel_scale_data(out, d, M, N, fp8) + assert torch.isfinite(deq).all(), f"{op}/{d}/{fp8}/{in_dtype} {M}x{N}: non-finite" + # Data: dequant within MXFP8 granularity (the C++ fp8 atol/rtol bar). + rel = (deq - v).norm() / v.norm().clamp_min(1e-6) + assert rel < tol, f"{op}/{d}/{fp8}/{in_dtype} {M}x{N}: rel_err={rel:.4f}" + # Scales: bit-exact vs the fp32 reference (C++ zero-tolerance) — only for + # no-activation ops, where the reference value equals the kernel input. + if not is_act: + assert torch.equal(sc, _ref_scales(v, d, fp8)), \ + f"{op}/{d}/{fp8}/{in_dtype} {M}x{N}: e8m0 scales differ from reference" + + if db is not None: + dref = v.sum(dim=0) + drel = (db.float() - dref).norm() / dref.norm().clamp_min(1e-6) + assert drel < 0.1, f"{op}/{in_dtype} {M}x{N}: dbias rel_err={drel:.4f}" + + +# Sweep 1 — CAST_ONLY across all shapes/directions/dtypes/formats +# (C++ OperatorTest_FusedCastMXFP8_CastOnly). +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("shape", _PARITY_SHAPES, ids=lambda s: f"{s[0]}x{s[1]}") +@pytest.mark.parametrize("in_dtype", ["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("fp8", ["e4m3", "e5m2"]) +@pytest.mark.parametrize("direction", _CPP_DIRECTIONS) +def test_parity_cast_only(direction, fp8, in_dtype, shape): + _check_parity("plain", False, direction, *shape, in_dtype, fp8) + + +# Sweep 2 — all ops/directions/shapes at bf16/e4m3 +# (C++ OperatorTest_FusedCastMXFP8_Sizes). +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("shape", _PARITY_SHAPES, ids=lambda s: f"{s[0]}x{s[1]}") +@pytest.mark.parametrize("direction", _CPP_DIRECTIONS) +@pytest.mark.parametrize("op,is_act", _CPP_OPS, ids=[o for o, _ in _CPP_OPS]) +def test_parity_ops_and_sizes(op, is_act, direction, shape): + _check_parity(op, is_act, direction, *shape, "bf16", "e4m3") + + +# Sweep 3 — all ops/dtypes/formats at a fixed both-direction shape +# (C++ OperatorTest_FusedCastMXFP8_Dtypes, {256,384}, block {32,32}). +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("in_dtype", ["bf16", "fp16", "fp32"]) +@pytest.mark.parametrize("fp8", ["e4m3", "e5m2"]) +@pytest.mark.parametrize("op,is_act", _CPP_OPS, ids=[o for o, _ in _CPP_OPS]) +def test_parity_dtypes(op, is_act, fp8, in_dtype): + _check_parity(op, is_act, "both", 256, 384, in_dtype, fp8) \ No newline at end of file From 97e801ce27ca6733a66200121e7c58f07e0643c4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 27 Jun 2026 08:33:23 +0000 Subject: [PATCH 22/22] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../mxfp8/test_mxfp8_cutedsl_backend.py | 105 +++++++++++------- 1 file changed, 62 insertions(+), 43 deletions(-) diff --git a/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py b/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py index ff6862333f..4e31643b05 100644 --- a/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py +++ b/tests/pytorch/mxfp8/test_mxfp8_cutedsl_backend.py @@ -38,29 +38,41 @@ _FP8 = {"e4m3": tex.DType.kFloat8E4M3, "e5m2": tex.DType.kFloat8E5M2} _DT = {"bf16": torch.bfloat16, "fp16": torch.float16} _FWD = {"plain", "gelu", "silu", "relu", "qgelu", "srelu"} -_FWD_FN = {"gelu": tex.gelu, "silu": tex.silu, "relu": tex.relu, - "qgelu": tex.qgelu, "srelu": tex.srelu} -_DACT_FN = {"dgelu": tex.dgelu, "dsilu": tex.dsilu, "drelu": tex.drelu, - "dqgelu": tex.dqgelu, "dsrelu": tex.dsrelu} -_DBIAS_DACT_FN = {f"dbias_{k}": getattr(tex, f"dbias_{k}") - for k in ("dgelu", "dsilu", "drelu", "dqgelu", "dsrelu")} +_FWD_FN = { + "gelu": tex.gelu, + "silu": tex.silu, + "relu": tex.relu, + "qgelu": tex.qgelu, + "srelu": tex.srelu, +} +_DACT_FN = { + "dgelu": tex.dgelu, + "dsilu": tex.dsilu, + "drelu": tex.drelu, + "dqgelu": tex.dqgelu, + "dsrelu": tex.dsrelu, +} +_DBIAS_DACT_FN = { + f"dbias_{k}": getattr(tex, f"dbias_{k}") + for k in ("dgelu", "dsilu", "drelu", "dqgelu", "dsrelu") +} # A diverse set of configs to interleave/repeat: mixed dtypes, fp8 formats, # directions, and the plain / forward-act / dact / dbias / dbias+dact families. _CONFIGS = [ # (combo, rowwise, columnwise, in_dtype, fp8) - ("plain", True, True, "bf16", "e4m3"), - ("plain", True, False, "bf16", "e4m3"), - ("plain", False, True, "bf16", "e4m3"), - ("plain", True, True, "bf16", "e5m2"), - ("plain", True, True, "fp16", "e4m3"), - ("gelu", True, True, "bf16", "e4m3"), - ("relu", True, True, "bf16", "e4m3"), - ("silu", True, True, "bf16", "e4m3"), - ("dgelu", True, True, "bf16", "e4m3"), - ("dbias", True, True, "bf16", "e4m3"), - ("dbias_dsilu", True, True, "bf16", "e4m3"), - ("dbias_dqgelu", True, False, "bf16", "e4m3"), + ("plain", True, True, "bf16", "e4m3"), + ("plain", True, False, "bf16", "e4m3"), + ("plain", False, True, "bf16", "e4m3"), + ("plain", True, True, "bf16", "e5m2"), + ("plain", True, True, "fp16", "e4m3"), + ("gelu", True, True, "bf16", "e4m3"), + ("relu", True, True, "bf16", "e4m3"), + ("silu", True, True, "bf16", "e4m3"), + ("dgelu", True, True, "bf16", "e4m3"), + ("dbias", True, True, "bf16", "e4m3"), + ("dbias_dsilu", True, True, "bf16", "e4m3"), + ("dbias_dqgelu", True, False, "bf16", "e4m3"), ] @@ -106,13 +118,11 @@ def _signature(out, db, rowwise, columnwise): if rowwise: data = out._rowwise_data.view(torch.uint8) M, N = data.shape - parts += [data.clone(), - out._rowwise_scale_inv[:M, :(N + 31) // 32].clone()] + parts += [data.clone(), out._rowwise_scale_inv[:M, : (N + 31) // 32].clone()] if columnwise: data = out._columnwise_data.view(torch.uint8) M, N = data.shape - parts += [data.clone(), - out._columnwise_scale_inv[:(M + 31) // 32, :N].clone()] + parts += [data.clone(), out._columnwise_scale_inv[: (M + 31) // 32, :N].clone()] if db is not None: parts.append(db.clone()) return parts @@ -150,9 +160,13 @@ def _require_active_cutedsl_backend(): active = False try: import tvm_ffi - active = tvm_ffi.get_global_func( - "cutedsl_mxfp8_bf16_e4m3_1_1_0_0_0_0_0_0_none", allow_missing=True - ) is not None + + active = ( + tvm_ffi.get_global_func( + "cutedsl_mxfp8_bf16_e4m3_1_1_0_0_0_0_0_0_none", allow_missing=True + ) + is not None + ) except Exception: active = False if not active: @@ -180,7 +194,7 @@ def test_interleaved_configs_do_not_clobber_each_other(): out, db = _run(combo, x, ain, rw, cw, fp8) assert _sig_equal(_signature(out, db, rw, cw), golden[cfg]), ( f"config {cfg} produced different output after other configs were " - f"(re)compiled — cached kernel was clobbered or mis-keyed" + "(re)compiled — cached kernel was clobbered or mis-keyed" ) @@ -199,9 +213,9 @@ def test_cached_kernel_stable_while_new_configs_compile(): x, ain = _inputs(M, N, dt) _run(combo, x, ain, rw, cw, fp8) # (re)compile / run another config out, db = _run(p_combo, px, pain, p_rw, p_cw, p_fp8) - assert _sig_equal(_signature(out, db, p_rw, p_cw), golden), ( - f"probe ({p_combo}) output changed after running config {combo!r}" - ) + assert _sig_equal( + _signature(out, db, p_rw, p_cw), golden + ), f"probe ({p_combo}) output changed after running config {combo!r}" @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @@ -212,8 +226,7 @@ def test_one_symbolic_kernel_handles_many_shapes(combo): forward combos, output close to the reference).""" rw = cw = True fp8, dt = "e4m3", "bf16" - shapes = [(32, 32), (64, 64), (32, 2048), (2048, 32), - (256, 512), (1024, 1536), (2048, 2048)] + shapes = [(32, 32), (64, 64), (32, 2048), (2048, 32), (256, 512), (1024, 1536), (2048, 2048)] for M, N in shapes: x, ain = _inputs(M, N, dt) out, _ = _run(combo, x, ain, rw, cw, fp8) @@ -232,12 +245,11 @@ def test_repeated_calls_are_deterministic(): M, N = 256, 512 for combo, rw, cw, dt, fp8 in _CONFIGS: x, ain = _inputs(M, N, dt) - sigs = [_signature(*_run(combo, x, ain, rw, cw, fp8), rw, cw) - for _ in range(4)] + sigs = [_signature(*_run(combo, x, ain, rw, cw, fp8), rw, cw) for _ in range(4)] for i in range(1, len(sigs)): - assert _sig_equal(sigs[i], sigs[0]), ( - f"config ({combo},{dt},{fp8}) call {i} differs from call 0" - ) + assert _sig_equal( + sigs[i], sigs[0] + ), f"config ({combo},{dt},{fp8}) call {i} differs from call 0" @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @@ -260,6 +272,7 @@ def test_distinct_configs_compile_and_are_correct(combo, fp8, direction): tol = 0.12 if fp8 == "e4m3" else 0.30 assert rel < tol, f"{combo}/{fp8}/{direction}: rel_err={rel:.4f}" + # --------------------------------------------------------------------------- # Numerical parity vs an fp32 reference, mirroring tests/cpp/operator/ # test_cast_mxfp8.cu. The C++ gtests never exercise the CuTeDSL backend (it is @@ -287,8 +300,13 @@ def test_distinct_configs_compile_and_are_correct(combo, fp8, direction): # mixing %128 and %32-not-%128 (the kernels' partial-tile / OOB edge). _PARITY_SHAPES = [(128, 128), (256, 1024), (512, 512), (160, 160), (128, 1056), (256, 384)] # (op for _run, is_activation): the GeLU-family ProcessingMethods. -_CPP_OPS = [("plain", False), ("dbias", False), - ("gelu", True), ("dgelu", True), ("dbias_dgelu", True)] +_CPP_OPS = [ + ("plain", False), + ("dbias", False), + ("gelu", True), + ("dgelu", True), + ("dbias_dgelu", True), +] _CPP_DIRECTIONS = ["row", "col", "both"] @@ -339,9 +357,9 @@ def _ref_scales(v, d, fp8): """Reference e8m0 scales for value v (fp32), direction d.""" M, N = v.shape if d == "row": - amax = v.reshape(M, N // 32, 32).abs().amax(-1) # (M, N//32) + amax = v.reshape(M, N // 32, 32).abs().amax(-1) # (M, N//32) else: - amax = v.reshape(M // 32, 32, N).abs().amax(1) # (M//32, N) + amax = v.reshape(M // 32, 32, N).abs().amax(1) # (M//32, N) return _ref_e8m0(amax * _FP8_MAX_RCP[fp8]) @@ -362,8 +380,9 @@ def _check_parity(op, is_act, direction, M, N, in_dtype, fp8): # Scales: bit-exact vs the fp32 reference (C++ zero-tolerance) — only for # no-activation ops, where the reference value equals the kernel input. if not is_act: - assert torch.equal(sc, _ref_scales(v, d, fp8)), \ - f"{op}/{d}/{fp8}/{in_dtype} {M}x{N}: e8m0 scales differ from reference" + assert torch.equal( + sc, _ref_scales(v, d, fp8) + ), f"{op}/{d}/{fp8}/{in_dtype} {M}x{N}: e8m0 scales differ from reference" if db is not None: dref = v.sum(dim=0) @@ -399,4 +418,4 @@ def test_parity_ops_and_sizes(op, is_act, direction, shape): @pytest.mark.parametrize("fp8", ["e4m3", "e5m2"]) @pytest.mark.parametrize("op,is_act", _CPP_OPS, ids=[o for o, _ in _CPP_OPS]) def test_parity_dtypes(op, is_act, fp8, in_dtype): - _check_parity(op, is_act, "both", 256, 384, in_dtype, fp8) \ No newline at end of file + _check_parity(op, is_act, "both", 256, 384, in_dtype, fp8)