From 28149875096054ab54dfd9ff5d06c85a5bf79287 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 16 Jan 2026 17:39:35 +0000 Subject: [PATCH 01/15] Generalize implementation to support different types --- .../38_block_scale_gemm/CMakeLists.txt | 2 + .../gemm_bquant_quantgrouped_bf16mxbf16.cpp | 35 +++++ .../gemm_bquant_quantgrouped_bf16mxbf8.cpp | 35 +++++ .../gemm_bquant_quantgrouped_bf16mxfp4.cpp | 12 +- .../38_block_scale_gemm/gemm_quant.cpp | 1 - .../38_block_scale_gemm/gemm_utils.hpp | 2 +- .../run_gemm_quant_example.inc | 52 ++++--- include/ck_tile/core/numeric/type_convert.hpp | 6 + .../ck_tile/host/reference/reference_gemm.hpp | 35 +++-- include/ck_tile/ops/common/utils.hpp | 1 + .../ops/epilogue/cshuffle_epilogue.hpp | 10 +- .../block/block_universal_gemm_as_bs_cr.hpp | 27 +++- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 4 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 27 ++-- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 6 +- .../gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 2 +- .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 132 ++++++++---------- 17 files changed, 247 insertions(+), 142 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 13cbcc8b558..40e3e91ac03 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -19,6 +19,8 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp gemm_bquant_quantgrouped_bf16mxfp4.cpp + gemm_bquant_quantgrouped_bf16mxbf8.cpp + gemm_bquant_quantgrouped_bf16mxbf16.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp new file mode 100644 index 00000000000..f064c4a8e6d --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_mx_bf16bf16_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"bf16mxbf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16mxbf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp new file mode 100644 index 00000000000..30d65edd474 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +void bquant_quantgrouped_mx_bf16bf8_instance_factory( + std::unordered_map>& lut) +{ + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"bf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"bf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; +} diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp index b8eb670135a..2f8cd12990c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp @@ -6,18 +6,18 @@ template using GemmConfig = GemmConfigQuantPrefill; -#define RUN_GEMM_EXAMPLE_PREC_TYPE \ - run_gemm_example_prec_type, \ - TypeConfig, \ - QuantGroupSize, \ +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); static auto _ = []() { auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); + ck_tile::e8m0_t>{}); lut[hash_multiple_strings( {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index cc4302a992c..4dbb93efc25 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -123,5 +123,4 @@ int main(int argc, char* argv[]) "group_size not supported." << std::endl; return -1; - } } diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 085d6344415..3048a01bbb6 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -45,7 +45,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_value) { using ComputeType = std::conditional_t< - std::is_same_v, + std::is_same_v, ADataType, std::conditional_t>; // Calculate thresholds diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 540d5725dd9..3f42bcf518b 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -158,10 +158,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str using BQuantPipeline = std::conditional_t< GemmConfig::PreshuffleB, ck_tile::WPQuantBPipelineAgBgCrV2, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + std::conditional_t, + ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; using ABQuantPipeline = std::conditional_t a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? args.K / 2 - : args.K, + std::is_same_v ? args.K / 2 + : args.K, args.N, args.stride_B, is_row_major(BLayout{}))); @@ -479,11 +478,11 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout)); + stride_B = + ck_tile::get_default_stride((std::is_same_v) ? (K / 2) : K, + N, + stride_B, + is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -516,7 +515,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - (std::is_same_v) ? (K / 2) : K, + (std::is_same_v) ? (K / 2) : K, N, stride_B, is_row_major(b_layout))); @@ -696,18 +695,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{ + range_min, range_max, fill_seed(gen)}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(*bq_tensor_ptr, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -1010,7 +1022,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) ck_tile::reference_mxfp4gemm_quant) && + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index deaa9e0bd90..f65b2a0e5e3 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -69,6 +69,12 @@ CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2) CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) #undef CK_TILE_TYPE_CONVERT +template <> +CK_TILE_HOST_DEVICE constexpr bf16_t type_convert(bf8_t x) +{ + return float_to_bf16(bf8_to_float(x)); +} + } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 7830150b630..77d8d47875a 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -404,21 +404,22 @@ CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); + constexpr index_t PackedSize = std::is_same_v ? 2 : 1; + auto f_mn = [&](auto m, auto n) { AccDataType v_acc = 0; AccDataType pasual = 0; - for(std::size_t k = 0; k < (K / 2); k++) + for(std::size_t k = 0; k < (K / PackedSize); k++) { using ComputeType = float; - auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; - ComputeType v_a_0, v_a_1; - ComputeType v_b_0, v_b_1; - - v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); - v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); - - if constexpr(std::is_same_v) + auto b_scale = type_convert(q((PackedSize * k) / QuantGroupSize::kK, n)) - 127; + if constexpr(std::is_same_v) { + ComputeType v_a_0, v_a_1; + ComputeType v_b_0, v_b_1; + + v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); + v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); auto b_pack = type_convert(b_element_op(b_k_n(k, n))); auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); @@ -427,9 +428,23 @@ CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + + pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; + } + else + { + ComputeType v_a; + ComputeType v_b; + + v_a = ck_tile::type_convert((a_element_op(a_m_k(m, k)))); + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); + + v_b *= b_scale_fp4; + + pasual = v_a * v_b; } - pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; v_acc += pasual; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 425083a9de3..a6f7513538d 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -23,6 +23,7 @@ template <> struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "e8m0"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4f636b59625..2a02d9fbdb7 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -97,11 +97,11 @@ struct CShuffleEpilogue BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = std::conditional_t || - std::is_same_v || - std::is_same_v, - ADataType, - BDataType>; + using BTypeToUse = std::conditional_t< + std::is_same_v || std::is_same_v || + std::is_same_v || sizeof(BDataType) < sizeof(ADataType), + ADataType, + BDataType>; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd513..a26786fbb56 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -97,7 +97,8 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; @@ -202,9 +203,14 @@ struct BlockUniversalGemmAsBsCr static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + // static distributed tensor with LDS type + BTypeTile b_warp_tile_lds_; + + // static distributed tensors with MMA type ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; @@ -219,8 +225,19 @@ struct BlockUniversalGemmAsBsCr { load_int4_tile(a_warp_tile_, a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + if constexpr(!std::is_same_v && + !std::is_same_v && + !std::is_same_v) + { + load_int4_tile( + b_warp_tile_lds_, b_block_window); + b_warp_tile_ = cast_tile(b_warp_tile_lds_); + } + else + { + load_int4_tile(b_warp_tile_, + b_block_window); + } } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c9410..8c1d9f362f8 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -21,7 +21,7 @@ struct GemmPipelineAgBgCrImplBase using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; using BDataType = - std::conditional_t, ADataType, BInDataType>; + std::conditional_t, ADataType, BInDataType>; using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; @@ -314,7 +314,7 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); using BLdsDataType = - std::conditional_t, + std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8074994fdd3..98c4178f69e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -305,11 +305,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLayout = remove_cvref_t; + using BDataType = std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -595,7 +594,9 @@ struct UniversalGemmBasePolicy using BLayout = remove_cvref_t{}, BsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std::conditional_t, + using BDataType = std::conditional_t || + sizeof(typename Problem::BDataType) < + sizeof(typename Problem::ADataType), typename Problem::ADataType, typename Problem::BDataType>; @@ -740,11 +741,11 @@ struct UniversalGemmBasePolicy constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; using BDataType = remove_cvref_t; - constexpr index_t KPerBlock = std::is_same_v + constexpr index_t KPerBlock = std::is_same_v ? Problem::BlockGemmShape::kK / 2 : Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = - std::is_same_v + std::is_same_v ? 4 : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; @@ -855,10 +856,9 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BDataType = std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); @@ -900,7 +900,8 @@ struct UniversalGemmPipelineAgBgCrPolicy using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 21bd691b497..07f0c7da0a0 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -717,7 +717,7 @@ struct QuantGemmKernel } else { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return make_naive_tensor_view( b_ptr, make_tuple(kargs.N, k_size / 2), @@ -744,7 +744,7 @@ struct QuantGemmKernel } else if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return pad_tensor_view(b_tensor_view, make_tuple(number{}, number{}), @@ -778,7 +778,7 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) return make_tile_window( b_pad_view, make_tuple(number{}, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp index 6cf9e22f414..e07c0f206d8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -126,7 +126,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< typename Problem::ADataType, - std::conditional_t, + std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>, typename Problem::CDataType, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index 7c448599edf..292250bfb8f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -24,12 +24,14 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using BDqDataType = remove_cvref_t; - using BQDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = std::conditional_t, + remove_cvref_t, + BDataType>; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); @@ -40,9 +42,14 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; - static constexpr index_t BPackedSize = + static constexpr index_t BDqPackedSize = ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BPackedSize = + std::is_same_v + ? 2 + : ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BQPackedSize = ck_tile::numeric_traits>::PackedSize; @@ -207,7 +214,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); + + auto apply_scale_func = [&]() { + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto scale = bq_block_tile(i_j_idx_scale); + auto b_scale_uint = uint32_t(scale.data) << 23; + if constexpr(std::is_same_v) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + b_block_tile(i_j_idx_lo) = type_convert( + type_convert(b_f4_lo) * bit_cast(b_scale_uint)); + b_block_tile(i_j_idx_hi) = type_convert( + type_convert(b_f4_hi) * bit_cast(b_scale_uint)); + } + else + { + auto b_pack = b_fp4_block_tile(i_j_idx); + b_block_tile(i_j_idx) = type_convert( + type_convert(b_pack) * bit_cast(b_scale_uint)); + } + }); }); - }); + }; + + apply_scale_func(); // initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -486,29 +509,10 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + apply_scale_func(); block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); __builtin_amdgcn_sched_barrier(0); @@ -551,29 +555,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = - type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + apply_scale_func(); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); From da4dbb18f8bfd9634c0ab24747c6e90b1335dd15 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 19 Jan 2026 16:50:35 +0000 Subject: [PATCH 02/15] Add tests --- .../gemm_block_scale/test_gemm_quant_base.hpp | 2 +- .../test_gemm_quant_bquant_1d_128.cpp | 12 +++++-- .../test_gemm_quant_bquant_1d_64.cpp | 14 ++++++--- .../test_gemm_quant_fixtures.hpp | 31 ++++++++++++++----- 4 files changed, 43 insertions(+), 16 deletions(-) diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp index 5937b442291..e12974d8579 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -153,7 +153,7 @@ class TestCkTileGemmQuantBase : public ::testing::Test const float max_accumulated_value) { using ComputeType = std::conditional_t< - std::is_same_v, + std::is_same_v, ADataType_, std::conditional_t>; // Calculate thresholds diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4e..d28fe620122 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; @@ -25,9 +28,12 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp index 1019caf1bca..4965d48a34b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; using BQuantGrouped = std::integral_constant; using GroupSize64 = ck_tile::QuantGroupShape>; @@ -24,10 +27,13 @@ using GroupSize64 = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant1D64Types = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a80..dcd6dcb598d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -102,7 +102,7 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigMxFp4 : public GemmConfigBase +struct GemmConfigMx : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -659,7 +659,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; + std::is_same_v ? (K / 2) : K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -671,7 +671,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? K / 2 : K, + std::is_same_v ? K / 2 : K, N, stride_B, this->is_row_major(BLayout{}))); @@ -680,14 +680,29 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{-0.5f, 0.5f}(a_m_k); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); } else { ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{range_min, range_max}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(bq_bqk_bqn, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); } @@ -769,7 +784,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) + if constexpr(std::is_same_v) ck_tile::reference_mxfp4gemm_quant, + std::conditional_t, ck_tile::MxFp4GemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; using GemmEpilogue = ck_tile::CShuffleEpilogue, + std::conditional_t, ADataType, BDataType>, ck_tile::tuple<>, From 16e849da754daa0dbdf921c9aefbc47a409151e8 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Mon, 19 Jan 2026 16:52:51 +0000 Subject: [PATCH 03/15] Use pk cvt instruction bf8 to bf16 on gfx950 --- .../ck_tile/core/tensor/tile_elementwise.hpp | 49 +++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index bc6d7d2f5ad..8a6eb90cfb8 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -282,6 +282,51 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors) return out_dstr_tensor; } +template +CK_TILE_DEVICE auto cast_tile_pk_bf16_bf8(const InTensor& in_dstr_tensors) +{ +#if defined(__gfx950__) + // This API is designed to use the _pk_ serious of function + constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); + + constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); + static_assert(thread_buffer_size % 2 == 0); + constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; + + auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); + + union + { + uint16_t i16val; + bf8_t i8val[2]; + } input; + + union + { + bf16x2_t bhalf_vec; + bf16_t bhalf_arr[2]; + } output; + + // TODO: this is rtz cvt, need be very careful + for(index_t i = 0; i < thread_buffer_size_pk; i++) + { + input.i8val[0] = in_dstr_tensors.get_thread_buffer()[2 * i + 0]; + input.i8val[1] = in_dstr_tensors.get_thread_buffer()[2 * i + 1]; + output.bhalf_vec = + __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0); + + out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = output.bhalf_arr[0]; + out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = output.bhalf_arr[1]; + } + + return out_dstr_tensor; +#else + // fallback + return tile_elementwise_in(type_convert, + in_dstr_tensors); +#endif +} + #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -354,6 +399,10 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) return impl::cast_tile_pk_fp8_fp32(src_tensor); + else if constexpr((std::is_same_v) && + std::is_same_v && + (SrcTensor::get_thread_buffer_size() % 2 == 0)) + return impl::cast_tile_pk_bf16_bf8(src_tensor); #if CK_TILE_USE_PK_FP16_TILE_CAST else if constexpr(std::is_same_v && std::is_same_v && From 263d3af09f1cc1bfbff1be21775b99ab78b6bff9 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 20 Jan 2026 13:09:14 +0000 Subject: [PATCH 04/15] Fix fp4 * scale -> bf16 to use pk instruction on gfx950 --- .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index 292250bfb8f..20beb5d28c9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -440,6 +440,23 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); + // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 on + // gfx950 + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type"); + } + }; + auto apply_scale_func = [&]() { sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { @@ -454,13 +471,12 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}; constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = type_convert( - type_convert(b_f4_lo) * bit_cast(b_scale_uint)); - b_block_tile(i_j_idx_hi) = type_convert( - type_convert(b_f4_hi) * bit_cast(b_scale_uint)); + auto b_pack = b_fp4_block_tile(i_j_idx); + + auto cvt = + pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); + b_block_tile(i_j_idx_lo) = cvt.x; + b_block_tile(i_j_idx_hi) = cvt.y; } else { From 4121ed220e1694850cc2aaee0f3dd46f7eeb3185 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 20 Jan 2026 14:24:39 +0000 Subject: [PATCH 05/15] Fix vectorsize buffer load for 16/8 bit case --- .../gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 98c4178f69e..a4950bc7979 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -594,9 +594,7 @@ struct UniversalGemmBasePolicy using BLayout = remove_cvref_t{}, BsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std::conditional_t || - sizeof(typename Problem::BDataType) < - sizeof(typename Problem::ADataType), + using BDataType = std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>; From 74e9a6d7faa2df53e7d07895363cb022540072dc Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Tue, 20 Jan 2026 14:31:18 +0000 Subject: [PATCH 06/15] Fix LDS read/write for 16/8 bit case Both A and B are using 128 bit instructions --- .../gemm_bquant_quantgrouped_bf16mxbf8.cpp | 15 +++++++-------- .../ck_tile/38_block_scale_gemm/gemm_utils.hpp | 18 ++++++++++++++++++ include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 12 ++++++++++++ .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 1 + 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp index 30d65edd474..2e09af27247 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp @@ -3,13 +3,12 @@ #include "run_gemm_quant_example.inc" -template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigMixedPrecision; -#define RUN_GEMM_EXAMPLE_PREC_TYPE \ - run_gemm_example_prec_type, \ - TypeConfig, \ - QuantGroupSize, \ +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type(arg_parser); void bquant_quantgrouped_mx_bf16bf8_instance_factory( @@ -21,9 +20,9 @@ void bquant_quantgrouped_mx_bf16bf8_instance_factory( ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + {"bf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { - using QuantGroupSize = ck_tile::QuantGroupShape>; + using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 3048a01bbb6..deee211d812 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -271,6 +271,24 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill static constexpr bool TransposeC = true; }; +// Used for A=16bit and B=8bit. The warp tile has KPack=16 +// Matrix A: Vectorsize = 8, KPack=16 -> LDS read/write vectorsize = 8 (128bit) +// Matrix B: Vectorsize = 16, KPack=16 -> LDS read/write vectorsize = 16 (128bit) +struct GemmConfigMixedPrecision : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + template struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 00512424752..d00fa58731a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -189,12 +189,24 @@ template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 2, + AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, AttrNumAccess>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 4, + AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl struct Dispatcher { using template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; From 4119209eb9f585475d7508ca837297e11cc3137b Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 21 Jan 2026 17:54:44 +0000 Subject: [PATCH 07/15] Refactor pipeline Support both scale before writing to LDS or scale after reading from LDS --- .../run_gemm_quant_example.inc | 26 +-- .../core/tensor/static_distributed_tensor.hpp | 2 +- .../ck_tile/host/reference/reference_gemm.hpp | 56 ++--- .../unary_element_wise_operation.hpp | 98 +++++++++ .../block/block_universal_gemm_as_bs_cr.hpp | 24 +-- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 12 +- .../gemm_pipeline_ag_bg_cr_scheduler.hpp | 6 + ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 19 +- .../block_universal_gemm_as_bs_bquant_cr.hpp | 204 ++++++++++++------ .../gemm_quant/kernel/gemm_quant_kernel.hpp | 49 ++--- .../gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp | 93 +++++--- .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 135 ++++++++---- .../pipeline/gemm_quant_pipeline_problem.hpp | 16 +- .../test_gemm_quant_fixtures.hpp | 10 +- 14 files changed, 475 insertions(+), 275 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 3f42bcf518b..db5cde50652 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -93,6 +93,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; + constexpr auto b_cast_policy = + std::is_same_v + ? ck_tile::CastPolicy::BeforeLDSWrite + : ck_tile::CastPolicy::AfterLDSRead; // row-col and tensor quants use the regular pipeline, A/B/AB quants use their own using PipelineProblem = std::conditional_t< @@ -135,7 +139,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ComputeDataType, GemmConfig::Scheduler, has_hot_loop_v, - tail_number_v>, + tail_number_v, + b_cast_policy>, ck_tile::GemmABQuantPipelineProblem a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? args.K / 2 - : args.K, - args.N, - args.stride_B, - is_row_major(BLayout{}))); + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -478,11 +479,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = - ck_tile::get_default_stride((std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -514,11 +511,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 10c7587bcb4..6a2d560bbcd 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -75,7 +75,7 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - thread_buffer + thread_buffer sliced_thread_data; static_ford>{}([&](auto idx) { diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 77d8d47875a..2b6a1be7a44 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -404,48 +404,38 @@ CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); - constexpr index_t PackedSize = std::is_same_v ? 2 : 1; - auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0; - AccDataType pasual = 0; - for(std::size_t k = 0; k < (K / PackedSize); k++) + AccDataType v_acc = 0; + using ComputeType = float; + ComputeType v_a; + ComputeType v_b; + + for(std::size_t k = 0; k < K; k++) { - using ComputeType = float; - auto b_scale = type_convert(q((PackedSize * k) / QuantGroupSize::kK, n)) - 127; + auto b_scale = type_convert(q((k) / QuantGroupSize::kK, n)) - 127; + auto b_scale_fp32 = type_convert(std::pow(2.0f, b_scale)); + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); if constexpr(std::is_same_v) { - ComputeType v_a_0, v_a_1; - ComputeType v_b_0, v_b_1; - - v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); - v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); - auto b_pack = type_convert(b_element_op(b_k_n(k, n))); - auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); - - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - - v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; - v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + auto b_pack = type_convert(b_element_op(b_k_n(k, n))); - pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; + if(k % 2 == 0) + { + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + v_b = type_convert(b_f4_lo); + } + else + { + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + v_b = type_convert(b_f4_hi); + } } else { - ComputeType v_a; - ComputeType v_b; - - v_a = ck_tile::type_convert((a_element_op(a_m_k(m, k)))); - v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); - auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); - - v_b *= b_scale_fp4; - - pasual = v_a * v_b; + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); } - - v_acc += pasual; + v_b *= b_scale_fp32; + v_acc += v_a * v_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); }; diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 3f58eceb333..06dfbaa5d7f 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -359,6 +359,84 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) } #endif +CK_TILE_HOST_DEVICE bf16x8_t bf8x8_to_bf16x8_scale(const bf8x8_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t i16val; + bf8_t i8val[4]; + } input; + + union + { + bf16x2_t bhalf_vec; + bf16_t bhalf_arr[2]; + } output; + + input.i8val[0] = src[0]; + input.i8val[1] = src[1]; + input.i8val[2] = src[2]; + input.i8val[3] = src[3]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 0); + y[0] = output.bhalf_arr[0]; + y[1] = output.bhalf_arr[1]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 1); + y[2] = output.bhalf_arr[0]; + y[3] = output.bhalf_arr[1]; + + input.i8val[0] = src[4]; + input.i8val[1] = src[5]; + input.i8val[2] = src[6]; + input.i8val[3] = src[7]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 0); + y[4] = output.bhalf_arr[0]; + y[5] = output.bhalf_arr[1]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 1); + y[6] = output.bhalf_arr[0]; + y[7] = output.bhalf_arr[1]; +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE bf16x8_t fp4x4_to_bf16x8_scale(const pk_fp4x4_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t u32; + pk_fp4x4_t pf4; + } cvt; + cvt.pf4 = src; + bf16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 0); + bf16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 1); + bf16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 2); + bf16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 3); + + y[0] = y0[0]; + y[1] = y0[1]; + y[2] = y1[0]; + y[3] = y1[1]; + y[4] = y2[0]; + y[5] = y2[1]; + y[6] = y3[0]; + y[7] = y3[1]; +#else + static_for<0, 4, 1>{}([&](auto i) { + auto yi = pk_fp4_to_bf16x2(src[i.value], scale); + y[2 * i.value] = yi[0]; + y[2 * i.value + 1] = yi[1]; + }); +#endif + return y; +} + struct PassThroughPack8 { static constexpr const char* name = "PassThroughPack8"; @@ -437,6 +515,26 @@ struct DequantPack8 y.hi = i4_to_half4_scale(bit_cast(x) >> 8, z); } + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const pk_fp4x4_t& x, const float& z) const + { + y = fp4x4_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf8x8_t& x, const float& z) const + { + y = bf8x8_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf16x8_t& x, const float& z) const + { + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(x[i.value]) * z); + }); + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index a26786fbb56..7f34ae24bbd 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -203,14 +203,9 @@ struct BlockUniversalGemmAsBsCr static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - // static distributed tensor with LDS type - BTypeTile b_warp_tile_lds_; + using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); + using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - // static distributed tensors with MMA type ALdsTile a_warp_tile_; BLdsTile b_warp_tile_; @@ -225,19 +220,8 @@ struct BlockUniversalGemmAsBsCr { load_int4_tile(a_warp_tile_, a_block_window); - if constexpr(!std::is_same_v && - !std::is_same_v && - !std::is_same_v) - { - load_int4_tile( - b_warp_tile_lds_, b_block_window); - b_warp_tile_ = cast_tile(b_warp_tile_lds_); - } - else - { - load_int4_tile(b_warp_tile_, - b_block_window); - } + load_int4_tile(b_warp_tile_, + b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 8c1d9f362f8..8b1456631e7 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,8 +20,9 @@ struct GemmPipelineAgBgCrImplBase using ADataType = remove_cvref_t{}, AsDataType>>; using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = - std::conditional_t, ADataType, BInDataType>; + using BDataType = std:: + conditional_t; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; @@ -313,10 +314,9 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLdsDataType = std::conditional_t; auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 957cf7ab8f3..69455abc9c2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -10,6 +10,12 @@ namespace ck_tile { +enum struct CastPolicy +{ + BeforeLDSWrite, + AfterLDSRead, +}; + enum struct GemmPipelineScheduler { Default, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index a4950bc7979..4177e556717 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -306,7 +306,7 @@ struct UniversalGemmBasePolicy CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { using BLayout = remove_cvref_t; - using BDataType = std::conditional_t, + using BDataType = std::conditional_t; @@ -588,13 +588,11 @@ struct UniversalGemmBasePolicy CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; - using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; using BLayout = remove_cvref_t{}, BsLayout>>; - using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std::conditional_t, + using BDataType = std::conditional_t; @@ -738,13 +736,12 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - using BDataType = remove_cvref_t; - constexpr index_t KPerBlock = std::is_same_v - ? Problem::BlockGemmShape::kK / 2 - : Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // If we cast before writing to LDS, the vectorsize is defined by the A type + // since the assumption is that A type is going to be the B LDS type constexpr index_t VecLoadSize = - std::is_same_v - ? 4 + Problem::BCastPolicy == CastPolicy::BeforeLDSWrite + ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using BLayout = remove_cvref_t< @@ -854,7 +851,7 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - using BDataType = std::conditional_t, + using BDataType = std::conditional_t; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9d711c48623..65a92811662 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -101,15 +102,18 @@ struct BQuantBlockUniversalGemmAsBsCr // 2. bf8, bf8, fp32 -> f32 // 3. i4, fp8, (fp8/fp32) -> f32 // 4. i4, bf8, (fp8/fp32) -> f32 - static_assert((std::is_same_v || std::is_same_v) && - (std::is_same_v || std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v) && - std::is_same_v); + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + std::is_same_v); static constexpr index_t InterWaveSchedulingMacClusters = 1; @@ -176,57 +180,17 @@ struct BQuantBlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + // Use gemm universal block distribution encoding instead of duplicating it + using BlockGemmBase = BlockUniversalGemmAsBsCr; + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; + return BlockGemmBase::MakeABlockDistributionEncode(); } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; + return BlockGemmBase::MakeBBlockDistributionEncode(); } private: @@ -235,19 +199,22 @@ struct BQuantBlockUniversalGemmAsBsCr { }; + using BlockGemmImplBase = typename BlockUniversalGemmAsBsCr:: + template BlockGemmImpl; + template - struct BlockGemmImpl + struct BlockGemmImpl : public BlockGemmImplBase { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + using BlockGemmImplBase::a_warp_tile_; + using BlockGemmImplBase::b_warp_tile_; + using BlockGemmImplBase::BLdsTileDistr; + // If we apply scale while reading from LDS, then we can use the operator() from + // BlockUniversalGemmAsBsCr + using BlockGemmImplBase::operator(); - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; + // static distributed tensor with LDS type + using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + BTypeTile b_warp_tile_lds_; template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + BQRegBlockTile& bq_block_tensor, + bool_constant = {}, + bool_constant = {}) + { + load_int4_tile( + a_warp_tile_, a_block_window); + load_int4_tile( + b_warp_tile_lds_, b_block_window); + + // Apply scale + using BDataTypeRaw = typename std:: + conditional, pk_fp4_t::type, BDataType>::type; + + constexpr auto warp_size = get_warp_size(); + constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size; + constexpr index_t thread_buffer_size = nelements / UnaryOpSize_; + const element_wise::DequantPack8 elementwise_op{}; + using SrcVectorRawType = + BDataTypeRaw __attribute__((ext_vector_type(UnaryOpSize_ / BPackedSize))); + using DstVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize_))); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // Thread buffers + using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + + BWarpThreadBuffer b_warp_thread_buffer; + BLDSThreadBuffer b_lds_thread_buffer; + + // BQuant register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + // Load thread buffer from tile (LDS type) + b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Apply scale to thread buffer and cast + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); + + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op( + b_warp_thread_buffer.template get_as()(i), + b_lds_thread_buffer.template get_as()[i], + b_scale_f); + }); + + // Store thread buffer to tile (MMA type) + b_warp_tile_.set_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), + b_warp_thread_buffer); + }); + }); + }); + } + // C += A * B template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + BQRegBlockTile bq_block_tile, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + block_gemm_impl_.LocalPrefetch( + a_block_window, b_block_window, bq_block_tile, a_load_tr, b_load_tr); + } + // C += A * B + // Apply scale after MMA template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + } + private: BlockGemmImpl block_gemm_impl_{}; }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 07f0c7da0a0..77ccf600577 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -717,20 +717,12 @@ struct QuantGemmKernel } else { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -744,16 +736,10 @@ struct QuantGemmKernel } else if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -778,17 +764,10 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp index e07c0f206d8..309585304d8 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp @@ -70,37 +70,70 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() { - // using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead) + { + using BQLayout = remove_cvref_t; + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; - constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using TileEncodingPattern = + tile_distribution_encoding_pattern_bq; - constexpr index_t warp_size = get_warp_size(); - constexpr index_t num_warps = BlockSize / get_warp_size(); - constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); - constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; - constexpr index_t K0 = KPerBlock / b_vec; - constexpr index_t K1 = K0 / KScale; - constexpr index_t K3 = K0 / K1; - constexpr index_t K2 = 1; - - constexpr index_t N0 = num_warps / NumWaveGroups; - constexpr index_t N1 = warp_size / K0; - constexpr index_t N2 = NPerBlock / (N0 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2, 0>>, - tuple, sequence<1, 0, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = K0 / K1; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } } template @@ -133,7 +166,7 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy BlockWarps, WarpGemm>; - return BlockUniversalGemmAsBsCr{}; + return BQuantBlockUniversalGemmAsBsCr{}; } }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index 20beb5d28c9..a007190d4c4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -24,8 +24,9 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using BDqDataType = std::conditional_t, remove_cvref_t, BDataType>; @@ -89,6 +90,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3()); auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; @@ -401,7 +397,6 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3{}], [&](auto idx0) { sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { @@ -466,17 +471,20 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3) { - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - auto b_pack = b_fp4_block_tile(i_j_idx); - - auto cvt = - pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); - b_block_tile(i_j_idx_lo) = cvt.x; - b_block_tile(i_j_idx_hi) = cvt.y; + if constexpr(idx1.impl_.at(0) % BPackedSize == 0) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + auto b_pack = b_fp4_block_tile(i_j_idx); + + auto cvt = + pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); + b_block_tile(i_j_idx_lo) = cvt.x; + b_block_tile(i_j_idx_hi) = cvt.y; + } } else { @@ -488,7 +496,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); + auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); + transpose_tile2d(b_shuffle_tmp, b_block_tile_); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); } Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); - bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + if constexpr(IsCastBeforeLDS) + { + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - apply_scale_func(); + apply_scale_func(); - block_sync_lds(); + block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_sync_lds(); + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + } __builtin_amdgcn_sched_barrier(0); // main body @@ -556,12 +575,14 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); + auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); + transpose_tile2d(b_shuffle_tmp, b_block_tile_); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); } Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); @@ -571,12 +592,18 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); + auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); + transpose_tile2d(b_shuffle_tmp, b_block_tile_); Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); } else { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); + auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); + Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); } block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + if constexpr(IsCastBeforeLDS) + { + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + } + else + { + block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + } block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); @@ -648,12 +690,13 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; + ck_tile::ignore = n; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BDqDataType& b) { return b; }, + [](const BElementwise& b) { return b; }, bq_dram_block_window_tmp, num_loop, p_smem); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index 9b02585e691..b6d645c3f2b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -24,7 +24,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase< ADataType_, @@ -78,9 +79,10 @@ struct GemmQuantPipelineProblemBase using AQLayout = remove_cvref_t; using BQLayout = remove_cvref_t; - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + static constexpr auto BCastPolicy = BCastPolicy_; static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0); static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0); @@ -155,7 +157,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase; + TailNum_, + BCastPolicy_>; template ? (K / 2) : K; + const ck_tile::index_t stride_B = K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -670,11 +669,8 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? K / 2 : K, - N, - stride_B, - this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); From 8981ffecc884884e7b7fa0188c76c39bc8c4316f Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 28 Jan 2026 15:48:45 +0000 Subject: [PATCH 08/15] Naming convention examples mx blockscale --- example/ck_tile/38_block_scale_gemm/CMakeLists.txt | 6 +++--- example/ck_tile/38_block_scale_gemm/README.md | 2 +- ...6mxbf16.cpp => gemm_bquant_quantgrouped_mx_bf16bf16.cpp} | 4 ++-- ...f16mxbf8.cpp => gemm_bquant_quantgrouped_mx_bf16bf8.cpp} | 4 ++-- ...f16mxfp4.cpp => gemm_bquant_quantgrouped_mx_bf16fp4.cpp} | 6 +++--- example/ck_tile/38_block_scale_gemm/gemm_quant.cpp | 2 +- 6 files changed, 12 insertions(+), 12 deletions(-) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_bf16mxbf16.cpp => gemm_bquant_quantgrouped_mx_bf16bf16.cpp} (91%) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_bf16mxbf8.cpp => gemm_bquant_quantgrouped_mx_bf16bf8.cpp} (88%) rename example/ck_tile/38_block_scale_gemm/{gemm_bquant_quantgrouped_bf16mxfp4.cpp => gemm_bquant_quantgrouped_mx_bf16fp4.cpp} (85%) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 40e3e91ac03..305cbf907b0 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -18,9 +18,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp - gemm_bquant_quantgrouped_bf16mxbf8.cpp - gemm_bquant_quantgrouped_bf16mxbf16.cpp + gemm_bquant_quantgrouped_mx_bf16fp4.cpp + gemm_bquant_quantgrouped_mx_bf16bf8.cpp + gemm_bquant_quantgrouped_mx_bf16bf16.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index eb36ae58008..8b92e10cb2f 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or bf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp similarity index 91% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp index f064c4a8e6d..77b3be86848 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf16.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp @@ -21,13 +21,13 @@ void bquant_quantgrouped_mx_bf16bf16_instance_factory( ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16mxbf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16mxbf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp similarity index 88% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp index 2e09af27247..9ac00b62ddd 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxbf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp @@ -20,13 +20,13 @@ void bquant_quantgrouped_mx_bf16bf8_instance_factory( ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp similarity index 85% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp index 2f8cd12990c..1f48609a1f0 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp @@ -20,19 +20,19 @@ static auto _ = []() { ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index 4dbb93efc25..d21c94ff891 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8, fp4") + " mxbf16bf16, mxbf16bf8, mxbf16fp4 or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") From 86f08a99dc990bc42e259ae03c4a0ab3c60d428d Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 28 Jan 2026 16:08:18 +0000 Subject: [PATCH 09/15] Rename pipeline (not limited to fp4 any more) --- .../38_block_scale_gemm/run_gemm_quant_example.inc | 2 +- include/ck_tile/ops/gemm_quant.hpp | 6 +++--- ..._bg_cr_base.hpp => gemm_mx_pipeline_ag_bg_cr_base.hpp} | 2 +- ...cr_policy.hpp => gemm_mx_pipeline_ag_bg_cr_policy.hpp} | 2 +- ...e_ag_bg_cr_v3.hpp => gemm_mx_pipeline_ag_bg_cr_v3.hpp} | 8 ++++---- .../ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) rename include/ck_tile/ops/gemm_quant/pipeline/{gemm_mxfp4_pipeline_ag_bg_cr_base.hpp => gemm_mx_pipeline_ag_bg_cr_base.hpp} (96%) rename include/ck_tile/ops/gemm_quant/pipeline/{gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp => gemm_mx_pipeline_ag_bg_cr_policy.hpp} (99%) rename include/ck_tile/ops/gemm_quant/pipeline/{gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp => gemm_mx_pipeline_ag_bg_cr_v3.hpp} (99%) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index db5cde50652..fd3102fe5f6 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -164,7 +164,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str GemmConfig::PreshuffleB, ck_tile::WPQuantBPipelineAgBgCrV2, std::conditional_t, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::MxGemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; using ABQuantPipeline = diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 696de378aaf..a7edb66f15a 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -21,9 +21,9 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp similarity index 96% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp index facec252a35..d8b055bd2ea 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp @@ -10,7 +10,7 @@ namespace ck_tile { template -struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +struct GemmMxPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { using Base = GemmPipelineAgBgCrImplBase; using ADataType = typename Base::ADataType; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp similarity index 99% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp index 309585304d8..64fff27fa1b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -9,7 +9,7 @@ namespace ck_tile { -struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmMxPipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy { using Base = UniversalGemmPipelineAgBgCrPolicy; using Base::I0; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp similarity index 99% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index a007190d4c4..03381f569ab 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -9,7 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -18,11 +18,11 @@ namespace ck_tile { // B Tile Window: global memory // C Distributed tensor: register -template -struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +template +struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; - using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + using PipelineImplBase = GemmMxPipelineAgBgCrImplBase; using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index ab3ba0640cf..8e8ea1dc6eb 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -871,7 +871,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + ck_tile::MxGemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; From ce392d5dd4708dd61ac7593c3d0187b74c1835ef Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 28 Jan 2026 16:32:54 +0000 Subject: [PATCH 10/15] Clean up tile_elementwise (casting not needed with new approach) --- .../ck_tile/core/tensor/tile_elementwise.hpp | 49 ------------------- 1 file changed, 49 deletions(-) diff --git a/include/ck_tile/core/tensor/tile_elementwise.hpp b/include/ck_tile/core/tensor/tile_elementwise.hpp index 8a6eb90cfb8..bc6d7d2f5ad 100644 --- a/include/ck_tile/core/tensor/tile_elementwise.hpp +++ b/include/ck_tile/core/tensor/tile_elementwise.hpp @@ -282,51 +282,6 @@ CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors) return out_dstr_tensor; } -template -CK_TILE_DEVICE auto cast_tile_pk_bf16_bf8(const InTensor& in_dstr_tensors) -{ -#if defined(__gfx950__) - // This API is designed to use the _pk_ serious of function - constexpr auto in_tile_dstr = InTensor::get_tile_distribution(); - - constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size(); - static_assert(thread_buffer_size % 2 == 0); - constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2; - - auto out_dstr_tensor = make_static_distributed_tensor(in_tile_dstr); - - union - { - uint16_t i16val; - bf8_t i8val[2]; - } input; - - union - { - bf16x2_t bhalf_vec; - bf16_t bhalf_arr[2]; - } output; - - // TODO: this is rtz cvt, need be very careful - for(index_t i = 0; i < thread_buffer_size_pk; i++) - { - input.i8val[0] = in_dstr_tensors.get_thread_buffer()[2 * i + 0]; - input.i8val[1] = in_dstr_tensors.get_thread_buffer()[2 * i + 1]; - output.bhalf_vec = - __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, /*scale*/ 1.f, 0); - - out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = output.bhalf_arr[0]; - out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = output.bhalf_arr[1]; - } - - return out_dstr_tensor; -#else - // fallback - return tile_elementwise_in(type_convert, - in_dstr_tensors); -#endif -} - #if CK_TILE_USE_SUBDWORD_TILE_CAST // this function assume either src or dst (or both) date type is under 1 dword // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy) @@ -399,10 +354,6 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor) std::is_same_v && (SrcTensor::get_thread_buffer_size() % 4 == 0)) return impl::cast_tile_pk_fp8_fp32(src_tensor); - else if constexpr((std::is_same_v) && - std::is_same_v && - (SrcTensor::get_thread_buffer_size() % 2 == 0)) - return impl::cast_tile_pk_bf16_bf8(src_tensor); #if CK_TILE_USE_PK_FP16_TILE_CAST else if constexpr(std::is_same_v && std::is_same_v && From 4615450a82eeb6c663f364e1fb77d53e706f3a11 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Wed, 28 Jan 2026 17:15:05 +0000 Subject: [PATCH 11/15] Clean up pipeline --- include/ck_tile/core/numeric/type_convert.hpp | 6 - .../block_universal_gemm_as_bs_bquant_cr.hpp | 2 +- .../pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp | 405 +++++++++--------- 3 files changed, 206 insertions(+), 207 deletions(-) diff --git a/include/ck_tile/core/numeric/type_convert.hpp b/include/ck_tile/core/numeric/type_convert.hpp index f65b2a0e5e3..deaa9e0bd90 100644 --- a/include/ck_tile/core/numeric/type_convert.hpp +++ b/include/ck_tile/core/numeric/type_convert.hpp @@ -69,12 +69,6 @@ CK_TILE_TYPE_CONVERT(fp16x2_t, fp16x2, fp32x2_t, fp32x2) CK_TILE_TYPE_CONVERT(bf16x2_t, bf16x2, fp32x2_t, fp32x2) #undef CK_TILE_TYPE_CONVERT -template <> -CK_TILE_HOST_DEVICE constexpr bf16_t type_convert(bf8_t x) -{ - return float_to_bf16(bf8_to_float(x)); -} - } // namespace ck_tile #include "ck_tile/core/numeric/pk_fp4.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 65a92811662..51bc35efe95 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -239,7 +239,7 @@ struct BQuantBlockUniversalGemmAsBsCr bool BLoadTranspose = false> CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, const BSmemBlockWindow& b_block_window, - BQRegBlockTile& bq_block_tensor, + const BQRegBlockTile& bq_block_tensor, bool_constant = {}, bool_constant = {}) { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index 03381f569ab..de92d45763b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -27,9 +27,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - using BDqDataType = std::conditional_t, - remove_cvref_t, - BDataType>; + using BDqDataType = remove_cvref_t; + + static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + + using BLDSType = std::conditional_t; + using BQDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -43,17 +46,16 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr index_t APackedSize = ck_tile::numeric_traits>::PackedSize; - static constexpr index_t BDqPackedSize = - ck_tile::numeric_traits>::PackedSize; static constexpr index_t BPackedSize = - std::is_same_v - ? 2 - : ck_tile::numeric_traits>::PackedSize; + ck_tile::numeric_traits>::PackedSize; static constexpr index_t BQPackedSize = ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BLDSPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BQLayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -90,8 +92,6 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 static constexpr auto TailNum = Problem::TailNum; static constexpr auto Scheduler = Problem::Scheduler; - static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; - using Base::PrefetchStages; [[nodiscard]] CK_TILE_HOST static const std::string GetName() @@ -175,6 +175,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = PipelineImplBase; + static constexpr bool is_a_col_major = + std::is_same_v; + static constexpr bool is_b_row_major = + std::is_same_v; + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; @@ -217,7 +222,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num : A_LDS_Read_Inst_Num / 2; constexpr auto num_ds_read_inst_b = - B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16 + B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? B_LDS_Read_Inst_Num : B_LDS_Read_Inst_Num / 2; @@ -233,7 +238,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 constexpr auto ds_read_a_issue_cycle = A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4; constexpr auto ds_read_b_issue_cycle = - B_LDS_Read_Width * sizeof(BDqDataType) / BDqPackedSize == 16 ? 8 : 4; + B_LDS_Read_Width * sizeof(BLDSType) / BLDSPackedSize == 16 ? 8 : 4; constexpr auto ds_read_a_mfma_rate = (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle); constexpr auto ds_read_b_mfma_rate = @@ -316,6 +321,139 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 }); } + template + CK_TILE_DEVICE static void + ScaleTile(TileType& block_tile, CastTileType& block_tile_cast, ScaleTileType& scale_tile) + { + if constexpr(IsCastBeforeLDS) + { + constexpr auto b_block = TileType::get_distributed_spans(); + constexpr auto idx1_js = tile_distributed_index<0>{}; + + // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 + // on gfx950 + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(false, "unsupported compute type"); + } + }; + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto scale = scale_tile(i_j_idx_scale); + auto b_scale_uint = uint32_t(scale.data) << 23; + if constexpr(std::is_same_v) + { + if constexpr(idx1.impl_.at(0) % BPackedSize == 0) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + auto b_pack = block_tile(i_j_idx); + auto cvt = + pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); + block_tile_cast(i_j_idx_lo) = cvt.x; + block_tile_cast(i_j_idx_hi) = cvt.y; + } + } + else + { + auto b_pack = block_tile(i_j_idx); + block_tile_cast(i_j_idx) = type_convert( + type_convert(b_pack) * bit_cast(b_scale_uint)); + } + }); + }); + } + } + + template + CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const ElementwiseFunc& element_func) const + { + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, block_tile); + Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill(lds_window, block_tile, element_func); + } + } + + template + CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const TileTypeCast& block_tile_cast, + const ElementwiseFunc& element_func) const + { + // Fill LDS and apply the scale if IsCastBeforeLDS + auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { + if constexpr(IsCastBeforeLDS) + { + return b_block_tile_cast; + } + else + { + return b_block_tile_orig; + } + }; + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast)); + Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill( + lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func); + } + } + + template + CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm, + const AWindowType& a_lds_window, + const BWindowType& b_lds_window, + const QTileType& q_block_tile) const + { + // Load from LDS + // It can apply the scale and cast if we scale after reading from LDS + if constexpr(IsCastBeforeLDS) + { + block_gemm.LocalPrefetch(a_lds_window, b_lds_window); + } + else + { + block_gemm.LocalPrefetch(a_lds_window, b_lds_window, q_block_tile); + } + } + template index_t num_loop, void* p_smem) const { + // ----------------------------------------------------------------------------------------- + // Pipeline checks static_assert( std::is_same_v> && std::is_same_v "A/B/BQ Dram block window should have the same data type as appropriate " "([A|B|BQ]DataType) defined in Problem definition!"); - constexpr bool is_a_col_major = - std::is_same_v; constexpr bool is_bq_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -393,6 +530,11 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + // This defines the scaled and casted block tile for B matrix. + // Effectively, it is used only if we scale and cast before writing to LDS. + auto bdq_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + // Block GEMM auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); @@ -405,7 +547,7 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 decltype(make_static_distributed_tensor(BBlockTileDistr{})); ABlockTile a_block_tile; - BBlockTile b_fp4_block_tile; + BBlockTile b_block_tile; using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -419,137 +561,44 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 // ----------------------------------------------------------------------------------------- // Gemm pipeline start - // prefetch - // global read 0 - // auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){}; + // prefetch stages + + // Vmem -> Vgpr 0 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // BDataType - auto b_block_tile = make_static_distributed_tensor( - Policy::template MakeBRegTileDistribution()); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr 0 (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - constexpr auto idx1_js = tile_distributed_index<0>{}; - constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); - - // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 on - // gfx950 - auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { - if constexpr(std::is_same_v) - { - return pk_fp4_to_fp16x2(pk_mxfp4, fscale); - } - else if constexpr(std::is_same_v) - { - return pk_fp4_to_bf16x2(pk_mxfp4, fscale); - } - else - { - static_assert(sizeof(pk_mxfp4) == 0, "unsupported compute type"); - } - }; - - auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { - if constexpr(IsCastBeforeLDS) - { - return b_block_tile_cast; - } - else - { - return b_block_tile_orig; - } - }; - - auto apply_scale_func = [&]() { - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - auto scale = bq_block_tile(i_j_idx_scale); - auto b_scale_uint = uint32_t(scale.data) << 23; - if constexpr(std::is_same_v) - { - if constexpr(idx1.impl_.at(0) % BPackedSize == 0) - { - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - auto b_pack = b_fp4_block_tile(i_j_idx); - - auto cvt = - pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); - b_block_tile(i_j_idx_lo) = cvt.x; - b_block_tile(i_j_idx_hi) = cvt.y; - } - } - else - { - auto b_pack = b_fp4_block_tile(i_j_idx); - b_block_tile(i_j_idx) = type_convert( - type_convert(b_pack) * bit_cast(b_scale_uint)); - } - }); - }); - }; - - if constexpr(IsCastBeforeLDS) - apply_scale_func(); - - // initialize C + // initialize C tile to zero tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); block_sync_lds(); - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - transpose_tile2d(b_shuffle_tmp, b_block_tile_); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); - } + // Vgpr -> LDS 0 + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr 1 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // If we scale and cast before writing to LDS, + // we need to read another tile of Q matrix from Vmem, then scale and cast tile if constexpr(IsCastBeforeLDS) { bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + } + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - apply_scale_func(); - - block_sync_lds(); + block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - } - else - { - block_sync_lds(); + // LDS -> Vgpr 0 + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); - } __builtin_amdgcn_sched_barrier(0); // main body @@ -560,58 +609,34 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { block_sync_lds(); - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - transpose_tile2d(b_shuffle_tmp, b_block_tile_); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); - } + // Vgpr -> LDS + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - if constexpr(IsCastBeforeLDS) - apply_scale_func(); - + // Consume tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(IsCastBeforeLDS) - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - else - block_gemm.LocalPrefetch( - a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + // LDS -> Vgpr + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); i += 1; - // b_block_stride +=1; } while(i < (num_loop - 1)); } - // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) { @@ -621,50 +646,31 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 } else { + // If we scale and cast after reading from LDS, + // we didn't read the second tile of Q matrix from Vmem during prefetch stages, + // so we need to read the last tile here. + // This is not a problem because we have all block_gemm instructions to hide the + // latency. if constexpr(!IsCastBeforeLDS) { bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); } + // Consume second to last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - transpose_tile2d(b_shuffle_tmp, b_block_tile_); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - auto b_block_tile_ = get_b_block_tile(b_fp4_block_tile, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_block_tile_, b_element_func); - } + // Vgpr -> LDS last tile + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); block_sync_lds(); - if constexpr(IsCastBeforeLDS) - { - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); - } - else - { - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); - } + // LDS -> Vgpr last tile + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + + // Consume last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); } @@ -690,13 +696,12 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 void* p_smem, index_t n = 0) const { - using BElementwise = std::conditional_t; - ck_tile::ignore = n; + ck_tile::ignore = n; return PipelineImpl{}.template operator()( a_dram_block_window_tmp, [](const ADataType& a) { return a; }, b_dram_block_window_tmp, - [](const BElementwise& b) { return b; }, + [](const BLDSType& b) { return b; }, bq_dram_block_window_tmp, num_loop, p_smem); From b4304b7388bcfb1c87ede5fbde3a40cbd6e49111 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 30 Jan 2026 11:27:27 +0000 Subject: [PATCH 12/15] Finalize cleanup --- .../run_gemm_quant_example.inc | 14 ++-- .../ck_tile/host/reference/reference_gemm.hpp | 14 ++-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 75 ++++++++++++------- .../gemm_mx_pipeline_ag_bg_cr_policy.hpp | 4 + .../test_gemm_quant_fixtures.hpp | 14 ++-- 5 files changed, 73 insertions(+), 48 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index fd3102fe5f6..086894d4f2d 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -1017,13 +1017,13 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { if constexpr(std::is_same_v) - ck_tile::reference_mxfp4gemm_quant( + ck_tile::reference_mx_gemm_bquant( a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant -CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, - const HostTensor& q, - const HostTensor& b_k_n, - HostTensor& c_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) +CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const std::size_t M = a_m_k.get_length(0); const std::size_t N = b_k_n.get_length(1); diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 51bc35efe95..277a2496142 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -216,6 +216,7 @@ struct BQuantBlockUniversalGemmAsBsCr using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); BTypeTile b_warp_tile_lds_; + // Load from LDS (assumption is that the scale will be applied in the block gemm) template = {}, bool_constant = {}) { - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_lds_, b_block_window); - - // Apply scale - using BDataTypeRaw = typename std:: - conditional, pk_fp4_t::type, BDataType>::type; - - constexpr auto warp_size = get_warp_size(); + // Load tile from LDS + + // Do not use load_int4_tile here because it will have support to cast from fp4 to + // compute type, while here we want to only load from LDS and then apply the scale + // and cast later + if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + + if constexpr(BLoadTranspose) + { + b_warp_tile_lds_ = load_tile_transpose(b_block_window); + } + else + { + load_tile(b_warp_tile_lds_, b_block_window); + } + + // Apply scale and cast + using BDataTypeRaw = + std::conditional_t, pk_fp4_t::type, BDataType>; + + constexpr index_t warp_size = get_warp_size(); constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size; constexpr index_t thread_buffer_size = nelements / UnaryOpSize_; const element_wise::DequantPack8 elementwise_op{}; @@ -262,6 +282,22 @@ struct BQuantBlockUniversalGemmAsBsCr static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / GemmTraits::QuantGroupSize::kN * + Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; // Thread buffers @@ -275,27 +311,12 @@ struct BQuantBlockUniversalGemmAsBsCr BWarpThreadBuffer b_warp_thread_buffer; BLDSThreadBuffer b_lds_thread_buffer; - // BQuant register offset - constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / - GemmTraits::QuantGroupSize::kN * Traits::KQPerBlock + - kQScale; - else - { - return nIter * Traits::KQPerBlock + kQScale; - } - }(); - // Load thread buffer from tile (LDS type) b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); - // Apply scale to thread buffer and cast - auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; - float b_scale_f = float(scale_reg); - + // Apply scale to B thread buffer and cast static_for<0, thread_buffer_size, 1>{}([&](auto i) { elementwise_op( b_warp_thread_buffer.template get_as()(i), @@ -303,7 +324,7 @@ struct BQuantBlockUniversalGemmAsBsCr b_scale_f); }); - // Store thread buffer to tile (MMA type) + // Store B thread buffer to tile (MMA type) b_warp_tile_.set_y_sliced_thread_data( merge_sequences(sequence{}, b_warp_y_index_zeros), merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp index 64fff27fa1b..d77a2d1da64 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -70,6 +70,10 @@ struct GemmMxPipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() { + // If we apply scale before writing to LDS, we need a tile distribution for + // BQuant consistent with global memory reading of matrix B, while + // if we apply scale after reading from LDS, we need a tile distribution for + // BQuant consistent with the MMA instructions layout if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead) { using BQLayout = remove_cvref_t; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 8e8ea1dc6eb..bede7f11a1d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -781,13 +781,13 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) - ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + ck_tile::reference_mx_gemm_bquant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant Date: Fri, 30 Jan 2026 13:12:37 +0000 Subject: [PATCH 13/15] Fix rebase issues --- .../gemm_bquant_quantgrouped_mx_bf16bf16.cpp | 8 ++++---- .../gemm_bquant_quantgrouped_mx_bf16bf8.cpp | 8 ++++---- example/ck_tile/38_block_scale_gemm/gemm_quant.cpp | 1 + .../block/block_universal_gemm_as_bs_bquant_cr.hpp | 6 +++--- .../gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp | 6 +++--- 5 files changed, 15 insertions(+), 14 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp index 77b3be86848..e1a64c86569 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp @@ -12,9 +12,8 @@ using GemmConfig = GemmConfigQuantPrefill; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_mx_bf16bf16_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp index 9ac00b62ddd..0eb2a0ce349 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp @@ -11,9 +11,8 @@ using GemmConfig = GemmConfigMixedPrecision; QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); -void bquant_quantgrouped_mx_bf16bf8_instance_factory( - std::unordered_map>& lut) -{ +static auto _ = []() { + auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; -} + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index d21c94ff891..dc4d1ad8147 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -123,4 +123,5 @@ int main(int argc, char* argv[]) "group_size not supported." << std::endl; return -1; + } } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 277a2496142..5545112a4f3 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -284,9 +284,9 @@ struct BQuantBlockUniversalGemmAsBsCr static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { // B scale register offset constexpr index_t reg_offset = [&]() { - if constexpr(GemmTraits::QuantGroupSize::kN >= (NWarp * WarpGemm::kN)) - return (nIter * NWarp * WarpGemm::kN) / GemmTraits::QuantGroupSize::kN * - Traits::KQPerBlock + + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + kQScale; else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index de92d45763b..589ec44ab6d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -33,9 +33,9 @@ struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 using BLDSType = std::conditional_t; - using BQDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using BlockGemmShape = remove_cvref_t; + using BQDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; using BQuantGroupSize = remove_cvref_t; static_assert(BQuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); From 95c082dd3506184f006456f4ad88c08b9adb0172 Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 30 Jan 2026 16:27:06 +0000 Subject: [PATCH 14/15] Fix usage of Problem static member BCastPolicy --- example/ck_tile/38_block_scale_gemm/README.md | 2 +- include/ck_tile/core/utility/type_traits.hpp | 11 ++++++ .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 7 ++-- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 36 +++++++++++++------ 4 files changed, 41 insertions(+), 15 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 8b92e10cb2f..accac6f0838 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or bf16fp4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index c11d180839b..143a970e01c 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -98,6 +98,17 @@ using is_known_at_compile_time = is_static; // , this helper will also return false, which is not good(?) // do we need something like is_constexpr()? +#define DEFINE_STATIC_MEMBER_CHECKER(trait_name, member) \ + template \ + struct trait_name : std::false_type \ + { \ + }; \ + \ + template \ + struct trait_name> : std::true_type \ + { \ + }; + // FIXME: do we need this anymore? template < typename PY, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 8b1456631e7..250fe3af385 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,8 +20,9 @@ struct GemmPipelineAgBgCrImplBase using ADataType = remove_cvref_t{}, AsDataType>>; using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std:: - conditional_t; + + static constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; using BLayout = remove_cvref_t{}, BsLayout>>; @@ -314,7 +315,7 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = std::conditional_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 4177e556717..db00f87fd57 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -11,6 +11,16 @@ namespace ck_tile { +DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy); + +template +static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; +}(); + template struct has_a_tile_access_pattern : std::false_type { @@ -305,10 +315,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = std::conditional_t; + using BLayout = remove_cvref_t; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -592,9 +603,10 @@ struct UniversalGemmBasePolicy constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; using BLayout = remove_cvref_t{}, BsLayout>>; - using BDataType = std::conditional_t; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; if constexpr(Problem::FixedVectorSize) { @@ -739,8 +751,9 @@ struct UniversalGemmBasePolicy constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; // If we cast before writing to LDS, the vectorsize is defined by the A type // since the assumption is that A type is going to be the B LDS type + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; constexpr index_t VecLoadSize = - Problem::BCastPolicy == CastPolicy::BeforeLDSWrite + IsBCastPolicyBeforeLDSWrite ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; @@ -851,9 +864,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - using BDataType = std::conditional_t; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); From b2c836c726b5b5e01903fdc78317967e379730ab Mon Sep 17 00:00:00 2001 From: Enrico Degregori Date: Fri, 30 Jan 2026 18:08:51 +0000 Subject: [PATCH 15/15] Move static constexpr to structs --- include/ck_tile/core/utility/type_traits.hpp | 2 ++ .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 8 +++++++- ...gemm_universal_pipeline_ag_bg_cr_policy.hpp | 18 ++++++++---------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index 143a970e01c..9ef3a675a71 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -109,6 +109,8 @@ using is_known_at_compile_time = is_static; { \ }; +DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy); + // FIXME: do we need this anymore? template < typename PY, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 250fe3af385..9f6edbad261 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -21,7 +21,13 @@ struct GemmPipelineAgBgCrImplBase using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - static constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; + }(); + using BDataType = std::conditional_t; using BLayout = remove_cvref_t{}, BsLayout>>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index db00f87fd57..a6c02232a8e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -11,16 +11,6 @@ namespace ck_tile { -DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy); - -template -static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { - if constexpr(has_bcastpolicy::value) - return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; - else - return false; -}(); - template struct has_a_tile_access_pattern : std::false_type { @@ -90,6 +80,14 @@ struct UniversalGemmBasePolicy static constexpr bool is_b_load_tr = false; #endif + template + static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; + }(); + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{};