Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
add_executable(${EXE_NAME}
gemm_quant.cpp
gemm_abquant_quantgrouped.cpp
gemm_abquant_quantgrouped_preshuffleb_preshufflequant.cpp
gemm_aquant_quantgrouped.cpp
gemm_aquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "38_block_scale_gemm/gemm_utils.hpp"
#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfigPreshuffleB_PreshuffleBQuant =
GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill<T>;

static auto _ = []() {
auto& lut = get_kernel_lut();
lut[hash_multiple_strings({"fp8",
"abquant",
"preshuffleb",
"preshufflequant",
"1x1x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
lut[hash_multiple_strings({"fp8",
"abquant",
"preshuffleb",
"preshufflequant",
"1x128x128"})] = [](const ck_tile::ArgParser& arg_parser) {
using AQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB_PreshuffleBQuant<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
ck_tile::QuantType::ABQuantGrouped>(arg_parser);
};
return 0;
}();
7 changes: 7 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,13 @@ struct GemmConfigPreshuffleB_ABQuant_Prefill : public GemmConfigPreshuffleB_BQua
static constexpr bool TransposeC = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleB_ABQuant_PreshuffleBQuant_Prefill
: public GemmConfigPreshuffleB_ABQuant_Prefill<PrecType>
{
static constexpr bool BPreshuffleQuant = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleB_ABQuant_Decode : public GemmConfigPreshuffleB_BQuant_Prefill<PrecType>
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,21 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
using Base::m_preload;

static constexpr index_t VectorLoadSize = Problem::VectorLoadSize;
static constexpr index_t KPerBlockAQ =
integer_divide_ceil(BlockGemmShape::kK, AQuantGroupSize::kK);
static constexpr index_t KPerBlockBQ =
integer_divide_ceil(BlockGemmShape::kK, BQuantGroupSize::kK);

static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;

static constexpr index_t NPerBlockBQ = (BQuantGroupSize::kN <= KPerBlock)
? integer_divide_ceil(NPerBlock, BQuantGroupSize::kN)
: 1;

static constexpr index_t KPerBlockAQ = integer_divide_ceil(KPerBlock, AQuantGroupSize::kK);
static constexpr index_t KPerBlockBQ = integer_divide_ceil(KPerBlock, BQuantGroupSize::kK);
static constexpr index_t QScalesPerBlockRow =
integer_divide_ceil(kKPerBlock, BQuantGroupSize::kK);

static constexpr bool BPreshuffleQuant = Problem::Traits::BPreshuffleQuant;

static constexpr index_t GetVectorSizeAQ()
{
return PipelinePolicy::template GetVectorSizeAQ<Problem>();
Expand Down Expand Up @@ -348,6 +357,17 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
bq_dram_block_window_tmp.get_window_origin(),
PipelinePolicy::template MakeBQDramTileDistribution<Problem>());

// BQ DRAM window step
using BQDramTileWindowStep = typename BQDramBlockWindowTmp::BottomTensorIndex;
const BQDramTileWindowStep bq_dram_tile_window_step =
(BPreshuffleQuant)
? make_array(((NPerBlockBQ <= BlockGemmShape::BlockWarps::at(number<1>{}))
? ck_tile::integer_divide_ceil(n, BQuantGroupSize::kN)
: ck_tile::integer_least_multiple(n, NPerBlock) /
BlockGemmShape::WarpTile::at(number<1>{})),
0)
: make_array(0, KPerBlockBQ);

// Prefetch A0
auto a_block_tile = load_tile(a_copy_dram_window);
// move A window to next k
Expand Down Expand Up @@ -383,7 +403,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
bq_block_tile = load_tile(bq_copy_dram_window);
// move BQ to tile 1
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step);
// Prefill A0
auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
store_tile(a_copy_lds_window_ping, a_block_tile_tmp);
Expand Down Expand Up @@ -453,7 +473,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile_2 = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile_2 = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step);

// Preload A(2i+1) ds_read
static_for<0, m_preload, 1>{}([&](auto loadIter) {
Expand Down Expand Up @@ -482,7 +502,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe
aq_block_tile = load_tile(aq_copy_dram_window);
move_tile_window(aq_copy_dram_window, {0, KPerBlockAQ});
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, KPerBlockBQ});
move_tile_window(bq_copy_dram_window, bq_dram_tile_window_step);

// Prefill A(2i+2)
a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile);
Expand Down
4 changes: 4 additions & 0 deletions test/ck_tile/gemm_block_scale/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

add_gtest_executable(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant
test_gemm_quant_abquant_preshuffle_preshuffleQuant.cpp
)
target_compile_options(test_tile_gemm_quant_abquant_preshuffle_preshuffleQuant PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})

add_gtest_executable(test_tile_gemm_quant_abquant_a4w4_base
test_gemm_quant_abquant_a4w4_base.cpp
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"

#include <gtest/gtest.h>
#include <memory>

#include "test_gemm_quant_fixtures.hpp"

// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using ABQuantGrouped =
std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::ABQuantGrouped>;
using GroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;

// 2d block sizes for BQuant
using GroupSize2D128N = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;

// Type combinations for ABQuant tests
// Tuple format: <ALayout, BLayout, CLayout, AQLayout, ADataType, BDataType, QDataType, CDataType,
// QuantType, GemmConfig, AQuantGroupSize, BQuantGroupSize, BQLayout>
// clang-format off
using ABQuantPreshuffleQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize, ColumnMajor>,
std::tuple<RowMajor, ColumnMajor, RowMajor, RowMajor, FP8, FP8, float, Half, ABQuantGrouped, GemmConfigPreshuffleBPreshuffleQuantPrefill, GroupSize, GroupSize2D128N, ColumnMajor>
>;
// clang-format on

// Test suite for ABQuant
TYPED_TEST_SUITE(TestCkTileGemmABQuant, ABQuantPreshuffleQuantTypes);

// AQuant tests
TYPED_TEST(TestCkTileGemmABQuant, ABQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
5 changes: 5 additions & 0 deletions test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ struct GemmConfigPreshuffleBPrefillTiledPermuteN : public GemmConfigPreshuffleBP
static constexpr bool TiledMMAPermuteN = N_Repeat % 2 == 0;
};

struct GemmConfigPreshuffleBPreshuffleQuantPrefill : public GemmConfigPreshuffleBPrefill
{
static constexpr bool BPreshuffleQuant = true;
};

struct GemmConfigPreshuffleBPreshuffleQuantDecode : public GemmConfigPreshuffleBDecode
{
static constexpr bool BPreshuffleQuant = true;
Expand Down