From 4340c1c3993b5f378cd549fe18f524dcfed441c7 Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Thu, 29 Jan 2026 22:18:26 -0500 Subject: [PATCH 1/2] [CK TILE] fix bugs of preshuffle_b --- .../17_grouped_gemm/abquant_grouped_gemm.cpp | 13 ++++----- include/ck_tile/host/tensor_shuffle_utils.hpp | 20 +++++--------- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 27 +++++++++---------- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 4 +-- ..._abquant_pipeline_ag_bg_cr_base_policy.hpp | 11 +++++--- .../test_gemm_pipeline_util.hpp | 16 ++++++++++- 6 files changed, 50 insertions(+), 41 deletions(-) diff --git a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp index 703751b7604..28b3884d0f4 100644 --- a/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp @@ -75,8 +75,8 @@ float grouped_gemm_abquant(const std::vector& gemm_descs, ck_tile::GemmPipelineProblem; using BaseGemmPipeline = - GemmQuantConfig::template BaseGemmPipeline; + typename GemmQuantConfig::template BaseGemmPipeline; const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile; const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile; @@ -108,8 +108,8 @@ float grouped_gemm_abquant(const std::vector& gemm_descs, tail_number_v>; using GemmPipeline = - GemmQuantConfig::template GemmPipeline; + typename GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem; - using GemmPipeline = GemmQuantConfig::template GemmPipeline; + using GemmPipeline = + typename GemmQuantConfig::template GemmPipeline; using GemmEpilogue = ck_tile::CShuffleEpilogue< ck_tile::CShuffleEpilogueProblem& t, const GemmConfig& gemmC } else { - int divisor = 1; - if(ck_tile::is_gfx11_supported()) - { - divisor = 1; - } - else - { - assert(is_wave32() == false); - divisor = get_warp_size() / gemmConfig.N_Warp_Tile; - } + constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile; + constexpr int ItemsPerAccess = + std::min(16 / static_cast(sizeof(T)), GemmConfig::K_Warp_Tile / KLane); ck_tile::HostTensor t_view({n_ / gemmConfig.N_Tile, gemmConfig.N_Warp, gemmConfig.N_Warp_Tile, NRepeat, - k_ / gemmConfig.K_Warp_Tile, - divisor, - gemmConfig.K_Warp_Tile / divisor}); + k_ / ItemsPerAccess, + ItemsPerAccess /*gemmConfig.K_Warp_Tile / divisor*/}); std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6}); + return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); } } diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 0044b412ec5..4903f8e501b 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -271,20 +271,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; - using BDataType = typename Problem::BDataType; - constexpr index_t KLaneBytes = - KLane / numeric_traits::PackedSize * sizeof(BDataType); - constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); - using WarpGemm = WarpGemmDispatcher; + // When BDataType is pk_int4_t, it is internally converted to fp8 for computation. + constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse); + constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy())>; // A/B DataType gets converted from PkInt4/PkFp4 during loading - using OverrideADataType = BlockGemm::OverrideADataType; - using OverrideBDataType = BlockGemm::OverrideBDataType; + using OverrideADataType = typename BlockGemm::OverrideADataType; + using OverrideBDataType = typename BlockGemm::OverrideBDataType; static constexpr index_t BlockSize = Problem::kBlockSize; static constexpr index_t MPerBlock = BlockGemmShape::kM; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp index f136b863141..36d85605435 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp @@ -97,10 +97,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel constexpr index_t WaveSize = get_warp_size(); constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize; - using BDataType = typename Problem::BDataType; - constexpr index_t KLaneBytes = - KLane / numeric_traits::PackedSize * sizeof(BDataType); - constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); + + // When BDataType is pk_int4_t, it is internally converted to fp8 for computation. + using BTypeToUse = mixed_prec_compute_type_from_input_t; + constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse); + constexpr auto NumAccess = static_cast(max(1, KLaneBytes / 16)); using WarpGemm = WarpGemmDispatcher +struct config_mn_16x16 : public config +{ static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; }; +template +struct config_mn_32x32 : public config +{ + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); +}; + template struct config_wmma { @@ -252,7 +264,9 @@ class TestCkTileGemmPipeline : public ::testing::Test RunSingle, PadM, PadN, PadK, Preshuffle>( M, N, K, StrideA, StrideB, StrideC, kb); #else - RunSingle, PadM, PadN, PadK, Preshuffle>( + RunSingle, PadM, PadN, PadK, Preshuffle>( + M, N, K, StrideA, StrideB, StrideC, kb); + RunSingle, PadM, PadN, PadK, Preshuffle>( M, N, K, StrideA, StrideB, StrideC, kb); #endif } From 758921f999cce95f985d096658d0663405da8baa Mon Sep 17 00:00:00 2001 From: Cong Ma Date: Sat, 31 Jan 2026 18:40:42 -0500 Subject: [PATCH 2/2] [CK TILE] fix bugs of preshuffle_b --- include/ck_tile/host/tensor_shuffle_utils.hpp | 2 +- .../gemm_weight_preshuffle/test_gemm_pipeline_util.hpp | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/host/tensor_shuffle_utils.hpp b/include/ck_tile/host/tensor_shuffle_utils.hpp index 0b15dd093af..ea2174f4356 100644 --- a/include/ck_tile/host/tensor_shuffle_utils.hpp +++ b/include/ck_tile/host/tensor_shuffle_utils.hpp @@ -172,7 +172,7 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor& t, const GemmConfig& gemmC gemmConfig.N_Warp_Tile, NRepeat, k_ / ItemsPerAccess, - ItemsPerAccess /*gemmConfig.K_Warp_Tile / divisor*/}); + ItemsPerAccess}); std::copy(t.begin(), t.end(), t_view.begin()); return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5}); } diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 2387c0d50fd..2416ef09b01 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -86,15 +86,15 @@ struct config }; template -struct config_mn_16x16 : public config +struct config_mn_32x32 : public config { static constexpr ck_tile::index_t M_Warp_Tile = 32; static constexpr ck_tile::index_t N_Warp_Tile = 32; - static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); }; template -struct config_mn_32x32 : public config +struct config_mn_16x16 : public config { static constexpr ck_tile::index_t M_Warp_Tile = 16; static constexpr ck_tile::index_t N_Warp_Tile = 16;