From de69e8188ace7b34ee771f5b5beffe371937833a Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Tue, 9 Dec 2025 09:36:12 +0000 Subject: [PATCH 01/16] Migrated last commit and some additional changes from origin/172-implement-device_grouped_gemm_fastgelu-for-rdna4 --- .../device_grouped_gemm_wmma_fixed_nk.hpp | 1007 +++++++++++++++++ ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 240 ++++ .../gpu/grouped_gemm_fixed_nk.hpp | 28 + .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 6 +- ...ed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 73 ++ ...ed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 76 ++ 6 files changed, 1428 insertions(+), 2 deletions(-) create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp create mode 100644 library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp new file mode 100644 index 00000000000..7d7efeb7e47 --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -0,0 +1,1007 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_grouped_gemm_Wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + uint32_t* barrier_count, + const index_t barrier_size_grp, + const index_t group_count, + const index_t grid_size_grp, + const index_t KBatch, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) +{ +#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) + if constexpr(GridwiseGemm::template IsValidCompilationParameter()) + { + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + + const index_t block_id = get_block_1d_id(); + + const auto gemm_desc_ptr = reinterpret_cast( + cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + + if(group_id >= group_count) + return; + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideA = gemm_desc_ptr[group_id].StrideA; + const auto StrideB = gemm_desc_ptr[group_id].StrideB; + const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + + const index_t BlockStart = group_id * grid_size_grp; + + const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; + + const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - BlockStart; + + const index_t mn_blocks = local_grid_size / KBatch; + + while(id_local < local_grid_size) + { + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); + + if constexpr(Zeroing) + { + auto barrier_count_finished = + barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; + GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + barrier_count_finished, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + else + { + + GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, + gemm_desc_ptr[group_id].p_b_grid, + p_ds_grid_, + gemm_desc_ptr[group_id].p_e_grid, + p_shared, + nullptr, + a_element_op, + b_element_op, + c_element_op, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + KBatch, + block_2_etile_map); + } + + id_off += grid_size_grp; + id_local += grid_size_grp; + } + } +#else + ignore = gemm_descs_const; + ignore = barrier_count; + ignore = barrier_size_grp; + ignore = group_count; + ignore = grid_size_grp; + ignore = KBatch; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Wmma_Fixed_Nk; + + // GET_NXDL_PER_WAVE_IMPL + // static constexpr auto NXdlPerWave64 = GetNXdlPerWave(); + // static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); + + static constexpr index_t NumDTensor = DsDataType::Size(); + + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + + using AComputeType = ComputeType; + using BComputeType = ComputeType; + + // GridwiseGemm + template + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, // PermuteA not supported by DeviceBatchedGemm base class. + false>; // PermuteB not supported by DeviceBatchedGemm base class. + + using GridwiseGemm64 = GridwiseGemmBase; + using GridwiseGemm32 = GridwiseGemmBase; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + // TODO: replace with GroupedGemmKernelArgument + struct GemmBiasTransKernelArg + { + // pointers + const void* a_ptr_; + const void* b_ptr_; + std::array ds_ptr_; + void* e_ptr_; + + index_t M_, N_, K_; + index_t StrideA_, StrideB_; + std::array StrideDs_; + index_t StrideE_; + }; + + // Argument + struct Argument : public BaseArgument + { + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; + const index_t N = gemm_desc_kernel_arg_[0].N_; + + const auto e_grid_desc_m_n = + GridwiseGemm64::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const auto e_grid_desc_m_n = + GridwiseGemm64::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + const auto& karg = reinterpret_cast( + arg.gemm_kernel_args_[i].karg_); + + if(!GridwiseGemm::CheckValidity(karg)) + { + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; + } + + const auto e_grid_desc_sum_m_n = + GridwiseGemm64::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool has_main_k_block_loop = true; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + { + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + } + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk, + GemmSpec, + false, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + nullptr, + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + else + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk, + GemmSpec, + true, + ALayout, + BLayout, + DsLayout, + ELayout, + DsDataType, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + e_global_memory_operation_, + has_main_k_block_loop_>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + reinterpret_cast(arg.p_workspace_), + arg.barrier_size_grp_, + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } + }; + + constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; + constexpr auto Set = InMemoryDataOperationEnum::Set; + + // For bf16 datatype only kbatch = 1 scenario is supported. This condition is + // enforced in IsSupportedArgument function + if constexpr(std::is_same::value) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel(integral_constant{}, + integral_constant{}); + } + } + else + { + if(arg.k_batch_ > 1) + { + if(has_main_k_block_loop) + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + else + { + ave_time = launch_kernel( + integral_constant{}, + integral_constant{}); + } + } + else + { + if(has_main_k_block_loop) + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time = + launch_kernel(integral_constant{}, + integral_constant{}); + } + } + } + + return ave_time; + } + + INVOKER_RUN_IMPL + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(ck::type_convert(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + { + return false; + } + + bool supported = true; + + // If we use padding we do not support vector loads for dimensions not divisible by + // vector load size. + if constexpr(GemmSpec != GemmSpecialization::Default) + { + // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} + // layout, thus we have to adapt it to the {M,K} or {N,K} layout. + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + + for(index_t i = 0; i < arg.group_count_; ++i) + { + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + } + } + + // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd + // instruction that supports bf16 and we cannot use splitk because of that + if constexpr(std::is_same::value) + { + supported = supported & (arg.k_batch_ == 1); + } + + return supported; + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedGemm_Wmma_Fixed_Nk" + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWMMA << ", " + << NPerWMMA << ", " + << MWmmaPerWave << ", " + << NWmmaPerWave << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMWmmaPerWavePerShuffle << ", " + << CShuffleNWmmaPerWavePerShuffle << ", " + << getGemmSpecializationString(GemmSpec) + << ">"; + // clang-format on + + return str.str(); + } + + // polymorphic + void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_workspace_ = p_workspace; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(k_batch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + } + + void SetKBatchSize(BaseArgument* p_arg, index_t kbatch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(kbatch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp new file mode 100644 index 00000000000..a4f3ee3df8a --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -0,0 +1,240 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/loop_scheduler.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AccDataType = F32; +using DsDataType = Empty_Tuple; + +using DsLayout = Empty_Tuple; +using ELayout = Row; + +static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3; +static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave; +static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = device::GemmSpecialization::Default; + +// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang`-format on + >; + +// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// List of instance variants to add (pipeline/scheduler/padding combinations) +// Some are disabled now, can be re-enabled if needed +using InstanceVariant = + ck::Tuple; +static constexpr InstanceVariant InstanceVariants[] = { + + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1), + // make_tuple(GemmDefault, InterwaveScheduler, PipelineV1), + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), + + make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1), + // make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1), + // make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV3), +}; + +// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported +// padding/scheduler/pipeline version combinations +template + typename LayoutInstances, + typename ADataType, // NOTE: type parameters as last so that they can be inferred from the + typename BDataType, // vector argument + typename EDataType, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + // Add all instances from our instance list + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); +} + +// Helper function to add a list of layout instances for instances with matching A/B/E data types +// for all supported padding/scheduler/pipeline version combinations +template + typename LayoutInstances, + typename AElementOp, // NOTE: element-wise op parameters as last so that they can be + typename BElementOp, // inferred from the vector argument + typename CDEElementOp> +void add_device_grouped_gemm_wmma_universal_instances( + std::vector>>& instances) +{ + // Add all instances from our instance list + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index 3f17a930a0b..7b018b3a402 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -141,6 +141,19 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances); + void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances); #endif // CK_ENABLE_BF16 template ) { add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); } } #endif // CK_ENABLE_BF16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index e56df524cfa..ffefecc1ac1 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -13,6 +13,8 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16 device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp) + device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..d6ba50472c1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,73 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; +using DsLayout = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..d3d82c9f75b --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,76 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using DsDataType = ck::Tuple<>; +using DsLayout = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances = + std::tuple< + // clang-format off + //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| + //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| + //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> + // clang-format on + >; + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_operation_instances( + instances, + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 6060579c22d7d59fcd77445e4d7a7f515d898c66 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Wed, 17 Dec 2025 15:46:14 +0000 Subject: [PATCH 02/16] Implementation working and example added --- example/15_grouped_gemm/CMakeLists.txt | 4 + .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 70 ++ .../run_grouped_gemm_example.inc | 8 + .../device_grouped_gemm_wmma_fixed_nk.hpp | 880 +++++++++--------- ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 51 +- .../gpu/grouped_gemm_fixed_nk.hpp | 54 +- .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 27 +- ...ed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 73 -- ...ed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 76 -- ...fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp | 38 + ...fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp | 38 + 11 files changed, 646 insertions(+), 673 deletions(-) create mode 100644 example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp delete mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index a7dae9dcd82..6b75ff76ced 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -47,6 +47,10 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp) add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16) +add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16) + + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp new file mode 100644 index 00000000000..b35b6463f12 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -0,0 +1,70 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; + +// clang-format on + +// #define EXAMPLE_USE_SPLITK +#include "run_grouped_gemm_example.inc" + +int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index ffd0c5e9b7b..b6b9835161a 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -272,6 +272,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ComputeDataType>(c_device_tensors[i], c_host_tensors[i]); #endif } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; } if(config.time_kernel) @@ -347,8 +349,14 @@ bool run_grouped_gemm_example(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); + +#ifdef FIX_NK + problem_size.Ns.push_back(512); + problem_size.Ks.push_back(512); +#else problem_size.Ns.push_back(128 + 128 * i); problem_size.Ks.push_back(128 + 64 * i); +#endif problem_size.stride_As.push_back( get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i])); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 7d7efeb7e47..a1b70619549 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -6,6 +6,14 @@ #include #include +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/tuple.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + #include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" @@ -22,165 +30,88 @@ namespace device { template + bool HasMainKBlockLoop, + InMemoryDataOperationEnum CGlobalMemoryDataOperation, + typename Block2CTileMap, + index_t MinimumOccupancy = 1, + TailNumber TailNum = TailNumber::Full> __global__ void #if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - kernel_grouped_gemm_Wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - uint32_t* barrier_count, - const index_t barrier_size_grp, - const index_t group_count, - const index_t grid_size_grp, - const index_t KBatch, - const AElementwiseOperation a_element_op, - const BElementwiseOperation b_element_op, - const CDEElementwiseOperation c_element_op) + kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count) { -#if defined(__gfx9__) || defined(__gfx11__) || defined(__gfx12__) - if constexpr(GridwiseGemm::template IsValidCompilationParameter()) +#if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + // Binary search lookup to find which group this block is part of + index_t left = 0; + index_t right = group_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && + block_id < gemm_desc_ptr[group_id].block_end_)) && + left <= right) { - __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; - - const index_t block_id = get_block_1d_id(); - - const auto gemm_desc_ptr = reinterpret_cast( - cast_pointer_to_generic_address_space(gemm_descs_const)); - - const index_t group_id = block_id / grid_size_grp; - - if(group_id >= group_count) - return; - - const index_t M = gemm_desc_ptr[group_id].M; - const index_t N = gemm_desc_ptr[group_id].N; - const index_t K = gemm_desc_ptr[group_id].K; - - if(M == 0 || N == 0 || K == 0) - return; - - const auto StrideA = gemm_desc_ptr[group_id].StrideA; - const auto StrideB = gemm_desc_ptr[group_id].StrideB; - const auto StrideDs = gemm_desc_ptr[group_id].StrideDs; - const auto StrideE = gemm_desc_ptr[group_id].StrideE; - - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); - - const index_t BlockStart = group_id * grid_size_grp; - - const auto local_b2e_tile_map = Block2ETileMap{e_grid_desc_m_n, KBatch}; - - const auto local_grid_size = local_b2e_tile_map.CalculateGridSize(e_grid_desc_m_n); - - constexpr auto NumDTensor = DsDataType::Size(); - - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - - DsGridPointer p_ds_grid_; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - // D pointer - p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - - index_t id_off = 0; - index_t id_local = get_block_1d_id() - BlockStart; - - const index_t mn_blocks = local_grid_size / KBatch; - - while(id_local < local_grid_size) + if(block_id < gemm_desc_ptr[group_id].block_start_) { - const auto block_2_etile_map = - GroupedGemmBlock2ETileMap(local_b2e_tile_map, BlockStart, id_off); - - if constexpr(Zeroing) - { - auto barrier_count_finished = - barrier_count + group_id * barrier_size_grp + id_local % mn_blocks; - GridwiseGemm::template RunWithZeroing(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - barrier_count_finished, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); - } - else - { - - GridwiseGemm::template Run(gemm_desc_ptr[group_id].p_a_grid, - gemm_desc_ptr[group_id].p_b_grid, - p_ds_grid_, - gemm_desc_ptr[group_id].p_e_grid, - p_shared, - nullptr, - a_element_op, - b_element_op, - c_element_op, - M, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - KBatch, - block_2_etile_map); - } - - id_off += grid_size_grp; - id_local += grid_size_grp; + right = group_id; + } + else + { + left = group_id; } + group_id = index_t((left + right) / 2); + } + + // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index + // and thus needs a non-const reference. It's also not feasible to store this in global + // memory as different threads would be writing different K values to the same arg struct + auto karg = gemm_desc_ptr[group_id].karg_; + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; + + // Tile index first dimension is the K batch + auto tile_index = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run(static_cast(p_shared), + splitk_batch_offset, + karg, + block_2_ctile_map, + epilogue_args); +#if defined(__gfx11__) } +#endif #else ignore = gemm_descs_const; - ignore = barrier_count; - ignore = barrier_size_grp; ignore = group_count; - ignore = grid_size_grp; - ignore = KBatch; - ignore = a_element_op; - ignore = b_element_op; - ignore = c_element_op; -#endif +#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) } template + BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, + BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, + typename ComputeTypeA = EDataType, + typename ComputeTypeB = ComputeTypeA, + bool PermuteA = false, + bool PermuteB = false> struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(); - // static constexpr auto NXdlPerWave32 = GetNXdlPerWave(); - static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; - using AComputeType = ComputeType; - using BComputeType = ComputeType; - - // GridwiseGemm - template using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, @@ -309,11 +232,12 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; // PermuteB not supported by DeviceBatchedGemm base class. + false, + false>; - using GridwiseGemm64 = GridwiseGemmBase; - using GridwiseGemm32 = GridwiseGemmBase; + using CGridDesc_M_N = + remove_cvref_t( + 1, 1, 1, 1, 1))>; template struct OffsettedBlockToCTileMapMLoops @@ -364,9 +288,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops { - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( @@ -459,92 +380,109 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; - // TODO: replace with GroupedGemmKernelArgument - struct GemmBiasTransKernelArg + static constexpr index_t DefaultKBatch = 1; + using KernelArgument = typename GridwiseGemm::Argument; + + template + struct GemmTransKernelArgBase { - // pointers - const void* a_ptr_; - const void* b_ptr_; - std::array ds_ptr_; - void* e_ptr_; - - index_t M_, N_, K_; - index_t StrideA_, StrideB_; - std::array StrideDs_; - index_t StrideE_; + KernelArgument_ karg_; + GroupedGemmBlock2ETileMap block_2_ctile_map_; + index_t block_start_, block_end_; + + GemmTransKernelArgBase() = default; + GemmTransKernelArgBase(KernelArgument_&& karg, + GroupedGemmBlock2ETileMap&& b2c_map, + index_t block_start, + index_t block_end) + : karg_{karg}, + block_2_ctile_map_{b2c_map}, + block_start_{block_start}, + block_end_{block_end} + { + } }; + using GemmTransKernelArg = GemmTransKernelArgBase; + + static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) + { + index_t k_grain = karg.KBatch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / karg.KBatch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } // Argument struct Argument : public BaseArgument { - void UpdateKBatch(index_t k_batch) + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) { - k_batch_ = k_batch; - - if(k_batch_ < 1) - { - - throw std::runtime_error("wrong! k_batch must be > 0"); - } - - const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); - - const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE_; - const index_t N = gemm_desc_kernel_arg_[0].N_; - - const auto e_grid_desc_m_n = - GridwiseGemm64::template MakeEGridDescriptor_M_N( - AverM, N, StrideE); - - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; - - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - - grid_size_ = grid_size_grp_ * group_count_; + // TODO: use occupancy api to calculate appropriate batch size. } - Argument(std::vector&, - std::vector&, - std::vector>&, - std::vector&, + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) - : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + CDEElementwiseOperation c_element_op, + index_t kbatch) + : group_count_{ck::type_convert(gemm_descs.size())}, + grouped_gemm_kernel_args_dev{nullptr}, + gemm_kernel_host_args_{nullptr}, + grid_size_{0}, + k_batch_{kbatch} { - grid_size_ = 0; - - k_batch_ = 1; - - grouped_gemm_kernel_args_dev = nullptr; - group_count_ = ck::type_convert(gemm_descs.size()); + if(!(group_count_ == ck::type_convert(p_As.size()) && + group_count_ == ck::type_convert(p_Bs.size()) && + ((NumDTensor == 0 && p_Ds.size() == 0) || + group_count_ == ck::type_convert(p_Ds.size())) && + group_count_ == ck::type_convert(p_Es.size()))) + { + throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); + } gemm_desc_kernel_arg_.reserve(group_count_); - index_t group_id = 0; - - sum_of_m = gemm_descs[0].M_; - const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); - const index_t N = gemm_descs[0].N_; - const index_t K = gemm_descs[0].K_; + const index_t fixed_N = gemm_descs[0].N_; + const index_t fixed_K = gemm_descs[0].K_; for(std::size_t i = 0; i < gemm_descs.size(); i++) { - if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + const index_t M = gemm_descs[i].M_; + const index_t N = gemm_descs[i].N_; + const index_t K = gemm_descs[i].K_; + + if(N != fixed_N || K != fixed_K) { - throw std::runtime_error("wrong! M/N/K is not identical"); + throw std::runtime_error("wrong! N or K are not fixed across GEMM groups"); } - a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + a_mtx_mraw_kraw_.emplace_back(M, K); b_mtx_nraw_kraw_.emplace_back(N, K); const index_t StrideA = gemm_descs[i].stride_A_; const index_t StrideB = gemm_descs[i].stride_B_; const index_t StrideE = gemm_descs[i].stride_C_; - // pointer std::array p_ds_grid; @@ -564,60 +502,124 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - AverM, N, StrideE); + GridwiseGemm::template MakeDEGridDescriptor_M_N( + M, m_padded, N, n_padded, StrideE); // block-to-e-tile map const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - if(group_id * grid_size_grp_ != grid_size_) + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) { - throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + throw std::runtime_error("wrong! block_2_etile_map validation failed"); } + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp_; + grid_size_ += grid_size_grp_; - if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) - { - throw std::runtime_error("wrong! block_2_etile_map validation failed"); - } + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + auto karg = KernelArgument(std::array{p_As[i]}, + std::array{p_Bs[i]}, + p_Ds[i], + type_convert(p_Es[i]), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false); + + gemm_desc_kernel_arg_.emplace_back( + std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); + + // group_id++; + } - const auto& karg = reinterpret_cast( - arg.gemm_kernel_args_[i].karg_); + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + group_count_ * gemm_descs[0].M_, + group_count_ * gemm_descs[0].M_, + gemm_descs[0].N_, + gemm_descs[0].N_, + gemm_descs[0].stride_C_); + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); - if(!GridwiseGemm::CheckValidity(karg)) - { - std::ostringstream err; - err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } - gemm_desc_kernel_arg_.push_back(GemmBiasTransKernelArg{ - nullptr, - nullptr, - p_ds_grid, - nullptr, - AverM, - N, - K, - StrideA, - StrideB, - StrideDs, - StrideE, - }); + /** + * @brief Recalculate group grid size for all gemms and update B2C maps. + * + * @param[in] k_batch The new splitK parameter value. + */ + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + if(k_batch_ < 1) + { + throw std::runtime_error("wrong! k_batch must be > 0"); + } - group_id++; + for(std::size_t i = 0; i < gemm_desc_kernel_arg_.size(); ++i) + { + auto& karg = gemm_desc_kernel_arg_[i].karg_; + + const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, k_batch_); + const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, k_batch_); + const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, k_batch_); + const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, k_batch_); + + const auto c_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{c_grid_desc_m_n, k_batch_}; + const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); + + const index_t block_start = grid_size_; + const index_t block_end = grid_size_ + grid_size_grp; + + grid_size_ += grid_size_grp; + + auto grouped_block_2_ctile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + + karg.KRead = k_read; + karg.KPadded = k_padded; + karg.AK0 = ak0_padded; + karg.BK0 = bk0_padded; + karg.KBatch = k_batch_; + gemm_desc_kernel_arg_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; + gemm_desc_kernel_arg_[i].block_start_ = block_start; + gemm_desc_kernel_arg_[i].block_end_ = block_end; } const auto e_grid_desc_sum_m_n = - GridwiseGemm64::template MakeEGridDescriptor_M_N( - sum_of_m, gemm_desc_kernel_arg_[0].N_, gemm_desc_kernel_arg_[0].StrideE_); + GridwiseGemm::template MakeDEGridDescriptor_M_N( + group_count_ * gemm_desc_kernel_arg_[0].karg_.M, + group_count_ * gemm_desc_kernel_arg_[0].karg_.M, + gemm_desc_kernel_arg_[0].karg_.N, + gemm_desc_kernel_arg_[0].karg_.N, + gemm_desc_kernel_arg_[0].karg_.StrideE); - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, 1}; + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); } @@ -625,175 +627,194 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK gemm_desc_kernel_arg_; + std::vector gemm_desc_kernel_arg_; std::vector> a_mtx_mraw_kraw_; std::vector> b_mtx_nraw_kraw_; const void* grouped_gemm_kernel_args_dev; - + void* gemm_kernel_host_args_; index_t grid_size_; index_t grid_size_grp_; index_t barrier_size_grp_; - index_t sum_of_m; - index_t k_batch_; }; // Invoker struct Invoker : public BaseInvoker { - using Argument = DeviceOp::Argument; - - template - float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + float Run(const Argument& arg, + const StreamConfig& stream_config = StreamConfig{}, + hipStream_t cpy_stream = nullptr, + hipEvent_t cpy_event = nullptr) { - bool has_main_k_block_loop = true; + using GemmTransKernelArg_ = GemmTransKernelArgBase; + static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); + + bool all_have_kbatch_gt_one = arg.gemm_desc_kernel_arg_[0].karg_.KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[0].karg_); - for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; + + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) { - const auto KPad = - GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K_, arg.k_batch_); + const auto& karg = reinterpret_cast( + arg.gemm_desc_kernel_arg_[i].karg_); + + if(stream_config.log_level_ > 0) + { + karg.Print(); + } + + auto kbatch = karg.KBatch; - if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) + if(!GridwiseGemm::CheckValidity(karg)) { - throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); + std::ostringstream err; + err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } + + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); + not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); } - if(arg.grouped_gemm_kernel_args_dev == nullptr) + if(not_all_have_main_k0_block_loop_same) { - throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); } - float ave_time = 0; + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } - auto launch_kernel = [&](auto has_main_k_block_loop_, auto e_global_memory_operation_) { - if(arg.k_batch_ == 1) - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk, - GemmSpec, - false, - ALayout, - BLayout, - DsLayout, - ELayout, - DsDataType, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - e_global_memory_operation_, - has_main_k_block_loop_>; - - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - nullptr, - arg.barrier_size_grp_, - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); - } - else + // If the user provides copy stream and copy event, we assume that they're also + // responsible for providing allocated host memory (eg. pinned) which + // would be used to copy kernel arguments to the device. + if(cpy_stream && cpy_event) + { + if(arg.gemm_kernel_host_args_ == nullptr) { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk, - GemmSpec, - true, - ALayout, - BLayout, - DsLayout, - ELayout, - DsDataType, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - e_global_memory_operation_, - has_main_k_block_loop_>; - - return launch_and_time_kernel( - stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - reinterpret_cast(arg.p_workspace_), - arg.barrier_size_grp_, - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + std::ostringstream err; + err << "No memory has been allocated for gemm kernel host args " + << "when providing the copy stream and copy event! In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } - }; + hip_check_error(hipMemcpyAsync(arg.p_workspace_, + arg.gemm_kernel_host_args_, + arg.group_count_ * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + cpy_stream)); - constexpr auto AtomicAdd = InMemoryDataOperationEnum::AtomicAdd; - constexpr auto Set = InMemoryDataOperationEnum::Set; + hip_check_error(hipEventRecord(cpy_event, cpy_stream)); - // For bf16 datatype only kbatch = 1 scenario is supported. This condition is - // enforced in IsSupportedArgument function - if constexpr(std::is_same::value) + hip_check_error(hipEventSynchronize(cpy_event)); + } + else // In this case CK owns memory allocated on host. { - if(has_main_k_block_loop) - { - ave_time = launch_kernel(integral_constant{}, - integral_constant{}); - } - else + + hip_check_error( + hipMemcpyAsync(arg.p_workspace_, + arg.gemm_desc_kernel_arg_.data(), + arg.gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg_), + hipMemcpyHostToDevice, + stream_config.stream_id_)); + } + + float ave_time = 0; + + const auto Run = [&](const auto& kernel) { + if(all_have_kbatch_gt_one) { - ave_time = launch_kernel(integral_constant{}, - integral_constant{}); + for(const auto& trans_arg : arg.gemm_desc_kernel_arg_) + { + + const auto& karg = trans_arg.karg_; + hip_check_error(hipMemsetAsync(karg.p_e_grid, + 0, + karg.M * karg.N * sizeof(EDataType), + stream_config.stream_id_)); + } } - } - else + + ave_time = + launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.p_workspace_), + arg.gemm_desc_kernel_arg_.size()); + }; + + // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) { - if(arg.k_batch_ > 1) + // Tail number always full + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if(has_main_k_block_loop) + if(all_have_kbatch_gt_one) { - ave_time = launch_kernel( - integral_constant{}, - integral_constant{}); + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); } else { - ave_time = launch_kernel( - integral_constant{}, - integral_constant{}); + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); } } - else + } + else + { + // Tail number always 1 + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if(has_main_k_block_loop) + if(all_have_kbatch_gt_one) { - ave_time = - launch_kernel(integral_constant{}, - integral_constant{}); + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); } else { - ave_time = - launch_kernel(integral_constant{}, - integral_constant{}); + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk; + + Run(kernel); } } } @@ -801,8 +822,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK"; // clang-format on @@ -923,56 +940,27 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg); - if(arg_ptr) - { - arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; - } - else - throw std::runtime_error("The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + return this->SetWorkSpacePointer(p_arg, kernel_args); } size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { - auto arg_ptr = dynamic_cast(p_arg); - if(arg_ptr) + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) { - return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); + return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); } else throw std::runtime_error("The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); } size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override { - auto arg_ptr = dynamic_cast(p_arg); - if(arg_ptr) - { - return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); - } - else - throw std::runtime_error("The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + return GetWorkSpaceSize(p_arg); } - void SetWorkSpacePointer(BaseArgument* p_arg, - void* p_workspace, - const StreamConfig& stream_config = StreamConfig{}) const override - { - auto arg_ptr = dynamic_cast(p_arg); - if(arg_ptr) - { - arg_ptr->p_workspace_ = p_workspace; - } - else - throw std::runtime_error("The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); - - hip_check_error( - hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); - } + size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } @@ -1000,6 +988,20 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg); + if(!pArg_) + { + throw std::runtime_error("Failed to cast argument pointer!"); + } + + pArg_->gemm_kernel_host_args_ = p_host_kernel_args; + std::copy(pArg_->gemm_desc_kernel_arg_.begin(), + pArg_->gemm_desc_kernel_arg_.end(), + static_cast(pArg_->gemm_kernel_host_args_)); + } }; } // namespace device diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp index a4f3ee3df8a..0e072584ee3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -46,49 +46,6 @@ static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; static constexpr auto GemmDefault = device::GemmSpecialization::Default; -// Instances for 2 byte datatypes in CRR layout with ADataType = BDataType = EDataType -template = false> -using device_grouped_gemm_wmma_universal_km_kn_mn_instances = - std::tuple< - // clang-format off - //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> - // clang`-format on - >; - -// Instances for 2 byte datatypes in CCR layout with ADataType = BDataType = EDataType -template = false> -using device_grouped_gemm_wmma_universal_km_nk_mn_instances = std::tuple< - // clang-format off - //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| - //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| - //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Col, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> - // clang-format on - >; - // Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType template = false> -using device_grouped_gemm_wmma_universal_mk_kn_mn_instances = +using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances = std::tuple< // clang-format off //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -120,7 +77,7 @@ template = false> -using device_grouped_gemm_wmma_universal_mk_nk_mn_instances = +using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances = std::tuple< // clang-format off //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| @@ -165,7 +122,7 @@ template -void add_device_grouped_gemm_wmma_universal_instances( +void add_device_grouped_gemm_wmma_fixed_nk_instances( std::vector -void add_device_grouped_gemm_wmma_universal_instances( +void add_device_grouped_gemm_wmma_fixed_nk_instances( std::vector>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + // fp8_inputB void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>>& instances); void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>>& instances); #endif // CK_ENABLE_BF16 template ) { add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } } @@ -291,13 +295,11 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); - add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); - add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); } } #endif // CK_ENABLE_BF16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index ffefecc1ac1..c8f48b2a7f3 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -4,17 +4,20 @@ # ONLY XDL_AND_WMMA_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) -list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp - device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp) +list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp + + device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + ) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp deleted file mode 100644 index d6ba50472c1..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using DsDataType = ck::Tuple<>; -using DsLayout = ck::Tuple<>; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances = - std::tuple< - // clang-format off - //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| C| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 2, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 16,16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 1, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 2, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S< 1, 4, 64, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 64, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 2, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 2, 0, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 32, 8, 8, 32, 32, 2, 2, S< 1, 4, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 4, 32, 1>, S<0, 1, 3, 2>, S<0, 1, 3, 2>, 2, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 8> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp deleted file mode 100644 index d3d82c9f75b..00000000000 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ /dev/null @@ -1,76 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" - -#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -using BF16 = ck::bhalf_t; -using F32 = float; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -template -using S = ck::Sequence; - -using DsDataType = ck::Tuple<>; -using DsLayout = ck::Tuple<>; -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding; - -using device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances = - std::tuple< - // clang-format off - //############################| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| - //############################| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| - //############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| - //############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 256, 64, 8, 8, 32, 32, 2, 4, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 128, 64, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 256, 64, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 32, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 128, 64, 8, 8, 32, 32, 4, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 64, 128, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 128, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 128, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 128, 32, 256, 64, 8, 8, 32, 32, 1, 4, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 16, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 64, 64, 8, 8, 32, 32, 2, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 64, 32, 64, 8, 8, 32, 32, 2, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, Row, BF16, BF16, F32, F32, DsDataType, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 1, 64, 32, 64, 64, 8, 8, 32, 32, 1, 2, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, S<1, 8, 8, 1>, S<0, 2, 1, 3>, S<0, 2, 1, 3>, 3, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8> - // clang-format on - >; - -void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( - std::vector>>& instances) -{ - add_device_operation_instances( - instances, - device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_irregular_tile_instances{}); -} - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..f9c2eba0793 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + F16, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..f870a01f9ad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + F16, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 277d3691adcd61693e7be413ce644ad9c421c8ad Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Mon, 22 Dec 2025 10:52:09 +0000 Subject: [PATCH 03/16] Added instances supported by the XDL version --- .../device_grouped_gemm_wmma_fixed_nk.hpp | 1 - ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 135 ++++++++++++++---- .../gpu/grouped_gemm_fixed_nk.hpp | 110 +++++++++++++- .../gpu/grouped_gemm_fixed_nk/CMakeLists.txt | 8 ++ ...ed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 37 +++++ ...ed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 37 +++++ ...ixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp | 39 +++++ ...ixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp | 39 +++++ ...fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp | 4 +- ...fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp | 4 +- ...fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp | 39 +++++ ...fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp | 39 +++++ ..._fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp | 39 +++++ ..._fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp | 39 +++++ 14 files changed, 538 insertions(+), 32 deletions(-) create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index a1b70619549..286f82035d8 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -547,7 +547,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; +template +using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +template +using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + // List of instance variants to add (pipeline/scheduler/padding combinations) // Some are disabled now, can be re-enabled if needed using InstanceVariant = @@ -114,8 +157,7 @@ template - typename LayoutInstances, + typename CDEElementOp> typename LayoutInstances, typename ADataType, // NOTE: type parameters as last so that they can be inferred from the typename BDataType, // vector argument typename EDataType, @@ -123,17 +165,17 @@ template void add_device_grouped_gemm_wmma_fixed_nk_instances( - std::vector>>& instances) + std::vector>>& instances) { // Add all instances from our instance list static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { @@ -159,23 +201,22 @@ template - typename LayoutInstances, + typename CDEElementOp> typename LayoutInstances, typename AElementOp, // NOTE: element-wise op parameters as last so that they can be typename BElementOp, // inferred from the vector argument typename CDEElementOp> void add_device_grouped_gemm_wmma_fixed_nk_instances( - std::vector>>& instances) + std::vector>>& instances) { // Add all instances from our instance list static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { @@ -191,6 +232,48 @@ void add_device_grouped_gemm_wmma_fixed_nk_instances( }); } +template typename LayoutInstances, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> +void add_device_grouped_gemm_wmma_fixed_nk_irregular_instances( + std::vector>>& instances) +{ + // Add all instances from our instance list + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); +} + } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index 0df4013f660..0aab9913571 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -96,6 +96,32 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_f16_f8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f8_f16_mk_nk_mn_instances( + std::vector>>& instances); + // i8_inputB void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances( std::vector>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances( + std::vector>>& instances); + // bf16_inputA i8_inputB #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( @@ -138,6 +190,19 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances); + void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + std::vector>>& instances); #endif // bf16_inputA bf16_inputB @@ -181,6 +259,33 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( PassThrough, PassThrough>>>& instances); +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances); + + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances); + #endif // CK_ENABLE_BF16 template ) { add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); } } #endif @@ -295,11 +401,13 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); } } #endif // CK_ENABLE_BF16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index c8f48b2a7f3..f7ce0a70730 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -16,8 +16,16 @@ list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp ) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..81e0e9a4414 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + BF16, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..939f21a88e3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + BF16, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..cfde259a4d3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + BF16, + I8, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..64c96b8d913 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + BF16, + I8, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp index f9c2eba0793..c5f08ca1ee2 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -13,7 +13,7 @@ namespace instance { void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( - std::vector>>& instances) + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp index f870a01f9ad..eae317f8f73 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -13,7 +13,7 @@ namespace instance { void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( - std::vector>>& instances) + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..43a990064ba --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + F8, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..115cc95558f --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + F8, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..3bb479cafe9 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + I8, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..3e40d594557 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + I8, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck From 1cf5026c85cee5a6a5db18e05db711966342d7b5 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Tue, 23 Dec 2025 10:43:42 +0000 Subject: [PATCH 04/16] unit test basic implementation --- .../gpu/grouped_gemm_fixed_nk.hpp | 164 ++++++++++++++---- test/grouped_gemm/CMakeLists.txt | 6 + .../test_grouped_gemm_fixed_nk.cpp | 81 +++++++++ .../test_grouped_gemm_fixed_nk_cases.inc | 84 +++++++++ 4 files changed, 299 insertions(+), 36 deletions(-) create mode 100644 test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp create mode 100644 test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index 0aab9913571..c3a3fbd6b4c 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -17,6 +17,7 @@ namespace device { namespace instance { // fp16_output +#if defined(CK_USE_XDL) void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( +// fp8_inputB +void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( +void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances( std::vector>>& instances); + PassThrough>>>& instances); -// fp8_inputB -void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances( +// i8_inputB +void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances( +void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( std::vector>>& instances); +#endif -void add_device_grouped_gemm_wmma_fixed_nk_f16_f8_f16_mk_kn_mn_instances( +#if defined (CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_f16_f8_f16_mk_nk_mn_instances( +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances); + PassThrough>>>& instances); -// i8_inputB -void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances( +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances( std::vector>>& instances); +#endif // bf16_inputA i8_inputB #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) +#if defined (CK_USE_XDL) void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( std::vector>>& instances); + PassThrough>>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( +void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( std::vector>>& instances); - -void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + PassThrough>>>& instances); +#endif +#if defined (CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( std::vector>>& instances); -void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( std::vector>>& instances); + PassThrough>>>& instances); +#endif #endif // bf16_inputA bf16_inputB #if defined(CK_ENABLE_BF16) +#if defined (CK_USE_XDL) void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( std::vector>>& instances); - +#endif +#if defined (CK_USE_WMMA) void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( std::vector>>& instances); - +#endif #endif // CK_ENABLE_BF16 template > op_ptrs; - +#if defined (CK_USE_XDL) // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) @@ -331,13 +340,11 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); } } @@ -382,6 +389,91 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif + +// bf16_inputA bf16_inputB +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif //CK_USE_XDL + +#if defined (CK_USE_WMMA) + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // fp8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // i8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs); + } + } + +// bf16_i8_input +#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && @@ -400,17 +492,17 @@ struct DeviceOperationInstanceFactory< if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); } if constexpr(is_same_v && is_same_v && is_same_v) { - add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); } } #endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA + return op_ptrs; } diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 450950cbd66..3091667da3f 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,6 +18,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() + + add_gtest_executable(test_grouped_gemm_fixed_nk test_grouped_gemm_fixed_nk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp new file mode 100644 index 00000000000..f83283e9813 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -0,0 +1,81 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +ck::index_t param_mask = 0xffffff; +ck::index_t instance_index = -1; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using I8 = int8_t; + +using AElementOp = ck::tensor_operation::element_wise::PassThrough; +using BElementOp = ck::tensor_operation::element_wise::PassThrough; +using CDEElementOp = ck::tensor_operation::element_wise::PassThrough; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +class TestGroupedGemm : public ck::test::TestGroupedGemm +{ + public: + void SetUp() override + { + ck::test::TestGroupedGemm::SetUp(); + +#if defined(CK_USE_WMMA) + // The old XDL tests didn't fail if instances were not supported, so we want to keep that + // behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a + // specific case is not supported + this->fail_if_no_supported_instances_ = + ck::is_gfx11_supported() || ck::is_gfx12_supported(); +#endif + } +}; + +// clang-format off +using KernelTypes = ::testing::Types< + std::tuple< Row, Row, Row, BF16, BF16, BF16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Col, Row, BF16, BF16, BF16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Row, Row, BF16, I8, BF16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Col, Row, BF16, I8, BF16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Row, Row, F16, F8, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Col, Row, F16, F8, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Row, Row, F16, I8, F16, AElementOp, BElementOp, CDEElementOp>, + std::tuple< Row, Col, Row, F16, I8, F16, AElementOp, BElementOp, CDEElementOp> +>; +// clang-format on + +TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes); + +#include "test_grouped_gemm_fixed_nk_cases.inc" +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc new file mode 100644 index 00000000000..efdf467d66a --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TestGroupedGemm, TinyCases) +{ + const std::vector Ms{2, 1}; + constexpr int N = 544; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5}; + constexpr int N = 544; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 768; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 320; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 280; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, TestLargeKBatch) +{ + // In some cases Split K is not supported. Running this test would fail since no instance will + // be supported, so we skip the test + if(!this->IsSplitKSupported()) + GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on " + "GFX11, or using CDE element-wise operation)"; + + const std::vector Ms{188, 210}; + constexpr int N = 4096; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->k_batches_ = {32, 64}; + + this->Run(Ms, Ns, Ks); +} From be375c4c6794600d78d5c33118f6c0c9b76efb10 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Mon, 5 Jan 2026 09:53:17 +0000 Subject: [PATCH 05/16] Test added --- .../device_grouped_gemm_wmma_fixed_nk.hpp | 4 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 7 +- .../profile_grouped_gemm_fixed_nk_impl.hpp | 6 +- .../test_grouped_gemm_fixed_nk.cpp | 24 +-- .../test_grouped_gemm_fixed_nk_cases.inc | 51 +---- test/grouped_gemm/test_grouped_gemm_util.hpp | 202 +++++++++++++++--- 6 files changed, 201 insertions(+), 93 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 286f82035d8..5cc18205f30 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -50,7 +50,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t block_id = get_block_1d_id(); const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - + // Binary search lookup to find which group this block is part of index_t left = 0; index_t right = group_count; @@ -570,6 +570,8 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK 0"); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index bcf131003c2..aea770016b4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -1091,8 +1091,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! " + "K_Batch:" << karg.KBatch << " " << + "K0PerBlock:" << KPerBlock << " " << + "K1:" << AK1Number << " " << + "K:" << karg.K << " " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } return false; diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index f551a16a1b9..f488c5210be 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -74,11 +74,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, std::vector> b_k_n; std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; - int sum_of_m = 0; + // int sum_of_m = 0; for(std::size_t i = 0; i < group_count; i++) { - sum_of_m += Ms[i]; + // sum_of_m += Ms[i]; a_m_k.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); b_k_n.push_back( @@ -150,7 +150,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp index f83283e9813..76fa7d2ad65 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -28,12 +28,12 @@ using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; template -class TestGroupedGemm : public ck::test::TestGroupedGemm +class TestGroupedGemm : public ck::test::TestGroupedGemm { public: void SetUp() override { - ck::test::TestGroupedGemm::SetUp(); + ck::test::TestGroupedGemm::SetUp(); #if defined(CK_USE_WMMA) // The old XDL tests didn't fail if instances were not supported, so we want to keep that @@ -47,16 +47,16 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm // clang-format off using KernelTypes = ::testing::Types< - std::tuple< Row, Row, Row, BF16, BF16, BF16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Col, Row, BF16, BF16, BF16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Row, Row, BF16, I8, BF16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Col, Row, BF16, I8, BF16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Row, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Col, Row, F16, F16, F16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Row, Row, F16, F8, F16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Col, Row, F16, F8, F16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Row, Row, F16, I8, F16, AElementOp, BElementOp, CDEElementOp>, - std::tuple< Row, Col, Row, F16, I8, F16, AElementOp, BElementOp, CDEElementOp> + std::tuple< Row, Row, Row, BF16, BF16, BF16>, + std::tuple< Row, Col, Row, BF16, BF16, BF16>, + std::tuple< Row, Row, Row, BF16, I8, BF16>, + std::tuple< Row, Col, Row, BF16, I8, BF16>, + std::tuple< Row, Row, Row, F16, F16, F16>, + std::tuple< Row, Col, Row, F16, F16, F16>, + std::tuple< Row, Row, Row, F16, F8, F16>, + std::tuple< Row, Col, Row, F16, F8, F16>, + std::tuple< Row, Row, Row, F16, I8, F16>, + std::tuple< Row, Col, Row, F16, I8, F16> >; // clang-format on diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc index efdf467d66a..e41508b9f0d 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -3,59 +3,12 @@ #pragma once -TYPED_TEST(TestGroupedGemm, TinyCases) -{ - const std::vector Ms{2, 1}; - constexpr int N = 544; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemm, SmallCases) -{ - const std::vector Ms{2, 1, 3, 4, 5}; - constexpr int N = 544; - constexpr int K = 544; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemm, MidCases) -{ - const std::vector Ms{167, 183, 177, 153, 139, 204}; - constexpr int N = 768; - constexpr int K = 768; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} - -TYPED_TEST(TestGroupedGemm, Regular) -{ - const std::vector Ms{64, 128, 256}; - constexpr int N = 320; - constexpr int K = 320; - - const std::vector Ns(Ms.size(), N); - const std::vector Ks(Ms.size(), K); - - this->Run(Ms, Ns, Ks); -} TYPED_TEST(TestGroupedGemm, MNKPadded) { const std::vector Ms{127, 150, 188, 210}; - constexpr int N = 280; - constexpr int K = 280; + constexpr int N = 512; + constexpr int K = 1024; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index ee95fe03c66..38841d74963 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp" extern ck::index_t param_mask; extern ck::index_t instance_index; @@ -23,7 +24,124 @@ extern ck::index_t instance_index; namespace ck { namespace test { -template + +struct DefaultGroupedGemmProfiler +{ + template < + typename ADataType, + typename BDataType, + typename EDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename ELayout, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> + static bool Run(bool verify, + int init_method, + bool log, + bool bench, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + const std::vector& kbatches, + int n_warmup, + int n_iter, + int instance_index, + bool fail_if_no_supported_instances) + { + return ck::profiler::profile_grouped_gemm_impl( + verify, + init_method, + log, + bench, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup, + n_iter, + instance_index, + fail_if_no_supported_instances); + } +}; + +struct FixedNKGroupedGemmProfiler +{ + template < + typename ADataType, + typename BDataType, + typename EDataType, + typename AccDataType, + typename ALayout, + typename BLayout, + typename CLayout> + static bool Run( + bool verify, + int init_method, + bool log, + bool bench, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + const std::vector& kbatches, + int n_warmup, + int n_iter, + int /*instance_index*/, + bool /*fail_if_no_supported_instances*/) + { + bool pass = true; + for(int kbatch : kbatches) + { + pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl< + ADataType, + BDataType, + EDataType, + AccDataType, + ALayout, + BLayout, + CLayout>( + verify, + init_method, + log, + bench, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + return pass; + } +}; + + +template class TestGroupedGemm : public testing::Test { protected: @@ -146,31 +264,63 @@ class TestGroupedGemm : public testing::Test const std::vector& StrideCs, const std::vector& kbatches) { - bool pass = - ck::profiler::profile_grouped_gemm_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index, - fail_if_no_supported_instances_); + bool pass = false; + using AccDataType = float; + + if constexpr (std::is_same_v) + { + pass = Profiler::template Run( + verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); + } + else + { + pass = Profiler::template Run( + verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); + } + EXPECT_TRUE(pass); } }; From 02580be6c2eaf85292cc9810c439afa56c330992 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Tue, 6 Jan 2026 14:19:34 +0000 Subject: [PATCH 06/16] Pushing everything to test out xdl version on another server --- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 28 ++++++ ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 2 +- .../gpu/grouped_gemm_fixed_nk.hpp | 87 ------------------- profiler/src/CMakeLists.txt | 29 +++++++ .../test_grouped_gemm_fixed_nk.cpp | 27 ++++-- .../test_grouped_gemm_fixed_nk_cases.inc | 53 ++++++++++- test/grouped_gemm/test_grouped_gemm_util.hpp | 2 +- 7 files changed, 129 insertions(+), 99 deletions(-) diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index aea770016b4..6efcfbaa955 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -1043,6 +1043,34 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && (NPerBlock % (NPerWmma * NRepeat)) == 0, "Invalid tuning param!"); + // if (!(MPerBlock % (MPerWmma * MRepeat) == 0 && NPerBlock % (NPerWmma * NRepeat) == 0)) + // { + // std::cout << "[DEBUG] Invalid tuning param!\n" + // << " MPerBlock: " << MPerBlock << "\n" + // << " NPerBlock: " << NPerBlock << "\n" + // << " MPerWmma : " << MPerWmma << "\n" + // << " NPerWmma : " << NPerWmma << "\n" + // << " MRepeat : " << MRepeat << "\n" + // << " NRepeat : " << NRepeat << "\n" + // << " Check: MPerBlock % (MPerWmma * MRepeat) == " + // << (MPerBlock % (MPerWmma * MRepeat)) << "\n" + // << " NPerBlock % (NPerWmma * NRepeat) == " + // << (NPerBlock % (NPerWmma * NRepeat)) << "\n"; + // } + // std::cout + // << "[CK_CHECK] " + // << "M=" << karg.M + // << " N=" << karg.N + // << " K=" << karg.K + // << " KBatch=" << karg.KBatch + // << " | MPerBlock=" << MPerBlock + // << " NPerBlock=" << NPerBlock + // << " KPerBlock=" << KPerBlock + // << " | MRepeat=" << MRepeat + // << " NRepeat=" << NRepeat + // << " | AK1=" << AK1Number + // << " BK1=" << BK1Number + // << std::endl; if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp index d0c92fdb1d9..660472b7140 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -141,7 +141,7 @@ static constexpr InstanceVariant InstanceVariants[] = { make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1), // make_tuple(GemmDefault, InterwaveScheduler, PipelineV1), - make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), + // make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1), // make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1), diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index c3a3fbd6b4c..ae3751d296d 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -331,93 +331,7 @@ struct DeviceOperationInstanceFactory< static auto GetInstances() { std::vector> op_ptrs; -#if defined (CK_USE_XDL) - // fp16_output - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); - } - } - - // fp8_input - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances(op_ptrs); - } - } - // i8_input - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs); - } - } - -// bf16_i8_input -#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); - } - } -#endif - -// bf16_inputA bf16_inputB -#if defined(CK_ENABLE_BF16) - if constexpr(is_same_v && is_same_v && - is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); - } - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); - } - } -#endif // CK_ENABLE_BF16 -#endif //CK_USE_XDL - -#if defined (CK_USE_WMMA) // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) @@ -501,7 +415,6 @@ struct DeviceOperationInstanceFactory< } } #endif // CK_ENABLE_BF16 -#endif // CK_USE_WMMA return op_ptrs; diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 012d6e15027..f507d418661 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -298,3 +298,32 @@ message(VERBOSE "ckProfiler libs: ${PROFILER_LIBS}") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS}) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) + + +## Defining specific operation targets + +macro(define_profiler_target NAME SOURCES LIBS) + add_executable(${NAME} profiler.cpp ${SOURCES}) + target_compile_options(${NAME} PRIVATE -Wno-global-constructors) + + if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) + target_compile_options(${NAME} PRIVATE --offload-compress) + endif() + + target_link_libraries(${NAME} PRIVATE utility getopt::getopt ${LIBS}) + + rocm_install(TARGETS ${NAME} COMPONENT profiler) +endmacro() + + +define_profiler_target(ckProfiler_gemm + "profile_gemm.cpp" + "device_gemm_instance") + +define_profiler_target(ckProfiler_gemm_universal + "profile_gemm_universal.cpp" + "device_gemm_universal_instance") + +define_profiler_target(ckProfiler_gemm_fixed_nk + "profile_grouped_gemm_fixed_nk.cpp" + "device_grouped_gemm_fixed_nk_instance") diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp index 76fa7d2ad65..8fb80b09813 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -45,19 +45,30 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm, - std::tuple< Row, Col, Row, BF16, BF16, BF16>, + +#if defined(CK_USE_XDL) && defined(__gfx9__) + // XDL only at the moment, instances for WMMA not defined std::tuple< Row, Row, Row, BF16, I8, BF16>, std::tuple< Row, Col, Row, BF16, I8, BF16>, - std::tuple< Row, Row, Row, F16, F16, F16>, - std::tuple< Row, Col, Row, F16, F16, F16>, +#endif + +#if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__)) std::tuple< Row, Row, Row, F16, F8, F16>, std::tuple< Row, Col, Row, F16, F8, F16>, - std::tuple< Row, Row, Row, F16, I8, F16>, - std::tuple< Row, Col, Row, F16, I8, F16> ->; +#endif + + std::tuple< Row, Row, Row, F16, F16, F16>, + std::tuple< Row, Col, Row, F16, F16, F16>, + + + std::tuple< Row, Row, Row, BF16, BF16, BF16>, + std::tuple< Row, Col, Row, BF16, BF16, BF16>, + + std::tuple, + std::tuple + >; // clang-format on TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes); diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc index e41508b9f0d..bd514de4e83 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -4,11 +4,60 @@ #pragma once +TYPED_TEST(TestGroupedGemm, TinyCases) +{ + const std::vector Ms{2, 1}; + constexpr int N = 512; + constexpr int K = 256; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5}; + constexpr int N = 512; + constexpr int K = 256; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 512; + constexpr int K = 256; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 512; + constexpr int K = 256; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + + TYPED_TEST(TestGroupedGemm, MNKPadded) { const std::vector Ms{127, 150, 188, 210}; - constexpr int N = 512; - constexpr int K = 1024; + constexpr int N = 128; + constexpr int K = 256; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index 38841d74963..ff94d0c1ccb 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -194,7 +194,7 @@ class TestGroupedGemm : public testing::Test } else { - k_batches_ = {1, 2, 3, 5, 8}; + k_batches_ = {1, 2, 3, 4, 8}; } } From c6f9cd9e7baee46e64b7f3c4e01edf0d4e3e3630 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Fri, 9 Jan 2026 12:43:57 +0000 Subject: [PATCH 07/16] added xdl to factory and fixed wmma test bugs --- .../device_grouped_gemm_wmma_fixed_nk.hpp | 296 ++++-------------- ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 28 +- .../gpu/grouped_gemm_fixed_nk.hpp | 99 +++++- .../profile_grouped_gemm_fixed_nk_impl.hpp | 51 ++- .../test_grouped_gemm_fixed_nk.cpp | 7 +- .../test_grouped_gemm_fixed_nk_cases.inc | 12 +- test/grouped_gemm/test_grouped_gemm_util.hpp | 2 +- 7 files changed, 204 insertions(+), 291 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 5cc18205f30..82f7eeba8d9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -12,9 +12,7 @@ #include "ck/host_utility/hip_check_error.hpp" #include "ck/utility/common_header.hpp" #include "ck/utility/tuple.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/utility/common_header.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" #include "ck/tensor_description/tensor_descriptor_helper.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" @@ -179,10 +177,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< ALayout, BLayout, @@ -239,146 +233,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( 1, 1, 1, 1, 1))>; - template - struct OffsettedBlockToCTileMapMLoops - { - using underlying_type = UnderlyingBlockToCTileMap; - - __host__ __device__ OffsettedBlockToCTileMapMLoops( - UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) - { - block_to_ctile_map_ = block_to_ctile_map; - block_start_ = block_start; - id_off_ = id_off; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( - make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); - - return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, - const CTileDim& c_tile_dim) const - { - return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); - } - - template - __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); - } - - UnderlyingBlockToCTileMap block_to_ctile_map_; - index_t block_start_; - index_t id_off_; - }; - - template - struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops - { - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& - operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; - - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, - index_t N, - index_t KBatch, - index_t M01 = 8) - : M_(M), N_(N), KBatch_(KBatch), M01_(M01) - { - } - - template - __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) - : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( - c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) - { - } - - __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const - { - const auto M0 = math::integer_divide_ceil(M, MPerBlock); - const auto N0 = math::integer_divide_ceil(N, NPerBlock); - - return M0 * N0 * KBatch_; - } - - template - __host__ __device__ constexpr index_t - CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const - { - return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); - } - - template - __host__ bool CheckValidity(const CGridDesc_M_N& /* c_grid_desc_m_n */) const - { - return true; - } - - template - __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const - { - auto block_1d_id = idx_top[I0]; - - const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); - const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); - - block_1d_id = block_1d_id % (M0 * N0 * KBatch_); // hide groups - - const index_t idx_ksplit = block_1d_id / (M0 * N0); - block_1d_id = block_1d_id % (M0 * N0); - - index_t idx_N0 = block_1d_id % N0; - index_t idx_M0 = block_1d_id / N0; - - const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; - - index_t idx_M00 = idx_M0 / M01_; - index_t idx_M01 = idx_M0 % M01_; - index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; - - return make_tuple(idx_ksplit, - idx_N0_M01_local % M01_adapt + idx_M00 * M01_, - idx_N0_M01_local / M01_adapt); - } - - template - __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, - const CTileDim& /* c_tile_dim */) const - { - return true; // always valid provided that user gets grid size from CalculateGridSize() - } - - private: - index_t M_; - index_t N_; - index_t KBatch_; - index_t M01_; - }; - - using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + using Block2ETileMap = BlockToCTileMap_KSplit_M00_N0_M01Adapt; + static constexpr index_t B2E_M01 = 8; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; static constexpr index_t DefaultKBatch = 1; using KernelArgument = typename GridwiseGemm::Argument; @@ -446,12 +303,10 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(gemm_descs.size())}, - grouped_gemm_kernel_args_dev{nullptr}, gemm_kernel_host_args_{nullptr}, grid_size_{0}, k_batch_{kbatch} { - if(!(group_count_ == ck::type_convert(p_As.size()) && group_count_ == ck::type_convert(p_Bs.size()) && ((NumDTensor == 0 && p_Ds.size() == 0) || @@ -466,7 +321,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK p_ds_grid; + const auto& stride_d_vec = gemm_descs[i].stride_Ds_; - static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); + if(!(NumDTensor == ck::type_convert(stride_d_vec.size()))) + { + throw std::runtime_error("wrong! stride D mismatch"); + } std::array StrideDs; - - static_for<0, NumDTensor, 1>{}([&](auto j) { - // using DLayout = remove_cvref_t>; - - if(gemm_descs[i].stride_Ds_.size() != NumDTensor) - { - throw std::runtime_error( - "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); - } - - StrideDs[j] = gemm_descs[i].stride_Ds_[j]; - }); + if constexpr(NumDTensor > 0) + { + std::copy(stride_d_vec.begin(), stride_d_vec.end(), StrideDs); + } const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t n_padded = GridwiseGemm::CalculateNPadded(N); @@ -510,9 +356,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - group_count_ * gemm_descs[0].M_, - group_count_ * gemm_descs[0].M_, - gemm_descs[0].N_, - gemm_descs[0].N_, - gemm_descs[0].stride_C_); - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); - - barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); } /** @@ -590,7 +423,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE); - const auto local_b2c_tile_map = Block2ETileMap{c_grid_desc_m_n, k_batch_}; + const auto local_b2c_tile_map = Block2ETileMap{c_grid_desc_m_n, B2E_M01, k_batch_}; const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); const index_t block_start = grid_size_; @@ -611,32 +444,17 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - group_count_ * gemm_desc_kernel_arg_[0].karg_.M, - group_count_ * gemm_desc_kernel_arg_[0].karg_.M, - gemm_desc_kernel_arg_[0].karg_.N, - gemm_desc_kernel_arg_[0].karg_.N, - gemm_desc_kernel_arg_[0].karg_.StrideE); - - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); - barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); } // private: index_t group_count_; std::vector gemm_desc_kernel_arg_; - std::vector> a_mtx_mraw_kraw_; - std::vector> b_mtx_nraw_kraw_; - const void* grouped_gemm_kernel_args_dev; void* gemm_kernel_host_args_; index_t grid_size_; - index_t grid_size_grp_; - index_t barrier_size_grp_; + index_t k_batch_; }; @@ -833,42 +651,65 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) { return false; } - - bool supported = true; - - // If we use padding we do not support vector loads for dimensions not divisible by - // vector load size. - if constexpr(GemmSpec != GemmSpecialization::Default) + if constexpr(std::is_same_v || + std::is_same_v) { - // [A|B]BlockTransferSrcVectorDim value define dimension in the block {K0,M,K1} - // layout, thus we have to adapt it to the {M,K} or {N,K} layout. - const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; - const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; - - for(index_t i = 0; i < arg.group_count_; ++i) + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) { - const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); - const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } - supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); - supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + if constexpr(!std::is_same_v) + { + if(arg.k_batch_ > 1) + { + // Using SplitK and a C element op would require a two stage kernel where the second + // stage applies the op on the accumulated results + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "C element operators are not supported when using SplitK. Set " + "K_BATCH to 1 or remove the operator." + << std::endl; + } + return false; } } - // For bf16 datatype only kbatch = 1 is supported since there is no AtomicAdd - // instruction that supports bf16 and we cannot use splitk because of that - if constexpr(std::is_same::value) + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { - supported = supported & (arg.k_batch_ == 1); + if(ck::is_gfx11_supported()) + { + return false; + } } + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) + { + const auto& a = arg.gemm_desc_kernel_arg_[i].karg_; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } return supported; } - // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { @@ -966,7 +807,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg); if(arg_ptr) @@ -978,17 +819,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg); - if(arg_ptr) - { - arg_ptr->UpdateKBatch(kbatch); - } - else - throw std::runtime_error("The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); - } void SetHostKernelArgsPointer(BaseArgument* p_arg, void* p_host_kernel_args) const { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp index 660472b7140..1a17454fea6 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -59,10 +59,10 @@ template , S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> @@ -81,10 +81,10 @@ template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> @@ -106,9 +106,9 @@ using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances = //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -127,9 +127,9 @@ using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances = //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index ae3751d296d..3418f0c69a3 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -99,7 +99,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances( PassThrough>>>& instances); #endif -#if defined (CK_USE_WMMA) +#if defined(CK_USE_WMMA) void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); #endif -#if defined (CK_USE_WMMA) +#if defined(CK_USE_WMMA) void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( std::vector>>& instances); #endif -#if defined (CK_USE_WMMA) +#if defined(CK_USE_WMMA) void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( std::vector> op_ptrs; +#if defined(CK_USE_XDL) + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // fp8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // i8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs); + } + } + +// bf16_i8_input +#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif + +// bf16_inputA bf16_inputB +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif //CK_USE_XDL + + +#if defined(CK_USE_WMMA) // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) @@ -415,6 +503,7 @@ struct DeviceOperationInstanceFactory< } } #endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA return op_ptrs; diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index f488c5210be..efcb52d3093 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -47,6 +47,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, int n_iter = 10) { bool pass = true; + using ComputeDataType = ADataType; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -54,11 +55,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; @@ -74,11 +75,10 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, std::vector> b_k_n; std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; - // int sum_of_m = 0; + double max_abs_in_val = 0.f; for(std::size_t i = 0; i < group_count; i++) { - // sum_of_m += Ms[i]; a_m_k.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); b_k_n.push_back( @@ -95,17 +95,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc << std::endl; } - std::size_t num_thread = 1; switch(init_method) { case 0: break; case 1: - a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n[i]); + max_abs_in_val = 10.f; break; default: - a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + ck::utils::FillUniformDistribution{0.0f, 1.0f}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5f, 0.5f}(b_k_n[i]); + max_abs_in_val = 1.0f; } } @@ -282,23 +283,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, bool instance_pass = true; for(std::size_t i = 0; i < gemm_descs.size(); i++) { - c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); - - if(std::is_same_v && kbatch_curr > 1) - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i], - "Error: Incorrect results!", - 0.06); - } - else - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i]); - } + auto atol = ck::utils::get_absolute_threshold( + max_abs_in_val, gemm_descs[i].K_); + auto rtol = ck::utils::get_relative_threshold( + gemm_descs[i].K_); + + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + rtol, + atol); if(do_log) { @@ -315,7 +311,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, } } - std::cout << "Instance: " << gemm_name << " verification " + std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " " << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; pass = pass && instance_pass; @@ -355,7 +351,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, } else { - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + std::cout << "Instance: " << gemm_name + << ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")" << std::endl; } } diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp index 8fb80b09813..253d179a995 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -48,11 +48,6 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm, - std::tuple< Row, Col, Row, BF16, I8, BF16>, -#endif #if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__)) std::tuple< Row, Row, Row, F16, F8, F16>, @@ -65,6 +60,8 @@ using KernelTypes = ::testing::Types< std::tuple< Row, Row, Row, BF16, BF16, BF16>, std::tuple< Row, Col, Row, BF16, BF16, BF16>, + std::tuple< Row, Row, Row, BF16, I8, BF16>, + std::tuple< Row, Col, Row, BF16, I8, BF16>, std::tuple, std::tuple diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc index bd514de4e83..68647741b25 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -6,7 +6,7 @@ TYPED_TEST(TestGroupedGemm, TinyCases) { - const std::vector Ms{2, 1}; + const std::vector Ms{2, 2}; constexpr int N = 512; constexpr int K = 256; @@ -18,7 +18,7 @@ TYPED_TEST(TestGroupedGemm, TinyCases) TYPED_TEST(TestGroupedGemm, SmallCases) { - const std::vector Ms{2, 1, 3, 4, 5}; + const std::vector Ms{2, 2, 2, 2, 2}; constexpr int N = 512; constexpr int K = 256; @@ -30,7 +30,7 @@ TYPED_TEST(TestGroupedGemm, SmallCases) TYPED_TEST(TestGroupedGemm, MidCases) { - const std::vector Ms{167, 183, 177, 153, 139, 204}; + const std::vector Ms{167, 167, 167, 167, 167, 167}; constexpr int N = 512; constexpr int K = 256; @@ -42,7 +42,7 @@ TYPED_TEST(TestGroupedGemm, MidCases) TYPED_TEST(TestGroupedGemm, Regular) { - const std::vector Ms{64, 128, 256}; + const std::vector Ms{64, 64, 64}; constexpr int N = 512; constexpr int K = 256; @@ -55,7 +55,7 @@ TYPED_TEST(TestGroupedGemm, Regular) TYPED_TEST(TestGroupedGemm, MNKPadded) { - const std::vector Ms{127, 150, 188, 210}; + const std::vector Ms{127, 127, 127, 127}; constexpr int N = 128; constexpr int K = 256; @@ -73,7 +73,7 @@ TYPED_TEST(TestGroupedGemm, TestLargeKBatch) GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on " "GFX11, or using CDE element-wise operation)"; - const std::vector Ms{188, 210}; + const std::vector Ms{188, 188}; constexpr int N = 4096; constexpr int K = 4096; diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index ff94d0c1ccb..c4e0ed68439 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -194,7 +194,7 @@ class TestGroupedGemm : public testing::Test } else { - k_batches_ = {1, 2, 3, 4, 8}; + k_batches_ = {1, 2, 3, 4, 8}; } } From 6053dcbb5dc693772238241ddcd4c644e61c2a00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Bidlek?= Date: Thu, 22 Jan 2026 09:54:28 +0000 Subject: [PATCH 08/16] fixed kernel and examples --- .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 323 ++++- .../grouped_gemm_xdl_fixed_nk_fp16.cpp | 7 +- .../device_grouped_gemm_wmma_fixed_nk.hpp | 1175 +++++++++++------ .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 + .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 64 +- .../test_grouped_gemm_fixed_nk.cpp | 4 +- .../test_grouped_gemm_fixed_nk_cases.inc | 30 +- 7 files changed, 1108 insertions(+), 496 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp index b35b6463f12..a31596d6fa1 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -52,7 +52,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = PassThrough; -static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::MNPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk // clang-format off @@ -60,11 +60,322 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_ //######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| //######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| //######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; - + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on -// #define EXAMPLE_USE_SPLITK -#include "run_grouped_gemm_example.inc" +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + } + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back( + std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back( + std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + // // Copy device tensors back to host + // for(std::size_t i = 0; i < c_device_tensors.size(); i++) + // { + // c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + // c_device_tensors[i].mDesc.GetElementSize() * + // sizeof(EDataType)); + + + // } + // // Print out device and reference results for debugging + // std::cout << "[CK GEMM RESULT TRACE]\n"; + // for(std::size_t i = 0; i < c_device_tensors.size(); i++) + // { + // std::cout << "GEMM[" << i << "] device C:\n"; + // auto& devC = c_device_tensors[i].mData; + // for(std::size_t m = 0; m < static_cast(problem_size.Ms[i]); m++) + // { + // for(std::size_t n = 0; n < static_cast(problem_size.Ns[i]); n++) + // { + // std::cout << static_cast(devC[m * problem_size.Ns[i] + n]) << " "; + // } + // std::cout << "\n"; + // } + + // std::cout << "GEMM[" << i << "] reference C:\n"; + // auto& refC = c_host_tensors[i].mData; + // for(std::size_t m = 0; m < static_cast(problem_size.Ms[i]); m++) + // { + // for(std::size_t n = 0; n < static_cast(problem_size.Ns[i]); n++) + // { + // std::cout << static_cast(refC[m * problem_size.Ns[i] + n]) << " "; + // } + // std::cout << "\n"; + // } + // std::cout << "--------------------------------------\n"; + // } + } + + + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + + exit(0); + } + + for(int i = 0; i < problem_size.group_count; i++) + { + problem_size.Ms.push_back(256); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(256); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } -int main(int argc, char* argv[]) { return !run_grouped_gemm_example(argc, argv); } + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 61f03907b72..d8759eb5461 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -291,6 +291,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co } } + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; return pass; } @@ -329,9 +330,9 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ms.push_back(128 + rand() % 128); - problem_size.Ns.push_back(1024); - problem_size.Ks.push_back(1024); + problem_size.Ms.push_back(256); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(256); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 82f7eeba8d9..792c89a1b44 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -29,18 +29,38 @@ namespace device { template + typename Block2ETileMap, + typename GroupedGemmBlock2ETileMap, + typename AElementwiseOperation, + typename BElementwiseOperation, + typename CDEElementwiseOperation, + index_t MinimumOccupancy, + TailNumber TailNum, + index_t MPerBlock, + index_t NPerBlock, + index_t KPerBlock, + GemmSpecialization GemmSpec> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, - const index_t group_count) + + const index_t group_count, + const index_t grid_size_grp, + const index_t k_batch_, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) { #if(defined(__gfx11__) || defined(__gfx12__)) + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; @@ -49,67 +69,190 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - // Binary search lookup to find which group this block is part of - index_t left = 0; - index_t right = group_count; - index_t group_id = index_t((left + right) / 2); - while((!(block_id >= gemm_desc_ptr[group_id].block_start_ && - block_id < gemm_desc_ptr[group_id].block_end_)) && - left <= right) - { - if(block_id < gemm_desc_ptr[group_id].block_start_) - { - right = group_id; - } - else + + const index_t group_id = block_id / grid_size_grp; + if(group_id >= group_count) + return; + const index_t group_start = group_id * grid_size_grp; + + + const index_t M = gemm_desc_ptr[group_id].M; + const index_t N = gemm_desc_ptr[group_id].N; + const index_t K = gemm_desc_ptr[group_id].K; + + if(M == 0 || N == 0 || K == 0) + return; + + + const auto StrideE = gemm_desc_ptr[group_id].StrideE; + // const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + M, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + const auto local_grid_size = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); + }); + + + + +// #if defined(__gfx11__) +// // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions +// using c_data_type = remove_cvref_t>; +// if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && +// (std::is_same_v || +// std::is_same_v))) +// { +// #endif + + + auto epilogue_args = + typename GridwiseGemm::EpilogueCShuffle{}; + + const auto& desc = gemm_desc_ptr[group_id]; + const typename GridwiseGemm::Problem problem{ + desc.M, + desc.N, + desc.K, + std::array{desc.StrideA}, + std::array{desc.StrideB}, + desc.StrideDs, + desc.StrideE, + k_batch_ + }; + + using AsGridPointer = typename GridwiseGemm::AsGridPointer; + using ADataType0 = remove_cvref_t>; + + AsGridPointer p_as_grid_ = make_tuple( + static_cast(gemm_desc_ptr[group_id].p_a_grid) + ); + using BsGridPointer = typename GridwiseGemm::BsGridPointer; + using BDataType0 = remove_cvref_t>; + + BsGridPointer p_bs_grid_ = make_tuple( + static_cast(gemm_desc_ptr[group_id].p_b_grid) + ); + + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - group_start; + + while(id_local < local_grid_size) { - left = group_id; + + // if(threadIdx.x == 0) + // { + // printf( + // "\n[CK GEMM TRACE]\n" + // " id_local = %d\n" + // " local_grid_size = %d\n", + // int(id_local), + // int(local_grid_size) + // ); + + // } + + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off); + + // auto tile_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(id_local)); + + // const index_t m_tile_idx = tile_idx[Number<0>{}]; + // const index_t n_tile_idx = tile_idx[Number<1>{}]; + // const index_t k_tile_idx = tile_idx[Number<2>{}]; + + // calculate ranges for each dimension + // const index_t m_start = m_tile_idx * MPerBlock; + // const index_t m_end = min(m_start + MPerBlock, M); + + // const index_t n_start = n_tile_idx * NPerBlock; + // const index_t n_end = min(n_start + NPerBlock, N); + + // const index_t k_start = k_tile_idx * KPerBlock; + // const index_t k_end = min(k_start + KPerBlock, K); + + // if(threadIdx.x == 0) + // { + // printf("[CK GEMM TRACE] grid_size=%d, group_id=%d, block_id=%d, " + // "m_tile=%d, n_tile=%d, k_tile=%d, " + // "M_range=[%d,%d), N_range=[%d,%d), K_range=[%d,%d)\n", + // int(local_grid_size), + // int(group_id), + // int(get_block_1d_id()), + // int(m_tile_idx), + // int(n_tile_idx), + // int(k_tile_idx), + // int(m_start), + // int(m_end), + // int(n_start), + // int(n_end), + // int(k_start), + // int(k_end)); + // } + + + + GridwiseGemm::template Run, + typename GridwiseGemm::EpilogueCShuffle, + 1, + 2> + (p_as_grid_, + p_bs_grid_, + p_ds_grid_, + static_cast(gemm_desc_ptr[group_id].p_e_grid), + static_cast(p_shared), + problem, + block_2_etile_map, + a_element_op, + b_element_op, + c_element_op, + epilogue_args); + + // if(threadIdx.x == 0) + // { + // printf( + // "\n[CK GEMM TRACE]\n" + // " id_local = %d\n" + // " local_grid_size = %d\n", + // int(id_local), + // int(local_grid_size)); + // } + id_off += grid_size_grp; + id_local += grid_size_grp; } - group_id = index_t((left + right) / 2); - } - // NOTE: Local copy of the arg struct since SplitKBatchOffset verifies and modifies K index - // and thus needs a non-const reference. It's also not feasible to store this in global - // memory as different threads would be writing different K values to the same arg struct - auto karg = gemm_desc_ptr[group_id].karg_; - -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; - if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - const auto& block_2_ctile_map = gemm_desc_ptr[group_id].block_2_ctile_map_; - - // Tile index first dimension is the K batch - auto tile_index = - block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); - - auto splitk_batch_offset = - typename GridwiseGemm::SplitKBatchOffset(karg, tile_index[Number<0>{}]); - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run(static_cast(p_shared), - splitk_batch_offset, - karg, - block_2_ctile_map, - epilogue_args); -#if defined(__gfx11__) - } -#endif +#undef TRACE_THREAD +// #if defined(__gfx11__) +// } +// #endif #else ignore = gemm_descs_const; ignore = group_count; -#endif // end of if(defined(__gfx11__) || defined(__gfx12__)) + ignore = grid_size_grp; + ignore = k_batch_; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif } template ; using CGridDesc_M_N = - remove_cvref_t( - 1, 1, 1, 1, 1))>; + remove_cvref_t( + 1, 1, 1))>; - using Block2ETileMap = BlockToCTileMap_KSplit_M00_N0_M01Adapt; - static constexpr index_t B2E_M01 = 8; - using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMap; - static constexpr index_t DefaultKBatch = 1; - using KernelArgument = typename GridwiseGemm::Argument; - template - struct GemmTransKernelArgBase + template + struct OffsettedBlockToCTileMapMLoops { - KernelArgument_ karg_; - GroupedGemmBlock2ETileMap block_2_ctile_map_; - index_t block_start_, block_end_; - - GemmTransKernelArgBase() = default; - GemmTransKernelArgBase(KernelArgument_&& karg, - GroupedGemmBlock2ETileMap&& b2c_map, - index_t block_start, - index_t block_end) - : karg_{karg}, - block_2_ctile_map_{b2c_map}, - block_start_{block_start}, - block_end_{block_end} + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; }; - using GemmTransKernelArg = GemmTransKernelArgBase; + + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N&) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + const auto total_tiles_per_group = M0 * N0 * KBatch_; + + // #if defined(__HIP_DEVICE_COMPILE__) + // if(threadIdx.x == 0) + // { + // printf( + // "\n[CK TILE MAP TRACE]\n" + // " raw block_1d_id = %d\n" + // " M = %d\n" + // " N = %d\n" + // " MPerBlock = %d\n" + // " NPerBlock = %d\n" + // " M0 (tiles) = %d\n" + // " N0 (tiles) = %d\n" + // " KBatch = %d\n" + // " tiles/group = %d\n", + // int(block_1d_id), + // int(M_), + // int(N_), + // int(MPerBlock_), + // int(NPerBlock_), + // int(M0), + // int(N0), + // int(KBatch_), + // int(total_tiles_per_group)); + // } + // #endif + + // wrap block id into this group + block_1d_id = block_1d_id % total_tiles_per_group; + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = + (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + // #if defined(__HIP_DEVICE_COMPILE__) + // if(threadIdx.x == 0) + // { + // printf( + // " wrapped block_id = %d\n" + // " idx_ksplit = %d\n" + // " idx_M0 = %d\n" + // " idx_N0 = %d\n" + // " M01 = %d\n" + // " M01_adapt = %d\n" + // " idx_M00 = %d\n" + // " idx_M01 = %d\n" + // " idx_N0_M01_local = %d\n" + // " --> m_tile = %d\n" + // " --> n_tile = %d\n" + // " --> k_tile = %d\n" + // "\n", + // int(block_1d_id), + // int(idx_ksplit), + // int(idx_M0), + // int(idx_N0), + // int(M01_), + // int(M01_adapt), + // int(idx_M00), + // int(idx_M01), + // int(idx_N0_M01_local), + // int(idx_N0_M01_local % M01_adapt + idx_M00 * M01_), + // int(idx_N0_M01_local / M01_adapt), + // int(idx_ksplit)); + // } + // #endif + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + static constexpr index_t DefaultKBatch = 1; + using KernelArgument = typename GridwiseGemm::Argument; + + + using GemmTransKernelArg = GroupedGemmKernelArgument; static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) { @@ -268,379 +598,355 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK& p_As, - std::vector& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, - std::vector& gemm_descs, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) - : Argument(p_As, - p_Bs, - p_Ds, - p_Es, - gemm_descs, - a_element_op, - b_element_op, - c_element_op, - DefaultKBatch) + void UpdateKBatch(index_t k_batch) { - // TODO: use occupancy api to calculate appropriate batch size. + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE; + const index_t N = gemm_desc_kernel_arg_[0].N; + + // const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); + // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; } - Argument(std::vector& p_As, - std::vector& p_Bs, - std::vector>& p_Ds, - std::vector& p_Es, + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op, - index_t kbatch) - : group_count_{ck::type_convert(gemm_descs.size())}, - gemm_kernel_host_args_{nullptr}, - grid_size_{0}, - k_batch_{kbatch} + CDEElementwiseOperation c_element_op) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { - if(!(group_count_ == ck::type_convert(p_As.size()) && - group_count_ == ck::type_convert(p_Bs.size()) && - ((NumDTensor == 0 && p_Ds.size() == 0) || - group_count_ == ck::type_convert(p_Ds.size())) && - group_count_ == ck::type_convert(p_Es.size()))) - { - throw std::runtime_error("wrong! group_count_ != p_As/b/d/e.size"); - } + grid_size_ = 0; + + k_batch_ = 1; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); gemm_desc_kernel_arg_.reserve(group_count_); - const index_t fixed_N = gemm_descs[0].N_; - const index_t fixed_K = gemm_descs[0].K_; + index_t group_id = 0; - for(std::size_t i = 0; i < gemm_descs.size(); ++i) - { - const index_t M = gemm_descs[i].M_; - const index_t N = gemm_descs[i].N_; - const index_t K = gemm_descs[i].K_; + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; - if(N != fixed_N || K != fixed_K) + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) { - throw std::runtime_error("wrong! N or K are not fixed across GEMM groups"); + throw std::runtime_error("wrong! M/N/K is not identical"); } + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + const index_t StrideA = gemm_descs[i].stride_A_; const index_t StrideB = gemm_descs[i].stride_B_; const index_t StrideE = gemm_descs[i].stride_C_; - const auto& stride_d_vec = gemm_descs[i].stride_Ds_; - if(!(NumDTensor == ck::type_convert(stride_d_vec.size()))) - { - throw std::runtime_error("wrong! stride D mismatch"); - } + // pointer + std::array p_ds_grid; + + static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); std::array StrideDs; - if constexpr(NumDTensor > 0) - { - std::copy(stride_d_vec.begin(), stride_d_vec.end(), StrideDs); - } - const index_t m_padded = GridwiseGemm::CalculateMPadded(M); - const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + // const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); + // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const auto e_grid_desc_m_n = - GridwiseGemm::template MakeDEGridDescriptor_M_N( - M, m_padded, N, n_padded, StrideE); + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); // block-to-e-tile map - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, B2E_M01, k_batch_}; + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; - const index_t grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + if(group_id * grid_size_grp_ != grid_size_) { - throw std::runtime_error("wrong! block_2_etile_map validation failed"); + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); } - const index_t block_start = grid_size_; - const index_t block_end = grid_size_ + grid_size_grp_; - grid_size_ += grid_size_grp_; - auto grouped_block_2_ctile_map = - GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); - - auto karg = KernelArgument(std::array{p_As[i]}, - std::array{p_Bs[i]}, - p_Ds[i], - type_convert(p_Es[i]), - M, - N, - K, - std::array{StrideA}, - std::array{StrideB}, - StrideDs, - StrideE, - k_batch_, - a_element_op, - b_element_op, - c_element_op, - false); - - gemm_desc_kernel_arg_.emplace_back( - std::move(karg), std::move(grouped_block_2_ctile_map), block_start, block_end); - } - } - - /** - * @brief Recalculate group grid size for all gemms and update B2C maps. - * - * @param[in] k_batch The new splitK parameter value. - */ - void UpdateKBatch(index_t k_batch) - { - k_batch_ = k_batch; - grid_size_ = 0; - - if(k_batch_ < 1) - { - throw std::runtime_error("wrong! k_batch must be > 0"); - } - - for(std::size_t i = 0; i < gemm_desc_kernel_arg_.size(); ++i) - { - auto& karg = gemm_desc_kernel_arg_[i].karg_; - - const index_t k_read = GridwiseGemm::CalculateKRead(karg.K, k_batch_); - const index_t k_padded = GridwiseGemm::CalculateKPadded(karg.K, k_batch_); - const index_t ak0_padded = GridwiseGemm::CalculateAK0Padded(karg.K, k_batch_); - const index_t bk0_padded = GridwiseGemm::CalculateBK0Padded(karg.K, k_batch_); - - const auto c_grid_desc_m_n = - GridwiseGemm::template MakeDEGridDescriptor_M_N( - karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideE); - - const auto local_b2c_tile_map = Block2ETileMap{c_grid_desc_m_n, B2E_M01, k_batch_}; - const index_t grid_size_grp = local_b2c_tile_map.CalculateGridSize(c_grid_desc_m_n); - - const index_t block_start = grid_size_; - const index_t block_end = grid_size_ + grid_size_grp; - - grid_size_ += grid_size_grp; - - auto grouped_block_2_ctile_map = - GroupedGemmBlock2ETileMap(local_b2c_tile_map, block_start); + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } - karg.KRead = k_read; - karg.KPadded = k_padded; - karg.AK0 = ak0_padded; - karg.BK0 = bk0_padded; - karg.KBatch = k_batch_; - gemm_desc_kernel_arg_[i].block_2_ctile_map_ = grouped_block_2_ctile_map; - gemm_desc_kernel_arg_[i].block_start_ = block_start; - gemm_desc_kernel_arg_[i].block_end_ = block_end; + // if(!GridwiseGemm::CheckValidity(arg)) + // { + // std::ostringstream err; + // err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + // << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + // } + + gemm_desc_kernel_arg_.push_back(GemmTransKernelArg{ + nullptr, + nullptr, + p_ds_grid, + nullptr, + AverM, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + }); + + group_id++; } + // const index_t sum_of_m_padded = GridwiseGemm::CalculateMPadded(sum_of_m); + // const index_t n_padded = GridwiseGemm::CalculateNPadded(gemm_desc_kernel_arg_[0].N); + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + sum_of_m, gemm_desc_kernel_arg_[0].N, + gemm_desc_kernel_arg_[0].StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); } // private: index_t group_count_; + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; - void* gemm_kernel_host_args_; index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; index_t k_batch_; }; - // Invoker + +// Invoker struct Invoker : public BaseInvoker { - float Run(const Argument& arg, - const StreamConfig& stream_config = StreamConfig{}, - hipStream_t cpy_stream = nullptr, - hipEvent_t cpy_event = nullptr) - { - using GemmTransKernelArg_ = GemmTransKernelArgBase; - static_assert(sizeof(GemmTransKernelArg_) == sizeof(GemmTransKernelArg)); - - bool all_have_kbatch_gt_one = arg.gemm_desc_kernel_arg_[0].karg_.KBatch > 1; - bool all_have_main_k0_block_loop = - CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[0].karg_); + using Argument = DeviceOp::Argument; - bool not_all_have_main_k0_block_loop_same = false; - bool not_all_have_kbatch_value_same = false; + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + constexpr bool has_main_k_block_loop = true; - for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) { - const auto& karg = reinterpret_cast( - arg.gemm_desc_kernel_arg_[i].karg_); - - if(stream_config.log_level_ > 0) - { - karg.Print(); - } - - auto kbatch = karg.KBatch; + const auto KPad = + GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K, arg.k_batch_); - if(!GridwiseGemm::CheckValidity(karg)) + if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) { - std::ostringstream err; - err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); + throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); } - - not_all_have_main_k0_block_loop_same |= - all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(karg); - not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (kbatch > 1); - } - - if(not_all_have_main_k0_block_loop_same) - { - std::ostringstream err; - err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - // throw std::runtime_error(err.str()); } - if(not_all_have_kbatch_value_same) + if(arg.grouped_gemm_kernel_args_dev == nullptr) { - std::ostringstream err; - err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ - << ":" << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - - // If the user provides copy stream and copy event, we assume that they're also - // responsible for providing allocated host memory (eg. pinned) which - // would be used to copy kernel arguments to the device. - if(cpy_stream && cpy_event) - { - if(arg.gemm_kernel_host_args_ == nullptr) - { - std::ostringstream err; - err << "No memory has been allocated for gemm kernel host args " - << "when providing the copy stream and copy event! In " << __FILE__ << ":" - << __LINE__ << ", in function: " << __func__; - throw std::runtime_error(err.str()); - } - hip_check_error(hipMemcpyAsync(arg.p_workspace_, - arg.gemm_kernel_host_args_, - arg.group_count_ * sizeof(GemmTransKernelArg_), - hipMemcpyHostToDevice, - cpy_stream)); - - hip_check_error(hipEventRecord(cpy_event, cpy_stream)); - - hip_check_error(hipEventSynchronize(cpy_event)); - } - else // In this case CK owns memory allocated on host. - { - - hip_check_error( - hipMemcpyAsync(arg.p_workspace_, - arg.gemm_desc_kernel_arg_.data(), - arg.gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg_), - hipMemcpyHostToDevice, - stream_config.stream_id_)); + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); } float ave_time = 0; - const auto Run = [&](const auto& kernel) { - if(all_have_kbatch_gt_one) - { - for(const auto& trans_arg : arg.gemm_desc_kernel_arg_) - { + auto launch_kernel = [&](auto has_main_k_block_loop_, + auto e_global_memory_operation_, + auto min_occupancy_, + auto tail_num_) { - const auto& karg = trans_arg.karg_; - hip_check_error(hipMemsetAsync(karg.p_e_grid, - 0, - karg.M * karg.N * sizeof(EDataType), - stream_config.stream_id_)); - } + if(arg.k_batch_ == 1) + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk, + Tuple, + DsDataType, + EDataType, + e_global_memory_operation_, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + min_occupancy_, + tail_num_, + MPerBlock, + NPerBlock, + KPerBlock, + GemmSpec>; + + return launch_and_time_kernel(stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); } - - ave_time = - launch_and_time_kernel(stream_config, + else + { + const auto kernel = + kernel_grouped_gemm_wmma_fixed_nk, + Tuple, + DsDataType, + EDataType, + e_global_memory_operation_, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + min_occupancy_, + tail_num_, + MPerBlock, + NPerBlock, + KPerBlock, + GemmSpec>; + + return launch_and_time_kernel(stream_config, kernel, dim3(arg.grid_size_), dim3(BlockSize), 0, - cast_pointer_to_constant_address_space(arg.p_workspace_), - arg.gemm_desc_kernel_arg_.size()); + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + } }; - // NOTE: If at least one gemm problem has a main k0 block loop, we include it for all - if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(all_have_kbatch_gt_one) - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; + const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(arg.gemm_desc_kernel_arg_[0].K); + constexpr index_t min_occupancy = 1; - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - Run(kernel); - } - } + if constexpr(std::is_same::value) + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); } else { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + if(arg.k_batch_ > 1) { - if(all_have_kbatch_gt_one) - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - - Run(kernel); - } - else - { - const auto kernel = - kernel_grouped_gemm_wmma_fixed_nk; - - Run(kernel); - } + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + else + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); } } + + + return ave_time; } + template + void SelectTailNumber(TailNumber tail_num, Lambda&& lambda) + { + switch(tail_num) + { + case TailNumber::Full: lambda(std::integral_constant{}); break; + case TailNumber::Empty: lambda(std::integral_constant{}); break; + case TailNumber::One: lambda(std::integral_constant{}); break; + case TailNumber::Two: lambda(std::integral_constant{}); break; + case TailNumber::Three: lambda(std::integral_constant{}); break; + case TailNumber::Four: lambda(std::integral_constant{}); break; + case TailNumber::Five: lambda(std::integral_constant{}); break; + case TailNumber::Six: lambda(std::integral_constant{}); break; + case TailNumber::Seven: lambda(std::integral_constant{}); break; + case TailNumber::Odd: lambda(std::integral_constant{}); break; + case TailNumber::Even: lambda(std::integral_constant{}); break; + default: lambda(std::integral_constant{}); break;; + } + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + return RunImp(arg, stream_config); + } + // polymorphic float Run(const BaseArgument* p_arg, const StreamConfig& stream_config = StreamConfig{}) override @@ -651,66 +957,35 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) { return false; } - if constexpr(std::is_same_v || - std::is_same_v) - { - if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) - { - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - return false; - } - } - if constexpr(!std::is_same_v) + bool supported = true; + if constexpr(GemmSpec != GemmSpecialization::Default) { - if(arg.k_batch_ > 1) - { - // Using SplitK and a C element op would require a two stage kernel where the second - // stage applies the op on the accumulated results - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "C element operators are not supported when using SplitK. Set " - "K_BATCH to 1 or remove the operator." - << std::endl; - } - return false; - } - } + const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; + const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; - if constexpr(std::is_same_v || std::is_same_v || - std::is_same_v || std::is_same_v) - { - if(ck::is_gfx11_supported()) + for(index_t i = 0; i < arg.group_count_; ++i) { - return false; + const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); + const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + + supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); + supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); } } - - bool supported = true; - for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) + if constexpr(std::is_same::value) { - const auto& a = arg.gemm_desc_kernel_arg_[i].karg_; - bool group_arg_valid = GridwiseGemm::CheckValidity(a); - - if(not group_arg_valid) - { - if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) - { - std::cout << "[" << __func__ << "] group id: " << i - << " has invalid GridwiseGemm settings!" << std::endl; - a.Print(); - } - } - supported = supported && group_arg_valid; + supported = supported & (arg.k_batch_ == 1); } + return supported; } - // polymorphic + + bool IsSupportedArgument(const BaseArgument* p_arg) override { return IsSupportedArgument(*dynamic_cast(p_arg)); @@ -779,35 +1054,65 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg); + if(arg_ptr) + { + arg_ptr->p_workspace_ = p_workspace; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); + } + void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override { - return this->SetWorkSpacePointer(p_arg, kernel_args); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); } size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override { - auto p_arg_ = dynamic_cast(p_arg); - if(p_arg_) + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) { - return p_arg_->gemm_desc_kernel_arg_.size() * sizeof(GemmTransKernelArg); + return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); } else throw std::runtime_error("The argument pointer is not an object of " - "DeviceGroupedGemm_Wmma_CShuffleV3::Argument structure!"); + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); } size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override { - return GetWorkSpaceSize(p_arg); + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); } - size_t GetHostKernelArgSize(const BaseArgument* p_arg) const { return GetWorkSpaceSize(p_arg); } - static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } // polymorphic - void SetKBatchSize(BaseArgument* p_arg, index_t k_batch) const override + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override { auto arg_ptr = dynamic_cast(p_arg); if(arg_ptr) @@ -816,22 +1121,20 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg); - if(!pArg_) + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) { - throw std::runtime_error("Failed to cast argument pointer!"); + arg_ptr->UpdateKBatch(k_batch); } - - pArg_->gemm_kernel_host_args_ = p_host_kernel_args; - std::copy(pArg_->gemm_desc_kernel_arg_.begin(), - pArg_->gemm_desc_kernel_arg_.end(), - static_cast(pArg_->gemm_kernel_host_args_)); + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index a1cba118b28..70a85f33a91 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -318,6 +318,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 using Base::MakeAsGridDescriptor_AK0_M_AK1; using Base::MakeBsGridDescriptor_BK0_N_BK1; using Base::MakeDEGridDescriptor_M_N; + using Base::MakeEGridDescriptor_M_N; using Base::MakeDsGridDescriptor_M_N; using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 6efcfbaa955..3c8c07c816e 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -344,6 +344,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using ThisThreadBlock = ThisThreadBlock; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -629,7 +631,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const std::array& StrideAs, const index_t AK0) { - using GemmSpecialization = tensor_operation::device::GemmSpecialization; + // using GemmSpecialization = tensor_operation::device::GemmSpecialization; constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MPadding || @@ -698,7 +700,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const std::array& StrideBs, const index_t BK0) { - using GemmSpecialization = tensor_operation::device::GemmSpecialization; + // using GemmSpecialization = tensor_operation::device::GemmSpecialization; constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::NPadding || @@ -772,6 +774,30 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return BTransfer::template MakeWmmaTileDescriptor(); } + + template + __host__ __device__ static auto + MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) + { + constexpr auto matrix_padder = + ck::tensor_operation::device::MatrixPadder{ + MPerBlock, NPerBlock, KPerBlock}; + const auto e_grid_desc_mraw_nraw = [&]() { + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(StrideE, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), + make_tuple(I1, StrideE)); + } + }(); + + return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); + } + template __host__ __device__ static auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE) @@ -795,8 +821,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base make_tuple(Sequence<0>{}, Sequence<1>{})); // TODO: Investigate why this path is not used in the original // gridwise_gemm_xdl_cshuffle_v3.hpp -#if 0 - using GemmSpecialization = tensor_operation::device::GemmSpecialization; + #if 0 + // using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || GemmSpec == GemmSpecialization::MNKPadding) @@ -833,7 +859,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // not pad M or N return c_grid_desc_mraw_nraw; } -#endif + #endif } static constexpr auto MakeDsGridPointer() @@ -1043,34 +1069,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) && (NPerBlock % (NPerWmma * NRepeat)) == 0, "Invalid tuning param!"); - // if (!(MPerBlock % (MPerWmma * MRepeat) == 0 && NPerBlock % (NPerWmma * NRepeat) == 0)) - // { - // std::cout << "[DEBUG] Invalid tuning param!\n" - // << " MPerBlock: " << MPerBlock << "\n" - // << " NPerBlock: " << NPerBlock << "\n" - // << " MPerWmma : " << MPerWmma << "\n" - // << " NPerWmma : " << NPerWmma << "\n" - // << " MRepeat : " << MRepeat << "\n" - // << " NRepeat : " << NRepeat << "\n" - // << " Check: MPerBlock % (MPerWmma * MRepeat) == " - // << (MPerBlock % (MPerWmma * MRepeat)) << "\n" - // << " NPerBlock % (NPerWmma * NRepeat) == " - // << (NPerBlock % (NPerWmma * NRepeat)) << "\n"; - // } - // std::cout - // << "[CK_CHECK] " - // << "M=" << karg.M - // << " N=" << karg.N - // << " K=" << karg.K - // << " KBatch=" << karg.KBatch - // << " | MPerBlock=" << MPerBlock - // << " NPerBlock=" << NPerBlock - // << " KPerBlock=" << KPerBlock - // << " | MRepeat=" << MRepeat - // << " NRepeat=" << NRepeat - // << " | AK1=" << AK1Number - // << " BK1=" << BK1Number - // << std::endl; if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || GemmSpec == tensor_operation::device::GemmSpecialization::MNPadding || diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp index 253d179a995..8f49a25ee3d 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -63,8 +63,8 @@ using KernelTypes = ::testing::Types< std::tuple< Row, Row, Row, BF16, I8, BF16>, std::tuple< Row, Col, Row, BF16, I8, BF16>, - std::tuple, - std::tuple + std::tuple< Row, Row, Row, F16, I8, F16>, + std::tuple< Row, Col, Row, F16, I8, F16> >; // clang-format on diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc index 68647741b25..af39b6b0520 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -3,12 +3,11 @@ #pragma once - TYPED_TEST(TestGroupedGemm, TinyCases) { const std::vector Ms{2, 2}; - constexpr int N = 512; - constexpr int K = 256; + constexpr int N = 768; + constexpr int K = 544; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); @@ -19,8 +18,8 @@ TYPED_TEST(TestGroupedGemm, TinyCases) TYPED_TEST(TestGroupedGemm, SmallCases) { const std::vector Ms{2, 2, 2, 2, 2}; - constexpr int N = 512; - constexpr int K = 256; + constexpr int N = 768; + constexpr int K = 544; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); @@ -30,9 +29,9 @@ TYPED_TEST(TestGroupedGemm, SmallCases) TYPED_TEST(TestGroupedGemm, MidCases) { - const std::vector Ms{167, 167, 167, 167, 167, 167}; - constexpr int N = 512; - constexpr int K = 256; + const std::vector Ms{128, 128, 128, 128, 128, 128}; + constexpr int N = 768; + constexpr int K = 544; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); @@ -42,9 +41,9 @@ TYPED_TEST(TestGroupedGemm, MidCases) TYPED_TEST(TestGroupedGemm, Regular) { - const std::vector Ms{64, 64, 64}; - constexpr int N = 512; - constexpr int K = 256; + const std::vector Ms{128, 128, 128}; + constexpr int N = 768; + constexpr int K = 320; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); @@ -52,12 +51,11 @@ TYPED_TEST(TestGroupedGemm, Regular) this->Run(Ms, Ns, Ks); } - TYPED_TEST(TestGroupedGemm, MNKPadded) { - const std::vector Ms{127, 127, 127, 127}; - constexpr int N = 128; - constexpr int K = 256; + const std::vector Ms{188, 188, 188, 188}; + constexpr int N = 136; + constexpr int K = 280; const std::vector Ns(Ms.size(), N); const std::vector Ks(Ms.size(), K); @@ -74,7 +72,7 @@ TYPED_TEST(TestGroupedGemm, TestLargeKBatch) "GFX11, or using CDE element-wise operation)"; const std::vector Ms{188, 188}; - constexpr int N = 4096; + constexpr int N = 768; constexpr int K = 4096; const std::vector Ns(Ms.size(), N); From b5a9aa56abb843a566ebba678e96b6aff1aee765 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Thu, 22 Jan 2026 14:38:26 +0000 Subject: [PATCH 09/16] changing test input and kernel tuple type --- .../device_grouped_gemm_wmma_fixed_nk.hpp | 2 +- .../profile_grouped_gemm_fixed_nk_impl.hpp | 5 ++-- .../test_grouped_gemm_fixed_nk.cpp | 27 ++++++++----------- .../test_grouped_gemm_fixed_nk_cases.inc | 12 ++++----- 4 files changed, 21 insertions(+), 25 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 792c89a1b44..af826abd60f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -364,7 +364,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK, + Sequence, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index efcb52d3093..d4869591980 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -76,9 +76,10 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; double max_abs_in_val = 0.f; - + int sum_of_m = 0; for(std::size_t i = 0; i < group_count; i++) { + sum_of_m += Ms[i]; a_m_k.push_back( Tensor(f_host_tensor_descriptor(Ms[i], Ks[i], StrideAs[i], ALayout{}))); b_k_n.push_back( @@ -151,7 +152,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, a_device_buf[i]->ToDevice(a_m_k[i].mData.data()); b_device_buf[i]->ToDevice(b_k_n[i].mData.data()); - gemm_descs.push_back({Ms[i], Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); + gemm_descs.push_back({sum_of_m, Ns[i], Ks[i], StrideAs[i], StrideBs[i], StrideCs[i], {}}); p_a.push_back(a_device_buf[i]->GetDeviceBuffer()); p_b.push_back(b_device_buf[i]->GetDeviceBuffer()); diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp index 8f49a25ee3d..d9ee6797f11 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -5,9 +5,8 @@ #include #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" #include "gtest/gtest.h" #include "test_grouped_gemm_util.hpp" @@ -20,10 +19,6 @@ using BF16 = ck::bhalf_t; using F8 = ck::f8_t; using I8 = int8_t; -using AElementOp = ck::tensor_operation::element_wise::PassThrough; -using BElementOp = ck::tensor_operation::element_wise::PassThrough; -using CDEElementOp = ck::tensor_operation::element_wise::PassThrough; - using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; @@ -50,21 +45,21 @@ using KernelTypes = ::testing::Types< #if (defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || (defined(CK_USE_WMMA) && defined(__gfx12__)) - std::tuple< Row, Row, Row, F16, F8, F16>, - std::tuple< Row, Col, Row, F16, F8, F16>, + ck::Tuple< Row, Row, Row, F16, F8, F16>, + ck::Tuple< Row, Col, Row, F16, F8, F16>, #endif - std::tuple< Row, Row, Row, F16, F16, F16>, - std::tuple< Row, Col, Row, F16, F16, F16>, + ck::Tuple< Row, Row, Row, F16, F16, F16>, + ck::Tuple< Row, Col, Row, F16, F16, F16>, - std::tuple< Row, Row, Row, BF16, BF16, BF16>, - std::tuple< Row, Col, Row, BF16, BF16, BF16>, - std::tuple< Row, Row, Row, BF16, I8, BF16>, - std::tuple< Row, Col, Row, BF16, I8, BF16>, + ck::Tuple< Row, Row, Row, BF16, BF16, BF16>, + ck::Tuple< Row, Col, Row, BF16, BF16, BF16>, + ck::Tuple< Row, Row, Row, BF16, I8, BF16>, + ck::Tuple< Row, Col, Row, BF16, I8, BF16>, - std::tuple< Row, Row, Row, F16, I8, F16>, - std::tuple< Row, Col, Row, F16, I8, F16> + ck::Tuple< Row, Row, Row, F16, I8, F16>, + ck::Tuple< Row, Col, Row, F16, I8, F16> >; // clang-format on diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc index af39b6b0520..f0b4ee61088 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -5,7 +5,7 @@ TYPED_TEST(TestGroupedGemm, TinyCases) { - const std::vector Ms{2, 2}; + const std::vector Ms{2, 1}; constexpr int N = 768; constexpr int K = 544; @@ -17,7 +17,7 @@ TYPED_TEST(TestGroupedGemm, TinyCases) TYPED_TEST(TestGroupedGemm, SmallCases) { - const std::vector Ms{2, 2, 2, 2, 2}; + const std::vector Ms{2, 1, 3, 4, 5}; constexpr int N = 768; constexpr int K = 544; @@ -29,7 +29,7 @@ TYPED_TEST(TestGroupedGemm, SmallCases) TYPED_TEST(TestGroupedGemm, MidCases) { - const std::vector Ms{128, 128, 128, 128, 128, 128}; + const std::vector Ms{167, 183, 177, 153, 139, 204}; constexpr int N = 768; constexpr int K = 544; @@ -41,7 +41,7 @@ TYPED_TEST(TestGroupedGemm, MidCases) TYPED_TEST(TestGroupedGemm, Regular) { - const std::vector Ms{128, 128, 128}; + const std::vector Ms{64, 128, 256}; constexpr int N = 768; constexpr int K = 320; @@ -53,7 +53,7 @@ TYPED_TEST(TestGroupedGemm, Regular) TYPED_TEST(TestGroupedGemm, MNKPadded) { - const std::vector Ms{188, 188, 188, 188}; + const std::vector Ms{127, 150, 188, 210}; constexpr int N = 136; constexpr int K = 280; @@ -71,7 +71,7 @@ TYPED_TEST(TestGroupedGemm, TestLargeKBatch) GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on " "GFX11, or using CDE element-wise operation)"; - const std::vector Ms{188, 188}; + const std::vector Ms{188, 210}; constexpr int N = 768; constexpr int K = 4096; From 254a038c5ec6b8d7e9005423d0276868073ca1dc Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Mon, 26 Jan 2026 08:39:51 +0000 Subject: [PATCH 10/16] Working for kbatch=1, debugging the multi k versions --- .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 25 +- .../device_grouped_gemm_wmma_fixed_nk.hpp | 540 ++++++++++-------- ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 81 ++- 3 files changed, 349 insertions(+), 297 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp index a31596d6fa1..bec76a366e0 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -52,7 +52,7 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MKPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk // clang-format off @@ -289,15 +289,15 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); } - // // Copy device tensors back to host - // for(std::size_t i = 0; i < c_device_tensors.size(); i++) - // { - // c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), - // c_device_tensors[i].mDesc.GetElementSize() * - // sizeof(EDataType)); + // Copy device tensors back to host + for(std::size_t i = 0; i < c_device_tensors.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); - // } + } // // Print out device and reference results for debugging // std::cout << "[CK GEMM RESULT TRACE]\n"; // for(std::size_t i = 0; i < c_device_tensors.size(); i++) @@ -368,9 +368,12 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ms.push_back(256); - problem_size.Ns.push_back(256); - problem_size.Ks.push_back(256); + // problem_size.Ms.push_back(256); + // problem_size.Ns.push_back(256); + // problem_size.Ks.push_back(256); + problem_size.Ms.push_back(128 + rand() % 128); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(1024); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index af826abd60f..06a3478566e 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -11,6 +11,7 @@ #include "ck/utility/env.hpp" #include "ck/host_utility/hip_check_error.hpp" #include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" #include "ck/utility/tuple.hpp" #include "ck/tensor_description/tensor_descriptor.hpp" @@ -29,6 +30,9 @@ namespace device { template (); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -75,7 +83,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) return; const index_t group_start = group_id * grid_size_grp; - const index_t M = gemm_desc_ptr[group_id].M; const index_t N = gemm_desc_ptr[group_id].N; const index_t K = gemm_desc_ptr[group_id].K; @@ -96,154 +103,88 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto local_grid_size = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - constexpr auto NumDTensor = DsDataType::Size(); - - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - - DsGridPointer p_ds_grid_; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - // D pointer - p_ds_grid_(i) = static_cast(gemm_desc_ptr[group_id].p_ds_grid[i]); - }); - +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + auto epilogue_args = EpilogueType{}; -// #if defined(__gfx11__) -// // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions -// using c_data_type = remove_cvref_t>; -// if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && -// (std::is_same_v || -// std::is_same_v))) -// { -// #endif - auto epilogue_args = - typename GridwiseGemm::EpilogueCShuffle{}; - + // constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); const auto& desc = gemm_desc_ptr[group_id]; const typename GridwiseGemm::Problem problem{ desc.M, desc.N, desc.K, - std::array{desc.StrideA}, - std::array{desc.StrideB}, + std::array{desc.StrideA}, + std::array{desc.StrideB}, desc.StrideDs, desc.StrideE, k_batch_ }; - using AsGridPointer = typename GridwiseGemm::AsGridPointer; - using ADataType0 = remove_cvref_t>; - - AsGridPointer p_as_grid_ = make_tuple( - static_cast(gemm_desc_ptr[group_id].p_a_grid) - ); - using BsGridPointer = typename GridwiseGemm::BsGridPointer; - using BDataType0 = remove_cvref_t>; - - BsGridPointer p_bs_grid_ = make_tuple( - static_cast(gemm_desc_ptr[group_id].p_b_grid) - ); - + + typename GridwiseGemm::AsGridPointer p_as_grid_; + typename GridwiseGemm::BsGridPointer p_bs_grid_; + typename GridwiseGemm::DsGridPointer p_ds_grid_; + + static_for<0, 1, 1>{}([&](auto i) { + using ADataType = remove_cvref_t; + p_as_grid_(i) = static_cast(desc.p_a_grid); + }); + + static_for<0, 1, 1>{}([&](auto i) { + using BDataType = remove_cvref_t; + p_bs_grid_(i) = static_cast(desc.p_b_grid); + }); + + static_for<0, 1, 1>{}([&](auto i) { + using DDataType = remove_cvref_t; + p_ds_grid_(i) = static_cast(desc.p_ds_grid[i]); + }); index_t id_off = 0; index_t id_local = get_block_1d_id() - group_start; while(id_local < local_grid_size) { - - // if(threadIdx.x == 0) - // { - // printf( - // "\n[CK GEMM TRACE]\n" - // " id_local = %d\n" - // " local_grid_size = %d\n", - // int(id_local), - // int(local_grid_size) - // ); - - // } - const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off); - // auto tile_idx = block_2_etile_map.CalculateBottomIndex(make_multi_index(id_local)); - - // const index_t m_tile_idx = tile_idx[Number<0>{}]; - // const index_t n_tile_idx = tile_idx[Number<1>{}]; - // const index_t k_tile_idx = tile_idx[Number<2>{}]; - - // calculate ranges for each dimension - // const index_t m_start = m_tile_idx * MPerBlock; - // const index_t m_end = min(m_start + MPerBlock, M); - - // const index_t n_start = n_tile_idx * NPerBlock; - // const index_t n_end = min(n_start + NPerBlock, N); - - // const index_t k_start = k_tile_idx * KPerBlock; - // const index_t k_end = min(k_start + KPerBlock, K); - - // if(threadIdx.x == 0) - // { - // printf("[CK GEMM TRACE] grid_size=%d, group_id=%d, block_id=%d, " - // "m_tile=%d, n_tile=%d, k_tile=%d, " - // "M_range=[%d,%d), N_range=[%d,%d), K_range=[%d,%d)\n", - // int(local_grid_size), - // int(group_id), - // int(get_block_1d_id()), - // int(m_tile_idx), - // int(n_tile_idx), - // int(k_tile_idx), - // int(m_start), - // int(m_end), - // int(n_start), - // int(n_end), - // int(k_start), - // int(k_end)); - // } - - - GridwiseGemm::template Run, - typename GridwiseGemm::EpilogueCShuffle, + EpilogueType, 1, 2> (p_as_grid_, p_bs_grid_, p_ds_grid_, - static_cast(gemm_desc_ptr[group_id].p_e_grid), - static_cast(p_shared), + static_cast(desc.p_e_grid), + p_shared, problem, block_2_etile_map, a_element_op, b_element_op, c_element_op, epilogue_args); - - // if(threadIdx.x == 0) - // { - // printf( - // "\n[CK GEMM TRACE]\n" - // " id_local = %d\n" - // " local_grid_size = %d\n", - // int(id_local), - // int(local_grid_size)); - // } + id_off += grid_size_grp; id_local += grid_size_grp; } #undef TRACE_THREAD -// #if defined(__gfx11__) -// } -// #endif +#if defined(__gfx11__) + } +#endif #else ignore = gemm_descs_const; ignore = group_count; @@ -269,7 +210,6 @@ template + typename ComputeTypeB = ComputeTypeA> struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK, + Sequence, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -588,7 +526,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) @@ -600,47 +538,41 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) { - k_batch_ = k_batch; - - if(k_batch_ < 1) - { - throw std::runtime_error("wrong! k_batch must be > 0"); - } - - const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); - - const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE; - const index_t N = gemm_desc_kernel_arg_[0].N; - - // const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); - // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); - const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N( - AverM, N, StrideE); - - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; - - grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - - grid_size_ = grid_size_grp_ * group_count_; + // TODO: use occupancy api to calculate appropriate batch size. } - Argument(std::vector&, - std::vector&, - std::vector>&, - std::vector&, + Argument(std::vector&p_As, + std::vector&p_Bs, + std::vector>&p_Ds, + std::vector&p_Es, std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, - CDEElementwiseOperation c_element_op) + CDEElementwiseOperation c_element_op, + index_t kbatch) : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} { grid_size_ = 0; - k_batch_ = 1; + k_batch_ = kbatch; grouped_gemm_kernel_args_dev = nullptr; @@ -668,7 +600,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK p_ds_grid; @@ -718,19 +649,22 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{p_As[i]}, + std::array{p_Bs[i]}, + p_Ds[i], + type_convert(p_Es[i]), + AverM, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideE, + k_batch_ , + a_element_op, + b_element_op, + c_element_op, + false)); group_id++; } @@ -747,6 +681,33 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE; + const index_t N = gemm_desc_kernel_arg_[0].N; + + // const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); + // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeEGridDescriptor_M_N( + AverM, N, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + } + // private: index_t group_count_; @@ -754,7 +715,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK gemm_desc_kernel_arg_; + std::vector gemm_desc_kernel_arg_; std::vector> a_mtx_mraw_kraw_; std::vector> b_mtx_nraw_kraw_; @@ -777,17 +738,34 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) { - constexpr bool has_main_k_block_loop = true; + bool all_have_kbatch_gt_one = arg.gemm_desc_kernel_arg_[0].KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[0]); + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) { - const auto KPad = - GridwiseGemm::CalculateKPadded(arg.gemm_desc_kernel_arg_[i].K, arg.k_batch_); - if(GridwiseGemm::CalculateHasMainKBlockLoop(KPad) != has_main_k_block_loop) - { - throw std::runtime_error("wrong! not all gemm has_main_k_block_loop"); - } + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i]); + not_all_have_kbatch_value_same |= all_have_kbatch_gt_one xor (arg.gemm_desc_kernel_arg_[i].KBatch > 1); + } + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); } if(arg.grouped_gemm_kernel_args_dev == nullptr) @@ -801,50 +779,13 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK, - Tuple, - DsDataType, - EDataType, - e_global_memory_operation_, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - min_occupancy_, - tail_num_, - MPerBlock, - NPerBlock, - KPerBlock, - GemmSpec>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); - } - else - { const auto kernel = kernel_grouped_gemm_wmma_fixed_nk, Tuple, @@ -875,50 +816,101 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK::value) + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) { - SelectTailNumber(tail_num, [&](auto tail_num_ct) { - ave_time = launch_kernel( - std::integral_constant{}, - std::integral_constant{}, - std::integral_constant{}, - tail_num_ct); - }); + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(all_have_kbatch_gt_one) + { + + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + else + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + } } else { - if(arg.k_batch_ > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - SelectTailNumber(tail_num, [&](auto tail_num_ct) { + if(all_have_kbatch_gt_one) + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { ave_time = launch_kernel( - std::integral_constant{}, + std::integral_constant{}, std::integral_constant{}, std::integral_constant{}, tail_num_ct); }); - } - else - { - SelectTailNumber(tail_num, [&](auto tail_num_ct) { + } + else + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { ave_time = launch_kernel( - std::integral_constant{}, + std::integral_constant{}, std::integral_constant{}, std::integral_constant{}, tail_num_ct); }); + } } } - - - + // if constexpr(std::is_same::value) + // { + // SelectTailNumber(tail_num, [&](auto tail_num_ct) { + // ave_time = launch_kernel( + // std::integral_constant{}, + // std::integral_constant{}, + // std::integral_constant{}, + // tail_num_ct); + // }); + // } + // else + // { + // if(arg.k_batch_ > 1) + // { + // SelectTailNumber(tail_num, [&](auto tail_num_ct) { + // ave_time = launch_kernel( + // std::integral_constant{}, + // std::integral_constant{}, + // std::integral_constant{}, + // tail_num_ct); + // }); + // } + // else + // { + // SelectTailNumber(tail_num, [&](auto tail_num_ct) { + // ave_time = launch_kernel( + // std::integral_constant{}, + // std::integral_constant{}, + // std::integral_constant{}, + // tail_num_ct); + // }); + // } + // } return ave_time; } @@ -957,31 +949,75 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(arg.gemm_desc_kernel_arg_.size()) != arg.group_count_) + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) { return false; } - - bool supported = true; - if constexpr(GemmSpec != GemmSpecialization::Default) + if constexpr(std::is_same_v || + std::is_same_v) { - const auto a_raw_vector_dim = ABlockTransferSrcVectorDim != 1 ? 1 : 0; - const auto b_raw_vector_dim = BBlockTransferSrcVectorDim != 1 ? 1 : 0; + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } - for(index_t i = 0; i < arg.group_count_; ++i) + if constexpr(!std::is_same_v) + { + if(arg.k_batch_ > 1) { - const auto a_vector_dim = arg.a_mtx_mraw_kraw_[i].At(Number{}); - const auto b_vector_dim = arg.b_mtx_nraw_kraw_[i].At(Number{}); + // Using SplitK and a C element op would require a two stage kernel where the second + // stage applies the op on the accumulated results + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "C element operators are not supported when using SplitK. Set " + "K_BATCH to 1 or remove the operator." + << std::endl; + } + return false; + } + } - supported = supported & (a_vector_dim % ABlockTransferSrcScalarPerVector == 0); - supported = supported & (b_vector_dim % BBlockTransferSrcScalarPerVector == 0); + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; } } - if constexpr(std::is_same::value) + + if((ck::type_convert(arg.gemm_desc_kernel_arg_.size())) != arg.group_count_) { - supported = supported & (arg.k_batch_ == 1); + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; } + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) + { + + const auto& a = arg.gemm_desc_kernel_arg_[i]; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } return supported; } @@ -1032,9 +1068,23 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + // clang-format off str << "DeviceGroupedGemm_Wmma_Fixed_Nk" << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," << BlockSize << ", " << MPerBlock << ", " << NPerBlock << ", " @@ -1043,11 +1093,15 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK"; // clang-format on diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp index 1a17454fea6..4ebb0d2703a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -59,13 +59,13 @@ template , S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -81,13 +81,13 @@ template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -102,13 +102,13 @@ template , S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData|CShuffle| DsData| EData| A| B| CDE| GEMM|Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type|DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -123,14 +123,14 @@ template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> - // clang-format on + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on >; // List of instance variants to add (pipeline/scheduler/padding combinations) @@ -141,15 +141,13 @@ static constexpr InstanceVariant InstanceVariants[] = { make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1), // make_tuple(GemmDefault, InterwaveScheduler, PipelineV1), - // make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1), // make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1), // make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV3), }; -// Helper function to add a list of layout instances with specific A/B/E datatypes for all supported -// padding/scheduler/pipeline version combinations template typename LayoutInstances, - typename ADataType, // NOTE: type parameters as last so that they can be inferred from the - typename BDataType, // vector argument + typename CDEElementOp> + typename LayoutInstances, + typename ADataType, + typename BDataType, typename EDataType, typename AElementOp, typename BElementOp, @@ -177,7 +176,6 @@ void add_device_grouped_gemm_wmma_fixed_nk_instances( BElementOp, CDEElementOp>>>& instances) { - // Add all instances from our instance list static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { constexpr auto instance = InstanceVariants[i]; add_device_operation_instances(instances, @@ -190,8 +188,6 @@ void add_device_grouped_gemm_wmma_fixed_nk_instances( }); } -// Helper function to add a list of layout instances for instances with matching A/B/E data types -// for all supported padding/scheduler/pipeline version combinations template typename LayoutInstances, - typename AElementOp, // NOTE: element-wise op parameters as last so that they can be - typename BElementOp, // inferred from the vector argument + typename CDEElementOp> + typename LayoutInstances, + typename AElementOp, + typename BElementOp, typename CDEElementOp> void add_device_grouped_gemm_wmma_fixed_nk_instances( std::vector>>& instances) { - // Add all instances from our instance list static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { constexpr auto instance = InstanceVariants[i]; add_device_operation_instances(instances, @@ -260,7 +256,6 @@ void add_device_grouped_gemm_wmma_fixed_nk_irregular_instances( BElementOp, CDEElementOp>>>& instances) { - // Add all instances from our instance list static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { constexpr auto instance = InstanceVariants[i]; add_device_operation_instances(instances, From de7c62748d3e9af0163f2bacc0194b9287d13fd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A1rton=20Bidlek?= Date: Mon, 26 Jan 2026 12:26:10 +0000 Subject: [PATCH 11/16] reverting experimental changes --- .../device_grouped_gemm_wmma_fixed_nk.hpp | 148 ++++++------------ 1 file changed, 46 insertions(+), 102 deletions(-) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index 06a3478566e..f7b602620ba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -61,16 +61,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) + { #if(defined(__gfx11__) || defined(__gfx12__)) - using EpilogueType = typename std::conditional::type; - - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); - + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -83,15 +79,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) return; const index_t group_start = group_id * grid_size_grp; - const index_t M = gemm_desc_ptr[group_id].M; - const index_t N = gemm_desc_ptr[group_id].N; - const index_t K = gemm_desc_ptr[group_id].K; + auto karg = gemm_desc_ptr[group_id]; + + const index_t M = karg.M; + const index_t N = karg.N; + const index_t K = karg.K; if(M == 0 || N == 0 || K == 0) return; - const auto StrideE = gemm_desc_ptr[group_id].StrideE; + const auto StrideE = karg.StrideE; // const index_t m_padded = GridwiseGemm::CalculateMPadded(M); // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); @@ -103,9 +101,21 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto local_grid_size = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + constexpr auto NumDTensor = DsDataType::Size(); + + using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); + + DsGridPointer p_ds_grid_; + + static_for<0, NumDTensor, 1>{}([&](auto i) { + using DDataType = remove_cvref_t>; + // D pointer + p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); + }); + #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; + using c_data_type = remove_cvref_t>; if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && (std::is_same_v || std::is_same_v))) @@ -113,42 +123,34 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif - auto epilogue_args = EpilogueType{}; - - - - // constexpr auto NumDTensor = GridwiseGemm::DsGridPointer::Size(); - const auto& desc = gemm_desc_ptr[group_id]; + auto epilogue_args = + typename GridwiseGemm::EpilogueCShuffle{}; + + const auto desc = gemm_desc_ptr[group_id]; const typename GridwiseGemm::Problem problem{ desc.M, desc.N, desc.K, - std::array{desc.StrideA}, - std::array{desc.StrideB}, + std::array{desc.StrideA}, + std::array{desc.StrideB}, desc.StrideDs, desc.StrideE, k_batch_ }; - - typename GridwiseGemm::AsGridPointer p_as_grid_; - typename GridwiseGemm::BsGridPointer p_bs_grid_; - typename GridwiseGemm::DsGridPointer p_ds_grid_; - - static_for<0, 1, 1>{}([&](auto i) { - using ADataType = remove_cvref_t; - p_as_grid_(i) = static_cast(desc.p_a_grid); - }); - - static_for<0, 1, 1>{}([&](auto i) { - using BDataType = remove_cvref_t; - p_bs_grid_(i) = static_cast(desc.p_b_grid); - }); - - static_for<0, 1, 1>{}([&](auto i) { - using DDataType = remove_cvref_t; - p_ds_grid_(i) = static_cast(desc.p_ds_grid[i]); - }); + using AsGridPointer = typename GridwiseGemm::AsGridPointer; + using ADataType0 = remove_cvref_t>; + + AsGridPointer p_as_grid_ = make_tuple( + static_cast(karg.p_a_grid) + ); + using BsGridPointer = typename GridwiseGemm::BsGridPointer; + using BDataType0 = remove_cvref_t>; + + BsGridPointer p_bs_grid_ = make_tuple( + static_cast(karg.p_b_grid) + ); + index_t id_off = 0; index_t id_local = get_block_1d_id() - group_start; @@ -160,16 +162,16 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) GridwiseGemm::template Run, - EpilogueType, + typename GridwiseGemm::EpilogueCShuffle, 1, 2> (p_as_grid_, p_bs_grid_, p_ds_grid_, - static_cast(desc.p_e_grid), - p_shared, + static_cast(karg.p_e_grid), + static_cast(p_shared), problem, block_2_etile_map, a_element_op, @@ -196,6 +198,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif } + template m_tile = %d\n" - // " --> n_tile = %d\n" - // " --> k_tile = %d\n" - // "\n", - // int(block_1d_id), - // int(idx_ksplit), - // int(idx_M0), - // int(idx_N0), - // int(M01_), - // int(M01_adapt), - // int(idx_M00), - // int(idx_M01), - // int(idx_N0_M01_local), - // int(idx_N0_M01_local % M01_adapt + idx_M00 * M01_), - // int(idx_N0_M01_local / M01_adapt), - // int(idx_ksplit)); - // } - // #endif - return make_tuple(idx_ksplit, idx_N0_M01_local % M01_adapt + idx_M00 * M01_, idx_N0_M01_local / M01_adapt); @@ -526,7 +471,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) From cca13c4f4015899a9d9a00f5fb5d97a5a1505864 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Tue, 27 Jan 2026 12:19:39 +0000 Subject: [PATCH 12/16] Using kernelargument to calculate ksplit offset --- .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 16 +- .../device_grouped_gemm_wmma_fixed_nk.hpp | 157 ++++++++---------- 2 files changed, 76 insertions(+), 97 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp index bec76a366e0..69068f94557 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -52,15 +52,15 @@ using AElementOp = PassThrough; using BElementOp = PassThrough; using CDEElementOp = PassThrough; -static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MKPadding; +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on struct ProblemSize final @@ -167,9 +167,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co } } - using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; - - std::vector grouped_gemm_kernel_args_; + std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); for(int i = 0; i < group_count; i++) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index f7b602620ba..ceb1b61ff98 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -46,9 +46,6 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -65,8 +62,12 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) { #if(defined(__gfx11__) || defined(__gfx12__)) - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); __shared__ char p_shared[LDS_size]; const index_t block_id = get_block_1d_id(); @@ -79,20 +80,18 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) return; const index_t group_start = group_id * grid_size_grp; - auto karg = gemm_desc_ptr[group_id]; + auto gemmTransKernelArg = gemm_desc_ptr[group_id]; - const index_t M = karg.M; - const index_t N = karg.N; - const index_t K = karg.K; + const index_t M = gemmTransKernelArg.M; + const index_t N = gemmTransKernelArg.N; + const index_t K = gemmTransKernelArg.K; if(M == 0 || N == 0 || K == 0) return; - - const auto StrideE = karg.StrideE; + const auto StrideE = gemmTransKernelArg.StrideE; // const index_t m_padded = GridwiseGemm::CalculateMPadded(M); // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); - const auto e_grid_desc_m_n = GridwiseGemm::template MakeEGridDescriptor_M_N( M, N, StrideE); @@ -101,56 +100,37 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto local_grid_size = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - constexpr auto NumDTensor = DsDataType::Size(); - - using DsGridPointer = decltype(GridwiseGemm::MakeDsGridPointer()); - - DsGridPointer p_ds_grid_; - - static_for<0, NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - // D pointer - p_ds_grid_(i) = static_cast(karg.p_ds_grid[i]); - }); + // constexpr auto NumDTensor = DsDataType::Size(); #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using c_data_type = remove_cvref_t>; + using c_data_type = remove_cvref_t>; if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && (std::is_same_v || std::is_same_v))) { #endif - - - auto epilogue_args = - typename GridwiseGemm::EpilogueCShuffle{}; - - const auto desc = gemm_desc_ptr[group_id]; - const typename GridwiseGemm::Problem problem{ - desc.M, - desc.N, - desc.K, - std::array{desc.StrideA}, - std::array{desc.StrideB}, - desc.StrideDs, - desc.StrideE, - k_batch_ + using KernelArgument = typename GridwiseGemm::Argument; + + KernelArgument kernel_arg{ + std::array{gemmTransKernelArg.p_a_grid}, + std::array{gemmTransKernelArg.p_b_grid}, + gemmTransKernelArg.p_ds_grid, + type_convert(gemmTransKernelArg.p_e_grid), + gemmTransKernelArg.M, + gemmTransKernelArg.N, + gemmTransKernelArg.K, + std::array{gemmTransKernelArg.StrideA}, + std::array{gemmTransKernelArg.StrideB}, + gemmTransKernelArg.StrideDs, + gemmTransKernelArg.StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false }; - using AsGridPointer = typename GridwiseGemm::AsGridPointer; - using ADataType0 = remove_cvref_t>; - - AsGridPointer p_as_grid_ = make_tuple( - static_cast(karg.p_a_grid) - ); - using BsGridPointer = typename GridwiseGemm::BsGridPointer; - using BDataType0 = remove_cvref_t>; - - BsGridPointer p_bs_grid_ = make_tuple( - static_cast(karg.p_b_grid) - ); - index_t id_off = 0; index_t id_local = get_block_1d_id() - group_start; @@ -160,24 +140,26 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off); + auto tile_index = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]); + + auto epilogue_args = EpilogueType{}; + GridwiseGemm::template Run, - typename GridwiseGemm::EpilogueCShuffle, - 1, - 2> - (p_as_grid_, - p_bs_grid_, - p_ds_grid_, - static_cast(karg.p_e_grid), - static_cast(p_shared), - problem, - block_2_etile_map, - a_element_op, - b_element_op, - c_element_op, - epilogue_args); + CGlobalMemoryDataOperation, + TailNum, + GroupedGemmBlock2ETileMap, + EpilogueType, + 1, // Block2CTileMap MBlock index + 2 // Block2CTileMap NBlock index + >(static_cast(p_shared), + splitk_batch_offset, + kernel_arg, + block_2_etile_map, + epilogue_args); id_off += grid_size_grp; id_local += grid_size_grp; @@ -305,7 +287,7 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK, + Sequence, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -743,9 +725,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; return launch_and_time_kernel(stream_config, @@ -861,21 +840,23 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK void SelectTailNumber(TailNumber tail_num, Lambda&& lambda) { - switch(tail_num) - { - case TailNumber::Full: lambda(std::integral_constant{}); break; - case TailNumber::Empty: lambda(std::integral_constant{}); break; - case TailNumber::One: lambda(std::integral_constant{}); break; - case TailNumber::Two: lambda(std::integral_constant{}); break; - case TailNumber::Three: lambda(std::integral_constant{}); break; - case TailNumber::Four: lambda(std::integral_constant{}); break; - case TailNumber::Five: lambda(std::integral_constant{}); break; - case TailNumber::Six: lambda(std::integral_constant{}); break; - case TailNumber::Seven: lambda(std::integral_constant{}); break; - case TailNumber::Odd: lambda(std::integral_constant{}); break; - case TailNumber::Even: lambda(std::integral_constant{}); break; - default: lambda(std::integral_constant{}); break;; - } + ignore = tail_num; + lambda(std::integral_constant{}); + // switch(tail_num) + // { + // case TailNumber::Full: lambda(std::integral_constant{}); break; + // case TailNumber::Empty: lambda(std::integral_constant{}); break; + // case TailNumber::One: lambda(std::integral_constant{}); break; + // case TailNumber::Two: lambda(std::integral_constant{}); break; + // case TailNumber::Three: lambda(std::integral_constant{}); break; + // case TailNumber::Four: lambda(std::integral_constant{}); break; + // case TailNumber::Five: lambda(std::integral_constant{}); break; + // case TailNumber::Six: lambda(std::integral_constant{}); break; + // case TailNumber::Seven: lambda(std::integral_constant{}); break; + // case TailNumber::Odd: lambda(std::integral_constant{}); break; + // case TailNumber::Even: lambda(std::integral_constant{}); break; + // default: lambda(std::integral_constant{}); break;; + // } } float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) From 41a6d45b220b3cd39147fd8ec5d7fe7b67e25729 Mon Sep 17 00:00:00 2001 From: Erwin Terpstra Date: Tue, 27 Jan 2026 16:51:18 +0000 Subject: [PATCH 13/16] fix: split-k not working due to outdated values being used --- .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 26 +- .../device_grouped_gemm_wmma_fixed_nk.hpp | 330 +++++++++--------- 2 files changed, 180 insertions(+), 176 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp index 69068f94557..cefebf74e63 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -167,19 +167,23 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co } } - std::vector grouped_gemm_kernel_args_; + std::cout << "Sum of M: " << sum_of_m << std::endl; + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; grouped_gemm_kernel_args_.reserve(group_count); for(int i = 0; i < group_count; i++) { - a_tensors_device.emplace_back( - std::make_unique(sizeof(ADataType) * sum_of_m * problem_size.Ks[i])); + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i])); b_tensors_device.emplace_back(std::make_unique( sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); - c_tensors_device.emplace_back( - std::make_unique(sizeof(EDataType) * sum_of_m * problem_size.Ns[i])); + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); @@ -284,7 +288,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_element_op); ref_invoker.Run(ref_argument); - + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); } // Copy device tensors back to host @@ -293,8 +297,6 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), c_device_tensors[i].mDesc.GetElementSize() * sizeof(EDataType)); - - } // // Print out device and reference results for debugging // std::cout << "[CK GEMM RESULT TRACE]\n"; @@ -325,10 +327,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co // } } - - - std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; - return pass; + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + return pass; } int main(int argc, char* argv[]) @@ -369,7 +369,7 @@ int main(int argc, char* argv[]) // problem_size.Ms.push_back(256); // problem_size.Ns.push_back(256); // problem_size.Ks.push_back(256); - problem_size.Ms.push_back(128 + rand() % 128); + problem_size.Ms.push_back(128 + i * 128); problem_size.Ns.push_back(1024); problem_size.Ks.push_back(1024); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index ceb1b61ff98..da1b842690d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -73,7 +73,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) const index_t block_id = get_block_1d_id(); const auto gemm_desc_ptr = reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); - const index_t group_id = block_id / grid_size_grp; if(group_id >= group_count) @@ -89,12 +88,11 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if(M == 0 || N == 0 || K == 0) return; - const auto StrideE = gemmTransKernelArg.StrideE; + const auto StrideE = gemmTransKernelArg.StrideE; // const index_t m_padded = GridwiseGemm::CalculateMPadded(M); // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N( - M, N, StrideE); + GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; @@ -112,36 +110,33 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif using KernelArgument = typename GridwiseGemm::Argument; - KernelArgument kernel_arg{ - std::array{gemmTransKernelArg.p_a_grid}, - std::array{gemmTransKernelArg.p_b_grid}, - gemmTransKernelArg.p_ds_grid, - type_convert(gemmTransKernelArg.p_e_grid), - gemmTransKernelArg.M, - gemmTransKernelArg.N, - gemmTransKernelArg.K, - std::array{gemmTransKernelArg.StrideA}, - std::array{gemmTransKernelArg.StrideB}, - gemmTransKernelArg.StrideDs, - gemmTransKernelArg.StrideE, - k_batch_, - a_element_op, - b_element_op, - c_element_op, - false - }; - - index_t id_off = 0; index_t id_local = get_block_1d_id() - group_start; while(id_local < local_grid_size) { + KernelArgument kernel_arg{std::array{gemmTransKernelArg.p_a_grid}, + std::array{gemmTransKernelArg.p_b_grid}, + gemmTransKernelArg.p_ds_grid, + type_convert(gemmTransKernelArg.p_e_grid), + gemmTransKernelArg.M, + gemmTransKernelArg.N, + gemmTransKernelArg.K, + std::array{gemmTransKernelArg.StrideA}, + std::array{gemmTransKernelArg.StrideB}, + gemmTransKernelArg.StrideDs, + gemmTransKernelArg.StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false}; + const auto block_2_etile_map = GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off); auto tile_index = - block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]); @@ -149,17 +144,17 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) auto epilogue_args = EpilogueType{}; GridwiseGemm::template Run(static_cast(p_shared), - splitk_batch_offset, - kernel_arg, - block_2_etile_map, - epilogue_args); + CGlobalMemoryDataOperation, + TailNum, + GroupedGemmBlock2ETileMap, + EpilogueType, + 1, // Block2CTileMap MBlock index + 2 // Block2CTileMap NBlock index + >(static_cast(p_shared), + splitk_batch_offset, + kernel_arg, + block_2_etile_map, + epilogue_args); id_off += grid_size_grp; id_local += grid_size_grp; @@ -180,7 +175,6 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif } - template , + Sequence, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, @@ -299,8 +295,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( 1, 1, 1))>; - - template struct OffsettedBlockToCTileMapMLoops { @@ -347,7 +341,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops { @@ -421,18 +414,17 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, const CTileDim& /* c_tile_dim */) const @@ -485,10 +477,10 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK&p_As, - std::vector&p_Bs, - std::vector>&p_Ds, - std::vector&p_Es, + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, std::vector& gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, @@ -526,11 +518,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK p_ds_grid; - - static_for<0, NumDTensor, 1>{}([&](auto j) { p_ds_grid[j] = nullptr; }); - std::array StrideDs; static_for<0, NumDTensor, 1>{}([&](auto j) { @@ -542,8 +532,10 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{p_As[i]}, - std::array{p_Bs[i]}, - p_Ds[i], - type_convert(p_Es[i]), - AverM, - N, - K, - std::array{StrideA}, - std::array{StrideB}, - StrideDs, - StrideE, - k_batch_ , - a_element_op, - b_element_op, - c_element_op, - false)); + gemm_desc_kernel_arg_.push_back(KernelArgument(std::array{nullptr}, + std::array{nullptr}, + p_ds_grid, + nullptr, + AverM, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false)); group_id++; } @@ -598,10 +590,8 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - sum_of_m, gemm_desc_kernel_arg_[0].N, - gemm_desc_kernel_arg_[0].StrideE); + sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE); - const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); @@ -632,6 +622,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK 1); + all_have_main_k0_block_loop xor + CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i]); + not_all_have_kbatch_value_same |= + all_have_kbatch_gt_one xor (arg.gemm_desc_kernel_arg_[i].KBatch > 1); } if(not_all_have_main_k0_block_loop_same) @@ -705,47 +701,47 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK, - Tuple, - DsDataType, - EDataType, - e_global_memory_operation_, - Block2ETileMap, - GroupedGemmBlock2ETileMap, - AElementwiseOperation, - BElementwiseOperation, - CDEElementwiseOperation, - min_occupancy_, - tail_num_, - GemmSpec>; - - return launch_and_time_kernel(stream_config, - kernel, - dim3(arg.grid_size_), - dim3(BlockSize), - 0, - cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), - arg.gemm_desc_kernel_arg_.size(), - arg.grid_size_grp_, - arg.k_batch_, - arg.a_element_op_, - arg.b_element_op_, - arg.c_element_op_); + const auto kernel = kernel_grouped_gemm_wmma_fixed_nk, + Tuple, + DsDataType, + EDataType, + e_global_memory_operation_, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + min_occupancy_, + tail_num_, + GemmSpec>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); }; - // const auto tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(arg.gemm_desc_kernel_arg_[0].K); - const auto tail_num = TailNumber::Full; + // const auto tail_num = + // GridwiseGemm::CalculateKBlockLoopTailNum(arg.gemm_desc_kernel_arg_[0].K); + const auto tail_num = TailNumber::Full; constexpr index_t min_occupancy = 1; - if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) { if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || @@ -755,22 +751,24 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}, - std::integral_constant{}, - std::integral_constant{}, - tail_num_ct); - }); + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); } else { SelectTailNumber(tail_num, [&](auto tail_num_ct) { - ave_time = launch_kernel( - std::integral_constant{}, - std::integral_constant{}, - std::integral_constant{}, - tail_num_ct); - }); + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); } } } @@ -781,22 +779,24 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}, - std::integral_constant{}, - std::integral_constant{}, - tail_num_ct); - }); + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); } else { SelectTailNumber(tail_num, [&](auto tail_num_ct) { - ave_time = launch_kernel( - std::integral_constant{}, - std::integral_constant{}, - std::integral_constant{}, - tail_num_ct); - }); + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); } } } @@ -806,9 +806,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}, - // std::integral_constant{}, - // std::integral_constant{}, - // tail_num_ct); + // std::integral_constant{}, std::integral_constant{}, tail_num_ct); // }); // } // else @@ -818,7 +818,8 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}, - // std::integral_constant{}, + // std::integral_constant{}, // std::integral_constant{}, // tail_num_ct); // }); @@ -828,9 +829,9 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}, - // std::integral_constant{}, - // std::integral_constant{}, - // tail_num_ct); + // std::integral_constant{}, std::integral_constant{}, tail_num_ct); // }); // } // } @@ -844,24 +845,30 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK{}); // switch(tail_num) // { - // case TailNumber::Full: lambda(std::integral_constant{}); break; - // case TailNumber::Empty: lambda(std::integral_constant{}); break; - // case TailNumber::One: lambda(std::integral_constant{}); break; - // case TailNumber::Two: lambda(std::integral_constant{}); break; - // case TailNumber::Three: lambda(std::integral_constant{}); break; - // case TailNumber::Four: lambda(std::integral_constant{}); break; - // case TailNumber::Five: lambda(std::integral_constant{}); break; - // case TailNumber::Six: lambda(std::integral_constant{}); break; - // case TailNumber::Seven: lambda(std::integral_constant{}); break; - // case TailNumber::Odd: lambda(std::integral_constant{}); break; - // case TailNumber::Even: lambda(std::integral_constant{}); break; - // default: lambda(std::integral_constant{}); break;; + // case TailNumber::Full: lambda(std::integral_constant{}); break; case TailNumber::Empty: + // lambda(std::integral_constant{}); break; case + // TailNumber::One: lambda(std::integral_constant{}); break; case TailNumber::Two: + // lambda(std::integral_constant{}); break; case + // TailNumber::Three: lambda(std::integral_constant{}); break; case TailNumber::Four: + // lambda(std::integral_constant{}); break; case + // TailNumber::Five: lambda(std::integral_constant{}); break; case TailNumber::Six: + // lambda(std::integral_constant{}); break; case + // TailNumber::Seven: lambda(std::integral_constant{}); break; case TailNumber::Odd: + // lambda(std::integral_constant{}); break; case + // TailNumber::Even: lambda(std::integral_constant{}); break; default: lambda(std::integral_constant{}); break;; // } } float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - return RunImp(arg, stream_config); + { + return RunImp(arg, stream_config); } // polymorphic @@ -946,7 +953,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK(p_arg)); @@ -1033,7 +1039,6 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK Date: Wed, 28 Jan 2026 10:33:18 +0000 Subject: [PATCH 14/16] removing unnecessary changes --- .../run_grouped_gemm_example.inc | 6 ---- .../device_grouped_gemm_wmma_fixed_nk.hpp | 36 +++++++++---------- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 1 - .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 24 ------------- 4 files changed, 17 insertions(+), 50 deletions(-) diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index b6b9835161a..72b60d6beb1 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -349,14 +349,8 @@ bool run_grouped_gemm_example(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { problem_size.Ms.push_back(256 + 256 * i); - -#ifdef FIX_NK - problem_size.Ns.push_back(512); - problem_size.Ks.push_back(512); -#else problem_size.Ns.push_back(128 + 128 * i); problem_size.Ks.push_back(128 + 64 * i); -#endif problem_size.stride_As.push_back( get_stride(ALayout{}, problem_size.Ms[i], problem_size.Ks[i])); diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index da1b842690d..e1b3c827f74 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -89,17 +89,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) return; const auto StrideE = gemmTransKernelArg.StrideE; - // const index_t m_padded = GridwiseGemm::CalculateMPadded(M); - // const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const auto e_grid_desc_m_n = - GridwiseGemm::template MakeEGridDescriptor_M_N(M, N, StrideE); + GridwiseGemm::template MakeDEGridDescriptor_M_N(M, m_padded, N, n_padded, StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; const auto local_grid_size = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); - // constexpr auto NumDTensor = DsDataType::Size(); - #if defined(__gfx11__) // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions using c_data_type = remove_cvref_t>; @@ -292,8 +290,8 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK; using CGridDesc_M_N = - remove_cvref_t( - 1, 1, 1))>; + remove_cvref_t( + 1, 1, 1, 1, 1))>; template struct OffsettedBlockToCTileMapMLoops @@ -536,11 +534,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - AverM, N, StrideE); + GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, m_padded, N, n_padded, StrideE); // block-to-e-tile map const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; @@ -586,11 +584,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - sum_of_m, gemm_desc_kernel_arg_[0].N, gemm_desc_kernel_arg_[0].StrideE); + GridwiseGemm::template MakeDEGridDescriptor_M_N( + sum_of_m, sum_of_m_padded, gemm_desc_kernel_arg_[0].N, n_padded, gemm_desc_kernel_arg_[0].StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; @@ -611,11 +609,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - AverM, N, StrideE); + GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, m_padded, N, n_padded, StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index 70a85f33a91..a1cba118b28 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -318,7 +318,6 @@ struct GridwiseGemm_wmma_cshuffle_v3 using Base::MakeAsGridDescriptor_AK0_M_AK1; using Base::MakeBsGridDescriptor_BK0_N_BK1; using Base::MakeDEGridDescriptor_M_N; - using Base::MakeEGridDescriptor_M_N; using Base::MakeDsGridDescriptor_M_N; using Base::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 3c8c07c816e..6374f7b24ac 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -774,30 +774,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_base return BTransfer::template MakeWmmaTileDescriptor(); } - - template - __host__ __device__ static auto - MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE) - { - constexpr auto matrix_padder = - ck::tensor_operation::device::MatrixPadder{ - MPerBlock, NPerBlock, KPerBlock}; - const auto e_grid_desc_mraw_nraw = [&]() { - if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(StrideE, I1)); - } - else if constexpr(is_same::value) - { - return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw), - make_tuple(I1, StrideE)); - } - }(); - - return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw); - } - template __host__ __device__ static auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE) From d472ade66378448afdb5a15f849c2c26b1519e58 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Wed, 28 Jan 2026 11:22:08 +0000 Subject: [PATCH 15/16] restoring NumGemmKPrefetchStage --- .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 11 ++-- .../device_grouped_gemm_wmma_fixed_nk.hpp | 1 + ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 58 +++++++++---------- 3 files changed, 36 insertions(+), 34 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp index cefebf74e63..b1da5f76f5a 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -56,13 +56,14 @@ static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecializa using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk // clang-format off -//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| -//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| -//######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| -//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on + struct ProblemSize final { std::vector Ms; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index e1b3c827f74..e2c2a3cc3f9 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -187,6 +187,7 @@ template , S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -81,13 +81,13 @@ template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -102,13 +102,13 @@ template , S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> // clang-format on >; @@ -123,14 +123,14 @@ template , S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, - DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, TA, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> - // clang-format on + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on >; // List of instance variants to add (pipeline/scheduler/padding combinations) From e2ccce23d66b8ebf5e6c9e5913b3668788ebdca2 Mon Sep 17 00:00:00 2001 From: Marton Bidlek Date: Thu, 29 Jan 2026 09:18:20 +0000 Subject: [PATCH 16/16] clang format --- .../grouped_gemm_wmma_fixed_nk_fp16.cpp | 1 - .../device_grouped_gemm_wmma_fixed_nk.hpp | 17 +- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 17 +- ...ce_grouped_gemm_wmma_fixed_nk_instance.hpp | 26 +- .../gpu/grouped_gemm_fixed_nk.hpp | 18 +- ...ed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp | 20 +- ...ed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp | 20 +- ...ixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp | 21 +- ...ixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp | 21 +- ...fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp | 21 +- ...fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp | 21 +- ...fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp | 21 +- ...fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp | 21 +- ..._fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp | 21 +- ..._fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp | 21 +- .../profile_grouped_gemm_fixed_nk_impl.hpp | 6 +- .../test_grouped_gemm_fixed_nk.cpp | 32 ++- test/grouped_gemm/test_grouped_gemm_util.hpp | 235 +++++++++--------- 18 files changed, 269 insertions(+), 291 deletions(-) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp index b1da5f76f5a..06bdaea48b0 100644 --- a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -63,7 +63,6 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_ < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; // clang-format on - struct ProblemSize final { std::vector Ms; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp index e2c2a3cc3f9..9fcba0aa2ac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -88,7 +88,7 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) if(M == 0 || N == 0 || K == 0) return; - const auto StrideE = gemmTransKernelArg.StrideE; + const auto StrideE = gemmTransKernelArg.StrideE; const index_t m_padded = GridwiseGemm::CalculateMPadded(M); const index_t n_padded = GridwiseGemm::CalculateNPadded(N); const auto e_grid_desc_m_n = @@ -589,7 +589,11 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - sum_of_m, sum_of_m_padded, gemm_desc_kernel_arg_[0].N, n_padded, gemm_desc_kernel_arg_[0].StrideE); + sum_of_m, + sum_of_m_padded, + gemm_desc_kernel_arg_[0].N, + n_padded, + gemm_desc_kernel_arg_[0].StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; @@ -610,11 +614,10 @@ struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK( - AverM, m_padded, N, n_padded, StrideE); + const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, m_padded, N, n_padded, StrideE); const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 6374f7b24ac..621a5ff0d32 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -632,7 +632,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const index_t AK0) { // using GemmSpecialization = tensor_operation::device::GemmSpecialization; - constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || + constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding; @@ -701,7 +701,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const index_t BK0) { // using GemmSpecialization = tensor_operation::device::GemmSpecialization; - constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding || + constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding; @@ -797,7 +797,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base make_tuple(Sequence<0>{}, Sequence<1>{})); // TODO: Investigate why this path is not used in the original // gridwise_gemm_xdl_cshuffle_v3.hpp - #if 0 +#if 0 // using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || @@ -835,7 +835,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // not pad M or N return c_grid_desc_mraw_nraw; } - #endif +#endif } static constexpr auto MakeDsGridPointer() @@ -1094,11 +1094,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_base if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! " - "K_Batch:" << karg.KBatch << " " << - "K0PerBlock:" << KPerBlock << " " << - "K1:" << AK1Number << " " << - "K:" << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; + "K_Batch:" + << karg.KBatch << " " << "K0PerBlock:" << KPerBlock << " " + << "K1:" << AK1Number << " " << "K:" << karg.K << " " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp index a13e7964b72..4bd639cd19e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -56,9 +56,8 @@ template = false> -using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances = - std::tuple< - // clang-format off +using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances = std::tuple< + // clang-format off //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| //#############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| @@ -66,8 +65,8 @@ using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances = DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> - // clang-format on - >; + // clang-format on + >; // Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType template = false> -using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances = - std::tuple< - // clang-format off +using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances = std::tuple< + // clang-format off //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| //#############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| @@ -88,8 +86,8 @@ using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances = DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> - // clang-format on - >; + // clang-format on + >; template + typename CDEElementOp> typename LayoutInstances, typename ADataType, typename BDataType, @@ -239,7 +237,8 @@ template typename LayoutInstances, + typename CDEElementOp> + typename LayoutInstances, typename AElementOp, typename BElementOp, typename CDEElementOp> @@ -259,7 +258,8 @@ void add_device_grouped_gemm_wmma_fixed_nk_irregular_instances( static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { constexpr auto instance = InstanceVariants[i]; add_device_operation_instances(instances, - LayoutInstances{}), instance.At(Number<1>{}), instance.At(Number<2>{}), diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index 3418f0c69a3..6e1668ffa5e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -69,7 +69,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_f16_f8_f16_mk_nk_mn_instances( F16, PassThrough, PassThrough, - PassThrough>>>& instances); + PassThrough>>>& instances); // i8_inputB void add_device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instances( @@ -124,7 +124,7 @@ void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( F16, PassThrough, PassThrough, - PassThrough>>>& instances); + PassThrough>>>& instances); void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances( std::vector>>& instances); + PassThrough>>>& instances); #endif // bf16_inputA i8_inputB @@ -193,7 +193,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( BF16, PassThrough, PassThrough, - PassThrough>>>& instances); + PassThrough>>>& instances); void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( std::vector>>& instances); -#endif + PassThrough>>>& instances); +#endif #endif // bf16_inputA bf16_inputB @@ -253,7 +253,6 @@ void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( PassThrough, PassThrough>>>& instances); - void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( std::vector>>& instances); - void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( std::vector>>& instances) + Row, + DsLayout, + Row, + BF16, + BF16, + DsDataType, + BF16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp index 939f21a88e3..d49141ed681 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -13,16 +13,16 @@ namespace instance { void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( std::vector>>& instances) + Col, + DsLayout, + Row, + BF16, + BF16, + DsDataType, + BF16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp index cfde259a4d3..9356dac4fc6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( std::vector>>& instances) + Row, + DsLayout, + Row, + BF16, + I8, + DsDataType, + BF16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp index 64c96b8d913..8ce1b7ac1e6 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( std::vector>>& instances) + Col, + DsLayout, + Row, + BF16, + I8, + DsDataType, + BF16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< BF16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp index c5f08ca1ee2..d59a908c7c4 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances) + Row, + DsLayout, + Row, + F16, + F16, + DsDataType, + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp index eae317f8f73..eeacb3b1f9d 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( std::vector>>& instances) + Col, + DsLayout, + Row, + F16, + F16, + DsDataType, + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp index 43a990064ba..a128afe8932 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances( std::vector>>& instances) + Row, + DsLayout, + Row, + F16, + F8, + DsDataType, + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp index 115cc95558f..5d943a0f340 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances( std::vector>>& instances) + Col, + DsLayout, + Row, + F16, + F8, + DsDataType, + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp index 3bb479cafe9..f118fd75b8c 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances( std::vector>>& instances) + Row, + DsLayout, + Row, + F16, + I8, + DsDataType, + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< F16, diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp index 3e40d594557..86a27be9866 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp @@ -11,19 +11,18 @@ namespace tensor_operation { namespace device { namespace instance { - void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances( std::vector>>& instances) + Col, + DsLayout, + Row, + F16, + I8, + DsDataType, + F16, + PassThrough, + PassThrough, + PassThrough>>>& instances) { add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< F16, diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index d4869591980..988b7d002d7 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -46,7 +46,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, int n_warmup = 1, int n_iter = 10) { - bool pass = true; + bool pass = true; using ComputeDataType = ADataType; auto f_host_tensor_descriptor = @@ -76,7 +76,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; double max_abs_in_val = 0.f; - int sum_of_m = 0; + int sum_of_m = 0; for(std::size_t i = 0; i < group_count; i++) { sum_of_m += Ms[i]; @@ -352,7 +352,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, } else { - std::cout << "Instance: " << gemm_name + std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")" << std::endl; } diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp index d9ee6797f11..f855ef1ff5d 100644 --- a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -23,12 +23,13 @@ using Row = ck::tensor_layout::gemm::RowMajor; using Col = ck::tensor_layout::gemm::ColumnMajor; template -class TestGroupedGemm : public ck::test::TestGroupedGemm +class TestGroupedGemm + : public ck::test::TestGroupedGemm { public: void SetUp() override { - ck::test::TestGroupedGemm::SetUp(); + ck::test::TestGroupedGemm::SetUp(); #if defined(CK_USE_WMMA) // The old XDL tests didn't fail if instances were not supported, so we want to keep that @@ -40,27 +41,24 @@ class TestGroupedGemm : public ck::test::TestGroupedGemm, - ck::Tuple< Row, Col, Row, F16, F8, F16>, +#if(defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || \ + (defined(CK_USE_WMMA) && defined(__gfx12__)) + ck::Tuple, + ck::Tuple, #endif - ck::Tuple< Row, Row, Row, F16, F16, F16>, - ck::Tuple< Row, Col, Row, F16, F16, F16>, - + ck::Tuple, + ck::Tuple, - ck::Tuple< Row, Row, Row, BF16, BF16, BF16>, - ck::Tuple< Row, Col, Row, BF16, BF16, BF16>, - ck::Tuple< Row, Row, Row, BF16, I8, BF16>, - ck::Tuple< Row, Col, Row, BF16, I8, BF16>, + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, - ck::Tuple< Row, Row, Row, F16, I8, F16>, - ck::Tuple< Row, Col, Row, F16, I8, F16> - >; + ck::Tuple, + ck::Tuple>; // clang-format on TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes); diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index c4e0ed68439..e372af92f3a 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -24,20 +24,18 @@ extern ck::index_t instance_index; namespace ck { namespace test { - struct DefaultGroupedGemmProfiler { - template < - typename ADataType, - typename BDataType, - typename EDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename ELayout, - typename AElementOp, - typename BElementOp, - typename CDEElementOp> + template static bool Run(bool verify, int init_method, bool log, @@ -64,84 +62,79 @@ struct DefaultGroupedGemmProfiler AElementOp, BElementOp, CDEElementOp>( - verify, - init_method, - log, - bench, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup, - n_iter, - instance_index, - fail_if_no_supported_instances); + verify, + init_method, + log, + bench, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup, + n_iter, + instance_index, + fail_if_no_supported_instances); } }; struct FixedNKGroupedGemmProfiler { - template < - typename ADataType, - typename BDataType, - typename EDataType, - typename AccDataType, - typename ALayout, - typename BLayout, - typename CLayout> - static bool Run( - bool verify, - int init_method, - bool log, - bool bench, - const std::vector& Ms, - const std::vector& Ns, - const std::vector& Ks, - const std::vector& StrideAs, - const std::vector& StrideBs, - const std::vector& StrideCs, - const std::vector& kbatches, - int n_warmup, - int n_iter, - int /*instance_index*/, - bool /*fail_if_no_supported_instances*/) + template + static bool Run(bool verify, + int init_method, + bool log, + bool bench, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + const std::vector& kbatches, + int n_warmup, + int n_iter, + int /*instance_index*/, + bool /*fail_if_no_supported_instances*/) { bool pass = true; for(int kbatch : kbatches) { - pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl< - ADataType, - BDataType, - EDataType, - AccDataType, - ALayout, - BLayout, - CLayout>( - verify, - init_method, - log, - bench, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatch, - n_warmup, - n_iter); + pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl(verify, + init_method, + log, + bench, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); } return pass; } }; - -template + typename Profiler = ck::test::DefaultGroupedGemmProfiler> class TestGroupedGemm : public testing::Test { protected: @@ -264,61 +257,59 @@ class TestGroupedGemm : public testing::Test const std::vector& StrideCs, const std::vector& kbatches) { - bool pass = false; + bool pass = false; using AccDataType = float; - if constexpr (std::is_same_v) + if constexpr(std::is_same_v) { pass = Profiler::template Run( - verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index, - fail_if_no_supported_instances_); + BDataType, + EDataType, + AccDataType, + ALayout, + BLayout, + ELayout>(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); } else { pass = Profiler::template Run( - verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index, - fail_if_no_supported_instances_); + BDataType, + EDataType, + AccDataType, + ALayout, + BLayout, + ELayout, + AElementOp, + BElementOp, + CDEElementOp>(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); } EXPECT_TRUE(pass);