From b6213e69434dd296630378794e9db01af8a937de Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 30 Jan 2026 12:26:08 +0000 Subject: [PATCH 1/5] feat: bquant preshuffle for preshuffleb abquant pipeline --- .../38_block_scale_gemm/CMakeLists.txt | 1 + ...antgrouped_preshuffleb_preshufflequant.cpp | 44 +++++++++++++++++++ .../38_block_scale_gemm/gemm_utils.hpp | 7 +++ .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 34 +++++++++++--- 4 files changed, 79 insertions(+), 7 deletions(-) create mode 100644 example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 13cbcc8b558..2a5b12d4d58 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -14,6 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") add_executable(${EXE_NAME} gemm_quant.cpp gemm_abquant_quantgrouped.cpp + gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp gemm_aquant_quantgrouped.cpp gemm_aquant_quantgrouped_preshufflequant.cpp gemm_bquant_quantgrouped_bf8i4.cpp diff --git a/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp new file mode 100644 index 00000000000..e854a3bcc80 --- /dev/null +++ b/example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp @@ -0,0 +1,44 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "38_block_scale_gemm/gemm_utils.hpp" +#include "run_gemm_quant_example.inc" + +template +using GemmConfigPreshuffleB_PreshuffleBQuant = + GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill; + +static auto _ = []() { + auto& lut = get_kernel_lut(); + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "preshufflequant", + "1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + lut[hash_multiple_strings({"fp8", + "abquant", + "preshuffleb", + "preshufflequant", + "1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) { + using AQuantGroupSize = ck_tile::QuantGroupShape>; + using BQuantGroupSize = ck_tile::QuantGroupShape>; + using TypeConfig = + decltype(GemmQuantTypeConfig{}); + return run_gemm_example_prec_type, + TypeConfig, + AQuantGroupSize, + BQuantGroupSize, + ck_tile::QuantType::ABQuantGrouped>(arg_parser); + }; + return 0; +}(); diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 085d6344415..de734832d6c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -234,6 +234,13 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua static constexpr bool TransposeC = true; }; +template +struct GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill + : public GemmConfigPreshuffleB_ABQuant_Prefill +{ + static constexpr bool BPreshuffleQuant = true; +}; + template struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill { diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 051b71e2c33..ca4a931487f 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -73,12 +73,21 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe using Base::m_preload; static constexpr index_t VectorLoadSize = Problem::VectorLoadSize; - static constexpr index_t KPerBlockAQ = - integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK); - static constexpr index_t KPerBlockBQ = - integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK); + + static constexpr index_t NPerBlock = BlockGemmShape::kN; + static constexpr index_t KPerBlock = BlockGemmShape::kK; + + static constexpr index_t NPerBlockBQ = (BQuantGroupSize::kN <= KPerBlock) + ? integer_divide_ceil(NPerBlock, BQuantGroupSize::kN) + : 1; + + static constexpr index_t KPerBlockAQ = integer_divide_ceil(KPerBlock, AQuantGroupSize::kK); + static constexpr index_t KPerBlockBQ = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK); static constexpr index_t QScalesPerBlockRow = integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK); + + static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant; + static constexpr index_t GetVectorSizeAQ() { return PipelinePolicy::template GetVectorSizeAQ(); @@ -348,6 +357,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe bq_dram_block_window_tmp.get_window_origin(), PipelinePolicy::template MakeBQDramTileDistribution()); + // BQ DRAM window step + using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex; + const BQDramTileWindowStep bq_dram_tile_window_step = + (BPreshuffleQuant) + ? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{})) + ? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN) + : ck_tile::integer_least_multiple(n, NPerBlock) / + BlockGemmShape::WarpTile::at(number<1>{})), + 0) + : make_array(0, KPerBlockBQ); + // Prefetch A0 auto a_block_tile = load_tile(a_copy_dram_window); // move A window to next k @@ -383,7 +403,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe bq_block_tile = load_tile(bq_copy_dram_window); // move BQ to tile 1 move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); // Prefill A0 auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); store_tile(a_copy_lds_window_ping, a_block_tile_tmp); @@ -453,7 +473,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile_2 = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile_2 = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); // Preload A(2i+1) ds_read static_for<0, m_preload, 1>{}([&](auto loadIter) { @@ -482,7 +502,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe aq_block_tile = load_tile(aq_copy_dram_window); move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ}); bq_block_tile = load_tile(bq_copy_dram_window); - move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ}); + move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step); // Prefill A(2i+2) a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); From 87dda06e779720ce1184a567283c7bc9072bb81c Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 30 Jan 2026 13:33:34 +0000 Subject: [PATCH 2/5] fix: disable PermuteN for ABQuant PreshuffleB in example --- example/ck_tile/38_block_scale_gemm/gemm_utils.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index de734832d6c..231df134aff 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -232,6 +232,8 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua static constexpr bool kPadK = false; static constexpr bool TransposeC = true; + + static constexpr bool TiledMMAPermuteN = false; }; template From c14fa9aab3aefc472df8866d7ad590c9affbbc94 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 30 Jan 2026 13:34:03 +0000 Subject: [PATCH 3/5] feat: add test for abquant preshuffle b + bquant --- test/ck_tile/gemm_block_scale/CMakeLists.txt | 4 ++ ...ant_abquant_preshuffle_preshuffleQuant.cpp | 43 +++++++++++++++++++ .../test_gemm_quant_fixtures.hpp | 5 +++ 3 files changed, 52 insertions(+) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 8e005d588e7..c84848700f0 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -76,6 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant + test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp + ) + target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base test_gemm_quant_abquant_a4w4_base.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp new file mode 100644 index 00000000000..f061c7dd47e --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using ABQuantGrouped = + std::integral_constant; +using GroupSize = ck_tile::QuantGroupShape>; + +// 2d block sizes for BQuant +using GroupSize2D128N = ck_tile::QuantGroupShape>; + +// Type combinations for ABQuant tests +// Tuple format: +// clang-format off +using ABQuantPreshuffleQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for ABQuant +TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes); + +// AQuant tests +TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 0033bb42a80..fc671a6d59c 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -159,6 +159,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0; }; +struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill +{ + static constexpr bool BPreshuffleQuant = true; +}; + struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode { static constexpr bool BPreshuffleQuant = true; From 99264c69082f2d06c53de703107f6c518fae5f18 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Fri, 30 Jan 2026 16:24:11 +0000 Subject: [PATCH 4/5] chore: removed PermuteN override again, as this seemed not to be the issue for 1D block scale --- example/ck_tile/38_block_scale_gemm/gemm_utils.hpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index 231df134aff..de734832d6c 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -232,8 +232,6 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua static constexpr bool kPadK = false; static constexpr bool TransposeC = true; - - static constexpr bool TiledMMAPermuteN = false; }; template From 8e3ec4765d31bcde79826a0d0fb50691a563dd80 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Sat, 31 Jan 2026 10:38:03 +0000 Subject: [PATCH 5/5] chore: empty commit to trigger CI again