diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 13cbcc8b55..2a5b12d4d5 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -14,6 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} gemm_quant.cpp gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp new file mode 100644 index 0000000000..e854a3bcc8 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "38_block_scale_gemm/gemm_utils.hpp" +#include "run_gemm_quant_example.inc" + +template +using GemmConfigPreshuffleB_PreshuffleBQuant = + GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 085d634441..de734832d6 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -234,6 +234,13 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua static constexpr bool TransposeC = true; }; +template +struct GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill + : public GemmConfigPreshuffleB_ABQuant_Prefill +{ + static constexpr bool BPreshuffleQuant = true; +}; + template struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 051b71e2c3..ca4a931487 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -73,12 +73,21 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe using Base::m_preload; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; - static constexpr index_t KPerBlockAQ = - integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK); - static constexpr index_t KPerBlockBQ = - integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); + + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = (BQuantGroupSize::kN <= KPerBlock) + ? integer_divide_ceil(NPerBlock, BQuantGroupSize::kN) + : 1; + + static constexpr index_t KPerBlockAQ = integer_divide_ceil(KPerBlock, AQuantGroupSize::kK); + static constexpr index_t KPerBlockBQ = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); + + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; + static constexpr index_t GetVectorSizeAQ() { return PipelinePolicy::template GetVectorSizeAQ(); @@ -348,6 +357,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe bq_dram_block_window_tmp.get_window_origin(), PipelinePolicy::template MakeBQDramTileDistribution()); + // BQ DRAM window step + using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex; + const BQDramTileWindowStep bq_dram_tile_window_step = + (BPreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) + : make_array(0, KPerBlockBQ); + // Prefetch A0 auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k @@ -383,7 +403,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe bq_block_tile = load_tile(bq_copy_dram_window); // move BQ to tile 1 move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); @@ -453,7 +473,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile_2 = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile_2 = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); // Preload A(2i+1) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -482,7 +502,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); // Prefill A(2i+2) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8e005d588e..c84848700f 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -76,6 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant + test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base test_gemm_quant_abquant_a4w4_base.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp new file mode 100644 index 0000000000..f061c7dd47 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a8..fc671a6d59 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -159,6 +159,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill +{ + static constexpr bool BPreshuffleQuant = true; +}; + struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode { static constexpr bool BPreshuffleQuant = true;