Skip to content
7 changes: 6 additions & 1 deletion example/ck_tile/38_block_scale_gemm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@ if(CK_USE_OCP_FP8)
endif()

list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -Wno-global-constructors) # use global constructors to add kernel instances
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -enable-noalias-to-md-conversion=1")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS "SHELL: -mllvm -greedy-reverse-local-assignment=1")

if(GPU_TARGETS MATCHES "gfx95")
list(APPEND EXAMPLE_GEMM_COMPILE_OPTIONS -DCK_TILE_EIGHTWARP_SUP)
endif()

if(GPU_TARGETS MATCHES "gfx94|gfx95|gfx12")
set(EXE_NAME tile_example_gemm_quant)
Expand Down
25 changes: 14 additions & 11 deletions example/ck_tile/38_block_scale_gemm/gemm_abquant_quantgrouped.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

#include "run_gemm_quant_example.inc"

#if defined(CK_TILE_EIGHTWARP_SUP)
template <typename T>
using GemmConfig = GemmConfigEightWarps<T>;
template <typename T>
using GemmConfigPrefill = GemmConfigPreshuffleBEightWarps<T>;
#else
template <typename T>
using GemmConfig = GemmConfigABQuantPrefill<T>;

template <typename T>
using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Prefill<T>;

// template <typename T>
// using GemmConfigPreshuffleB = GemmConfigPreshuffleB_ABQuant_Decode<T>;
using GemmConfigPrefill = GemmConfigPreshuffleB_ABQuant_Prefill<T>;
#endif

static auto _ = []() {
auto& lut = get_kernel_lut();
Expand All @@ -23,7 +26,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down Expand Up @@ -53,7 +56,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfig<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigABQuantPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down Expand Up @@ -83,7 +86,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -98,7 +101,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::fp8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::fp8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -113,7 +116,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 1, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand All @@ -128,7 +131,7 @@ static auto _ = []() {
using BQuantGroupSize = ck_tile::QuantGroupShape<ck_tile::sequence<1, 128, 128>>;
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_example_prec_type<GemmConfigPreshuffleB<ck_tile::bf8_t>,
return run_gemm_example_prec_type<GemmConfigPrefill<ck_tile::bf8_t>,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
Expand Down
23 changes: 23 additions & 0 deletions example/ck_tile/38_block_scale_gemm/gemm_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,29 @@ struct GemmConfigABQuantPrefill : public GemmConfigQuantPrefill<PrecType>
static constexpr bool TransposeC = true;
};

template <typename PrecType>
struct GemmConfigEightWarps : public GemmConfigABQuantPrefill<PrecType>
{
static constexpr ck_tile::index_t M_Warp = 4;
static constexpr ck_tile::index_t N_Warp = 2; // NWarps == 2 for ping-pong!
static constexpr ck_tile::index_t K_Warp = 1;

static constexpr ck_tile::index_t M_Tile = 192;
static constexpr ck_tile::index_t N_Tile = 128 * N_Warp;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType) * K_Warp;

static constexpr bool kPadK = false;
static constexpr bool TransposeC = true;
static constexpr int kBlockPerCu = 1;
};

template <typename PrecType>
struct GemmConfigPreshuffleBEightWarps : public GemmConfigEightWarps<PrecType>
{
static constexpr bool PreshuffleB = true;
static constexpr bool DoubleSmemBuffer = true;
};

template <typename PrecType>
struct GemmConfigPreshuffleBQuantPrefill : public GemmConfigQuantPrefill<PrecType>
{
Expand Down
66 changes: 42 additions & 24 deletions example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
{
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr bool transpose_c =
GemmConfig::TransposeC; // QuantMode == ck_tile::QuantType::ABQuantGrouped;
QuantMode == ck_tile::QuantType::ABQuantGrouped && BQuantGroupSize::kN == 128;
constexpr bool eight_warps =
BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8) &&
GemmConfig::K_Warp_Tile == 128;

using ComputeDataType = std::conditional_t<QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::RowColQuant,
typename TypeConfig::BDataType,
Expand Down Expand Up @@ -71,16 +76,20 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ComputeDataType>;

// Base pipeline selection based on quant mode and preshuffle settings
using BaseGemmPipeline = std::conditional_t<
GemmConfig::PreshuffleB == true,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped && GemmConfig::APreshuffleQuant == true,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>,
std::conditional_t<
QuantMode == ck_tile::QuantType::AQuantGrouped,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>>>>;
constexpr auto base_gemm_pipeline = []() {
if constexpr(eight_warps)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(GemmConfig::PreshuffleB)
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped &&
GemmConfig::APreshuffleQuant)
return ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>{};
else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
return ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>{};
else
return ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>{};
}();
using BaseGemmPipeline = std::decay_t<decltype(base_gemm_pipeline)>;

const ck_tile::index_t K_split = ck_tile::integer_least_multiple(args.K, GemmConfig::K_Tile);
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
Expand Down Expand Up @@ -160,10 +169,12 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
ck_tile::MxFp4GemmPipelineAgBgCrCompV3<PipelineProblem>,
ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;

using ABQuantPipeline =
using ABQuantPipeline = std::conditional_t<
eight_warps,
ck_tile::ABQuantGemmPipelineAgBgCrAsync<PipelineProblem>,
std::conditional_t<GemmConfig::DoubleSmemBuffer && GemmConfig::PreshuffleB,
ck_tile::WPABQuantBPipelineAgBgCrV2<PipelineProblem>,
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>;
ck_tile::ABQuantGemmPipelineAgBgCrCompV3<PipelineProblem>>>;

using GemmPipeline = std::conditional_t<
QuantMode == ck_tile::QuantType::RowColQuant ||
Expand Down Expand Up @@ -197,7 +208,7 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::N_Warp * GemmConfig::K_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
Expand Down Expand Up @@ -1104,20 +1115,27 @@ int run_gemm_example_prec_type(const ck_tile::ArgParser& arg_parser)
{
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");

if(a_layout == "R" && b_layout == "C")
{
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
if constexpr(QuantMode == ck_tile::QuantType::ABQuantGrouped &&
!GemmConfig::APreshuffleQuant && BQuantGroupSize::kN == 128 &&
(GemmConfig::M_Warp * GemmConfig::N_Warp * GemmConfig::K_Warp == 8))
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Col{}, Col{}, Col{}, Row{});
else
return run_gemm_example_with_layouts<GemmConfig,
TypeConfig,
AQuantGroupSize,
BQuantGroupSize,
QuantMode>(
arg_parser, Row{}, Row{}, Col{}, Col{}, Row{});
}

if constexpr((QuantMode == ck_tile::QuantType::AQuantGrouped ||
QuantMode == ck_tile::QuantType::ABQuantGrouped) &&
!GemmConfig::APreshuffleQuant && !GemmConfig::PreshuffleB)
if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped)
{
if(a_layout == "R" && b_layout == "R")
{
Expand Down
61 changes: 32 additions & 29 deletions include/ck_tile/core/arch/amd_buffer_addressing_builtins.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr,
}
return r;
}
CK_TILE_DEVICE __amdgpu_buffer_rsrc_t make_builtin_buffer_resource(const void* ptr,
uint32_t size = 0xffffffff)
{
return __builtin_amdgcn_make_buffer_rsrc(
const_cast<void*>(ptr), /*stride*/ 0, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD);
}

namespace impl {
// below type indicate the data type used for buffer load inline asm
Expand Down Expand Up @@ -1695,46 +1701,41 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = true>
bool oob_conditional_check = true,
index_t IMM = 0>
CK_TILE_DEVICE void amd_async_buffer_load(CK_TILE_LDS_ADDR T* smem,
int32x4_t src_wave_buffer_resource,
const __amdgpu_buffer_rsrc_t rsrc,
index_t src_thread_addr_offset,
index_t src_wave_addr_offset,
index_t src_immediate_addr_offset = 0,
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
index_t src_wave_addr_offset = 0,
number<IMM> /*src_immediate_addr_offset*/ = {},
index_t flag = 0,
bool_constant<oob_conditional_check> = {})
{
constexpr index_t bytes = sizeof(T) * N;

// Used to catch the cases when src_immediate_addr_offset is NOT 0.
// Remove this assert once other sizes are implemented.
assert(src_immediate_addr_offset == 0 &&
"wrong! not implemented src_immediate_addr_offset size, only 0 supported");
ignore = src_immediate_addr_offset;
static_assert(IMM < (1 << 12), "wrong! immediate offset too large");

#if defined(__gfx950__)
static_assert(bytes == 4 || bytes == 12 || bytes == 16,
"wrong! only support in dword, dwordx3, dwordx4");
src_wave_addr_offset = 0;
#else
static_assert(bytes == 4, "wrong! not implemented vector size");
#endif

// Set up v_offset:
index_t v_offset = src_thread_addr_offset;
if constexpr(oob_conditional_check)
v_offset = flag ? v_offset : src_wave_buffer_resource[2];
v_offset = flag ? v_offset : 0x7fffffff; // large offset to cause OOB access

#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
// Use C-style cast to change address space without dropping llvm noalias attribute
llvm_amdgcn_raw_buffer_load_lds(src_wave_buffer_resource,
(as3_uint32_ptr)(smem),
bytes,
v_offset,
src_wave_addr_offset,
/*src_immediate_addr_offset*/ 0,
static_cast<index_t>(coherence));
__builtin_amdgcn_raw_ptr_buffer_load_lds(rsrc,
smem,
bytes,
v_offset,
src_wave_addr_offset,
/*imm*/ IMM,
static_cast<index_t>(coherence));
#pragma clang diagnostic pop
}

Expand Down Expand Up @@ -2535,22 +2536,24 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
template <typename T,
index_t N,
amd_buffer_coherence_enum coherence = amd_buffer_coherence_enum::coherence_default,
bool oob_conditional_check = false>
bool oob_conditional_check = false,
typename linear_offset_t>
CK_TILE_DEVICE void amd_async_buffer_load_with_oob(CK_TILE_LDS_ADDR T* smem,
const int32x4_t src_wave_buffer_resource,
const __amdgpu_buffer_rsrc_t rsrc,
index_t src_thread_element_offset,
index_t src_linear_element_offset,
index_t src_wave_addr_offset,
linear_offset_t,
bool is_valid_element,
bool_constant<oob_conditional_check> = {})
{
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
index_t src_linear_addr_offset = src_linear_element_offset * sizeof(T);
index_t src_thread_addr_offset = src_thread_element_offset * sizeof(T);
constexpr index_t src_linear_addr_offset = static_cast<index_t>(linear_offset_t{}) * sizeof(T);

amd_async_buffer_load<T, N, coherence>(smem,
src_wave_buffer_resource,
rsrc,
src_thread_addr_offset,
0,
src_linear_addr_offset,
src_wave_addr_offset,
number<src_linear_addr_offset>{},
is_valid_element,
bool_constant<oob_conditional_check>{});
}
Expand Down
5 changes: 5 additions & 0 deletions include/ck_tile/core/arch/arch.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,11 @@ CK_TILE_DEVICE void s_waitcnt()
waitcnt_arg::from_lgkmcnt<lgkmcnt>());
#endif
}
template <index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
CK_TILE_DEVICE void s_waitcnt_lgkm()
{
s_waitcnt<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
}

template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
index_t expcnt = waitcnt_arg::kMaxExpCnt,
Expand Down
8 changes: 8 additions & 0 deletions include/ck_tile/core/container/sequence.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1068,6 +1068,14 @@ CK_TILE_HOST_DEVICE constexpr auto to_sequence(tuple<number<Is>...>)
return sequence<Is...>{};
}

template <index_t... Is>
using number_tuple = tuple<number<Is>...>;
template <index_t... Is>
CK_TILE_HOST_DEVICE constexpr auto to_number_tuple(sequence<Is...> = {})
{
return number_tuple<Is...>{};
}

namespace detail {
template <index_t h_idx, typename SeqSortedSamples, typename SeqRange>
struct sorted_sequence_histogram;
Expand Down
Loading