Skip to content

Conversation

@CongMa13
Copy link
Collaborator

This pull request introduces several improvements and fixes related to quantized grouped GEMM (General Matrix Multiply) pipelines and their supporting utilities.

The numerical issue

Steps to reproduce

Run 
./bin/tile_example_gemm_weight_preshuffle -prec=fp8
./bin/tile_example_gemm_weight_preshuffle -prec=int4

Solution

The main changes address type correctness, improve data layout and shuffling logic, and expand test coverage to better validate different GEMM configurations.

Key changes include:

Data layout and shuffling logic

  • Refactored the logic in shuffle_b_permuteN to use constexpr variables for KLane and ItemsPerAccess, simplifying tile view construction and correcting the permutation order for improved efficiency and correctness (tensor_shuffle_utils.hpp).
  • Fixed the calculation of KLaneBytes in weight preshuffle pipeline policies to account for internal data type conversion (e.g., from pk_int4_t to fp8), ensuring accurate memory access and alignment in quantized GEMM policies (wp_pipeline_agmem_bgmem_creg_base_policy.hpp, gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp). [1] [2]

Test infrastructure enhancements

  • Unit tests did not catch this issue since there were no tests for fp8. Added new configuration structs (config_mn_16x16, config_mn_32x32) to support additional GEMM tile shapes and updated tests to run with these configurations for broader coverage (test_gemm_pipeline_util.hpp). [1] [2]

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes numerical issues in the weight preshuffle path for quantized GEMM (including pk_int4 → fp8), aligns host-side shuffle logic with the kernel’s expectations, and broadens test coverage to additional GEMM tile shapes so fp8 issues are exercised.

Changes:

  • Adjusted weight preshuffle and AB-quant pipeline policies to compute KLaneBytes based on the internal compute type (mixed_prec_compute_type_from_input_t) instead of the packed input storage type, ensuring correct NumAccess and memory access patterns for pk_int4/fp8 flows.
  • Fixed shuffle_b_permuteN’s tile-view layout and permutation to use a consistent KLane/ItemsPerAccess formulation and corrected the permutation order to match the intended layout.
  • Extended the GEMM preshuffle test utilities with new config structs and updated example grouped-GEMM code to use properly qualified dependent template types, improving both coverage and compilation robustness.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp Adds config_mn_16x16/config_mn_32x32 warp-tile configs and runs the preshuffle tests with both, increasing coverage across GEMM tile shapes (note: the *_16x16/*_32x32 names are currently inverted relative to their M_Warp_Tile/N_Warp_Tile values).
include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp Updates UniversalWeightPreshufflePipelineAgBgCrPolicy::GetBlockWeightPreshuffle to compute KLaneBytes from BTypeToUse (mixed-precision compute type) so access granularity matches the actual compute datatype.
include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_base_policy.hpp Mirrors the KLaneBytes/NumAccess fix in the AB-quant weight preshuffle B-quant policy, basing the lane-byte count on BTypeToUse derived from Problem::BDataType and the compute type.
include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp Qualifies BlockGemm::OverrideADataType/OverrideBDataType with typename to correctly refer to dependent nested types in the AB-quant compute pipeline.
include/ck_tile/host/tensor_shuffle_utils.hpp Reimplements the non-gfx12 branch of shuffle_b_permuteN to use constexpr KLane/ItemsPerAccess (aligned with shuffle_b) and adjusts the tensor view rank and reference_permute order to a layout consistent with the preshuffle kernel.
example/ck_tile/17_grouped_gemm/abquant_grouped_gemm.cpp Adds typename in BaseGemmPipeline and GemmPipeline dependent template instantiations so grouped AB-quant GEMM examples compile cleanly with strict C++ template rules.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 91 to +100
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};

template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
Copy link

Copilot AI Jan 31, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config_mn_16x16 and config_mn_32x32 appear to have their M_Warp_Tile/N_Warp_Tile values swapped relative to their names: config_mn_16x16 sets both tiles to 32, while config_mn_32x32 sets them to 16. This mismatch is confusing when reading the tests and can make it harder to reason about which warp-tile shape is actually being exercised; consider either swapping the tile constants or renaming the structs so the *_16x16/*_32x32 suffix matches the configured warp tile sizes.

Suggested change
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};
template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(Datatype) == 2 ? 16 : 32;
};
template <typename Datatype>
struct config_mn_32x32 : public config<Datatype>
{
static constexpr ck_tile::index_t M_Warp_Tile = 32;
static constexpr ck_tile::index_t N_Warp_Tile = 32;

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants