From 30e7163f2465fdcf5a56cbda364ea8c4a0af0930 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Sat, 24 Jan 2026 17:49:31 +0000 Subject: [PATCH 1/7] WIP: add splitk to bquant --- .../38_block_scale_gemm/CMakeLists.txt | 40 ++--- .../gemm_bquant_quantgrouped_bf8.cpp | 2 +- .../gemm_bquant_quantgrouped_fp8.cpp | 2 +- .../run_gemm_quant_example.inc | 163 +----------------- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 95 +++++++++- test_splitk_stress.sh | 157 +++++++++++++++++ 6 files changed, 271 insertions(+), 188 deletions(-) create mode 100755 test_splitk_stress.sh diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index 13cbcc8b558..ff51122af6e 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,28 +13,28 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} gemm_quant.cpp - gemm_abquant_quantgrouped.cpp - gemm_aquant_quantgrouped.cpp - gemm_aquant_quantgrouped_preshufflequant.cpp - gemm_bquant_quantgrouped_bf8i4.cpp - gemm_bquant_quantgrouped_fp8i4.cpp - gemm_bquant_quantgrouped_bf16mxfp4.cpp + # gemm_abquant_quantgrouped.cpp + # gemm_aquant_quantgrouped.cpp + # gemm_aquant_quantgrouped_preshufflequant.cpp + # gemm_bquant_quantgrouped_bf8i4.cpp + # gemm_bquant_quantgrouped_fp8i4.cpp + # gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_bf8.cpp - gemm_bquant_quantgrouped_preshuffleb_fp8.cpp - gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp - gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp - gemm_bquant_quantgrouped_preshufflequant_bf8.cpp - gemm_bquant_quantgrouped_preshufflequant_fp8.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp - gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp - gemm_quant_rowcol.cpp - gemm_quant_tensor.cpp + # gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp + # gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp + # gemm_bquant_quantgrouped_preshuffleb_bf8.cpp + # gemm_bquant_quantgrouped_preshuffleb_fp8.cpp + # gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp + # gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp + # gemm_bquant_quantgrouped_preshufflequant_bf8.cpp + # gemm_bquant_quantgrouped_preshufflequant_fp8.cpp + # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp + # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp + # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp + # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp + # gemm_quant_rowcol.cpp + # gemm_quant_tensor.cpp ) target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp index a95c0346cf7..1520f2c591b 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp index a8c13c1b3dd..39747ff0bcb 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 665c7828ad3..e14685cf791 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -214,11 +214,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); const dim3 blocks = Kernel::BlockSize(); - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - + // Split-K validation is handled by Kernel::IsSupportedArgument + // BQuantGrouped without preshuffle supports split-K if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -570,6 +567,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else { + ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); @@ -653,7 +651,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 3) + else if(init_method ==3) { if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) { @@ -671,7 +669,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, else { ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(2.0f)}(*aq_tensor_ptr); + ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); ck_tile::FillConstant{static_cast(0x38)}(b_k_n); if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) @@ -680,157 +678,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method == 4) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - } - ck_tile::FillUniformDistribution{2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } - else if(init_method == 5) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{125.f, 130.f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); - } - else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(a_m_k); - } - // Fill aquant such that column j has value 2^j (1, 2, 4, 8, ...) - for(ck_tile::index_t row = 0; - row < static_cast(aq_tensor_ptr->get_length(0)); - ++row) - { - for(ck_tile::index_t col = 0; - col < static_cast(aq_tensor_ptr->get_length(1)); - ++col) - { - (*aq_tensor_ptr)(row, col) = static_cast(col + 1); - } - } - // std::cout << "aq_tensor_ptr: " << *aq_tensor_ptr << std::endl; - ck_tile::FillUniformDistribution{1.0f, 1.0f, fill_seed(gen)}(b_k_n); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - b_k_n); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *aq_tensor_ptr); - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } - } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 21bd691b497..dee4553956f 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -410,10 +410,46 @@ struct QuantGemmKernel { splitted_k = amd_wave_read_first_lane(kargs.K - KRead * (kargs.k_batch - 1)); } + + // Compute BQ offset for BQuantGrouped mode (non-preshuffle only) + // Note: With the alignment validation in IsSupportedArgument, KRead is always + // a multiple of QuantGroupSize::kK, so bq_k_split_offset will be correctly aligned. + if constexpr(kQuantType == QuantType::BQuantGrouped && !PreshuffleQuant) + { + using QuantGroupSize = remove_cvref_t; + // Compute the K offset for this batch (in terms of K elements) + const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); + // Convert K offset to BQ group offset + const index_t bq_group_offset = + amd_wave_read_first_lane(k_offset / QuantGroupSize::kK); + + // BQ tensor layout: + // RowMajor: [K/kK, N/kN] with stride [N/kN, 1] + // ColumnMajor: [N/kN, K/kK] with stride [K/kK, 1] + if constexpr(std::is_same_v) + { + // For RowMajor BQ, K is the row dimension + // offset = bq_group_offset * stride_BQ + const index_t stride_bq = + amd_wave_read_first_lane(integer_divide_ceil(kargs.N, QuantGroupSize::kN)); + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq); + } + else if constexpr(std::is_same_v) + { + // For ColumnMajor BQ, K is the column dimension + // offset = bq_group_offset + bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset); + } + } + else + { + bq_k_split_offset = 0; + } } index_t a_k_split_offset; index_t b_k_split_offset; + index_t bq_k_split_offset; index_t splitted_k; }; @@ -809,6 +845,9 @@ struct QuantGemmKernel const index_t i_n) { // Step 1: Create tensor view for BQ + // Note: For split-K, the bq_ptr is already offset by bq_k_split_offset. + // The tensor view should use kargs.QK_B (full K-groups) as the dimension + // because the view needs to see all remaining K-groups from the offset position. const auto& bq_tensor_view = [&]() { if constexpr(kQuantType == QuantType::RowColQuant) { @@ -854,7 +893,7 @@ struct QuantGemmKernel { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), + make_tuple(kargs.QK_B, integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), number{}, @@ -865,8 +904,8 @@ struct QuantGemmKernel return make_naive_tensor_view( bq_ptr, make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), - integer_divide_ceil(kargs.K, BQuantGroupSize::kK)), - make_tuple(integer_divide_ceil(kargs.K, BQuantGroupSize::kK), 1), + kargs.QK_B), + make_tuple(kargs.QK_B, 1), number{}, number<1>{}); } @@ -1047,13 +1086,49 @@ struct QuantGemmKernel CK_TILE_HOST static bool IsSupportedArgument(const QuantGemmKernelArgs& kargs) { + // Split-K is supported for BQuantGrouped mode without preshuffle if(kargs.k_batch != 1) { - if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + constexpr bool is_bquant_non_preshuffle = + (kQuantType == QuantType::BQuantGrouped) && !PreshuffleQuant; + if constexpr(!is_bquant_non_preshuffle) { - CK_TILE_ERROR("Conditions not met for Kbatch >1 !"); + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Conditions not met for Kbatch >1 ! " + "Split-K only supported for BQuantGrouped without preshuffle."); + } + return false; + } + else + { + // For BQuantGrouped split-K, the K split must be aligned with quantization groups. + // This is because the pipeline applies BQ scales based on the relative K index + // within each batch, not the absolute K index. When a batch starts in the middle + // of a quantization group, the scale selection would be incorrect. + // + // To support misaligned splits, the pipeline would need to be modified to accept + // a K offset parameter and use (k_offset + local_k) / QuantGroupSize::kK for + // scale selection instead of local_k / QuantGroupSize::kK. + using QuantGroupSize = remove_cvref_t; + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + + // KRead must be a multiple of QuantGroupSize::kK to ensure proper BQ alignment + if(KRead % QuantGroupSize::kK != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("Split-K batch size must be aligned with quantization group " + "size! KRead=" + + std::to_string(KRead) + + " is not divisible by QuantGroupSize::kK=" + + std::to_string(QuantGroupSize::kK)); + } + return false; + } } - return false; } if constexpr(std::is_same_v) @@ -1215,6 +1290,9 @@ struct QuantGemmKernel const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); + // Note: BQ tensor view uses full dimensions (not splitted_qk_b) because + // the split-K offset is already applied to bq_ptr. The tensor view needs + // to see the full remaining K-groups from the offset position. const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); const index_t num_loop = @@ -1343,8 +1421,9 @@ struct QuantGemmKernel const BDataType* b_ptr = static_cast(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset; const AQDataType* aq_ptr = static_cast(kargs.aq_ptr); - const BQDataType* bq_ptr = static_cast(kargs.bq_ptr); - CDataType* c_ptr = static_cast(kargs.c_ptr); + const BQDataType* bq_ptr = + static_cast(kargs.bq_ptr) + splitk_batch_offset.bq_k_split_offset; + CDataType* c_ptr = static_cast(kargs.c_ptr); // allocate LDS __shared__ char smem_ptr[GetSmemSize()]; diff --git a/test_splitk_stress.sh b/test_splitk_stress.sh new file mode 100755 index 00000000000..7613db26c90 --- /dev/null +++ b/test_splitk_stress.sh @@ -0,0 +1,157 @@ +#!/bin/bash +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +# Stress test script for Split-K BQuant GEMM +# Tests only valid configurations where KRead % 128 == 0 +# Tests both fp8 and bf8 precisions + +BINARY="./build/bin/tile_example_gemm_quant" +QUANT_MODE="bquant" +INIT=0 +REPEAT=1 +WARMUP=0 + +# Arrays to track results +declare -a PASSED_TESTS=() +declare -a FAILED_TESTS=() +declare -a SKIPPED_TESTS=() + +echo "==============================================" +echo "Split-K BQuant GEMM Stress Test" +echo "==============================================" +echo "Binary: $BINARY" +echo "Quant Mode: $QUANT_MODE" +echo "Init: $INIT (random)" +echo "Testing precisions: fp8, bf8" +echo "==============================================" +echo "" + +# M values to test +M_VALUES=(16 32 64 128 256) + +# N values to test +N_VALUES=(64 128 256) + +# Valid (K, split_k) combinations where KRead % 128 == 0 +# Format: "K:split_k1,split_k2,..." +declare -A VALID_K_SPLITS +VALID_K_SPLITS[256]="1,2" +VALID_K_SPLITS[384]="1,3" +VALID_K_SPLITS[512]="1,2,4,5" +VALID_K_SPLITS[640]="1,5,6" +VALID_K_SPLITS[768]="1,2,3,6,7" +VALID_K_SPLITS[896]="1,7,8" +VALID_K_SPLITS[1024]="1,2,4,8" +VALID_K_SPLITS[1152]="1,3,5" +VALID_K_SPLITS[1280]="1,2,5" +VALID_K_SPLITS[1536]="1,2,3,4,6" +VALID_K_SPLITS[1792]="1,2,5,7" +VALID_K_SPLITS[2048]="1,2,4,8" +VALID_K_SPLITS[2560]="1,2,4,5,7" +VALID_K_SPLITS[3072]="1,2,3,4,5,6,8" +VALID_K_SPLITS[4096]="1,2,4,8" + +# Precisions to test +PREC_VALUES=("fp8" "bf8") + +TOTAL_TESTS=0 +PASS_COUNT=0 +FAIL_COUNT=0 +SKIP_COUNT=0 + +for PREC in "${PREC_VALUES[@]}"; do + echo "" + echo "############################################" + echo "Testing precision: $PREC" + echo "############################################" + echo "" + + for M in "${M_VALUES[@]}"; do + for N in "${N_VALUES[@]}"; do + for K in "${!VALID_K_SPLITS[@]}"; do + # Parse valid split_k values for this K + IFS=',' read -ra SPLIT_K_ARRAY <<< "${VALID_K_SPLITS[$K]}" + + for SPLIT_K in "${SPLIT_K_ARRAY[@]}"; do + ((TOTAL_TESTS++)) + + echo "----------------------------------------------" + echo "Test #$TOTAL_TESTS: prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K" + echo "----------------------------------------------" + + OUTPUT=$($BINARY -quant_mode=$QUANT_MODE -repeat=$REPEAT -warmup=$WARMUP \ + -prec=$PREC -split_k=$SPLIT_K -m=$M -n=$N -init=$INIT -k=$K 2>&1) + + # Print kernel output (grid size and verification result) + echo "$OUTPUT" | grep -E "(grid:|verification)" | head -2 + + # Check result + if echo "$OUTPUT" | grep -q "verification result is:correct"; then + echo "Result: PASS" + ((PASS_COUNT++)) + PASSED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") + elif echo "$OUTPUT" | grep -q "verification result is:fail"; then + echo "Result: FAIL (numerical error)" + ((FAIL_COUNT++)) + FAILED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") + # Show error details + echo "$OUTPUT" | grep -E "max err:|wrong values" | head -2 + elif echo "$OUTPUT" | grep -q "not supported\|Skipping\|Arguments not supported"; then + echo "Result: SKIPPED (configuration not supported)" + ((SKIP_COUNT++)) + SKIPPED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") + ((TOTAL_TESTS--)) + else + echo "Result: FAIL (unknown error)" + ((FAIL_COUNT++)) + FAILED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") + echo "$OUTPUT" | tail -5 + fi + echo "" + done + done + done + done +done + +echo "" +echo "==============================================" +echo " SUMMARY" +echo "==============================================" +echo "" +echo "Total Tests Run: $TOTAL_TESTS" +echo "Passed: $PASS_COUNT" +echo "Failed: $FAIL_COUNT" +echo "Skipped: $SKIP_COUNT" +echo "" + +if [ $FAIL_COUNT -eq 0 ]; then + echo "✓ ALL TESTS PASSED!" +else + echo "✗ SOME TESTS FAILED!" + echo "" + echo "Failed test cases:" + for test in "${FAILED_TESTS[@]}"; do + echo " - $test" + done +fi + +if [ $SKIP_COUNT -gt 0 ]; then + echo "" + echo "Skipped test cases (not supported):" + for test in "${SKIPPED_TESTS[@]}"; do + echo " - $test" + done +fi + +echo "" +echo "==============================================" +echo "Test completed at $(date)" +echo "==============================================" + +# Exit with error code if any tests failed +if [ $FAIL_COUNT -gt 0 ]; then + exit 1 +fi +exit 0 From c0c9ab80f85d1f826c74dbd5debeae0ac946cbe0 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Mon, 26 Jan 2026 22:21:00 +0000 Subject: [PATCH 2/7] feat: add support for bf8i4 and fp8i4 by calculating correct stride for packed data types --- .../38_block_scale_gemm/CMakeLists.txt | 40 ++--- .../gemm_bquant_quantgrouped_bf8i4.cpp | 2 +- .../gemm_bquant_quantgrouped_fp8i4.cpp | 2 +- .../run_gemm_quant_example.inc | 30 +--- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 49 ++++-- test_splitk_stress.sh | 142 +++++++++--------- 6 files changed, 133 insertions(+), 132 deletions(-) diff --git a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt index ff51122af6e..13cbcc8b558 100644 --- a/example/ck_tile/38_block_scale_gemm/CMakeLists.txt +++ b/example/ck_tile/38_block_scale_gemm/CMakeLists.txt @@ -13,28 +13,28 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") set(EXE_NAME tile_example_gemm_quant) add_executable(${EXE_NAME} gemm_quant.cpp - # gemm_abquant_quantgrouped.cpp - # gemm_aquant_quantgrouped.cpp - # gemm_aquant_quantgrouped_preshufflequant.cpp - # gemm_bquant_quantgrouped_bf8i4.cpp - # gemm_bquant_quantgrouped_fp8i4.cpp - # gemm_bquant_quantgrouped_bf16mxfp4.cpp + gemm_abquant_quantgrouped.cpp + gemm_aquant_quantgrouped.cpp + gemm_aquant_quantgrouped_preshufflequant.cpp + gemm_bquant_quantgrouped_bf8i4.cpp + gemm_bquant_quantgrouped_fp8i4.cpp + gemm_bquant_quantgrouped_bf16mxfp4.cpp gemm_bquant_quantgrouped_bf8.cpp gemm_bquant_quantgrouped_fp8.cpp - # gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp - # gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp - # gemm_bquant_quantgrouped_preshuffleb_bf8.cpp - # gemm_bquant_quantgrouped_preshuffleb_fp8.cpp - # gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp - # gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp - # gemm_bquant_quantgrouped_preshufflequant_bf8.cpp - # gemm_bquant_quantgrouped_preshufflequant_fp8.cpp - # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp - # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp - # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp - # gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp - # gemm_quant_rowcol.cpp - # gemm_quant_tensor.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_fp8.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshufflequant_fp8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8i4.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_bf8.cpp + gemm_bquant_quantgrouped_preshuffleb_preshufflequant_fp8.cpp + gemm_quant_rowcol.cpp + gemm_quant_tensor.cpp ) target_compile_options(${EXE_NAME} PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp index d2b95d32633..a93fe15a1ba 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_bf8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp index 6576b22c038..ed18cd8890f 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_bquant_quantgrouped_fp8i4.cpp @@ -4,7 +4,7 @@ #include "run_gemm_quant_example.inc" template -using GemmConfig = GemmConfigQuantPrefill; +using GemmConfig = GemmConfigQuantDecode; #define RUN_GEMM_EXAMPLE_PREC_TYPE \ run_gemm_example_prec_type, \ diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index e14685cf791..077e241d13e 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -215,7 +215,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str const dim3 blocks = Kernel::BlockSize(); // Split-K validation is handled by Kernel::IsSupportedArgument - // BQuantGrouped without preshuffle supports split-K + // Split-K is only supported for BQuantGrouped without preshuffle if(!Kernel::IsSupportedArgument(kargs)) { throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); @@ -567,7 +567,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } else { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(b_k_n); ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *bq_tensor_ptr); @@ -651,33 +650,6 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser, } } } - else if(init_method ==3) - { - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped) - { - ck_tile::FillConstant{static_cast(0x38)}(a_m_k); - ck_tile::FillConstant{static_cast(0x22)}(b_k_n); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - else - { - ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(*aq_tensor_ptr); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillConstant{static_cast(0.5f)}(*bq_tensor_ptr); - } - } - } else { a_m_k.SetZero(); diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index dee4553956f..74dc41a8219 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -380,9 +380,18 @@ struct QuantGemmKernel __device__ SplitKBatchOffset(const QuantGemmKernelArgs& kargs, const std::size_t k_id = blockIdx.z) { - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1); - const index_t KRead = amd_wave_read_first_lane((kargs.K + K_t - 1) / K_t * K1); + constexpr auto K1 = + GemmPipeline::BlockGemmShape::WarpTile::at(I2); // smallest unit of K work per block + const index_t K_t = amd_wave_read_first_lane( + kargs.k_batch * K1); // amount of K elements consumed if every split-K batch + // performs exactly one "unit" (K1) + const index_t KRead = amd_wave_read_first_lane( + (kargs.K + K_t - 1) / K_t * K1); // total k elements to be read in this batch + // offset not necessarily = KRead, because B can have packed elements (e.g. fp8i4) + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + const index_t b_k_offset_elements = + amd_wave_read_first_lane(k_id * KRead / BPackedSize); if constexpr(std::is_same_v) { @@ -395,11 +404,11 @@ struct QuantGemmKernel if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead * kargs.stride_B); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements * kargs.stride_B); } else if constexpr(std::is_same_v) { - b_k_split_offset = amd_wave_read_first_lane(k_id * KRead); + b_k_split_offset = amd_wave_read_first_lane(b_k_offset_elements); } if(k_id < static_cast(kargs.k_batch - 1)) @@ -1102,20 +1111,32 @@ struct QuantGemmKernel } else { - // For BQuantGrouped split-K, the K split must be aligned with quantization groups. - // This is because the pipeline applies BQ scales based on the relative K index - // within each batch, not the absolute K index. When a batch starts in the middle - // of a quantization group, the scale selection would be incorrect. - // - // To support misaligned splits, the pipeline would need to be modified to accept - // a K offset parameter and use (k_offset + local_k) / QuantGroupSize::kK for - // scale selection instead of local_k / QuantGroupSize::kK. using QuantGroupSize = remove_cvref_t; constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); const index_t K_t = kargs.k_batch * K1; const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + constexpr index_t BPackedSize = + ck_tile::numeric_traits>::PackedSize; + + // Constraint 1: KRead must align with B packing (pk_int4_t packs 2 elements/byte) + // For packed types like pk_int4_t, two K elements are stored in one byte. + // Split-K advances the B pointer by (KRead / BPackedSize) bytes per batch. + // If KRead is odd, this division produces a fractional byte offset, which is + // impossible - we cannot start reading from the middle of a packed byte. + if(KRead % BPackedSize != 0) + { + if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) + { + CK_TILE_ERROR("KRead must be a multiple of B packed size for split-K!"); + } + return false; + } - // KRead must be a multiple of QuantGroupSize::kK to ensure proper BQ alignment + // Constraint 2: KRead must align with quantization group boundaries. + // Each split-K batch reads KRead consecutive K elements. If KRead is not + // a multiple of QuantGroupSize::kK, the batch will span partial quantization + // groups. Since the pipeline indexes BQ scales using (local_k / kK) without + // knowledge of the global K offset, it would apply scales from the wrong groups. if(KRead % QuantGroupSize::kK != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) diff --git a/test_splitk_stress.sh b/test_splitk_stress.sh index 7613db26c90..2b8d902a6d7 100755 --- a/test_splitk_stress.sh +++ b/test_splitk_stress.sh @@ -3,8 +3,8 @@ # SPDX-License-Identifier: MIT # Stress test script for Split-K BQuant GEMM -# Tests only valid configurations where KRead % 128 == 0 -# Tests both fp8 and bf8 precisions +# Modified to test misaligned split-K configurations +# Tests fp8, bf8, fp8i4, and bf8i4 precisions BINARY="./build/bin/tile_example_gemm_quant" QUANT_MODE="bquant" @@ -18,42 +18,45 @@ declare -a FAILED_TESTS=() declare -a SKIPPED_TESTS=() echo "==============================================" -echo "Split-K BQuant GEMM Stress Test" +echo "Split-K BQuant GEMM Stress Test (Misaligned Support)" echo "==============================================" echo "Binary: $BINARY" echo "Quant Mode: $QUANT_MODE" echo "Init: $INIT (random)" -echo "Testing precisions: fp8, bf8" +echo "Testing precisions: fp8, bf8, fp8i4, bf8i4" echo "==============================================" echo "" -# M values to test +# More thorough set of values for edge-case coverage M_VALUES=(16 32 64 128 256) - -# N values to test N_VALUES=(64 128 256) -# Valid (K, split_k) combinations where KRead % 128 == 0 +# Layouts to test (A x B) +# BQuant in 38_block_scale_gemm only supports A=R, B=C +A_LAYOUT_VALUES=("R") +B_LAYOUT_VALUES=("C") + +# Test cases including aligned and misaligned ones # Format: "K:split_k1,split_k2,..." -declare -A VALID_K_SPLITS -VALID_K_SPLITS[256]="1,2" -VALID_K_SPLITS[384]="1,3" -VALID_K_SPLITS[512]="1,2,4,5" -VALID_K_SPLITS[640]="1,5,6" -VALID_K_SPLITS[768]="1,2,3,6,7" -VALID_K_SPLITS[896]="1,7,8" -VALID_K_SPLITS[1024]="1,2,4,8" -VALID_K_SPLITS[1152]="1,3,5" -VALID_K_SPLITS[1280]="1,2,5" -VALID_K_SPLITS[1536]="1,2,3,4,6" -VALID_K_SPLITS[1792]="1,2,5,7" -VALID_K_SPLITS[2048]="1,2,4,8" -VALID_K_SPLITS[2560]="1,2,4,5,7" -VALID_K_SPLITS[3072]="1,2,3,4,5,6,8" -VALID_K_SPLITS[4096]="1,2,4,8" +declare -A TEST_K_SPLITS +# Aligned cases (KRead % 128 == 0) +TEST_K_SPLITS[128]="1" +TEST_K_SPLITS[256]="1,2" +TEST_K_SPLITS[384]="1,3" +TEST_K_SPLITS[512]="1,2,4" +TEST_K_SPLITS[640]="1,5" +TEST_K_SPLITS[768]="1,2,3,6" +TEST_K_SPLITS[896]="1,7" +TEST_K_SPLITS[1024]="1,2,4,8" +TEST_K_SPLITS[1536]="1,2,3,4,6" +TEST_K_SPLITS[2048]="1,2,4,8" +# Misaligned cases (expected to be skipped by IsSupportedArgument) +TEST_K_SPLITS[320]="2,5" +TEST_K_SPLITS[448]="2,7" +TEST_K_SPLITS[960]="3,5,6" # Precisions to test -PREC_VALUES=("fp8" "bf8") +PREC_VALUES=("fp8" "bf8" "fp8i4" "bf8i4") TOTAL_TESTS=0 PASS_COUNT=0 @@ -67,48 +70,53 @@ for PREC in "${PREC_VALUES[@]}"; do echo "############################################" echo "" - for M in "${M_VALUES[@]}"; do - for N in "${N_VALUES[@]}"; do - for K in "${!VALID_K_SPLITS[@]}"; do - # Parse valid split_k values for this K - IFS=',' read -ra SPLIT_K_ARRAY <<< "${VALID_K_SPLITS[$K]}" - - for SPLIT_K in "${SPLIT_K_ARRAY[@]}"; do - ((TOTAL_TESTS++)) - - echo "----------------------------------------------" - echo "Test #$TOTAL_TESTS: prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K" - echo "----------------------------------------------" - - OUTPUT=$($BINARY -quant_mode=$QUANT_MODE -repeat=$REPEAT -warmup=$WARMUP \ - -prec=$PREC -split_k=$SPLIT_K -m=$M -n=$N -init=$INIT -k=$K 2>&1) - - # Print kernel output (grid size and verification result) - echo "$OUTPUT" | grep -E "(grid:|verification)" | head -2 - - # Check result - if echo "$OUTPUT" | grep -q "verification result is:correct"; then - echo "Result: PASS" - ((PASS_COUNT++)) - PASSED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") - elif echo "$OUTPUT" | grep -q "verification result is:fail"; then - echo "Result: FAIL (numerical error)" - ((FAIL_COUNT++)) - FAILED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") - # Show error details - echo "$OUTPUT" | grep -E "max err:|wrong values" | head -2 - elif echo "$OUTPUT" | grep -q "not supported\|Skipping\|Arguments not supported"; then - echo "Result: SKIPPED (configuration not supported)" - ((SKIP_COUNT++)) - SKIPPED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") - ((TOTAL_TESTS--)) - else - echo "Result: FAIL (unknown error)" - ((FAIL_COUNT++)) - FAILED_TESTS+=("prec=$PREC M=$M N=$N K=$K split_k=$SPLIT_K") - echo "$OUTPUT" | tail -5 - fi - echo "" + for A_LAYOUT in "${A_LAYOUT_VALUES[@]}"; do + for B_LAYOUT in "${B_LAYOUT_VALUES[@]}"; do + for M in "${M_VALUES[@]}"; do + for N in "${N_VALUES[@]}"; do + for K in "${!TEST_K_SPLITS[@]}"; do + # Parse split_k values for this K + IFS=',' read -ra SPLIT_K_ARRAY <<< "${TEST_K_SPLITS[$K]}" + + for SPLIT_K in "${SPLIT_K_ARRAY[@]}"; do + ((TOTAL_TESTS++)) + + echo "----------------------------------------------" + echo "Test #$TOTAL_TESTS: prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K" + echo "----------------------------------------------" + + OUTPUT=$($BINARY -quant_mode=$QUANT_MODE -repeat=$REPEAT -warmup=$WARMUP \ + -prec=$PREC -split_k=$SPLIT_K -m=$M -n=$N -init=$INIT -k=$K \ + -a_layout=$A_LAYOUT -b_layout=$B_LAYOUT 2>&1) + + # Print kernel output (grid size and verification result) + echo "$OUTPUT" | grep -E "(grid:|verification)" | head -2 + + # Check result + if echo "$OUTPUT" | grep -q "verification result is:correct"; then + echo "Result: PASS" + ((PASS_COUNT++)) + PASSED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") + elif echo "$OUTPUT" | grep -q "verification result is:fail"; then + echo "Result: FAIL (numerical error)" + ((FAIL_COUNT++)) + FAILED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") + # Show error details + echo "$OUTPUT" | grep -E "max err:|wrong values" | head -2 + elif echo "$OUTPUT" | grep -q "not supported\|Skipping\|Arguments not supported"; then + echo "Result: SKIPPED (configuration not supported)" + ((SKIP_COUNT++)) + SKIPPED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") + ((TOTAL_TESTS--)) + else + echo "Result: FAIL (unknown error)" + ((FAIL_COUNT++)) + FAILED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") + echo "$OUTPUT" | tail -5 + fi + echo "" + done + done done done done From 22cdf55b4a50237fe99690ec68425aff72b06d92 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Mon, 26 Jan 2026 23:06:07 +0000 Subject: [PATCH 3/7] chore: remove temporary test script --- test_splitk_stress.sh | 165 ------------------------------------------ 1 file changed, 165 deletions(-) delete mode 100755 test_splitk_stress.sh diff --git a/test_splitk_stress.sh b/test_splitk_stress.sh deleted file mode 100755 index 2b8d902a6d7..00000000000 --- a/test_splitk_stress.sh +++ /dev/null @@ -1,165 +0,0 @@ -#!/bin/bash -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -# Stress test script for Split-K BQuant GEMM -# Modified to test misaligned split-K configurations -# Tests fp8, bf8, fp8i4, and bf8i4 precisions - -BINARY="./build/bin/tile_example_gemm_quant" -QUANT_MODE="bquant" -INIT=0 -REPEAT=1 -WARMUP=0 - -# Arrays to track results -declare -a PASSED_TESTS=() -declare -a FAILED_TESTS=() -declare -a SKIPPED_TESTS=() - -echo "==============================================" -echo "Split-K BQuant GEMM Stress Test (Misaligned Support)" -echo "==============================================" -echo "Binary: $BINARY" -echo "Quant Mode: $QUANT_MODE" -echo "Init: $INIT (random)" -echo "Testing precisions: fp8, bf8, fp8i4, bf8i4" -echo "==============================================" -echo "" - -# More thorough set of values for edge-case coverage -M_VALUES=(16 32 64 128 256) -N_VALUES=(64 128 256) - -# Layouts to test (A x B) -# BQuant in 38_block_scale_gemm only supports A=R, B=C -A_LAYOUT_VALUES=("R") -B_LAYOUT_VALUES=("C") - -# Test cases including aligned and misaligned ones -# Format: "K:split_k1,split_k2,..." -declare -A TEST_K_SPLITS -# Aligned cases (KRead % 128 == 0) -TEST_K_SPLITS[128]="1" -TEST_K_SPLITS[256]="1,2" -TEST_K_SPLITS[384]="1,3" -TEST_K_SPLITS[512]="1,2,4" -TEST_K_SPLITS[640]="1,5" -TEST_K_SPLITS[768]="1,2,3,6" -TEST_K_SPLITS[896]="1,7" -TEST_K_SPLITS[1024]="1,2,4,8" -TEST_K_SPLITS[1536]="1,2,3,4,6" -TEST_K_SPLITS[2048]="1,2,4,8" -# Misaligned cases (expected to be skipped by IsSupportedArgument) -TEST_K_SPLITS[320]="2,5" -TEST_K_SPLITS[448]="2,7" -TEST_K_SPLITS[960]="3,5,6" - -# Precisions to test -PREC_VALUES=("fp8" "bf8" "fp8i4" "bf8i4") - -TOTAL_TESTS=0 -PASS_COUNT=0 -FAIL_COUNT=0 -SKIP_COUNT=0 - -for PREC in "${PREC_VALUES[@]}"; do - echo "" - echo "############################################" - echo "Testing precision: $PREC" - echo "############################################" - echo "" - - for A_LAYOUT in "${A_LAYOUT_VALUES[@]}"; do - for B_LAYOUT in "${B_LAYOUT_VALUES[@]}"; do - for M in "${M_VALUES[@]}"; do - for N in "${N_VALUES[@]}"; do - for K in "${!TEST_K_SPLITS[@]}"; do - # Parse split_k values for this K - IFS=',' read -ra SPLIT_K_ARRAY <<< "${TEST_K_SPLITS[$K]}" - - for SPLIT_K in "${SPLIT_K_ARRAY[@]}"; do - ((TOTAL_TESTS++)) - - echo "----------------------------------------------" - echo "Test #$TOTAL_TESTS: prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K" - echo "----------------------------------------------" - - OUTPUT=$($BINARY -quant_mode=$QUANT_MODE -repeat=$REPEAT -warmup=$WARMUP \ - -prec=$PREC -split_k=$SPLIT_K -m=$M -n=$N -init=$INIT -k=$K \ - -a_layout=$A_LAYOUT -b_layout=$B_LAYOUT 2>&1) - - # Print kernel output (grid size and verification result) - echo "$OUTPUT" | grep -E "(grid:|verification)" | head -2 - - # Check result - if echo "$OUTPUT" | grep -q "verification result is:correct"; then - echo "Result: PASS" - ((PASS_COUNT++)) - PASSED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") - elif echo "$OUTPUT" | grep -q "verification result is:fail"; then - echo "Result: FAIL (numerical error)" - ((FAIL_COUNT++)) - FAILED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") - # Show error details - echo "$OUTPUT" | grep -E "max err:|wrong values" | head -2 - elif echo "$OUTPUT" | grep -q "not supported\|Skipping\|Arguments not supported"; then - echo "Result: SKIPPED (configuration not supported)" - ((SKIP_COUNT++)) - SKIPPED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") - ((TOTAL_TESTS--)) - else - echo "Result: FAIL (unknown error)" - ((FAIL_COUNT++)) - FAILED_TESTS+=("prec=$PREC A=$A_LAYOUT B=$B_LAYOUT M=$M N=$N K=$K split_k=$SPLIT_K") - echo "$OUTPUT" | tail -5 - fi - echo "" - done - done - done - done - done - done -done - -echo "" -echo "==============================================" -echo " SUMMARY" -echo "==============================================" -echo "" -echo "Total Tests Run: $TOTAL_TESTS" -echo "Passed: $PASS_COUNT" -echo "Failed: $FAIL_COUNT" -echo "Skipped: $SKIP_COUNT" -echo "" - -if [ $FAIL_COUNT -eq 0 ]; then - echo "✓ ALL TESTS PASSED!" -else - echo "✗ SOME TESTS FAILED!" - echo "" - echo "Failed test cases:" - for test in "${FAILED_TESTS[@]}"; do - echo " - $test" - done -fi - -if [ $SKIP_COUNT -gt 0 ]; then - echo "" - echo "Skipped test cases (not supported):" - for test in "${SKIPPED_TESTS[@]}"; do - echo " - $test" - done -fi - -echo "" -echo "==============================================" -echo "Test completed at $(date)" -echo "==============================================" - -# Exit with error code if any tests failed -if [ $FAIL_COUNT -gt 0 ]; then - exit 1 -fi -exit 0 From 650fa1e034a628f732e102dd182f55e1c4d59a07 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 27 Jan 2026 19:29:10 +0000 Subject: [PATCH 4/7] fix: incorrect tile window length for splitted bq tensor window --- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 33 ++++++++++--------- .../kernel/grouped_gemm_quant_kernel.hpp | 8 ++--- 2 files changed, 21 insertions(+), 20 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 74dc41a8219..c7035cb41dd 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -428,9 +428,8 @@ struct QuantGemmKernel using QuantGroupSize = remove_cvref_t; // Compute the K offset for this batch (in terms of K elements) const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); - // Convert K offset to BQ group offset - const index_t bq_group_offset = - amd_wave_read_first_lane(k_offset / QuantGroupSize::kK); + // Convert K offset to BQ group offset (logical offset in K/kK dimension) + bq_group_offset = amd_wave_read_first_lane(k_offset / QuantGroupSize::kK); // BQ tensor layout: // RowMajor: [K/kK, N/kN] with stride [N/kN, 1] @@ -452,13 +451,15 @@ struct QuantGemmKernel } else { + bq_group_offset = 0; bq_k_split_offset = 0; } } index_t a_k_split_offset; index_t b_k_split_offset; - index_t bq_k_split_offset; + index_t bq_group_offset; // Logical offset in K-groups (K/kK dimension) + index_t bq_k_split_offset; // Memory pointer offset (accounting for layout/stride) index_t splitted_k; }; @@ -850,13 +851,13 @@ struct QuantGemmKernel CK_TILE_DEVICE static auto MakeBQBlockWindow(const BQDataType* bq_ptr, const QuantGemmKernelArgs& kargs, + const index_t bq_group_offset, const index_t i_m, const index_t i_n) { // Step 1: Create tensor view for BQ - // Note: For split-K, the bq_ptr is already offset by bq_k_split_offset. - // The tensor view should use kargs.QK_B (full K-groups) as the dimension - // because the view needs to see all remaining K-groups from the offset position. + // Note: For split-K, the bq_ptr is already offset by bq_k_split_offset (pointer offset). + // The dimension should use the remaining K-groups from this offset position. const auto& bq_tensor_view = [&]() { if constexpr(kQuantType == QuantType::RowColQuant) { @@ -902,9 +903,9 @@ struct QuantGemmKernel { return make_naive_tensor_view( bq_ptr, - make_tuple(kargs.QK_B, - integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), + make_tuple(kargs.QK_B - bq_group_offset, + integer_divide_ceil(kargs.N, QuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), number{}, number<1>{}); } @@ -912,8 +913,8 @@ struct QuantGemmKernel { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), - kargs.QK_B), + make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + kargs.QK_B - bq_group_offset), make_tuple(kargs.QK_B, 1), number{}, number<1>{}); @@ -1311,10 +1312,10 @@ struct QuantGemmKernel const auto& b_block_window = MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - // Note: BQ tensor view uses full dimensions (not splitted_qk_b) because - // the split-K offset is already applied to bq_ptr. The tensor view needs - // to see the full remaining K-groups from the offset position. - const auto& bq_block_window = MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + // Note: Pass bq_group_offset so the tensor view dimension reflects + // the remaining K-groups from the split-K offset position. + const auto& bq_block_window = MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); diff --git a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index c9e725f5fde..8b77b01e2fe 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -387,8 +387,8 @@ struct QuantGroupedGemmKernel Base::MakeABlockWindow(a_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_m); const auto& b_block_window = Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); const index_t num_loop = __builtin_amdgcn_readfirstlane( TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k)); @@ -453,8 +453,8 @@ struct QuantGroupedGemmKernel Base::MakeBBlockWindow(b_ptr, kargs, splitk_batch_offset.splitted_k, block_idx_n); const auto& aq_block_window = Base::MakeAQBlockWindow(aq_ptr, kargs, block_idx_m, block_idx_n); - const auto& bq_block_window = - Base::MakeBQBlockWindow(bq_ptr, kargs, block_idx_m, block_idx_n); + const auto& bq_block_window = Base::MakeBQBlockWindow( + bq_ptr, kargs, splitk_batch_offset.bq_group_offset, block_idx_m, block_idx_n); // Get hot-loop and tail configuration const index_t num_loop = __builtin_amdgcn_readfirstlane( From 23ac2365d4439308bf006536e22ce07b18e77c6f Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 27 Jan 2026 19:36:51 +0000 Subject: [PATCH 5/7] chore: improve comments --- .../ops/gemm_quant/kernel/gemm_quant_kernel.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index c7035cb41dd..1418891ed1c 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -1119,11 +1119,11 @@ struct QuantGemmKernel constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; - // Constraint 1: KRead must align with B packing (pk_int4_t packs 2 elements/byte) - // For packed types like pk_int4_t, two K elements are stored in one byte. - // Split-K advances the B pointer by (KRead / BPackedSize) bytes per batch. - // If KRead is odd, this division produces a fractional byte offset, which is - // impossible - we cannot start reading from the middle of a packed byte. + // Constraint 1: KRead must align with B packing requirements. + // For packed data types, multiple K elements are stored in each storage unit. + // Split-K advances the B pointer by (KRead / BPackedSize) storage units per batch. + // If KRead is not divisible by BPackedSize, this division produces a fractional + // offset, making it impossible to start reading from a valid storage unit boundary. if(KRead % BPackedSize != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -1136,8 +1136,8 @@ struct QuantGemmKernel // Constraint 2: KRead must align with quantization group boundaries. // Each split-K batch reads KRead consecutive K elements. If KRead is not // a multiple of QuantGroupSize::kK, the batch will span partial quantization - // groups. Since the pipeline indexes BQ scales using (local_k / kK) without - // knowledge of the global K offset, it would apply scales from the wrong groups. + // groups, requiring split access to a quantization scale. This violates the + // atomic processing requirement where each batch must work with complete groups. if(KRead % QuantGroupSize::kK != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) From 6498fc92de282304893a0e7efc407ce4136d4509 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Tue, 27 Jan 2026 21:23:51 +0000 Subject: [PATCH 6/7] test: add unit tests to cover bquant splitk functionality --- test/ck_tile/gemm_block_scale/CMakeLists.txt | 11 ++++ .../test_gemm_quant_bquant_splitk_decode.cpp | 61 ++++++++++++++++++ .../test_gemm_quant_bquant_splitk_prefill.cpp | 64 +++++++++++++++++++ .../test_gemm_quant_fixtures.hpp | 16 +++-- 4 files changed, 147 insertions(+), 5 deletions(-) create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp create mode 100644 test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 9dd9670ff50..8dca80fa245 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -112,6 +112,17 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12") ) target_compile_options(test_tile_gemm_quant_bquant_transpose PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant split-K tests (no preshuffle) + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_decode + test_gemm_quant_bquant_splitk_decode.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_decode PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + + add_gtest_executable(test_tile_gemm_quant_bquant_splitk_prefill + test_gemm_quant_bquant_splitk_prefill.cpp + ) + target_compile_options(test_tile_gemm_quant_bquant_splitk_prefill PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) + # BQuant tests (with PreshuffleB) - split into 5 files add_gtest_executable(test_tile_gemm_quant_bquant_preshuffle_decode_1d test_gemm_quant_bquant_preshuffle_decode_1d.cpp diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp new file mode 100644 index 00000000000..ea1a8a1fbbd --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_decode.cpp @@ -0,0 +1,61 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Decode shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKDecodeTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Decode +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKDecodeTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + this->run_test_with_validation(32, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + this->run_test_with_validation(32, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + this->run_test_with_validation(32, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=2560 for split_k=5: 2560/5=512=4×128 ✓ + // Also K must be divisible by K_Tile(256)*split_k(5)=1280 + this->run_test_with_validation(32, 128, 2560, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp new file mode 100644 index 00000000000..f4f93dbbb65 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_bquant_splitk_prefill.cpp @@ -0,0 +1,64 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using BQuantGrouped = std::integral_constant; +using GroupSize128 = ck_tile::QuantGroupShape>; + +// Type combinations for BQuant split-K tests - Prefill shape, GroupSize 128 +// Tuple format: +// clang-format off +using BQuantSplitKPrefillTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// Test suite for BQuant split-K Prefill +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantSplitKPrefillTypes); + +// BQuant split-K tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK2Test) +{ + // K=1024 for split_k=2: 1024/2=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(2)=256 + this->run_test_with_validation(128, 128, 1024, 2); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK3Test) +{ + // K=3072 for split_k=3: 3072/3=1024=8×128 ✓ + // K must be divisible by K_Tile(128)*split_k(3)=384 + this->run_test_with_validation(128, 128, 3072, 3); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK4Test) +{ + // K=2048 for split_k=4: 2048/4=512=4×128 ✓ + // K must be divisible by K_Tile(128)*split_k(4)=512 + this->run_test_with_validation(128, 128, 2048, 4); +} + +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedSplitK5Test) +{ + // K=1920 for split_k=5: 1920/5=384=3×128 ✓ + // K must be divisible by K_Tile(128)*split_k(5)=640 + this->run_test_with_validation(128, 128, 1920, 5); +} diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp index 9683fa98aa3..1181e56dc3f 100644 --- a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -655,7 +655,10 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase b_k_n_dev = b_k_n; @@ -746,12 +752,12 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBasetemplate calculate_rtol_atol( - K, 1, max_accumulated_value); + K, k_batch, max_accumulated_value); // Validate results bool pass = ck_tile::check_err(c_m_n_dev_result, @@ -806,7 +812,7 @@ class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase{})); EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N - << ", K=" << K; + << ", K=" << K << ", k_batch=" << k_batch; if(!pass) { From 08998da360dde47c04747c9ec723714643c40575 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 29 Jan 2026 13:26:26 +0000 Subject: [PATCH 7/7] fix: conflict resolution by renaming variables --- .../gemm_quant/kernel/gemm_quant_kernel.hpp | 35 ++++++++++--------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 1418891ed1c..db86fdbeac8 100644 --- a/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -422,14 +422,14 @@ struct QuantGemmKernel // Compute BQ offset for BQuantGrouped mode (non-preshuffle only) // Note: With the alignment validation in IsSupportedArgument, KRead is always - // a multiple of QuantGroupSize::kK, so bq_k_split_offset will be correctly aligned. - if constexpr(kQuantType == QuantType::BQuantGrouped && !PreshuffleQuant) + // a multiple of BQuantGroupSize::kK, so bq_k_split_offset will be correctly aligned. + if constexpr(kQuantType == QuantType::BQuantGrouped && !BPreshuffleQuant) { - using QuantGroupSize = remove_cvref_t; + using BQuantGroupSize = remove_cvref_t; // Compute the K offset for this batch (in terms of K elements) const index_t k_offset = amd_wave_read_first_lane(k_id * KRead); // Convert K offset to BQ group offset (logical offset in K/kK dimension) - bq_group_offset = amd_wave_read_first_lane(k_offset / QuantGroupSize::kK); + bq_group_offset = amd_wave_read_first_lane(k_offset / BQuantGroupSize::kK); // BQ tensor layout: // RowMajor: [K/kK, N/kN] with stride [N/kN, 1] @@ -439,7 +439,7 @@ struct QuantGemmKernel // For RowMajor BQ, K is the row dimension // offset = bq_group_offset * stride_BQ const index_t stride_bq = - amd_wave_read_first_lane(integer_divide_ceil(kargs.N, QuantGroupSize::kN)); + amd_wave_read_first_lane(integer_divide_ceil(kargs.N, BQuantGroupSize::kN)); bq_k_split_offset = amd_wave_read_first_lane(bq_group_offset * stride_bq); } else if constexpr(std::is_same_v) @@ -899,13 +899,14 @@ struct QuantGemmKernel "ABQuantGrouped requires ColumnMajor BQ layout"); } + using BQuantGroupSize = remove_cvref_t; if constexpr(std::is_same_v) { return make_naive_tensor_view( bq_ptr, make_tuple(kargs.QK_B - bq_group_offset, - integer_divide_ceil(kargs.N, QuantGroupSize::kN)), - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), 1), + integer_divide_ceil(kargs.N, BQuantGroupSize::kN)), + make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), 1), number{}, number<1>{}); } @@ -913,7 +914,7 @@ struct QuantGemmKernel { return make_naive_tensor_view( bq_ptr, - make_tuple(integer_divide_ceil(kargs.N, QuantGroupSize::kN), + make_tuple(integer_divide_ceil(kargs.N, BQuantGroupSize::kN), kargs.QK_B - bq_group_offset), make_tuple(kargs.QK_B, 1), number{}, @@ -1100,7 +1101,7 @@ struct QuantGemmKernel if(kargs.k_batch != 1) { constexpr bool is_bquant_non_preshuffle = - (kQuantType == QuantType::BQuantGrouped) && !PreshuffleQuant; + (kQuantType == QuantType::BQuantGrouped) && !BPreshuffleQuant; if constexpr(!is_bquant_non_preshuffle) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) @@ -1112,10 +1113,10 @@ struct QuantGemmKernel } else { - using QuantGroupSize = remove_cvref_t; - constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); - const index_t K_t = kargs.k_batch * K1; - const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; + using BQuantGroupSize = remove_cvref_t; + constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(I2); + const index_t K_t = kargs.k_batch * K1; + const index_t KRead = (kargs.K + K_t - 1) / K_t * K1; constexpr index_t BPackedSize = ck_tile::numeric_traits>::PackedSize; @@ -1135,18 +1136,18 @@ struct QuantGemmKernel // Constraint 2: KRead must align with quantization group boundaries. // Each split-K batch reads KRead consecutive K elements. If KRead is not - // a multiple of QuantGroupSize::kK, the batch will span partial quantization + // a multiple of BQuantGroupSize::kK, the batch will span partial quantization // groups, requiring split access to a quantization scale. This violates the // atomic processing requirement where each batch must work with complete groups. - if(KRead % QuantGroupSize::kK != 0) + if(KRead % BQuantGroupSize::kK != 0) { if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING))) { CK_TILE_ERROR("Split-K batch size must be aligned with quantization group " "size! KRead=" + std::to_string(KRead) + - " is not divisible by QuantGroupSize::kK=" + - std::to_string(QuantGroupSize::kK)); + " is not divisible by BQuantGroupSize::kK=" + + std::to_string(BQuantGroupSize::kK)); } return false; }