Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
gemm_aquant_quantgrouped_preshufflequant.cpp
gemm_bquant_quantgrouped_bf8i4.cpp
gemm_bquant_quantgrouped_fp8i4.cpp
gemm_bquant_quantgrouped_bf16mxfp4.cpp
gemm_bquant_quantgrouped_mx_bf16fp4.cpp
gemm_bquant_quantgrouped_mx_bf16bf8.cpp
gemm_bquant_quantgrouped_mx_bf16bf16.cpp
gemm_bquant_quantgrouped_bf8.cpp
gemm_bquant_quantgrouped_fp8.cpp
gemm_bquant_quantgrouped_preshuffleb_bf8i4.cpp
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/38_block_scale_gemm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ args:
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0: No validation, 1: Validation on CPU, 2: Validation on GPU (default:1)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, or bf16fp4 (default for both AQuant and Bquant: fp8)
-prec Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, bf8i4, mxbf16bf16, mxbf16bf8 or mxbf16fp4 (default for both AQuant and Bquant: fp8)
-warmup Number of iterations before benchmarking the kernel (default:50)
-repeat Number of iterations to benchmark the kernel (default:1000)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "run_gemm_quant_example.inc"

template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::bf16_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::bf16_t,
ck_tile::e8m0_t>{});

lut[hash_multiple_strings(
{"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"mxbf16bf16", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
return 0;
}();
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT

#include "run_gemm_quant_example.inc"

using GemmConfig = GemmConfigMixedPrecision;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::bf8_t,
ck_tile::bf16_t,
ck_tile::e8m0_t>{});

lut[hash_multiple_strings(
{"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"mxbf16bf8", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
return 0;
}();
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,33 @@
template <typename T>
using GemmConfig = GemmConfigQuantPrefill<T>;

#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_raw_t>, \
TypeConfig, \
QuantGroupSize, \
#define RUN_GEMM_EXAMPLE_PREC_TYPE \
run_gemm_example_prec_type<GemmConfig<ck_tile::pk_fp4_t>, \
TypeConfig, \
QuantGroupSize, \
ck_tile::QuantType::BQuantGrouped>(arg_parser);

static auto _ = []() {
auto& lut = get_kernel_lut();
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t,
ck_tile::pk_fp4_t,
ck_tile::bf16_t,
ck_tile::pk_fp4_raw_t>{});
ck_tile::e8m0_t>{});

lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
{"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x32"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 32>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
{"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x64"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 64>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
};
lut[hash_multiple_strings(
{"bf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
{"mxbf16fp4", "bquant", "non-preshuffleb", "non-preshufflequant", "1x1x128"})] =
[](const ck_tile::ArgParser& arg_parser) {
using QuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
return RUN_GEMM_EXAMPLE_PREC_TYPE;
Expand Down
2 changes: 1 addition & 1 deletion example/ck_tile/38_block_scale_gemm/gemm_quant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ auto create_args(int argc, char* argv[])
.insert("prec",
"fp8",
"Data type. For AQuant: fp8, bf8, i4fp8, or i4bf8; for Bquant: fp8, bf8, fp8i4, "
"or bf8i4; for ABQuant: fp8, bf8, fp4")
" mxbf16bf16, mxbf16bf8, mxbf16fp4 or bf8i4; for ABQuant: fp8, bf8, fp4")
.insert("warmup", "50", "Number of iterations before benchmarking the kernel")
.insert("repeat", "1000", "Number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
Expand Down
20 changes: 19 additions & 1 deletion example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
const float max_accumulated_value)
{
using ComputeType = std::conditional_t<
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>,
std::is_same_v<BDataType, ck_tile::pk_fp4_t>,
ADataType,
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>>;
// Calculate thresholds
Expand Down Expand Up @@ -271,6 +271,24 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
static constexpr bool TransposeC = true;
};

// Used for A=16bit and B=8bit. The warp tile has KPack=16
// Matrix A: Vectorsize = 8, KPack=16 -> LDS read/write vectorsize = 8 (128bit)
// Matrix B: Vectorsize = 16, KPack=16 -> LDS read/write vectorsize = 16 (128bit)
struct GemmConfigMixedPrecision : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128;

static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 64;
};

template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
Expand Down
76 changes: 41 additions & 35 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto b_cast_policy =
std::is_same_v<typename TypeConfig::ADataType, typename TypeConfig::BDataType>
? ck_tile::CastPolicy::BeforeLDSWrite
: ck_tile::CastPolicy::AfterLDSRead;

// row-col and tensor quants use the regular pipeline, A/B/AB quants use their own
using PipelineProblem = std::conditional_t<
Expand Down Expand Up @@ -135,7 +139,8 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ComputeDataType,
GemmConfig::Scheduler,
has_hot_loop_v,
tail_number_v>,
tail_number_v,
b_cast_policy>,
ck_tile::GemmABQuantPipelineProblem<typename TypeConfig::ADataType,
typename TypeConfig::QDataType, // For AQ
typename TypeConfig::BDataType,
Expand All @@ -158,10 +163,9 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
using BQuantPipeline = std::conditional_t<
GemmConfig::PreshuffleB,
ck_tile::WPQuantBPipelineAgBgCrV2<PipelineProblem>,
std::conditional_t<
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>,
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;
std::conditional_t<std::is_same_v<typename TypeConfig::QDataType, ck_tile::e8m0_t>,
ck_tile::MxGemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;

using ABQuantPipeline =
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
Expand Down Expand Up @@ -243,11 +247,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::HostTensor<typename TypeConfig::ADataType> a_m(ck_tile::host_tensor_descriptor(
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
ck_tile::HostTensor<typename TypeConfig::BDataType> b_n(ck_tile::host_tensor_descriptor(
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t> ? args.K / 2
: args.K,
args.N,
args.stride_B,
is_row_major(BLayout{})));
args.K, args.N, args.stride_B, is_row_major(BLayout{})));

auto size_a_buffer = a_m.get_element_space_size_in_bytes();
auto size_b_buffer = b_n.get_element_space_size_in_bytes();
Expand Down Expand Up @@ -479,11 +479,7 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
int rotating_count = arg_parser.get_int("rotating_count");

stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_B = ck_tile::get_default_stride(
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
N,
stride_B,
is_row_major(b_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));

// Conditional stride calculation based on QuantMode
Expand Down Expand Up @@ -515,11 +511,8 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,

ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>) ? (K / 2) : K,
N,
stride_B,
is_row_major(b_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));

Expand Down Expand Up @@ -696,18 +689,31 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
else if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_t>)
{
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
ck_tile::FillUniformDistribution<BQDataType>{125.f, 130.f, fill_seed(gen)}(
*bq_tensor_ptr);
}
else
{
ck_tile::FillUniformDistribution<BDataType>{-2.0f, 3.0f, fill_seed(gen)}(b_k_n);
}

if constexpr(std::is_same_v<BQDataType, ck_tile::e8m0_t>)
{
auto gen_scales = [&](auto& scales, float range_min, float range_max) {
// e8m0_t is basically an exponent of float32
ck_tile::HostTensor<float> pow2(scales.get_lengths());
ck_tile::FillUniformDistributionIntegerValue<float>{
range_min, range_max, fill_seed(gen)}(pow2);
scales.ForEach([&](auto& self, const auto& i) {
self(i) = static_cast<BQDataType>(std::exp2(pow2(i)));
});
};
gen_scales(*bq_tensor_ptr, -2, 2);
}
else
{
ck_tile::FillUniformDistribution<BQDataType>{-2.0f, 2.0f, fill_seed(gen)}(
*bq_tensor_ptr);
}
Expand Down Expand Up @@ -1010,18 +1016,18 @@ int run_gemm_example_with_layouts(const ck_tile::ArgParser& arg_parser,
}
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
{
if constexpr(std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>)
ck_tile::reference_mxfp4gemm_quant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
BQuantGroupSize,
false>(
if constexpr(std::is_same_v<BQDataType, ck_tile::e8m0_t>)
ck_tile::reference_mx_gemm_bquant<ADataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
BQuantGroupSize,
false>(
a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref);
else
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BQDataType,
BDataType,
AccDataType,
CDataType,
Expand Down Expand Up @@ -1121,7 +1127,7 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)

if((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant ||
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_raw_t>) &&
std::is_same_v<typename TypeConfig::BDataType, ck_tile::pk_fp4_t>) &&
GemmConfig::PreshuffleB)
{
throw std::runtime_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct static_distributed_tensor
constexpr auto sliced_thread_tensor_desc =
make_naive_tensor_descriptor_packed(make_tuple(YSliceLengths...));

thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size()>
thread_buffer<DataType, sliced_thread_tensor_desc.get_element_space_size() / PackedSize>
sliced_thread_data;

static_ford<sequence<YSliceLengths...>>{}([&](auto idx) {
Expand Down
13 changes: 13 additions & 0 deletions include/ck_tile/core/utility/type_traits.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,19 @@ using is_known_at_compile_time = is_static<T>;
// , this helper will also return false, which is not good(?)
// do we need something like is_constexpr()?

#define DEFINE_STATIC_MEMBER_CHECKER(trait_name, member) \
template <typename, typename = void> \
struct trait_name : std::false_type \
{ \
}; \
\
template <typename T> \
struct trait_name<T, std::void_t<decltype(&T::member)>> : std::true_type \
{ \
};

DEFINE_STATIC_MEMBER_CHECKER(has_bcastpolicy, BCastPolicy);

// FIXME: do we need this anymore?
template <
typename PY,
Expand Down
Loading