diff --git a/example/15_grouped_gemm/CMakeLists.txt b/example/15_grouped_gemm/CMakeLists.txt index a7dae9dcd82..6b75ff76ced 100644 --- a/example/15_grouped_gemm/CMakeLists.txt +++ b/example/15_grouped_gemm/CMakeLists.txt @@ -47,6 +47,10 @@ add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_spl add_example_executable(example_grouped_gemm_multiple_d_wmma_fp16 grouped_gemm_multiple_d_wmma_fp16.cpp) add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_multiple_d_wmma_fp16) +add_example_executable(example_grouped_gemm_wmma_fixed_nk_fp16 grouped_gemm_wmma_fixed_nk_fp16.cpp) +add_example_dependencies(example_grouped_gemm_wmma example_grouped_gemm_wmma_fixed_nk_fp16) + + list(APPEND gpu_list_tf32 gfx942 gfx950) set(target 0) foreach(gpu IN LISTS GPU_TARGETS) diff --git a/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp new file mode 100644 index 00000000000..06bdaea48b0 --- /dev/null +++ b/example/15_grouped_gemm/grouped_gemm_wmma_fixed_nk_fp16.cpp @@ -0,0 +1,382 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/utility/check_err.hpp" +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" + +using ::ck::DeviceMem; +using ::ck::hip_check_error; +using ::ck::HostTensorDescriptor; +using ::ck::Tensor; + +template +using S = ck::Sequence; + +using F16 = ck::half_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using ADataType = F16; +using BDataType = F16; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = F16; + +using ALayout = Row; +using BLayout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::MNKPadding; + +using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemm_Wmma_Fixed_Nk + // clang-format off +//######| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +//######| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| +//######| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| +//######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + < ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementOp, BElementOp, CDEElementOp, GemmDefault, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>; +// clang-format on + +struct ProblemSize final +{ + std::vector Ms; + std::vector Ns; + std::vector Ks; + + std::vector stride_As; + std::vector stride_Bs; + std::vector stride_Cs; + + ck::index_t group_count; +}; + +struct ExecutionConfig final +{ + bool do_verification = true; + int init_method = 1; + int k_batch = 1; + bool time_kernel = false; +}; + +bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& config) +{ + auto group_count = problem_size.group_count; + + // GEMM shape + std::vector gemm_descs; + std::vector p_Cs; + + gemm_descs.reserve(group_count); + + int sum_of_m = 0; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + std::vector> a_tensors; + std::vector> b_tensors; + std::vector> c_host_tensors; + std::vector> c_device_tensors; + + a_tensors.reserve(group_count); + b_tensors.reserve(group_count); + c_host_tensors.reserve(group_count); + c_device_tensors.reserve(group_count); + + using DeviceMemPtr = std::unique_ptr; + + std::vector a_tensors_device, b_tensors_device, c_tensors_device; + + a_tensors_device.reserve(group_count); + b_tensors_device.reserve(group_count); + c_tensors_device.reserve(group_count); + + std::size_t flop = 0, num_btype = 0; + + for(int i = 0; i < group_count; i++) + { + sum_of_m += problem_size.Ms[i]; + a_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ks[i], problem_size.stride_As[i], ALayout{}))); + b_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ks[i], problem_size.Ns[i], problem_size.stride_Bs[i], BLayout{}))); + c_host_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + c_device_tensors.push_back(Tensor(f_host_tensor_descriptor( + problem_size.Ms[i], problem_size.Ns[i], problem_size.stride_Cs[i], ELayout{}))); + std::cout << "gemm[" << i << "] a_m_k: " << a_tensors[i].mDesc + << " b_k_n: " << b_tensors[i].mDesc << " c_m_n: " << c_device_tensors[i].mDesc + << std::endl; + + flop += std::size_t(2) * problem_size.Ms[i] * problem_size.Ks[i] * problem_size.Ns[i]; + num_btype += sizeof(ADataType) * a_tensors[i].mDesc.GetElementSize() + + sizeof(BDataType) * b_tensors[i].mDesc.GetElementSize() + + sizeof(EDataType) * c_device_tensors[i].mDesc.GetElementSize(); + + switch(config.init_method) + { + case 0: break; + case 1: + a_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}); + break; + case 2: + a_tensors[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + break; + default: + a_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + b_tensors[i].GenerateTensorValue(GeneratorTensor_Sequential{}); + } + } + + std::cout << "Sum of M: " << sum_of_m << std::endl; + + using GroupedGemmKernelArgument = ck::tensor_operation::device::GroupedGemmKernelArgument<>; + + std::vector grouped_gemm_kernel_args_; + grouped_gemm_kernel_args_.reserve(group_count); + + for(int i = 0; i < group_count; i++) + { + a_tensors_device.emplace_back(std::make_unique( + sizeof(ADataType) * problem_size.Ms[i] * problem_size.Ks[i])); + + b_tensors_device.emplace_back(std::make_unique( + sizeof(BDataType) * problem_size.Ns[i] * problem_size.Ks[i])); + + c_tensors_device.emplace_back(std::make_unique( + sizeof(EDataType) * problem_size.Ms[i] * problem_size.Ns[i])); + + a_tensors_device[i]->ToDevice(a_tensors[i].mData.data(), + a_tensors[i].mDesc.GetElementSpaceSize() * sizeof(ADataType)); + b_tensors_device[i]->ToDevice(b_tensors[i].mData.data(), + b_tensors[i].mDesc.GetElementSpaceSize() * sizeof(BDataType)); + c_tensors_device[i]->SetZero(); + + p_Cs.push_back(c_tensors_device[i]->GetDeviceBuffer()); + + gemm_descs.push_back({sum_of_m, + problem_size.Ns[i], + problem_size.Ks[i], + 1, + problem_size.stride_Bs[i], + 1, + {}}); + + grouped_gemm_kernel_args_.push_back({a_tensors_device[i]->GetDeviceBuffer(), + b_tensors_device[i]->GetDeviceBuffer(), + {}, + c_tensors_device[i]->GetDeviceBuffer(), + problem_size.Ms[i], + problem_size.Ns[i], + problem_size.Ks[i], + problem_size.stride_As[i], + problem_size.stride_Bs[i], + {}, + problem_size.stride_Cs[i]}); + } + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto c_element_op = CDEElementOp{}; + + auto gemm = DeviceGemmInstance{}; + auto invoker = gemm.MakeInvoker(); + + std::vector p_As = {}; + std::vector p_Bs = {}; + std::vector> p_Ds = {}; + + // do GEMM + auto argument = gemm.MakeArgument( + p_As, p_Bs, p_Ds, p_Cs, gemm_descs, a_element_op, b_element_op, c_element_op); + + DeviceMem gemm_arg_dev_mem(gemm.GetDeviceKernelArgSize(&argument)); + DeviceMem gemm_workspace_dev(gemm.GetWorkSpaceSize(&argument)); + + gemm.SetWorkSpacePointer(&argument, gemm_workspace_dev.GetDeviceBuffer()); + + hip_check_error(hipMemcpy(gemm_arg_dev_mem.GetDeviceBuffer(), + grouped_gemm_kernel_args_.data(), + gemm.GetDeviceKernelArgSize(&argument), + hipMemcpyHostToDevice)); + + if(!gemm.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + gemm.SetDeviceKernelArgs(&argument, gemm_arg_dev_mem.GetDeviceBuffer()); + gemm.SetKBatch(argument, config.k_batch); + + invoker.Run(argument, StreamConfig{nullptr, false}); + + if(config.time_kernel) + { + float ave_time = invoker.Run(argument, StreamConfig{nullptr, config.time_kernel}); + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec + << " GB/s, " << gemm.GetTypeString() << std::endl; + } + + bool pass = true; + if(config.do_verification) + { + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = ref_gemm.MakeArgument(a_tensors[i], + b_tensors[i], + c_host_tensors[i], + a_element_op, + b_element_op, + c_element_op); + + ref_invoker.Run(ref_argument); + + pass &= ck::utils::check_err(c_device_tensors[i], c_host_tensors[i]); + } + // Copy device tensors back to host + for(std::size_t i = 0; i < c_device_tensors.size(); i++) + { + c_tensors_device[i]->FromDevice(c_device_tensors[i].mData.data(), + c_device_tensors[i].mDesc.GetElementSize() * + sizeof(EDataType)); + } + // // Print out device and reference results for debugging + // std::cout << "[CK GEMM RESULT TRACE]\n"; + // for(std::size_t i = 0; i < c_device_tensors.size(); i++) + // { + // std::cout << "GEMM[" << i << "] device C:\n"; + // auto& devC = c_device_tensors[i].mData; + // for(std::size_t m = 0; m < static_cast(problem_size.Ms[i]); m++) + // { + // for(std::size_t n = 0; n < static_cast(problem_size.Ns[i]); n++) + // { + // std::cout << static_cast(devC[m * problem_size.Ns[i] + n]) << " "; + // } + // std::cout << "\n"; + // } + + // std::cout << "GEMM[" << i << "] reference C:\n"; + // auto& refC = c_host_tensors[i].mData; + // for(std::size_t m = 0; m < static_cast(problem_size.Ms[i]); m++) + // { + // for(std::size_t n = 0; n < static_cast(problem_size.Ns[i]); n++) + // { + // std::cout << static_cast(refC[m * problem_size.Ns[i] + n]) << " "; + // } + // std::cout << "\n"; + // } + // std::cout << "--------------------------------------\n"; + // } + } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; + return pass; +} + +int main(int argc, char* argv[]) +{ + ProblemSize problem_size; + ExecutionConfig config; + + problem_size.group_count = 16; + + if(argc == 5) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + } + else if(argc == 6) + { + config.do_verification = std::stoi(argv[1]); + config.init_method = std::stoi(argv[2]); + config.time_kernel = std::stoi(argv[3]); + config.k_batch = std::stoi(argv[4]); + problem_size.group_count = std::stoi(argv[5]); + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=n0, 1=yes)\n"); + printf("arg4: k_batch (> 0)\n"); + printf("arg5: group count (default=16)"); + + exit(0); + } + + for(int i = 0; i < problem_size.group_count; i++) + { + // problem_size.Ms.push_back(256); + // problem_size.Ns.push_back(256); + // problem_size.Ks.push_back(256); + problem_size.Ms.push_back(128 + i * 128); + problem_size.Ns.push_back(1024); + problem_size.Ks.push_back(1024); + + problem_size.stride_As.push_back(problem_size.Ks[i]); + problem_size.stride_Bs.push_back(problem_size.Ks[i]); + problem_size.stride_Cs.push_back(problem_size.Ns[i]); + } + + return !run_grouped_gemm(problem_size, config); +} diff --git a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp index 61f03907b72..d8759eb5461 100644 --- a/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp +++ b/example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_fp16.cpp @@ -291,6 +291,7 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co } } + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; return pass; } @@ -329,9 +330,9 @@ int main(int argc, char* argv[]) for(int i = 0; i < problem_size.group_count; i++) { - problem_size.Ms.push_back(128 + rand() % 128); - problem_size.Ns.push_back(1024); - problem_size.Ks.push_back(1024); + problem_size.Ms.push_back(256); + problem_size.Ns.push_back(256); + problem_size.Ks.push_back(256); problem_size.stride_As.push_back(problem_size.Ks[i]); problem_size.stride_Bs.push_back(problem_size.Ks[i]); diff --git a/example/15_grouped_gemm/run_grouped_gemm_example.inc b/example/15_grouped_gemm/run_grouped_gemm_example.inc index ffd0c5e9b7b..72b60d6beb1 100644 --- a/example/15_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/15_grouped_gemm/run_grouped_gemm_example.inc @@ -272,6 +272,8 @@ bool run_grouped_gemm(const ProblemSize& problem_size, const ExecutionConfig& co ComputeDataType>(c_device_tensors[i], c_host_tensors[i]); #endif } + + std::cout << "Verification: " << (pass ? "SUCCESS" : "FAILURE") << "!" << std::endl; } if(config.time_kernel) diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp new file mode 100644 index 00000000000..9fcba0aa2ac --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp @@ -0,0 +1,1128 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/env.hpp" +#include "ck/host_utility/hip_check_error.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/scheduler_enum.hpp" +#include "ck/utility/tuple.hpp" + +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_gemm_wmma_fixed_nk(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const, + const index_t group_count, + const index_t grid_size_grp, + const index_t k_batch_, + const AElementwiseOperation a_element_op, + const BElementwiseOperation b_element_op, + const CDEElementwiseOperation c_element_op) + +{ +#if(defined(__gfx11__) || defined(__gfx12__)) + + using EpilogueType = typename std::conditional::type; + + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte(); + __shared__ char p_shared[LDS_size]; + + const index_t block_id = get_block_1d_id(); + const auto gemm_desc_ptr = + reinterpret_cast(cast_pointer_to_generic_address_space(gemm_descs_const)); + + const index_t group_id = block_id / grid_size_grp; + if(group_id >= group_count) + return; + const index_t group_start = group_id * grid_size_grp; + + auto gemmTransKernelArg = gemm_desc_ptr[group_id]; + + const index_t M = gemmTransKernelArg.M; + const index_t N = gemmTransKernelArg.N; + const index_t K = gemmTransKernelArg.K; + + if(M == 0 || N == 0 || K == 0) + return; + + const auto StrideE = gemmTransKernelArg.StrideE; + const index_t m_padded = GridwiseGemm::CalculateMPadded(M); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N(M, m_padded, N, n_padded, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + const auto local_grid_size = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using c_data_type = remove_cvref_t>; + if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + using KernelArgument = typename GridwiseGemm::Argument; + + index_t id_off = 0; + index_t id_local = get_block_1d_id() - group_start; + + while(id_local < local_grid_size) + { + KernelArgument kernel_arg{std::array{gemmTransKernelArg.p_a_grid}, + std::array{gemmTransKernelArg.p_b_grid}, + gemmTransKernelArg.p_ds_grid, + type_convert(gemmTransKernelArg.p_e_grid), + gemmTransKernelArg.M, + gemmTransKernelArg.N, + gemmTransKernelArg.K, + std::array{gemmTransKernelArg.StrideA}, + std::array{gemmTransKernelArg.StrideB}, + gemmTransKernelArg.StrideDs, + gemmTransKernelArg.StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false}; + + const auto block_2_etile_map = + GroupedGemmBlock2ETileMap(local_b2c_tile_map, group_start, id_off); + + auto tile_index = + block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + auto splitk_batch_offset = + typename GridwiseGemm::SplitKBatchOffset(kernel_arg, tile_index[Number<0>{}]); + + auto epilogue_args = EpilogueType{}; + + GridwiseGemm::template Run(static_cast(p_shared), + splitk_batch_offset, + kernel_arg, + block_2_etile_map, + epilogue_args); + + id_off += grid_size_grp; + id_local += grid_size_grp; + } + +#undef TRACE_THREAD +#if defined(__gfx11__) + } +#endif +#else + ignore = gemm_descs_const; + ignore = group_count; + ignore = grid_size_grp; + ignore = k_batch_; + ignore = a_element_op; + ignore = b_element_op; + ignore = c_element_op; +#endif +} + +template +struct DeviceGroupedGemm_Wmma_Fixed_Nk : public DeviceGroupedGemmFixedNK +{ + using DeviceOp = DeviceGroupedGemm_Wmma_Fixed_Nk; + + static constexpr index_t NumDTensor = DsDataType::Size(); + + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3< + ALayout, + BLayout, + DsLayout, + ELayout, + Tuple, + Tuple, + AccDataType, + CShuffleDataType, + DsDataType, + EDataType, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + Sequence, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + false, + false>; + + using CGridDesc_M_N = + remove_cvref_t( + 1, 1, 1, 1, 1))>; + + template + struct OffsettedBlockToCTileMapMLoops + { + using underlying_type = UnderlyingBlockToCTileMap; + + __host__ __device__ OffsettedBlockToCTileMapMLoops( + UnderlyingBlockToCTileMap block_to_ctile_map, index_t block_start, index_t id_off = 0) + { + block_to_ctile_map_ = block_to_ctile_map; + block_start_ = block_start; + id_off_ = id_off; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto idx_bot = block_to_ctile_map_.CalculateBottomIndex( + make_multi_index(idx_top[Number<0>{}] - block_start_ + id_off_)); + + return make_tuple(idx_bot[Number<0>{}], idx_bot[Number<1>{}], idx_bot[Number<2>{}]); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& c_tile_idx, + const CTileDim& c_tile_dim) const + { + return block_to_ctile_map_.ValidCTileIndex(c_tile_idx, c_tile_dim); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CheckValidity(c_grid_desc_m_n); + } + + template + __host__ constexpr index_t CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return block_to_ctile_map_.CalculateGridSize(c_grid_desc_m_n); + } + + UnderlyingBlockToCTileMap block_to_ctile_map_; + index_t block_start_; + index_t id_off_; + }; + + template + struct BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops + { + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops() = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(const BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&) = default; + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops& + operator=(BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops&&) = default; + + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops(index_t M, + index_t N, + index_t KBatch, + index_t M01 = 8) + : M_(M), N_(N), KBatch_(KBatch), M01_(M01) + { + } + + template + __host__ __device__ BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + const CGridDesc_M_N& c_grid_desc_m_n, index_t KBatch, index_t M01 = 8) + : BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops( + c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), KBatch, M01) + { + } + + __host__ __device__ constexpr index_t CalculateGridSize(index_t M, index_t N) const + { + const auto M0 = math::integer_divide_ceil(M, MPerBlock); + const auto N0 = math::integer_divide_ceil(N, NPerBlock); + + return M0 * N0 * KBatch_; + } + + template + __host__ __device__ constexpr index_t + CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) const + { + return CalculateGridSize(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); + } + + template + __host__ bool CheckValidity(const CGridDesc_M_N&) const + { + return true; + } + + template + __host__ __device__ constexpr auto CalculateBottomIndex(const TopIdx& idx_top) const + { + auto block_1d_id = idx_top[I0]; + + const auto M0 = math::integer_divide_ceil(M_, MPerBlock_); + const auto N0 = math::integer_divide_ceil(N_, NPerBlock_); + + const auto total_tiles_per_group = M0 * N0 * KBatch_; + + // wrap block id into this group + block_1d_id = block_1d_id % total_tiles_per_group; + + const index_t idx_ksplit = block_1d_id / (M0 * N0); + block_1d_id = block_1d_id % (M0 * N0); + + index_t idx_N0 = block_1d_id % N0; + index_t idx_M0 = block_1d_id / N0; + + const auto M01_adapt = (idx_M0 < M0 - M0 % M01_) ? M01_ : M0 % M01_; + + index_t idx_M00 = idx_M0 / M01_; + index_t idx_M01 = idx_M0 % M01_; + index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0; + + return make_tuple(idx_ksplit, + idx_N0_M01_local % M01_adapt + idx_M00 * M01_, + idx_N0_M01_local / M01_adapt); + } + + template + __host__ __device__ bool ValidCTileIndex(const CTileIdx& /* c_tile_idx */, + const CTileDim& /* c_tile_dim */) const + { + return true; // always valid provided that user gets grid size from CalculateGridSize() + } + + private: + index_t M_; + index_t N_; + index_t KBatch_; + index_t M01_; + }; + + using Block2ETileMap = BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops; + using GroupedGemmBlock2ETileMap = OffsettedBlockToCTileMapMLoops; + + static constexpr index_t DefaultKBatch = 1; + using KernelArgument = typename GridwiseGemm::Argument; + + using GemmTransKernelArg = GroupedGemmKernelArgument; + + static constexpr bool CalculateHasMainKBlockLoop(const KernelArgument& karg) + { + index_t k_grain = karg.KBatch * KPerBlock; + index_t K_split = (karg.K + k_grain - 1) / karg.KBatch; + return GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + } + + struct Argument : public BaseArgument + { + Argument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + : Argument(p_As, + p_Bs, + p_Ds, + p_Es, + gemm_descs, + a_element_op, + b_element_op, + c_element_op, + DefaultKBatch) + { + // TODO: use occupancy api to calculate appropriate batch size. + } + + Argument(std::vector&, + std::vector&, + std::vector>&, + std::vector&, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op, + index_t kbatch) + : a_element_op_{a_element_op}, b_element_op_{b_element_op}, c_element_op_{c_element_op} + { + grid_size_ = 0; + + k_batch_ = kbatch; + + grouped_gemm_kernel_args_dev = nullptr; + + group_count_ = ck::type_convert(gemm_descs.size()); + + gemm_desc_kernel_arg_.reserve(group_count_); + + index_t group_id = 0; + + sum_of_m = gemm_descs[0].M_; + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + const index_t N = gemm_descs[0].N_; + const index_t K = gemm_descs[0].K_; + + for(std::size_t i = 0; i < gemm_descs.size(); i++) + { + if(sum_of_m != gemm_descs[i].M_ || N != gemm_descs[i].N_ || K != gemm_descs[i].K_) + { + throw std::runtime_error("wrong! M/N/K is not identical"); + } + + a_mtx_mraw_kraw_.emplace_back(sum_of_m, K); + b_mtx_nraw_kraw_.emplace_back(N, K); + + const index_t StrideA = gemm_descs[i].stride_A_; + const index_t StrideB = gemm_descs[i].stride_B_; + const index_t StrideE = gemm_descs[i].stride_C_; + + // pointer + std::array p_ds_grid; + std::array StrideDs; + + static_for<0, NumDTensor, 1>{}([&](auto j) { + // using DLayout = remove_cvref_t>; + + if(gemm_descs[i].stride_Ds_.size() != NumDTensor) + { + throw std::runtime_error( + "wrong! gemm_descs[i].stride_Ds_.size() does not match NumDTensor"); + } + + p_ds_grid[j] = nullptr; + StrideDs[j] = gemm_descs[i].stride_Ds_[j]; + }); + + const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const auto e_grid_desc_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, m_padded, N, n_padded, StrideE); + + // block-to-e-tile map + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + if(group_id * grid_size_grp_ != grid_size_) + { + throw std::runtime_error("wrong! grid_size_grp_ is not identical!"); + } + + grid_size_ += grid_size_grp_; + + if(!local_b2c_tile_map.CheckValidity(e_grid_desc_m_n)) + { + throw std::runtime_error("wrong! block_2_etile_map validation failed"); + } + + // if(!GridwiseGemm::CheckValidity(arg)) + // { + // std::ostringstream err; + // err << "Group id: " << i << " has invalid GridwiseGemm settings!" << __FILE__ + // << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + // } + + gemm_desc_kernel_arg_.push_back(KernelArgument(std::array{nullptr}, + std::array{nullptr}, + p_ds_grid, + nullptr, + AverM, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideE, + k_batch_, + a_element_op, + b_element_op, + c_element_op, + false)); + + group_id++; + } + const index_t sum_of_m_padded = GridwiseGemm::CalculateMPadded(sum_of_m); + const index_t n_padded = GridwiseGemm::CalculateNPadded(gemm_desc_kernel_arg_[0].N); + const auto e_grid_desc_sum_m_n = + GridwiseGemm::template MakeDEGridDescriptor_M_N( + sum_of_m, + sum_of_m_padded, + gemm_desc_kernel_arg_[0].N, + n_padded, + gemm_desc_kernel_arg_[0].StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_sum_m_n, k_batch_}; + + barrier_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_sum_m_n); + } + + void UpdateKBatch(index_t k_batch) + { + k_batch_ = k_batch; + + if(k_batch_ < 1) + { + throw std::runtime_error("wrong! k_batch must be > 0"); + } + + const index_t AverM = math::integer_divide_ceil(sum_of_m, group_count_); + + const index_t StrideE = gemm_desc_kernel_arg_[0].StrideE; + const index_t N = gemm_desc_kernel_arg_[0].N; + + const index_t m_padded = GridwiseGemm::CalculateMPadded(AverM); + const index_t n_padded = GridwiseGemm::CalculateNPadded(N); + const auto e_grid_desc_m_n = GridwiseGemm::template MakeDEGridDescriptor_M_N( + AverM, m_padded, N, n_padded, StrideE); + + const auto local_b2c_tile_map = Block2ETileMap{e_grid_desc_m_n, k_batch_}; + + grid_size_grp_ = local_b2c_tile_map.CalculateGridSize(e_grid_desc_m_n); + + grid_size_ = grid_size_grp_ * group_count_; + + for(std::size_t i = 0; i < gemm_desc_kernel_arg_.size(); i++) + { + gemm_desc_kernel_arg_[i].KBatch = k_batch_; + } + } + + // private: + index_t group_count_; + + AElementwiseOperation a_element_op_; + BElementwiseOperation b_element_op_; + CDEElementwiseOperation c_element_op_; + + std::vector gemm_desc_kernel_arg_; + std::vector> a_mtx_mraw_kraw_; + std::vector> b_mtx_nraw_kraw_; + + const void* grouped_gemm_kernel_args_dev; + + index_t grid_size_; + index_t grid_size_grp_; + index_t barrier_size_grp_; + index_t sum_of_m; + + index_t k_batch_; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunImp(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + bool all_have_kbatch_gt_one = arg.gemm_desc_kernel_arg_[0].KBatch > 1; + bool all_have_main_k0_block_loop = + CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[0]); + + bool not_all_have_main_k0_block_loop_same = false; + bool not_all_have_kbatch_value_same = false; + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); i++) + { + + not_all_have_main_k0_block_loop_same |= + all_have_main_k0_block_loop xor + CalculateHasMainKBlockLoop(arg.gemm_desc_kernel_arg_[i]); + not_all_have_kbatch_value_same |= + all_have_kbatch_gt_one xor (arg.gemm_desc_kernel_arg_[i].KBatch > 1); + } + + if(not_all_have_main_k0_block_loop_same) + { + std::ostringstream err; + err << "Not all gemms have same value for main_k0_block_loop! in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + // throw std::runtime_error(err.str()); + } + + if(not_all_have_kbatch_value_same) + { + std::ostringstream err; + err << "Not all gemms have same kbatch value (=1 or >1)! " << " in " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__; + throw std::runtime_error(err.str()); + } + + if(arg.grouped_gemm_kernel_args_dev == nullptr) + { + throw std::runtime_error("wrong! grouped_gemm_kernel_args_dev is nullpr"); + } + + float ave_time = 0; + + auto launch_kernel = [&](auto has_main_k_block_loop_, + auto e_global_memory_operation_, + auto min_occupancy_, + auto tail_num_) { + const auto kernel = kernel_grouped_gemm_wmma_fixed_nk, + Tuple, + DsDataType, + EDataType, + e_global_memory_operation_, + Block2ETileMap, + GroupedGemmBlock2ETileMap, + AElementwiseOperation, + BElementwiseOperation, + CDEElementwiseOperation, + min_occupancy_, + tail_num_, + GemmSpec>; + + return launch_and_time_kernel( + stream_config, + kernel, + dim3(arg.grid_size_), + dim3(BlockSize), + 0, + cast_pointer_to_constant_address_space(arg.grouped_gemm_kernel_args_dev), + arg.gemm_desc_kernel_arg_.size(), + arg.grid_size_grp_, + arg.k_batch_, + arg.a_element_op_, + arg.b_element_op_, + arg.c_element_op_); + }; + + // const auto tail_num = + // GridwiseGemm::CalculateKBlockLoopTailNum(arg.gemm_desc_kernel_arg_[0].K); + const auto tail_num = TailNumber::Full; + constexpr index_t min_occupancy = 1; + + if(all_have_main_k0_block_loop || not_all_have_main_k0_block_loop_same) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + { + if(all_have_kbatch_gt_one) + { + + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + else + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(all_have_kbatch_gt_one) + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + else + { + SelectTailNumber(tail_num, [&](auto tail_num_ct) { + ave_time = launch_kernel( + std::integral_constant{}, + std::integral_constant{}, + std::integral_constant{}, + tail_num_ct); + }); + } + } + } + + // if constexpr(std::is_same::value) + // { + // SelectTailNumber(tail_num, [&](auto tail_num_ct) { + // ave_time = launch_kernel( + // std::integral_constant{}, + // std::integral_constant{}, std::integral_constant{}, tail_num_ct); + // }); + // } + // else + // { + // if(arg.k_batch_ > 1) + // { + // SelectTailNumber(tail_num, [&](auto tail_num_ct) { + // ave_time = launch_kernel( + // std::integral_constant{}, + // std::integral_constant{}, + // std::integral_constant{}, + // tail_num_ct); + // }); + // } + // else + // { + // SelectTailNumber(tail_num, [&](auto tail_num_ct) { + // ave_time = launch_kernel( + // std::integral_constant{}, + // std::integral_constant{}, std::integral_constant{}, tail_num_ct); + // }); + // } + // } + return ave_time; + } + + template + void SelectTailNumber(TailNumber tail_num, Lambda&& lambda) + { + ignore = tail_num; + lambda(std::integral_constant{}); + // switch(tail_num) + // { + // case TailNumber::Full: lambda(std::integral_constant{}); break; case TailNumber::Empty: + // lambda(std::integral_constant{}); break; case + // TailNumber::One: lambda(std::integral_constant{}); break; case TailNumber::Two: + // lambda(std::integral_constant{}); break; case + // TailNumber::Three: lambda(std::integral_constant{}); break; case TailNumber::Four: + // lambda(std::integral_constant{}); break; case + // TailNumber::Five: lambda(std::integral_constant{}); break; case TailNumber::Six: + // lambda(std::integral_constant{}); break; case + // TailNumber::Seven: lambda(std::integral_constant{}); break; case TailNumber::Odd: + // lambda(std::integral_constant{}); break; case + // TailNumber::Even: lambda(std::integral_constant{}); break; default: lambda(std::integral_constant{}); break;; + // } + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + return RunImp(arg, stream_config); + } + + // polymorphic + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported()) + { + return false; + } + if constexpr(std::is_same_v || + std::is_same_v) + { + if(arg.k_batch_ > 1 && ck::is_gfx11_supported()) + { + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + return false; + } + } + + if constexpr(!std::is_same_v) + { + if(arg.k_batch_ > 1) + { + // Using SplitK and a C element op would require a two stage kernel where the second + // stage applies the op on the accumulated results + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "C element operators are not supported when using SplitK. Set " + "K_BATCH to 1 or remove the operator." + << std::endl; + } + return false; + } + } + + if constexpr(std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) + { + if(ck::is_gfx11_supported()) + { + return false; + } + } + + if((ck::type_convert(arg.gemm_desc_kernel_arg_.size())) != arg.group_count_) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "The group count is not equal to sum of skipped groups " + "and kernel args size!" + << std::endl; + } + return false; + } + + bool supported = true; + for(std::size_t i = 0; i < arg.gemm_desc_kernel_arg_.size(); ++i) + { + + const auto& a = arg.gemm_desc_kernel_arg_[i]; + bool group_arg_valid = GridwiseGemm::CheckValidity(a); + + if(not group_arg_valid) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "[" << __func__ << "] group id: " << i + << " has invalid GridwiseGemm settings!" << std::endl; + a.Print(); + } + } + supported = supported && group_arg_valid; + } + return supported; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) + { + return Argument{ + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(std::vector& p_As, + std::vector& p_Bs, + std::vector>& p_Ds, + std::vector& p_Es, + std::vector& gemm_descs, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation c_element_op) override + { + return std::make_unique( + p_As, p_Bs, p_Ds, p_Es, gemm_descs, a_element_op, b_element_op, c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGroupedGemm_Wmma_Fixed_Nk" + << "<" + << std::string(ALayout::name)[0] << "," + << std::string(BLayout::name)[0] << "," + << std::string(ELayout::name)[0] << "," + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << MPerWmma << ", " + << NPerWmma << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle << ", " + << getGemmSpecializationString(GemmSpec) << ", " + << BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", " + << BlkGemmPipelineVersionToString[BlkGemmPipelineVer] + << ">"; + // clang-format on + + return str.str(); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& stream_config = StreamConfig{}) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->p_workspace_ = p_workspace; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); + + hip_check_error( + hipMemsetAsync(p_workspace, 0, GetWorkSpaceSize(arg_ptr), stream_config.stream_id_)); + } + + void SetDeviceKernelArgs(BaseArgument* p_arg, void* kernel_args) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->grouped_gemm_kernel_args_dev = kernel_args; + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * arg_ptr->barrier_size_grp_ * sizeof(uint32_t); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); + } + + size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + return arg_ptr->group_count_ * sizeof(GroupedGemmKernelArgument); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); + } + + static void SetKBatch(Argument& arg, index_t k_batch) { arg.UpdateKBatch(k_batch); } + + // polymorphic + void SetKBatch(BaseArgument* p_arg, index_t k_batch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(k_batch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_NK::Argument structure!"); + } + + // polymorphic + void SetKBatchSize(BaseArgument* p_arg, index_t k_batch) const override + { + auto arg_ptr = dynamic_cast(p_arg); + if(arg_ptr) + { + arg_ptr->UpdateKBatch(k_batch); + } + else + throw std::runtime_error("The argument pointer is not an object of " + "DeviceGroupedGemm_Wmma_Fixed_Nk::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index bcf131003c2..621a5ff0d32 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -344,6 +344,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base using ThisThreadBlock = ThisThreadBlock; + using GemmSpecialization = ck::tensor_operation::device::GemmSpecialization; + static constexpr index_t APackedSize = []() { if constexpr(is_same_v, pk_i4_t>) return 2; @@ -629,8 +631,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const std::array& StrideAs, const index_t AK0) { - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || + // using GemmSpecialization = tensor_operation::device::GemmSpecialization; + constexpr bool padM = GemmSpec == GemmSpecialization::MKPadding || GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::MPadding || GemmSpec == GemmSpecialization::MNPadding; @@ -698,8 +700,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base const std::array& StrideBs, const index_t BK0) { - using GemmSpecialization = tensor_operation::device::GemmSpecialization; - constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding || + // using GemmSpecialization = tensor_operation::device::GemmSpecialization; + constexpr bool padN = GemmSpec == GemmSpecialization::NKPadding || GemmSpec == GemmSpecialization::MNKPadding || GemmSpec == GemmSpecialization::NPadding || GemmSpec == GemmSpecialization::MNPadding; @@ -796,7 +798,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base // TODO: Investigate why this path is not used in the original // gridwise_gemm_xdl_cshuffle_v3.hpp #if 0 - using GemmSpecialization = tensor_operation::device::GemmSpecialization; + // using GemmSpecialization = tensor_operation::device::GemmSpecialization; if constexpr(GemmSpec == GemmSpecialization::MNPadding || GemmSpec == GemmSpecialization::MNKPadding) @@ -1091,9 +1093,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_base { if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) { - std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " - << karg.K << " " << __FILE__ << ":" << __LINE__ - << ", in function: " << __func__ << std::endl; + std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! " + "K_Batch:" + << karg.KBatch << " " << "K0PerBlock:" << KPerBlock << " " + << "K1:" << AK1Number << " " << "K:" << karg.K << " " << __FILE__ + << ":" << __LINE__ << ", in function: " << __func__ << std::endl; } return false; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp new file mode 100644 index 00000000000..4bd639cd19e --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp @@ -0,0 +1,275 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_gemm_wmma_fixed_nk.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp" +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/utility/loop_scheduler.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = ck::f8_t; +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F32 = float; +using I8 = int8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +using S = ck::Sequence; + +using Empty_Tuple = ck::Tuple<>; +using PassThrough = ck::tensor_operation::element_wise::PassThrough; +using FastGelu = ck::tensor_operation::element_wise::FastGelu; + +using AccDataType = F32; +using DsDataType = Empty_Tuple; + +using DsLayout = Empty_Tuple; +using ELayout = Row; + +static constexpr auto PipelineV1 = BlockGemmPipelineVersion::v1; +static constexpr auto PipelineV3 = BlockGemmPipelineVersion::v3; +static constexpr auto IntrawaveScheduler = BlockGemmPipelineScheduler::Intrawave; +static constexpr auto InterwaveScheduler = BlockGemmPipelineScheduler::Interwave; +static constexpr auto GemmMNKPadding = device::GemmSpecialization::MNKPadding; +static constexpr auto GemmDefault = device::GemmSpecialization::Default; + +// Instances for 2 byte datatypes in RRR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances = std::tuple< + // clang-format off + //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// Instances for 2 byte datatypes in RCR layout with ADataType = BDataType = EDataType +template = false> +using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances = std::tuple< + // clang-format off + //#############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //#############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //#############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, T, T, AccDataType, T, DsDataType, T, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +template +using device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 2, 2, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Row, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<2, 0, 1>, S<2, 0, 1>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +template +using device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances = + std::tuple< + // clang-format off + //##############################| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| GEMM| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| + //##############################| | | | | Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Spacialization| Prefetch| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MRepeat| ScalarPerVector| + //##############################| | | | | | | | | | | Operation| Operation| Operation| | Stage| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NRepeat| _NRepeat| + //##############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 64, 2, 2, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer>, + DeviceGroupedGemm_Wmma_Fixed_Nk< Row, Col, DsLayout, ELayout, TA, TB, AccDataType, AccDataType, DsDataType, TA, AElementOp, BElementOp, CDEElementOp, GemmSpec, 1, 256, 128, 128, 32, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 64, 1, 4>, 8, BlkGemmPipeSched, BlkGemmPipelineVer> + // clang-format on + >; + +// List of instance variants to add (pipeline/scheduler/padding combinations) +// Some are disabled now, can be re-enabled if needed +using InstanceVariant = + ck::Tuple; +static constexpr InstanceVariant InstanceVariants[] = { + + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV1), + // make_tuple(GemmDefault, InterwaveScheduler, PipelineV1), + make_tuple(GemmDefault, IntrawaveScheduler, PipelineV3), + + make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV1), + // make_tuple(GemmMNKPadding, InterwaveScheduler, PipelineV1), + // make_tuple(GemmMNKPadding, IntrawaveScheduler, PipelineV3), +}; + +template + typename LayoutInstances, + typename ADataType, + typename BDataType, + typename EDataType, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> +void add_device_grouped_gemm_wmma_fixed_nk_instances( + std::vector>>& instances) +{ + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); +} + +template + typename LayoutInstances, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> +void add_device_grouped_gemm_wmma_fixed_nk_instances( + std::vector>>& instances) +{ + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); +} + +template + typename LayoutInstances, + typename AElementOp, + typename BElementOp, + typename CDEElementOp> +void add_device_grouped_gemm_wmma_fixed_nk_irregular_instances( + std::vector>>& instances) +{ + static_for<0, std::size(InstanceVariants), 1>{}([&](auto i) { + constexpr auto instance = InstanceVariants[i]; + add_device_operation_instances(instances, + LayoutInstances{}), + instance.At(Number<1>{}), + instance.At(Number<2>{}), + AElementOp, + BElementOp, + CDEElementOp>{}); + }); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp index 3f17a930a0b..6e1668ffa5e 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_fixed_nk.hpp @@ -17,6 +17,7 @@ namespace device { namespace instance { // fp16_output +#if defined(CK_USE_XDL) void add_device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instances( std::vector>>& instances); +#endif + +#if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances( + std::vector>>& instances); +#endif // bf16_inputA i8_inputB #if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) +#if defined(CK_USE_XDL) void add_device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( std::vector>>& instances); #endif +#if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances); +#endif +#endif // bf16_inputA bf16_inputB #if defined(CK_ENABLE_BF16) +#if defined(CK_USE_XDL) void add_device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( std::vector>>& instances); +#endif +#if defined(CK_USE_WMMA) +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances); + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances); +#endif #endif // CK_ENABLE_BF16 template > op_ptrs; +#if defined(CK_USE_XDL) // fp16_output if constexpr(is_same_v && is_same_v && is_same_v) @@ -273,6 +414,93 @@ struct DeviceOperationInstanceFactory< } } #endif // CK_ENABLE_BF16 +#endif // CK_USE_XDL + +#if defined(CK_USE_WMMA) + // fp16_output + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // fp8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances(op_ptrs); + } + } + + // i8_input + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances(op_ptrs); + } + } + +// bf16_i8_input +#if defined(CK_ENABLE_BF16) && defined(CK_ENABLE_INT8) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif + +// bf16_inputA bf16_inputB +#if defined(CK_ENABLE_BF16) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances(op_ptrs); + } + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances(op_ptrs); + } + } +#endif // CK_ENABLE_BF16 +#endif // CK_USE_WMMA return op_ptrs; } diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt index e56df524cfa..f7ce0a70730 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/CMakeLists.txt @@ -1,18 +1,31 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GROUPED_GEMM_FIXED_NK_INSTANCES) -list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp - device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp) +list(APPEND GROUPED_GEMM_FIXED_NK_INSTANCES + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_xdl_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp + + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp + device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp + ) add_instance_library(device_grouped_gemm_fixed_nk_instance ${GROUPED_GEMM_FIXED_NK_INSTANCES}) \ No newline at end of file diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..8082fd3136e --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + BF16, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..d49141ed681 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_bf16_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + BF16, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..9356dac4fc6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + BF16, + I8, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..8ce1b7ac1e6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_bf16_i8_bf16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + BF16, + I8, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..d59a908c7c4 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + F16, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..eeacb3b1f9d --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_f16_f16_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_instances< + F16, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..a128afe8932 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + F8, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..5d943a0f340 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_f16_fp8_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + F8, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp new file mode 100644 index 00000000000..f118fd75b8c --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_kn_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + I8, + Row, + Row, + device_grouped_gemm_wmma_fixed_nk_mk_kn_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp new file mode 100644 index 00000000000..86a27be9866 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_gemm_fixed_nk/device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_gemm/device_grouped_gemm_wmma_fixed_nk_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_gemm_wmma_fixed_nk_f16_i8_f16_mk_nk_mn_instances( + std::vector>>& instances) +{ + add_device_grouped_gemm_wmma_fixed_nk_irregular_instances< + F16, + I8, + Row, + Col, + device_grouped_gemm_wmma_fixed_nk_mk_nk_mn_irregular_instances>(instances); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp index f551a16a1b9..988b7d002d7 100644 --- a/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp +++ b/profiler/include/profiler/profile_grouped_gemm_fixed_nk_impl.hpp @@ -46,7 +46,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, int n_warmup = 1, int n_iter = 10) { - bool pass = true; + bool pass = true; + using ComputeDataType = ADataType; auto f_host_tensor_descriptor = [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { @@ -54,11 +55,11 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, if(is_same::value) { - return HostTensorDescriptor({row, col}, {stride, 1_uz}); + return HostTensorDescriptor({row, col}, {stride, 1_uz}, layout); } else { - return HostTensorDescriptor({row, col}, {1_uz, stride}); + return HostTensorDescriptor({row, col}, {1_uz, stride}, layout); } }; @@ -74,8 +75,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, std::vector> b_k_n; std::vector> c_m_n_host_results; std::vector> c_m_n_device_results; - int sum_of_m = 0; - + double max_abs_in_val = 0.f; + int sum_of_m = 0; for(std::size_t i = 0; i < group_count; i++) { sum_of_m += Ms[i]; @@ -95,17 +96,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, << i << "]:" << b_k_n[i].mDesc << ", c_m_n_device_results[" << i << "]:" << c_m_n_device_results[i].mDesc << std::endl; } - std::size_t num_thread = 1; switch(init_method) { case 0: break; case 1: - a_m_k[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_m_k[i]); + ck::utils::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_k_n[i]); + max_abs_in_val = 10.f; break; default: - a_m_k[i].GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); - b_k_n[i].GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + ck::utils::FillUniformDistribution{0.0f, 1.0f}(a_m_k[i]); + ck::utils::FillUniformDistribution{-0.5f, 0.5f}(b_k_n[i]); + max_abs_in_val = 1.0f; } } @@ -282,23 +284,18 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, bool instance_pass = true; for(std::size_t i = 0; i < gemm_descs.size(); i++) { - c_device_buf[i]->FromDevice(c_m_n_device_results[i].mData.data()); - - if(std::is_same_v && kbatch_curr > 1) - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i], - "Error: Incorrect results!", - 0.06); - } - else - { - instance_pass = - instance_pass && ck::utils::check_err(c_m_n_device_results[i], - c_m_n_host_results[i]); - } + auto atol = ck::utils::get_absolute_threshold( + max_abs_in_val, gemm_descs[i].K_); + auto rtol = ck::utils::get_relative_threshold( + gemm_descs[i].K_); + + instance_pass = + instance_pass && ck::utils::check_err(c_m_n_device_results[i], + c_m_n_host_results[i], + "Error: Incorrect results!", + rtol, + atol); if(do_log) { @@ -315,7 +312,7 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, } } - std::cout << "Instance: " << gemm_name << " verification " + std::cout << "Instance: " << gemm_name << "; KBatch: " << kbatch_curr << " " << (instance_pass ? "SUCCEED" : "FAILED") << std::endl; pass = pass && instance_pass; @@ -355,7 +352,8 @@ bool profile_grouped_gemm_fixed_nk_impl(int do_verification, } else { - std::cout << "Instance: " << gemm_name << ", does not support this GEMM problem" + std::cout << "Instance: " << gemm_name + << ", does not support this GEMM problem (KBatch: " << kbatch_curr << ")" << std::endl; } } diff --git a/profiler/src/CMakeLists.txt b/profiler/src/CMakeLists.txt index 012d6e15027..f507d418661 100644 --- a/profiler/src/CMakeLists.txt +++ b/profiler/src/CMakeLists.txt @@ -298,3 +298,32 @@ message(VERBOSE "ckProfiler libs: ${PROFILER_LIBS}") target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE ${PROFILER_LIBS}) rocm_install(TARGETS ${PROFILER_EXECUTABLE} COMPONENT profiler) + + +## Defining specific operation targets + +macro(define_profiler_target NAME SOURCES LIBS) + add_executable(${NAME} profiler.cpp ${SOURCES}) + target_compile_options(${NAME} PRIVATE -Wno-global-constructors) + + if(NOT WIN32 AND ${hip_VERSION_FLAT} GREATER 600241132) + target_compile_options(${NAME} PRIVATE --offload-compress) + endif() + + target_link_libraries(${NAME} PRIVATE utility getopt::getopt ${LIBS}) + + rocm_install(TARGETS ${NAME} COMPONENT profiler) +endmacro() + + +define_profiler_target(ckProfiler_gemm + "profile_gemm.cpp" + "device_gemm_instance") + +define_profiler_target(ckProfiler_gemm_universal + "profile_gemm_universal.cpp" + "device_gemm_universal_instance") + +define_profiler_target(ckProfiler_gemm_fixed_nk + "profile_grouped_gemm_fixed_nk.cpp" + "device_grouped_gemm_fixed_nk_instance") diff --git a/test/grouped_gemm/CMakeLists.txt b/test/grouped_gemm/CMakeLists.txt index 450950cbd66..3091667da3f 100644 --- a/test/grouped_gemm/CMakeLists.txt +++ b/test/grouped_gemm/CMakeLists.txt @@ -18,6 +18,12 @@ if (CK_USE_XDL OR CK_USE_WMMA) target_link_libraries(test_grouped_gemm_fastgelu PRIVATE utility device_grouped_gemm_fastgelu_instance) add_dependencies(test_grouped_gemm test_grouped_gemm_fastgelu) endif() + + add_gtest_executable(test_grouped_gemm_fixed_nk test_grouped_gemm_fixed_nk.cpp) + if(result EQUAL 0) + target_link_libraries(test_grouped_gemm_fixed_nk PRIVATE utility device_grouped_gemm_fixed_nk_instance) + add_dependencies(test_grouped_gemm test_grouped_gemm_fixed_nk) + endif() endif() add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp new file mode 100644 index 00000000000..f855ef1ff5d --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/utility/data_type.hpp" +#include "ck/utility/tuple.hpp" + +#include "gtest/gtest.h" +#include "test_grouped_gemm_util.hpp" + +ck::index_t param_mask = 0xffffff; +ck::index_t instance_index = -1; + +using F16 = ck::half_t; +using BF16 = ck::bhalf_t; +using F8 = ck::f8_t; +using I8 = int8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +template +class TestGroupedGemm + : public ck::test::TestGroupedGemm +{ + public: + void SetUp() override + { + ck::test::TestGroupedGemm::SetUp(); + +#if defined(CK_USE_WMMA) + // The old XDL tests didn't fail if instances were not supported, so we want to keep that + // behaviour When compiling WMMA instances and WMMA is supported, then we'll fail if a + // specific case is not supported + this->fail_if_no_supported_instances_ = + ck::is_gfx11_supported() || ck::is_gfx12_supported(); +#endif + } +}; + +using KernelTypes = ::testing::Types< + +#if(defined(CK_USE_XDL) && (defined(__gfx9__) || defined(__gfx12__))) || \ + (defined(CK_USE_WMMA) && defined(__gfx12__)) + ck::Tuple, + ck::Tuple, +#endif + + ck::Tuple, + ck::Tuple, + + ck::Tuple, + ck::Tuple, + ck::Tuple, + ck::Tuple, + + ck::Tuple, + ck::Tuple>; +// clang-format on + +TYPED_TEST_SUITE(TestGroupedGemm, KernelTypes); + +#include "test_grouped_gemm_fixed_nk_cases.inc" +int main(int argc, char** argv) +{ + testing::InitGoogleTest(&argc, argv); + if(argc == 1) {} + else if(argc == 3) + { + param_mask = strtol(argv[1], nullptr, 0); + instance_index = atoi(argv[2]); + } + else + { + std::cout << "Usage of " << argv[0] << std::endl; + std::cout << "Arg1,2: param_mask instance_index(-1 means all)" << std::endl; + } + return RUN_ALL_TESTS(); +} diff --git a/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc new file mode 100644 index 00000000000..f0b4ee61088 --- /dev/null +++ b/test/grouped_gemm/test_grouped_gemm_fixed_nk_cases.inc @@ -0,0 +1,84 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +TYPED_TEST(TestGroupedGemm, TinyCases) +{ + const std::vector Ms{2, 1}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, SmallCases) +{ + const std::vector Ms{2, 1, 3, 4, 5}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, MidCases) +{ + const std::vector Ms{167, 183, 177, 153, 139, 204}; + constexpr int N = 768; + constexpr int K = 544; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, Regular) +{ + const std::vector Ms{64, 128, 256}; + constexpr int N = 768; + constexpr int K = 320; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, MNKPadded) +{ + const std::vector Ms{127, 150, 188, 210}; + constexpr int N = 136; + constexpr int K = 280; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->Run(Ms, Ns, Ks); +} + +TYPED_TEST(TestGroupedGemm, TestLargeKBatch) +{ + // In some cases Split K is not supported. Running this test would fail since no instance will + // be supported, so we skip the test + if(!this->IsSplitKSupported()) + GTEST_SKIP() << "Split-K not supported for for the current configuration (FP16/BF16 on " + "GFX11, or using CDE element-wise operation)"; + + const std::vector Ms{188, 210}; + constexpr int N = 768; + constexpr int K = 4096; + + const std::vector Ns(Ms.size(), N); + const std::vector Ks(Ms.size(), K); + + this->k_batches_ = {32, 64}; + + this->Run(Ms, Ns, Ks); +} diff --git a/test/grouped_gemm/test_grouped_gemm_util.hpp b/test/grouped_gemm/test_grouped_gemm_util.hpp index ee95fe03c66..e372af92f3a 100644 --- a/test/grouped_gemm/test_grouped_gemm_util.hpp +++ b/test/grouped_gemm/test_grouped_gemm_util.hpp @@ -16,6 +16,7 @@ #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "profiler/profile_grouped_gemm_impl.hpp" +#include "profiler/profile_grouped_gemm_fixed_nk_impl.hpp" extern ck::index_t param_mask; extern ck::index_t instance_index; @@ -23,7 +24,117 @@ extern ck::index_t instance_index; namespace ck { namespace test { -template +struct DefaultGroupedGemmProfiler +{ + template + static bool Run(bool verify, + int init_method, + bool log, + bool bench, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + const std::vector& kbatches, + int n_warmup, + int n_iter, + int instance_index, + bool fail_if_no_supported_instances) + { + return ck::profiler::profile_grouped_gemm_impl( + verify, + init_method, + log, + bench, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup, + n_iter, + instance_index, + fail_if_no_supported_instances); + } +}; + +struct FixedNKGroupedGemmProfiler +{ + template + static bool Run(bool verify, + int init_method, + bool log, + bool bench, + const std::vector& Ms, + const std::vector& Ns, + const std::vector& Ks, + const std::vector& StrideAs, + const std::vector& StrideBs, + const std::vector& StrideCs, + const std::vector& kbatches, + int n_warmup, + int n_iter, + int /*instance_index*/, + bool /*fail_if_no_supported_instances*/) + { + bool pass = true; + for(int kbatch : kbatches) + { + pass &= ck::profiler::profile_grouped_gemm_fixed_nk_impl(verify, + init_method, + log, + bench, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatch, + n_warmup, + n_iter); + } + return pass; + } +}; + +template class TestGroupedGemm : public testing::Test { protected: @@ -76,7 +187,7 @@ class TestGroupedGemm : public testing::Test } else { - k_batches_ = {1, 2, 3, 5, 8}; + k_batches_ = {1, 2, 3, 4, 8}; } } @@ -146,31 +257,61 @@ class TestGroupedGemm : public testing::Test const std::vector& StrideCs, const std::vector& kbatches) { - bool pass = - ck::profiler::profile_grouped_gemm_impl(verify_, - init_method_, - log_, - bench_, - Ms, - Ns, - Ks, - StrideAs, - StrideBs, - StrideCs, - kbatches, - n_warmup_, - n_iter_, - instance_index, - fail_if_no_supported_instances_); + bool pass = false; + using AccDataType = float; + + if constexpr(std::is_same_v) + { + pass = Profiler::template Run(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); + } + else + { + pass = Profiler::template Run(verify_, + init_method_, + log_, + bench_, + Ms, + Ns, + Ks, + StrideAs, + StrideBs, + StrideCs, + kbatches, + n_warmup_, + n_iter_, + instance_index, + fail_if_no_supported_instances_); + } + EXPECT_TRUE(pass); } };