diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 13cbcc8b558..305cbf907b0 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -18,7 +18,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp + gemm_bquant_quantgrouped_mx_bf16fp4.cpp + gemm_bquant_quantgrouped_mx_bf16bf8.cpp + gemm_bquant_quantgrouped_mx_bf16bf16.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index eb36ae58008..accac6f0838 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -53,7 +53,7 @@ args: -stride_b Tensor B stride (default:0) -stride_c Tensor C stride (default:0) -v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1) - -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8) + -prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8) -warmup Number of iterations before benchmarking the kernel (default:50) -repeat Number of iterations to benchmark the kernel (default:1000) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp new file mode 100644 index 00000000000..e1a64c86569 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf16.cpp @@ -0,0 +1,35 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +template +using GemmConfig = GemmConfigQuantPrefill; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ + ck_tile::QuantType::BQuantGrouped>(arg_parser); + +static auto _ = []() { + auto& lut = get_kernel_lut(); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp new file mode 100644 index 00000000000..0eb2a0ce349 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16bf8.cpp @@ -0,0 +1,34 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "run_gemm_quant_example.inc" + +using GemmConfig = GemmConfigMixedPrecision; + +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type(arg_parser); + +static auto _ = []() { + auto& lut = get_kernel_lut(); + using TypeConfig = decltype(GemmQuantTypeConfig{}); + + lut[hash_multiple_strings( + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + lut[hash_multiple_strings( + {"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + [](const ck_tile::ArgParser& arg_parser) { + using QuantGroupSize = ck_tile::QuantGroupShape>; + return RUN_GEMM_EXAMPLE_PREC_TYPE; + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp similarity index 67% rename from example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp rename to example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp index b8eb670135a..1f48609a1f0 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf16mxfp4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_mx_bf16fp4.cpp @@ -6,33 +6,33 @@ template using GemmConfig = GemmConfigQuantPrefill; -#define RUN_GEMM_EXAMPLE_PREC_TYPE \ - run_gemm_example_prec_type, \ - TypeConfig, \ - QuantGroupSize, \ +#define RUN_GEMM_EXAMPLE_PREC_TYPE \ + run_gemm_example_prec_type, \ + TypeConfig, \ + QuantGroupSize, \ ck_tile::QuantType::BQuantGrouped>(arg_parser); static auto _ = []() { auto& lut = get_kernel_lut(); using TypeConfig = decltype(GemmQuantTypeConfig{}); + ck_tile::e8m0_t>{}); lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; }; lut[hash_multiple_strings( - {"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = + {"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { using QuantGroupSize = ck_tile::QuantGroupShape>; return RUN_GEMM_EXAMPLE_PREC_TYPE; diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp index cc4302a992c..dc4d1ad8147 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant.cpp @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[]) .insert("prec", "fp8", "Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, " - "or bf8i4; for ABQuant: fp8, bf8, fp4") + " mxbf16bf16, mxbf16bf8, mxbf16fp4 or bf8i4; for ABQuant: fp8, bf8, fp4") .insert("warmup", "50", "Number of iterations before benchmarking the kernel") .insert("repeat", "1000", "Number of iterations to benchmark the kernel") .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 085d6344415..deee211d812 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -45,7 +45,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K, const float max_accumulated_value) { using ComputeType = std::conditional_t< - std::is_same_v, + std::is_same_v, ADataType, std::conditional_t>; // Calculate thresholds @@ -271,6 +271,24 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill static constexpr bool TransposeC = true; }; +// Used for A=16bit and B=8bit. The warp tile has KPack=16 +// Matrix A: Vectorsize = 8, KPack=16 -> LDS read/write vectorsize = 8 (128bit) +// Matrix B: Vectorsize = 16, KPack=16 -> LDS read/write vectorsize = 16 (128bit) +struct GemmConfigMixedPrecision : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 64; +}; + template struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill { diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 540d5725dd9..086894d4f2d 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -93,6 +93,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr auto tail_number_v = tail_number_.value; + constexpr auto b_cast_policy = + std::is_same_v + ? ck_tile::CastPolicy::BeforeLDSWrite + : ck_tile::CastPolicy::AfterLDSRead; // row-col and tensor quants use the regular pipeline, A/B/AB quants use their own using PipelineProblem = std::conditional_t< @@ -135,7 +139,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str ComputeDataType, GemmConfig::Scheduler, has_hot_loop_v, - tail_number_v>, + tail_number_v, + b_cast_policy>, ck_tile::GemmABQuantPipelineProblem, - std::conditional_t< - std::is_same_v, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, - ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; + std::conditional_t, + ck_tile::MxGemmPipelineAgBgCrCompV3, + ck_tile::BQuantGemmPipelineAgBgCrCompV3>>; using ABQuantPipeline = std::conditional_t a_m(ck_tile::host_tensor_descriptor( args.M, args.K, args.stride_A, is_row_major(ALayout{}))); ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? args.K / 2 - : args.K, - args.N, - args.stride_B, - is_row_major(BLayout{}))); + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); auto size_a_buffer = a_m.get_element_space_size_in_bytes(); auto size_b_buffer = b_n.get_element_space_size_in_bytes(); @@ -479,11 +479,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, int rotating_count = arg_parser.get_int("rotating_count"); stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_B = ck_tile::get_default_stride( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); // Conditional stride calculation based on QuantMode @@ -515,11 +511,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - (std::is_same_v) ? (K / 2) : K, - N, - stride_B, - is_row_major(b_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); @@ -696,18 +689,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); } - else if constexpr(std::is_same_v) + else if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); } else { ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{ + range_min, range_max, fill_seed(gen)}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(*bq_tensor_ptr, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); } @@ -1010,18 +1016,18 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { - if constexpr(std::is_same_v) - ck_tile::reference_mxfp4gemm_quant( + if constexpr(std::is_same_v) + ck_tile::reference_mx_gemm_bquant( a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant) && + std::is_same_v) && GemmConfig::PreshuffleB) { throw std::runtime_error( diff --git a/include/ck_tile/core/tensor/static_distributed_tensor.hpp b/include/ck_tile/core/tensor/static_distributed_tensor.hpp index 10c7587bcb4..6a2d560bbcd 100644 --- a/include/ck_tile/core/tensor/static_distributed_tensor.hpp +++ b/include/ck_tile/core/tensor/static_distributed_tensor.hpp @@ -75,7 +75,7 @@ struct static_distributed_tensor constexpr auto sliced_thread_tensor_desc = make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...)); - thread_buffer + thread_buffer sliced_thread_data; static_ford>{}([&](auto idx) { diff --git a/include/ck_tile/core/utility/type_traits.hpp b/include/ck_tile/core/utility/type_traits.hpp index c11d180839b..9ef3a675a71 100644 --- a/include/ck_tile/core/utility/type_traits.hpp +++ b/include/ck_tile/core/utility/type_traits.hpp @@ -98,6 +98,19 @@ using is_known_at_compile_time = is_static; // , this helper will also return false, which is not good(?) // do we need something like is_constexpr()? +#define DEFINE_STATIC_MEMBER_CHECKER(trait_name, member) \ + template \ + struct trait_name : std::false_type \ + { \ + }; \ + \ + template \ + struct trait_name> : std::true_type \ + { \ + }; + +DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy); + // FIXME: do we need this anymore? template < typename PY, diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index 7830150b630..24cce418032 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -392,45 +392,50 @@ template -CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor& a_m_k, - const HostTensor& q, - const HostTensor& b_k_n, - HostTensor& c_m_n, - const AElementOp& a_element_op = {}, - const BElementOp& b_element_op = {}, - const ACCElementOp& acc_element_op = {}) +CK_TILE_HOST void reference_mx_gemm_bquant(const HostTensor& a_m_k, + const HostTensor& q, + const HostTensor& b_k_n, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) { const std::size_t M = a_m_k.get_length(0); const std::size_t N = b_k_n.get_length(1); const std::size_t K = a_m_k.get_length(1); auto f_mn = [&](auto m, auto n) { - AccDataType v_acc = 0; - AccDataType pasual = 0; - for(std::size_t k = 0; k < (K / 2); k++) - { - using ComputeType = float; - auto b_scale = type_convert(q((2 * k) / QuantGroupSize::kK, n)) - 127; - ComputeType v_a_0, v_a_1; - ComputeType v_b_0, v_b_1; - - v_a_0 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k)))); - v_a_1 = ck_tile::type_convert((a_element_op(a_m_k(m, 2 * k + 1)))); + AccDataType v_acc = 0; + using ComputeType = float; + ComputeType v_a; + ComputeType v_b; - if constexpr(std::is_same_v) + for(std::size_t k = 0; k < K; k++) + { + auto b_scale = type_convert(q((k) / QuantGroupSize::kK, n)) - 127; + auto b_scale_fp32 = type_convert(std::pow(2.0f, b_scale)); + v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + if constexpr(std::is_same_v) { - auto b_pack = type_convert(b_element_op(b_k_n(k, n))); - auto b_scale_fp4 = type_convert(std::pow(2.0f, b_scale)); - - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + auto b_pack = type_convert(b_element_op(b_k_n(k, n))); - v_b_0 = type_convert(b_f4_lo) * b_scale_fp4; - v_b_1 = type_convert(b_f4_hi) * b_scale_fp4; + if(k % 2 == 0) + { + auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); + v_b = type_convert(b_f4_lo); + } + else + { + auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); + v_b = type_convert(b_f4_hi); + } } - - pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1; - v_acc += pasual; + else + { + v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + } + v_b *= b_scale_fp32; + v_acc += v_a * v_b; } c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); }; diff --git a/include/ck_tile/ops/common/utils.hpp b/include/ck_tile/ops/common/utils.hpp index 425083a9de3..a6f7513538d 100644 --- a/include/ck_tile/ops/common/utils.hpp +++ b/include/ck_tile/ops/common/utils.hpp @@ -23,6 +23,7 @@ template <> struct DataTypeTraits { static constexpr const char * name = template <> struct DataTypeTraits { static constexpr const char * name = "pk_int4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4"; }; template <> struct DataTypeTraits { static constexpr const char * name = "pk_fp4_raw"; }; +template <> struct DataTypeTraits { static constexpr const char * name = "e8m0"; }; template struct memOpToStr; template <> struct memOpToStr { static constexpr const char * name = "set"; }; diff --git a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp index 3f58eceb333..06dfbaa5d7f 100644 --- a/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp +++ b/include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp @@ -359,6 +359,84 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q) } #endif +CK_TILE_HOST_DEVICE bf16x8_t bf8x8_to_bf16x8_scale(const bf8x8_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t i16val; + bf8_t i8val[4]; + } input; + + union + { + bf16x2_t bhalf_vec; + bf16_t bhalf_arr[2]; + } output; + + input.i8val[0] = src[0]; + input.i8val[1] = src[1]; + input.i8val[2] = src[2]; + input.i8val[3] = src[3]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 0); + y[0] = output.bhalf_arr[0]; + y[1] = output.bhalf_arr[1]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 1); + y[2] = output.bhalf_arr[0]; + y[3] = output.bhalf_arr[1]; + + input.i8val[0] = src[4]; + input.i8val[1] = src[5]; + input.i8val[2] = src[6]; + input.i8val[3] = src[7]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 0); + y[4] = output.bhalf_arr[0]; + y[5] = output.bhalf_arr[1]; + output.bhalf_vec = __builtin_amdgcn_cvt_scalef32_pk_bf16_bf8(input.i16val, scale, 1); + y[6] = output.bhalf_arr[0]; + y[7] = output.bhalf_arr[1]; +#else + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(src[i.value]) * scale); + }); +#endif + return y; +} + +CK_TILE_HOST_DEVICE bf16x8_t fp4x4_to_bf16x8_scale(const pk_fp4x4_t& src, const float& scale) +{ + bf16x8_t y; +#if defined(__gfx950__) + union + { + uint32_t u32; + pk_fp4x4_t pf4; + } cvt; + cvt.pf4 = src; + bf16x2_t y0 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 0); + bf16x2_t y1 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 1); + bf16x2_t y2 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 2); + bf16x2_t y3 = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(cvt.u32, scale, 3); + + y[0] = y0[0]; + y[1] = y0[1]; + y[2] = y1[0]; + y[3] = y1[1]; + y[4] = y2[0]; + y[5] = y2[1]; + y[6] = y3[0]; + y[7] = y3[1]; +#else + static_for<0, 4, 1>{}([&](auto i) { + auto yi = pk_fp4_to_bf16x2(src[i.value], scale); + y[2 * i.value] = yi[0]; + y[2 * i.value + 1] = yi[1]; + }); +#endif + return y; +} + struct PassThroughPack8 { static constexpr const char* name = "PassThroughPack8"; @@ -437,6 +515,26 @@ struct DequantPack8 y.hi = i4_to_half4_scale(bit_cast(x) >> 8, z); } + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const pk_fp4x4_t& x, const float& z) const + { + y = fp4x4_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf8x8_t& x, const float& z) const + { + y = bf8x8_to_bf16x8_scale(x, z); + } + + CK_TILE_HOST_DEVICE constexpr void + operator()(bf16x8_t& y, const bf16x8_t& x, const float& z) const + { + static_for<0, 8, 1>{}([&](auto i) { + y[i.value] = type_convert(type_convert(x[i.value]) * z); + }); + } + constexpr const static bool is_pack8_invocable = true; }; diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4f636b59625..2a02d9fbdb7 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -97,11 +97,11 @@ struct CShuffleEpilogue BDataType, ADataType>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = std::conditional_t || - std::is_same_v || - std::is_same_v, - ADataType, - BDataType>; + using BTypeToUse = std::conditional_t< + std::is_same_v || std::is_same_v || + std::is_same_v || sizeof(BDataType) < sizeof(ADataType), + ADataType, + BDataType>; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd513..7f34ae24bbd 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -97,7 +97,8 @@ struct BlockUniversalGemmAsBsCr using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c9410..9f6edbad261 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -20,8 +20,16 @@ struct GemmPipelineAgBgCrImplBase using ADataType = remove_cvref_t{}, AsDataType>>; using ALayout = remove_cvref_t{}, AsLayout>>; using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = - std::conditional_t, ADataType, BInDataType>; + + static constexpr bool IsBCastPolicyBeforeLDSWrite = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; + }(); + + using BDataType = std::conditional_t; + using BLayout = remove_cvref_t{}, BsLayout>>; static constexpr index_t MPerBlock = BlockGemmShape::kM; @@ -313,10 +321,9 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLdsDataType = std::conditional_t; auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp index 957cf7ab8f3..69455abc9c2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp @@ -10,6 +10,12 @@ namespace ck_tile { +enum struct CastPolicy +{ + BeforeLDSWrite, + AfterLDSRead, +}; + enum struct GemmPipelineScheduler { Default, diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8074994fdd3..a6c02232a8e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -80,6 +80,14 @@ struct UniversalGemmBasePolicy static constexpr bool is_b_load_tr = false; #endif + template + static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] { + if constexpr(has_bcastpolicy::value) + return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + else + return false; + }(); + static constexpr auto I0 = number<0>{}; static constexpr auto I1 = number<1>{}; static constexpr auto I2 = number<2>{}; @@ -305,11 +313,11 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + using BLayout = remove_cvref_t; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -589,15 +597,14 @@ struct UniversalGemmBasePolicy CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB() { using BsLayout = remove_cvref_t; - using BsDataType = remove_cvref_t; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; using BLayout = remove_cvref_t{}, BsLayout>>; - using BInDataType = remove_cvref_t{}, BsDataType>>; - using BDataType = std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; if constexpr(Problem::FixedVectorSize) { @@ -739,13 +746,13 @@ struct UniversalGemmBasePolicy { constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - using BDataType = remove_cvref_t; - constexpr index_t KPerBlock = std::is_same_v - ? Problem::BlockGemmShape::kK / 2 - : Problem::BlockGemmShape::kK; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // If we cast before writing to LDS, the vectorsize is defined by the A type + // since the assumption is that A type is going to be the B LDS type + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; constexpr index_t VecLoadSize = - std::is_same_v - ? 4 + IsBCastPolicyBeforeLDSWrite + ? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA()) : (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB()); constexpr index_t NumWaveGroups = Problem::NumWaveGroups; using BLayout = remove_cvref_t< @@ -855,10 +862,10 @@ struct UniversalGemmBasePolicy template CK_TILE_DEVICE static constexpr index_t GetSmemSizeB() { - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; + constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v; + using BDataType = std::conditional_t; constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor(); constexpr index_t smem_size_b = integer_least_multiple( b_lds_block_desc.get_element_space_size() * sizeof(BDataType), 16); @@ -900,7 +907,8 @@ struct UniversalGemmPipelineAgBgCrPolicy using ATypeToUse = std::conditional_t, BDataType, ADataType>; using BTypeToUse = std::conditional_t || - std::is_same_v, + std::is_same_v || + sizeof(BDataType) < sizeof(ADataType), ADataType, BDataType>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 00512424752..d00fa58731a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -189,12 +189,24 @@ template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl< WarpGemmAttributeMfma, AttrNumAccess>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 2, + AttrNumAccess>>; #else template using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2, AttrNumAccess>>; + +template +using WarpGemmMfmaBf16Bf16F32M16N16K64 = WarpGemmImpl, + 4, + AttrNumAccess>>; #endif using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl struct Dispatcher { using template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K64<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct Dispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 696de378aaf..a7edb66f15a 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -21,9 +21,9 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 9d711c48623..5545112a4f3 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -101,15 +102,18 @@ struct BQuantBlockUniversalGemmAsBsCr // 2. bf8, bf8, fp32 -> f32 // 3. i4, fp8, (fp8/fp32) -> f32 // 4. i4, bf8, (fp8/fp32) -> f32 - static_assert((std::is_same_v || std::is_same_v) && - (std::is_same_v || std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v || - std::is_same_v) && - (std::is_same_v || - std::is_same_v) && - std::is_same_v); + static_assert( + (std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v || + std::is_same_v) && + (std::is_same_v || std::is_same_v || + std::is_same_v) && + std::is_same_v); static constexpr index_t InterWaveSchedulingMacClusters = 1; @@ -176,57 +180,17 @@ struct BQuantBlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + // Use gemm universal block distribution encoding instead of duplicating it + using BlockGemmBase = BlockUniversalGemmAsBsCr; + CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); - - return a_block_dstr_encode; + return BlockGemmBase::MakeABlockDistributionEncode(); } CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr index_t KPerThread = Traits::KPerThread; - constexpr index_t NumMacClusters = Traits::InterWaveSchedulingMacClusters; - constexpr index_t KPerInnerLoop = - ck_tile::max(KPerThread / NumMacClusters, WarpGemm::kKPerThread); - constexpr index_t KIterInterwave = KPerInnerLoop / WarpGemm::kKPerThread; - - using KIterSeq = std::conditional_t, - sequence>; - - constexpr auto b_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, KIterSeq>, - tuple>, - tuple>, - sequence<1, 2>, - sequence<0, 0>>{}; - - constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); - - return b_block_dstr_encode; + return BlockGemmBase::MakeBBlockDistributionEncode(); } private: @@ -235,20 +199,24 @@ struct BQuantBlockUniversalGemmAsBsCr { }; + using BlockGemmImplBase = typename BlockUniversalGemmAsBsCr:: + template BlockGemmImpl; + template - struct BlockGemmImpl + struct BlockGemmImpl : public BlockGemmImplBase { - static constexpr auto ALdsTileDistr = - decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; - static constexpr auto BLdsTileDistr = - decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; - - using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); - using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); - - ALdsTile a_warp_tile_; - BLdsTile b_warp_tile_; - + using BlockGemmImplBase::a_warp_tile_; + using BlockGemmImplBase::b_warp_tile_; + using BlockGemmImplBase::BLdsTileDistr; + // If we apply scale while reading from LDS, then we can use the operator() from + // BlockUniversalGemmAsBsCr + using BlockGemmImplBase::operator(); + + // static distributed tensor with LDS type + using BTypeTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); + BTypeTile b_warp_tile_lds_; + + // Load from LDS (assumption is that the scale will be applied in the block gemm) template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + const BQRegBlockTile& bq_block_tensor, + bool_constant = {}, + bool_constant = {}) + { + // Load tile from LDS + + // Do not use load_int4_tile here because it will have support to cast from fp4 to + // compute type, while here we want to only load from LDS and then apply the scale + // and cast later + if constexpr(ALoadTranspose) + { + a_warp_tile_ = load_tile_transpose(a_block_window); + } + else + { + load_tile(a_warp_tile_, a_block_window); + } + + if constexpr(BLoadTranspose) + { + b_warp_tile_lds_ = load_tile_transpose(b_block_window); + } + else + { + load_tile(b_warp_tile_lds_, b_block_window); + } + + // Apply scale and cast + using BDataTypeRaw = + std::conditional_t, pk_fp4_t::type, BDataType>; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t nelements = WarpGemm::kK * WarpGemm::kN / warp_size; + constexpr index_t thread_buffer_size = nelements / UnaryOpSize_; + const element_wise::DequantPack8 elementwise_op{}; + using SrcVectorRawType = + BDataTypeRaw __attribute__((ext_vector_type(UnaryOpSize_ / BPackedSize))); + using DstVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize_))); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, Traits::QScalesPerBlockRow, 1>{}([&](auto kQScale) { + // B scale register offset + constexpr index_t reg_offset = [&]() { + if constexpr(GemmTraits::BQuantGroupSize::kN >= (NWarp * WarpGemm::kN)) + return (nIter * NWarp * WarpGemm::kN) / + GemmTraits::BQuantGroupSize::kN * Traits::KQPerBlock + + kQScale; + else + { + return nIter * Traits::KQPerBlock + kQScale; + } + }(); + + // Get B scale from thread buffer + auto& scale_reg = bq_block_tensor.get_thread_buffer()[reg_offset]; + float b_scale_f = float(scale_reg); + + static_for<0, Traits::KIterPerQScale, 1>{}([&](auto kIterInQScale) { + constexpr auto kIter = kQScale * Traits::KIterPerQScale + kIterInQScale; + // Thread buffers + using BWarpThreadBuffer = decltype(b_warp_tile_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + using BLDSThreadBuffer = decltype(b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths))); + + BWarpThreadBuffer b_warp_thread_buffer; + BLDSThreadBuffer b_lds_thread_buffer; + + // Load thread buffer from tile (LDS type) + b_lds_thread_buffer = b_warp_tile_lds_.get_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths)); + + // Apply scale to B thread buffer and cast + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op( + b_warp_thread_buffer.template get_as()(i), + b_lds_thread_buffer.template get_as()[i], + b_scale_f); + }); + + // Store B thread buffer to tile (MMA type) + b_warp_tile_.set_y_sliced_thread_data( + merge_sequences(sequence{}, b_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, b_warp_y_lengths), + b_warp_thread_buffer); + }); + }); + }); + } + // C += A * B template + CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window, + BQRegBlockTile bq_block_tile, + bool_constant a_load_tr = {}, + bool_constant b_load_tr = {}) + { + block_gemm_impl_.LocalPrefetch( + a_block_window, b_block_window, bq_block_tile, a_load_tr, b_load_tr); + } + // C += A * B + // Apply scale after MMA template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + const ASmemBlockWindow& a_block_window, + const BSmemBlockWindow& b_block_window) + { + block_gemm_impl_(c_block_tensor, a_block_window, b_block_window); + } + private: BlockGemmImpl block_gemm_impl_{}; }; diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 21bd691b497..77ccf600577 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -717,20 +717,12 @@ struct QuantGemmKernel } else { - if constexpr(std::is_same_v) - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size / 2), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); - else - return make_naive_tensor_view( - b_ptr, - make_tuple(kargs.N, k_size), - make_tuple(kargs.stride_B, 1), - number{}, - number<1>{}); + return make_naive_tensor_view( + b_ptr, + make_tuple(kargs.N, k_size), + make_tuple(kargs.stride_B, 1), + number{}, + number<1>{}); } } } @@ -744,16 +736,10 @@ struct QuantGemmKernel } else if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); - else - return pad_tensor_view(b_tensor_view, - make_tuple(number{}, - number{}), - sequence{}); + return pad_tensor_view(b_tensor_view, + make_tuple(number{}, + number{}), + sequence{}); } else { @@ -778,17 +764,10 @@ struct QuantGemmKernel { if constexpr(std::is_same_v) { - if constexpr(std::is_same_v) - return make_tile_window( - b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); - else - return make_tile_window(b_pad_view, - make_tuple(number{}, - number{}), - {i_n, 0}); + return make_tile_window(b_pad_view, + make_tuple(number{}, + number{}), + {i_n, 0}); } else { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp similarity index 96% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp index facec252a35..d8b055bd2ea 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp @@ -10,7 +10,7 @@ namespace ck_tile { template -struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase +struct GemmMxPipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase { using Base = GemmPipelineAgBgCrImplBase; using ADataType = typename Base::ADataType; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp similarity index 53% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp index 6cf9e22f414..d77a2d1da64 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_policy.hpp @@ -9,7 +9,7 @@ namespace ck_tile { -struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy +struct GemmMxPipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy { using Base = UniversalGemmPipelineAgBgCrPolicy; using Base::I0; @@ -70,37 +70,74 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution() { - // using BLayout = remove_cvref_t; - - constexpr index_t BlockSize = Problem::kBlockSize; - constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; - constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + // If we apply scale before writing to LDS, we need a tile distribution for + // BQuant consistent with global memory reading of matrix B, while + // if we apply scale after reading from LDS, we need a tile distribution for + // BQuant consistent with the MMA instructions layout + if constexpr(Problem::BCastPolicy == CastPolicy::AfterLDSRead) + { + using BQLayout = remove_cvref_t; + using BlockGemmShape = typename Problem::BlockGemmShape; + + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t NPerBlockBQ = NPerBlock / Problem::BQuantGroupSize::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t KPerBlockBQ = KPerBlock / Problem::BQuantGroupSize::kK; + + using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using WarpGemm = WarpGemmDispatcher; - constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 - constexpr index_t VecLoadSize = - Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); - constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + using TileEncodingPattern = + tile_distribution_encoding_pattern_bq; - constexpr index_t warp_size = get_warp_size(); - constexpr index_t num_warps = BlockSize / get_warp_size(); - constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); - constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; - constexpr index_t K0 = KPerBlock / b_vec; - constexpr index_t K1 = K0 / KScale; - constexpr index_t K3 = K0 / K1; - constexpr index_t K2 = 1; - - constexpr index_t N0 = num_warps / NumWaveGroups; - constexpr index_t N1 = warp_size / K0; - constexpr index_t N2 = NPerBlock / (N0 * N1); - - return make_static_tile_distribution( - tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2, 0>>, - tuple, sequence<1, 0, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + return TileEncodingPattern::make_2d_static_tile_distribution(); + } + else + { + constexpr index_t BlockSize = Problem::kBlockSize; + constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t KScale = KPerBlock / Problem::BQuantGroupSize::kK; // k_scale num //2 + constexpr index_t VecLoadSize = + Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB(); + constexpr index_t NumWaveGroups = Problem::NumWaveGroups; + + constexpr index_t warp_size = get_warp_size(); + constexpr index_t num_warps = BlockSize / get_warp_size(); + constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size); + constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize; + constexpr index_t K0 = KPerBlock / b_vec; + constexpr index_t K1 = K0 / KScale; + constexpr index_t K3 = K0 / K1; + constexpr index_t K2 = 1; + + constexpr index_t N0 = num_warps / NumWaveGroups; + constexpr index_t N1 = warp_size / K0; + constexpr index_t N2 = NPerBlock / (N0 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2, 0>>, + tuple, sequence<1, 0, 0>>, + sequence<1, 2>, + sequence<2, 1>>{}); + } } template @@ -126,14 +163,14 @@ struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy< typename Problem::ADataType, - std::conditional_t, + std::conditional_t, typename Problem::ADataType, typename Problem::BDataType>, typename Problem::CDataType, BlockWarps, WarpGemm>; - return BlockUniversalGemmAsBsCr{}; + return BQuantBlockUniversalGemmAsBsCr{}; } }; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp similarity index 68% rename from include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp index 7c448599edf..589ec44ab6d 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_v3.hpp @@ -9,7 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_mx_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { @@ -18,15 +18,21 @@ namespace ck_tile { // B Tile Window: global memory // C Distributed tensor: register -template -struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 +template +struct MxGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 { using Base = BaseGemmPipelineAgBgCrCompV3; - using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase; + using PipelineImplBase = GemmMxPipelineAgBgCrImplBase; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + + using BDqDataType = remove_cvref_t; + + static constexpr bool IsCastBeforeLDS = Problem::BCastPolicy == CastPolicy::BeforeLDSWrite; + + using BLDSType = std::conditional_t; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using BDqDataType = remove_cvref_t; using BQDataType = remove_cvref_t; using CDataType = remove_cvref_t; using BlockGemmShape = remove_cvref_t; @@ -40,12 +46,16 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3>::PackedSize; + static constexpr index_t BPackedSize = - ck_tile::numeric_traits>::PackedSize; + ck_tile::numeric_traits>::PackedSize; static constexpr index_t BQPackedSize = ck_tile::numeric_traits>::PackedSize; + static constexpr index_t BLDSPackedSize = + ck_tile::numeric_traits>::PackedSize; + using ALayout = remove_cvref_t; using BQLayout = remove_cvref_t; using BLayout = remove_cvref_t; @@ -165,6 +175,11 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3; + static constexpr bool is_b_row_major = + std::is_same_v; + CK_TILE_DEVICE static constexpr auto HotLoopScheduler() { constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM; @@ -207,7 +222,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + CK_TILE_DEVICE static void + ScaleTile(TileType& block_tile, CastTileType& block_tile_cast, ScaleTileType& scale_tile) + { + if constexpr(IsCastBeforeLDS) + { + constexpr auto b_block = TileType::get_distributed_spans(); + constexpr auto idx1_js = tile_distributed_index<0>{}; + + // Internally this is using V_CVT_SCALEF32_PK_BF16_FP4 or V_CVT_SCALEF32_PK_FP16_FP4 + // on gfx950 + auto pk_mxfp4_to_compute_v2 = [](auto pk_mxfp4, float fscale) { + if constexpr(std::is_same_v) + { + return pk_fp4_to_fp16x2(pk_mxfp4, fscale); + } + else if constexpr(std::is_same_v) + { + return pk_fp4_to_bf16x2(pk_mxfp4, fscale); + } + else + { + static_assert(false, "unsupported compute type"); + } + }; + + sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { + sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); + auto scale = scale_tile(i_j_idx_scale); + auto b_scale_uint = uint32_t(scale.data) << 23; + if constexpr(std::is_same_v) + { + if constexpr(idx1.impl_.at(0) % BPackedSize == 0) + { + constexpr auto idx1_lo = tile_distributed_index{}; + constexpr auto idx1_hi = + tile_distributed_index{}; + constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); + constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); + auto b_pack = block_tile(i_j_idx); + auto cvt = + pk_mxfp4_to_compute_v2(b_pack, bit_cast(b_scale_uint)); + block_tile_cast(i_j_idx_lo) = cvt.x; + block_tile_cast(i_j_idx_hi) = cvt.y; + } + } + else + { + auto b_pack = block_tile(i_j_idx); + block_tile_cast(i_j_idx) = type_convert( + type_convert(b_pack) * bit_cast(b_scale_uint)); + } + }); + }); + } + } + + template + CK_TILE_DEVICE void ALocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const ElementwiseFunc& element_func) const + { + if constexpr(is_a_col_major) + { + auto a_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledARegTileDistribution()); + transpose_tile2d(a_shuffle_tmp, block_tile); + Base::LocalPrefill(lds_window, a_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill(lds_window, block_tile, element_func); + } + } + + template + CK_TILE_DEVICE void BLocalPrefill(WindowType& lds_window, + const TileType& block_tile, + const TileTypeCast& block_tile_cast, + const ElementwiseFunc& element_func) const + { + // Fill LDS and apply the scale if IsCastBeforeLDS + auto get_b_block_tile = [](auto& b_block_tile_orig, auto& b_block_tile_cast) { + if constexpr(IsCastBeforeLDS) + { + return b_block_tile_cast; + } + else + { + return b_block_tile_orig; + } + }; + + if constexpr(is_b_row_major) + { + auto b_shuffle_tmp = make_static_distributed_tensor( + Policy::template MakeShuffledBRegTileDistribution()); + transpose_tile2d(b_shuffle_tmp, get_b_block_tile(block_tile, block_tile_cast)); + Base::LocalPrefill(lds_window, b_shuffle_tmp, element_func); + } + else + { + Base::LocalPrefill( + lds_window, get_b_block_tile(block_tile, block_tile_cast), element_func); + } + } + + template + CK_TILE_DEVICE void LocalPrefetch(BlockGemmType& block_gemm, + const AWindowType& a_lds_window, + const BWindowType& b_lds_window, + const QTileType& q_block_tile) const + { + // Load from LDS + // It can apply the scale and cast if we scale after reading from LDS + if constexpr(IsCastBeforeLDS) + { + block_gemm.LocalPrefetch(a_lds_window, b_lds_window); + } + else + { + block_gemm.LocalPrefetch(a_lds_window, b_lds_window, q_block_tile); + } + } + template > && std::is_same_v; constexpr bool is_bq_col_major = std::is_same_v; - constexpr bool is_b_row_major = std::is_same_v; static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)"); static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] && @@ -347,13 +494,12 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3()); auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp); auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){}; + // This defines the scaled and casted block tile for B matrix. + // Effectively, it is used only if we scale and cast before writing to LDS. + auto bdq_block_tile = make_static_distributed_tensor( + Policy::template MakeBRegTileDistribution()); + // Block GEMM auto block_gemm = BlockGemm(); auto c_block_tile = block_gemm.MakeCBlockTile(); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); - // using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using ABlockTile = @@ -402,7 +547,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(BBlockTileDistr{})); ABlockTile a_block_tile; - BBlockTile b_fp4_block_tile; + BBlockTile b_block_tile; using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; @@ -410,106 +555,49 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 Vgpr 0 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); - // BDataType - auto b_block_tile = make_static_distributed_tensor( - Policy::template MakeBRegTileDistribution()); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr 0 (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - constexpr auto idx1_js = tile_distributed_index<0>{}; - constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans(); - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); - - // initialize C + // initialize C tile to zero tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); block_sync_lds(); - // LDS write 0 - if constexpr(is_a_col_major) - { - auto a_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS 0 + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr 1 Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); - - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); + // If we scale and cast before writing to LDS, + // we need to read another tile of Q matrix from Vmem, then scale and cast tile + if constexpr(IsCastBeforeLDS) + { + bq_block_tile = load_tile(bq_copy_dram_window); + move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + } + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + + // LDS -> Vgpr 0 + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); __builtin_amdgcn_sched_barrier(0); @@ -521,72 +609,34 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); + // Vmem -> Vgpr Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( - b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + // Vmem -> Vgpr (Q matrix) + // Scale and cast tile before writing to LDS (if IsCastBeforeLDS) bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); + ScaleTile(b_block_tile, bdq_block_tile, bq_block_tile); - sweep_tile_span(b_block[number<0>{}], [&](auto idx0) { - sweep_tile_span(b_block[number<1>{}], [&](auto idx1) { - constexpr auto i_j_idx = make_tuple(idx0, idx1); - constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js); - - auto b_scale_uint = - type_convert(bq_block_tile(i_j_idx_scale)) - 127; - auto b_scale = type_convert(std::pow(2.0f, b_scale_uint)); - constexpr auto idx1_lo = tile_distributed_index{}; - constexpr auto idx1_hi = - tile_distributed_index{}; - constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo); - constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi); - - auto b_pack = type_convert(b_fp4_block_tile(i_j_idx)); - auto b_f4_lo = type_convert(b_pack.unpack(number<0>{})); - auto b_f4_hi = type_convert(b_pack.unpack(number<1>{})); - b_block_tile(i_j_idx_lo) = - type_convert(type_convert(b_f4_lo) * b_scale); - b_block_tile(i_j_idx_hi) = - type_convert(type_convert(b_f4_hi) * b_scale); - }); - }); - + // Consume tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS -> Vgpr + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + HotLoopScheduler(); __builtin_amdgcn_sched_barrier(0); i += 1; - // b_block_stride +=1; } while(i < (num_loop - 1)); } - // tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile); + // tail if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd)) { @@ -596,35 +646,31 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3( - Policy::template MakeShuffledARegTileDistribution()); - transpose_tile2d(a_shuffle_tmp, a_block_tile); - Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func); - } - else - { - Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); - } - if constexpr(is_b_row_major) - { - auto b_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledBRegTileDistribution()); - transpose_tile2d(b_shuffle_tmp, b_block_tile); - Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func); - } - else - { - Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func); - } + // Vgpr -> LDS last tile + ALocalPrefill(a_copy_lds_window, a_block_tile, a_element_func); + BLocalPrefill(b_copy_lds_window, b_block_tile, bdq_block_tile, b_element_func); block_sync_lds(); - block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window); + // LDS -> Vgpr last tile + LocalPrefetch(block_gemm, a_lds_gemm_window, b_lds_gemm_window, bq_block_tile); + + // Consume last tile block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); } @@ -655,7 +701,7 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> struct GemmQuantPipelineProblemBase : public GemmPipelineProblemBase< ADataType_, @@ -78,9 +79,10 @@ struct GemmQuantPipelineProblemBase using AQLayout = remove_cvref_t; using BQLayout = remove_cvref_t; - static constexpr auto Scheduler = Scheduler_; - static constexpr auto HasHotLoop = HasHotLoop_; - static constexpr auto TailNum = TailNum_; + static constexpr auto Scheduler = Scheduler_; + static constexpr auto HasHotLoop = HasHotLoop_; + static constexpr auto TailNum = TailNum_; + static constexpr auto BCastPolicy = BCastPolicy_; static_assert(BlockGemmShape::kM % AQuantGroupSize::kM == 0); static_assert(BlockGemmShape::kK % AQuantGroupSize::kK == 0); @@ -155,7 +157,8 @@ template + TailNumber TailNum_ = TailNumber::Full, + CastPolicy BCastPolicy_ = CastPolicy::AfterLDSRead> using GemmBQuantPipelineProblem = GemmQuantPipelineProblemBase; + TailNum_, + BCastPolicy_>; template , + std::is_same_v, ADataType_, std::conditional_t>; // Calculate thresholds diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp index d491d89ef4e..d28fe620122 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_128.cpp @@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; using BQuantGrouped = std::integral_constant; using GroupSize = ck_tile::QuantGroupShape>; @@ -25,9 +28,12 @@ using GroupSize = ck_tile::QuantGroupShape>; // clang-format off using BQuant1D128Types = ::testing::Types< // 1d cases with grouping only on k axis - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp index 1019caf1bca..4965d48a34b 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_1d_64.cpp @@ -14,8 +14,11 @@ using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; using FP8 = ck_tile::fp8_t; using BF8 = ck_tile::bf8_t; +using BF16 = ck_tile::bf16_t; using Half = ck_tile::half_t; using PkInt4 = ck_tile::pk_int4_t; +using PkFP4 = ck_tile::pk_fp4_t; +using E8M0 = ck_tile::e8m0_t; using BQuantGrouped = std::integral_constant; using GroupSize64 = ck_tile::QuantGroupShape>; @@ -24,10 +27,13 @@ using GroupSize64 = ck_tile::QuantGroupShape>; // QuantType, GemmConfig, QuantGroupSize> // clang-format off using BQuant1D64Types = ::testing::Types< - std::tuple, - std::tuple, - std::tuple, - std::tuple + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple, + std::tuple >; // clang-format on diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a80..bede7f11a1d 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -102,7 +102,7 @@ struct GemmConfigDecodeInterwave : public GemmConfigBase static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; }; -struct GemmConfigMxFp4 : public GemmConfigBase +struct GemmConfigMx : public GemmConfigBase { static constexpr ck_tile::index_t M_Tile = 128; static constexpr ck_tile::index_t N_Tile = 128; @@ -658,8 +658,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase ? (K / 2) : K; + const ck_tile::index_t stride_B = K; const ck_tile::index_t stride_C = N; // BQuant uses block/grouped quantization for B matrix @@ -670,24 +669,36 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); - ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( - std::is_same_v ? K / 2 : K, - N, - stride_B, - this->is_row_major(BLayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); ck_tile::HostTensor bq_bqk_bqn( ck_tile::host_tensor_descriptor(BQK, BQN, stride_BQ, this->is_row_major(BQLayout{}))); // Initialize data with random values ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f}(bq_bqk_bqn); } else { ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + } + + if constexpr(std::is_same_v) + { + auto gen_scales = [&](auto& scales, float range_min, float range_max) { + // e8m0_t is basically an exponent of float32 + ck_tile::HostTensor pow2(scales.get_lengths()); + ck_tile::FillUniformDistributionIntegerValue{range_min, range_max}(pow2); + scales.ForEach([&](auto& self, const auto& i) { + self(i) = static_cast(std::exp2(pow2(i))); + }); + }; + gen_scales(bq_bqk_bqn, -2, 2); + } + else + { ck_tile::FillUniformDistribution{-1.0f, 1.0f}(bq_bqk_bqn); } @@ -769,14 +780,14 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase) - ck_tile::reference_mxfp4gemm_quant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); + if constexpr(std::is_same_v) + ck_tile::reference_mx_gemm_bquant(a_m_k, bq_bqk_bqn, b_k_n, c_m_n_host_ref); else ck_tile::reference_gemm_quant, - ck_tile::MxFp4GemmPipelineAgBgCrCompV3, + std::conditional_t, + ck_tile::MxGemmPipelineAgBgCrCompV3, ck_tile::BQuantGemmPipelineAgBgCrCompV3>, ck_tile::WPQuantBPipelineAgBgCrV2>; using GemmEpilogue = ck_tile::CShuffleEpilogue, + std::conditional_t, ADataType, BDataType>, ck_tile::tuple<>,