Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;

using BaseGemmPipeline =
GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;
typename GemmQuantConfig<QuantMode>::template BaseGemmPipeline<GemmPipelineProblem,
GemmConfig::PreshuffleB>;

const ck_tile::index_t k_grain = gemm_descs[0].k_batch * GemmConfig::K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
Expand Down Expand Up @@ -108,8 +108,8 @@ float grouped_gemm_abquant(const std::vector<grouped_gemm_kargs>& gemm_descs,
tail_number_v>;

using GemmPipeline =
GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
typename GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
Expand Down Expand Up @@ -227,8 +227,9 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
BQuantGroupSize,
GemmConfig::TransposeC>;

using GemmPipeline = GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;
using GemmPipeline =
typename GemmQuantConfig<QuantMode>::template GemmPipeline<QuantGemmProblem,
GemmConfig::PreshuffleB>;

using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
Expand Down
20 changes: 6 additions & 14 deletions include/ck_tile/host/tensor_shuffle_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,25 +164,17 @@ auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t, const GemmConfig& gemmC
}
else
{
int divisor = 1;
if(ck_tile::is_gfx11_supported())
{
divisor = 1;
}
else
{
assert(is_wave32() == false);
divisor = get_warp_size() / gemmConfig.N_Warp_Tile;
}
constexpr int KLane = ck_tile::get_warp_size() / GemmConfig::N_Warp_Tile;
constexpr int ItemsPerAccess =
std::min(16 / static_cast<int>(sizeof(T)), GemmConfig::K_Warp_Tile / KLane);
ck_tile::HostTensor<T> t_view({n_ / gemmConfig.N_Tile,
gemmConfig.N_Warp,
gemmConfig.N_Warp_Tile,
NRepeat,
k_ / gemmConfig.K_Warp_Tile,
divisor,
gemmConfig.K_Warp_Tile / divisor});
k_ / ItemsPerAccess,
ItemsPerAccess /*gemmConfig.K_Warp_Tile / divisor*/});
std::copy(t.begin(), t.end(), t_view.begin());
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 2, 5});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,20 +271,19 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy

constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC,
false,
false,
NumAccess>;

using BlockWeightPreshufflePolicy =
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Pro
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;

// A/B DataType gets converted from PkInt4/PkFp4 during loading
using OverrideADataType = BlockGemm::OverrideADataType;
using OverrideBDataType = BlockGemm::OverrideBDataType;
using OverrideADataType = typename BlockGemm::OverrideADataType;
using OverrideBDataType = typename BlockGemm::OverrideBDataType;

static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,13 @@ struct GemmWPABQuantPipelineAgBgCrPolicy : public UniversalWeightPreshufflePipel

constexpr index_t WaveSize = get_warp_size();
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
using BDataType = typename Problem::BDataType;
constexpr index_t KLaneBytes =
KLane / numeric_traits<BDataType>::PackedSize * sizeof(BDataType);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));

// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
typename Problem::ADataType,
typename Problem::ComputeDataType>;
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));

using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,24 @@ struct config
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
};

template <typename Datatype>
struct config_mn_16x16 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};

template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
Comment on lines 91 to +100
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config_mn_16x16 and config_mn_32x32 appear to have their M_Warp_Tile/N_Warp_Tile values swapped relative to their names: config_mn_16x16 sets both tiles to 32, while config_mn_32x32 sets them to 16. This mismatch is confusing when reading the tests and can make it harder to reason about which warp-tile shape is actually being exercised; consider either swapping the tile constants or renaming the structs so the *_16x16/*_32x32 suffix matches the configured warp tile sizes.

Suggested change
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};
template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};
template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;

Copilot uses AI. Check for mistakes.
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<Datatype, M_Warp_Tile>();
};

template <typename Datatype>
struct config_wmma
{
Expand Down Expand Up @@ -252,7 +264,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
RunSingle<config_wmma<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
#else
RunSingle<config<ADataType>, PadM, PadN, PadK, Preshuffle>(
RunSingle<config_mn_16x16<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
RunSingle<config_mn_32x32<ADataType>, PadM, PadN, PadK, Preshuffle>(
M, N, K, StrideA, StrideB, StrideC, kb);
#endif
}
Expand Down