From adb8f67b4f4b1598afcae25a1e6c43136ce31203 Mon Sep 17 00:00:00 2001 From: AviralGoelAMD Date: Thu, 29 Jan 2026 12:45:18 +0000 Subject: [PATCH] feat: add new optimized tutorial kernels - Add 01_naive_gemm baseline implementation - Add 02_padding_k_first with PADDING_K_FIRST + MFMA_32x32x16 - Add 03_mfma_16x16x16 with PADDING_K_FIRST + MFMA_16x16x16 - Share common reference_gemm.hpp in parent gemm/ directory --- tutorial/ck_tile/01_naive_gemm/CMakeLists.txt | 10 - ...ce_gemm_host_pipeline_agmem_bgmem_creg.hpp | 92 ------ ...tice_gemm_host_policy_agmem_bgmem_creg.hpp | 54 ---- .../ck_tile/01_naive_gemm/practice_gemm.cpp | 141 --------- .../ck_tile/01_naive_gemm/practice_gemm.hpp | 74 ----- tutorial/ck_tile/CMakeLists.txt | 3 +- .../01_naive_gemm/BLOCK_LEVEL_PIPELINE.md | 0 .../ck_tile/gemm/01_naive_gemm/CMakeLists.txt | 17 ++ .../01_naive_gemm/HOST_LEVEL_PIPELINE.md | 0 .../01_naive_gemm/KERNEL_ENTRY_POINT.md | 0 .../{ => gemm}/01_naive_gemm/README.md | 0 .../01_naive_gemm/TILE_DISTRIBUTION.md | 0 .../{ => gemm}/01_naive_gemm/WALKTHROUGH.md | 0 .../block_gemm_pipeline_agmem_bgmem_creg.hpp} | 71 ++--- ...gemm_pipeline_agmem_bgmem_creg_policy.hpp} | 88 +++--- .../01_naive_gemm/host_level/grid_gemm.hpp | 72 +++++ .../gemm/01_naive_gemm/practice_gemm.cpp | 155 ++++++++++ .../gemm/01_naive_gemm/practice_gemm.hpp | 139 +++++++++ .../block_gemm_asmem_bsmem_creg.hpp} | 123 ++++++-- .../block_gemm_asmem_bsmem_creg_policy.hpp} | 11 +- .../gemm/02_padding_k_first/CMakeLists.txt | 17 ++ .../block_gemm_asmem_bsmem_creg.hpp | 285 ++++++++++++++++++ .../block_gemm_asmem_bsmem_creg_policy.hpp | 43 +++ .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 166 ++++++++++ ..._gemm_pipeline_agmem_bgmem_creg_policy.hpp | 129 ++++++++ .../ck_tile/gemm/02_padding_k_first/gemm.cpp | 158 ++++++++++ .../ck_tile/gemm/02_padding_k_first/gemm.hpp | 139 +++++++++ .../gemm/02_padding_k_first/grid_gemm.hpp | 72 +++++ .../gemm/03_mfma_16x16x16/CMakeLists.txt | 17 ++ .../block_gemm_asmem_bsmem_creg.hpp | 285 ++++++++++++++++++ .../block_gemm_asmem_bsmem_creg_policy.hpp | 43 +++ .../block_gemm_pipeline_agmem_bgmem_creg.hpp | 166 ++++++++++ ..._gemm_pipeline_agmem_bgmem_creg_policy.hpp | 129 ++++++++ .../ck_tile/gemm/03_mfma_16x16x16/gemm.cpp | 158 ++++++++++ .../ck_tile/gemm/03_mfma_16x16x16/gemm.hpp | 139 +++++++++ .../gemm/03_mfma_16x16x16/grid_gemm.hpp | 72 +++++ tutorial/ck_tile/gemm/CMakeLists.txt | 10 + .../reference_gemm.hpp | 3 +- 38 files changed, 2604 insertions(+), 477 deletions(-) delete mode 100644 tutorial/ck_tile/01_naive_gemm/CMakeLists.txt delete mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp delete mode 100644 tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp delete mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp delete mode 100644 tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp rename tutorial/ck_tile/{ => gemm}/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md (100%) create mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt rename tutorial/ck_tile/{ => gemm}/01_naive_gemm/HOST_LEVEL_PIPELINE.md (100%) rename tutorial/ck_tile/{ => gemm}/01_naive_gemm/KERNEL_ENTRY_POINT.md (100%) rename tutorial/ck_tile/{ => gemm}/01_naive_gemm/README.md (100%) rename tutorial/ck_tile/{ => gemm}/01_naive_gemm/TILE_DISTRIBUTION.md (100%) rename tutorial/ck_tile/{ => gemm}/01_naive_gemm/WALKTHROUGH.md (100%) rename tutorial/ck_tile/{01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp => gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp} (69%) rename tutorial/ck_tile/{01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp => gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp} (58%) create mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp create mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp create mode 100644 tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp rename tutorial/ck_tile/{01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp => gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg.hpp} (61%) rename tutorial/ck_tile/{01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp => gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp} (66%) create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg_policy.hpp create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp create mode 100644 tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg.hpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg_policy.hpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg.hpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp create mode 100644 tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp create mode 100644 tutorial/ck_tile/gemm/CMakeLists.txt rename tutorial/ck_tile/{01_naive_gemm => gemm}/reference_gemm.hpp (95%) diff --git a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt b/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt deleted file mode 100644 index 19549805320..00000000000 --- a/tutorial/ck_tile/01_naive_gemm/CMakeLists.txt +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -# SPDX-License-Identifier: MIT - -add_executable(tile_tutorial_naive_gemm practice_gemm.cpp) - -target_compile_options(tile_tutorial_naive_gemm PRIVATE - -mllvm -enable-noalias-to-md-conversion=0 -) - -add_dependencies(tutorials tile_tutorial_naive_gemm) \ No newline at end of file diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp deleted file mode 100644 index 45f439e8fa2..00000000000 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp +++ /dev/null @@ -1,92 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" - -namespace ck_tile { -template -struct PracticeGemmHostPipeline -{ - using ADataType = typename Problem_::ADataType; - using BDataType = typename Problem_::BDataType; - using CDataType = typename Problem_::CDataType; - using AccDataType = typename Problem_::AccDataType; - - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - - using BlockTile = typename Problem::Shape::BlockTile; - using WaveTile = typename Problem::Shape::WaveTile; - - template - CK_TILE_DEVICE void operator()(const ADRAMTensorView& a_dram, - const BDRAMTensorView& b_dram, - CDRAMTensorView& c_dram) const - { - - // Size of the entire problem - const auto M = a_dram.get_tensor_descriptor().get_length(number<0>{}); // M x K - const auto N = c_dram.get_tensor_descriptor().get_length(number<1>{}); // M x N - const auto K = a_dram.get_tensor_descriptor().get_length(number<1>{}); // M x K - - // Size of the block tile - const auto MPerBlock = BlockTile::at(number<0>{}); - const auto NPerBlock = BlockTile::at(number<1>{}); - const auto KPerBlock = BlockTile::at(number<2>{}); - - // Number of block tile in the N direction to cover C (resultant) matrix - const auto num_tile_n = integer_divide_ceil(N, NPerBlock); - // Number of block tile in the M direction to cover C (resultant) matrix - const auto num_tile_m = integer_divide_ceil(M, MPerBlock); - - // if(get_thread_id() == 0 && get_block_id() == 0) - // { - // printf("num_tile_m: %d, num_tile_n: %d\n", num_tile_m, num_tile_n); - // printf("total number of tiles: %d\n", num_tile_m * num_tile_n); - // } - - // Get block id - const auto id_block = - get_block_id(); // 0 to (M_block/BlockTile_M) * (N_block/BlockTile_N) - 1 - - // Map block id to tile id - const auto block2tile = Policy::MakeBlock2TileMap(num_tile_m, num_tile_n); - - const auto tile_id = block2tile(id_block); - - const auto tile_id_m = tile_id.at(number<0>{}); - const auto tile_id_n = tile_id.at(number<1>{}); - - // if(get_thread_id() == 0 && get_block_id() == 15) - // { - // printf("tile_id_m: %d, tile_id_n: %d\n", tile_id_m, tile_id_n); - // } - - const auto tile_origin_m = tile_id_m * MPerBlock; - const auto tile_origin_n = tile_id_n * NPerBlock; - - // create a tile window over dram for A and B - const auto a_block_window = make_tile_window( - a_dram, make_tuple(number{}, number{}), {tile_origin_m, 0}); - - const auto b_block_window = make_tile_window( - b_dram, make_tuple(number{}, number{}), {tile_origin_n, 0}); - - constexpr auto block_gemm_pipeline = - Policy::template GetPracticeGemmBlockPipeline(); - - int num_loops_k = integer_divide_ceil(K, KPerBlock); - - __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLDSSize()]; - const auto c_block_tile = - block_gemm_pipeline(a_block_window, b_block_window, num_loops_k, p_smem_char); - auto c_window = make_tile_window(c_dram, - make_tuple(number{}, number{}), - {tile_origin_m, tile_origin_n}); - store_tile(c_window, c_block_tile); - } -}; -} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp deleted file mode 100644 index 1c100796cbd..00000000000 --- a/tutorial/ck_tile/01_naive_gemm/host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/host.hpp" -#include "ck_tile/core.hpp" - -#include "../block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp" -#include "../block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp" - -namespace ck_tile { - -template -struct PracticeGemmHostProblem -{ - using ADataType = ADataType_; - using BDataType = BDataType_; - using CDataType = CDataType_; - using AccDataType = AccDataType_; - using Shape = remove_cvref_t; -}; - -struct PracticeGemmHostPolicy -{ - CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) - { - const auto unmerge = make_merge_transform(make_tuple(N0, M0)); - - return [unmerge](index_t block_id) { - multi_index<2> unmerged; - unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); - - return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); - }; - } - - template - CK_TILE_HOST_DEVICE static constexpr auto GetPracticeGemmBlockPipeline() - { - using PracticeGemmBlockPipelineProblem_ = - PracticeGemmBlockPipelineProblem; - return PracticeGemmBlockPipelineAGmemBGmemCreg{}; - } -}; -} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp deleted file mode 100644 index 7635c9376b3..00000000000 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.cpp +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#include -#include "ck_tile/host.hpp" -#include "practice_gemm.hpp" -#include "reference_gemm.hpp" - -int main(int argc, char* argv[]) -{ - // TODO: GemmTypeConfig - using ADataType = ck_tile::half_t; - using BDataType = ck_tile::half_t; - using CDataType = float; - using AccDataType = float; - - // Setup simple argument parser for M, N, K - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "512", "m dimension") - .insert("n", "256", "n dimension") - .insert("k", "64", "k dimension") - .insert("v", "1", "verification: 0=off, 1=on"); - - auto result = arg_parser.parse(argc, argv); - if(!result) - return -1; - - // Get problem dimensions from command line - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - ck_tile::index_t verification = arg_parser.get_int("v"); - - ck_tile::index_t stride_a = K; - ck_tile::index_t stride_b = K; - ck_tile::index_t stride_c = N; - - auto a_lengths = std::array{M, K}; - auto b_lengths = std::array{N, K}; - auto c_lengths = std::array{M, N}; - - auto a_strides = std::array{stride_a, 1}; - auto b_strides = std::array{stride_b, 1}; - auto c_strides = std::array{stride_c, 1}; - - // tensors on host (cpu) - ck_tile::HostTensor a_host(a_lengths, a_strides); - ck_tile::HostTensor b_host(b_lengths, b_strides); - ck_tile::HostTensor c_host(c_lengths, c_strides); - - // initialize tensors - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); - ck_tile::FillUniformDistribution{-5.f, 5.f}(b_host); - c_host.SetZero(); - - // Print the tensors using the new print_first_n member function - // std::cout << "Tensor A (first 10 elements): "; - // a_host.print_first_n(10); - // std::cout << std::endl; - - // std::cout << "Tensor B (first 10 elements): "; - // b_host.print_first_n(10); - // std::cout << std::endl; - - // std::cout << "Tensor C (first 10 elements): "; - // c_host.print_first_n(10); - // std::cout << std::endl; - - // Create device tensors of same size as host tensors and copy data - ck_tile::DeviceMem a_device(a_host); - ck_tile::DeviceMem b_device(b_host); - ck_tile::DeviceMem c_device(c_host); - - // TODO: BlockTileConfig - using BlockTile = ck_tile::sequence<256, 128, 32>; - using WaveTile = ck_tile::sequence<16, 16, 16>; - - std::cout << "Creating PracticeGemmShape, PracticeGemmProblem, PracticeGemmPolicy" << std::endl; - using PracticeGemmShape = ck_tile::PracticeGemmShape; - std::cout << "PracticeGemmShape: " << PracticeGemmShape::GetName() << std::endl; - using PracticeGemmHostProblem = ck_tile:: - PracticeGemmHostProblem; - using PracticeGemmHostPolicy = ck_tile::PracticeGemmHostPolicy; - - ck_tile::index_t kGridSize = ck_tile::integer_divide_ceil(M, PracticeGemmShape::BlockTile_M) * - ck_tile::integer_divide_ceil(N, PracticeGemmShape::BlockTile_N); - - std::cout << "Total number of thread blocks: " << kGridSize << std::endl; - constexpr ck_tile::index_t kBlockPerCU = 1; // 1 block per CU - - // Block size is now derived from the shape configuration - constexpr ck_tile::index_t kBlockSize = PracticeGemmShape::kBlockSize; - std::cout << "Number of threads per block: " << kBlockSize << std::endl; - std::cout << "Number of blocks per compute unit: " << kBlockPerCU << std::endl; - - using gemm_kernel = - ck_tile::PracticeGemmKernel; - - float ave_time = ck_tile::launch_kernel( - ck_tile::stream_config{nullptr, true, 0, 0, 1}, - ck_tile::make_kernel(gemm_kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(a_device.GetDeviceBuffer()), - static_cast(b_device.GetDeviceBuffer()), - static_cast(c_device.GetDeviceBuffer()), - M, - N, - K, - stride_a, - stride_b, - stride_c)); - - auto pass = true; - - if(verification) - { - // reference gemm - ck_tile::HostTensor c_host_ref(c_lengths, c_strides); - reference_basic_gemm( - a_host, b_host, c_host_ref); - ck_tile::HostTensor c_host_dev(c_lengths, c_strides); - c_device.FromDevice(c_host_dev.mData.data()); - pass &= ck_tile::check_err(c_host_dev, c_host_ref, "Error: Incorrect results!", 1e-3, 1e-3); - std::cout << "valid:" << (pass ? "y" : "n") << std::endl; - } - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_btype = - sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" - << std::endl; - - return !pass; -} diff --git a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp deleted file mode 100644 index 91d7fae90c0..00000000000 --- a/tutorial/ck_tile/01_naive_gemm/practice_gemm.hpp +++ /dev/null @@ -1,74 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include -#include "ck_tile/core.hpp" -#include "host_level/practice_gemm_host_policy_agmem_bgmem_creg.hpp" -#include "host_level/practice_gemm_host_pipeline_agmem_bgmem_creg.hpp" - -namespace ck_tile { - -template -struct PracticeGemmShape -{ - using BlockTile = remove_cvref_t; - using WaveTile = remove_cvref_t; - - static constexpr index_t BlockTile_M = BlockTile::at(number<0>{}); - static constexpr index_t BlockTile_N = BlockTile::at(number<1>{}); - static constexpr index_t BlockTile_K = BlockTile::at(number<2>{}); - - static constexpr index_t WaveTile_M = WaveTile::at(number<0>{}); - static constexpr index_t WaveTile_N = WaveTile::at(number<1>{}); - static constexpr index_t WaveTile_K = WaveTile::at(number<2>{}); - - // Thread block configuration - static constexpr index_t kWarpSize = 64; // AMD GPU warp size (also called wavefront) - static constexpr index_t kBlockSize = 256; // Total threads per block (4 warps × 64 threads) - - CK_TILE_HOST static std::string GetName() - { - // clang-format off - return concat('_', "practice_gemm_shape", - concat('x', BlockTile_M, BlockTile_N, BlockTile_K), - concat('x', WaveTile_M, WaveTile_N, WaveTile_K)); - // clang-format on - } -}; - -template -struct PracticeGemmKernel -{ - using Problem = remove_cvref_t; - using Policy = remove_cvref_t; - - // Derive block size from the shape configuration - static constexpr index_t kBlockSize = Problem::Shape::kBlockSize; - - CK_TILE_DEVICE void operator()(const typename Problem::ADataType* p_a, - const typename Problem::BDataType* p_b, - typename Problem::CDataType* p_c, - const index_t M, - const index_t N, - const index_t K, - const index_t stride_a, - const index_t stride_b, - const index_t stride_c) const - { - - auto a_dram = make_naive_tensor_view( - p_a, make_tuple(M, K), make_tuple(stride_a, 1), number<8>{}, number<1>{}); - - auto b_dram = make_naive_tensor_view( - p_b, make_tuple(N, K), make_tuple(stride_b, 1), number<8>{}, number<1>{}); - - const auto c_dram = make_naive_tensor_view( - p_c, make_tuple(M, N), make_tuple(stride_c, 1), number<8>{}, number<1>{}); - - PracticeGemmHostPipeline{}(a_dram, b_dram, c_dram); - } -}; - -} // namespace ck_tile diff --git a/tutorial/ck_tile/CMakeLists.txt b/tutorial/ck_tile/CMakeLists.txt index f9073acffc9..239270d8334 100644 --- a/tutorial/ck_tile/CMakeLists.txt +++ b/tutorial/ck_tile/CMakeLists.txt @@ -6,5 +6,4 @@ include_directories(AFTER ) add_subdirectory(00_copy_kernel) -add_subdirectory(01_naive_gemm) - +add_subdirectory(gemm) diff --git a/tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md b/tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md similarity index 100% rename from tutorial/ck_tile/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md rename to tutorial/ck_tile/gemm/01_naive_gemm/BLOCK_LEVEL_PIPELINE.md diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt b/tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt new file mode 100644 index 00000000000..ccdf2f8d90d --- /dev/null +++ b/tutorial/ck_tile/gemm/01_naive_gemm/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_NAIVE_GEMM "tile_tutorial_naive_gemm") + +message(DEBUG "adding example ${EXAMPLE_NAIVE_GEMM}") + +add_executable(${EXAMPLE_NAIVE_GEMM} EXCLUDE_FROM_ALL practice_gemm.cpp) +target_include_directories(${EXAMPLE_NAIVE_GEMM} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported) + +target_compile_options(${EXAMPLE_NAIVE_GEMM} PRIVATE ${EXAMPLE_NAIVE_GEMM_COMPILE_OPTIONS}) + +add_dependencies(tutorials ${EXAMPLE_NAIVE_GEMM}) diff --git a/tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md b/tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md similarity index 100% rename from tutorial/ck_tile/01_naive_gemm/HOST_LEVEL_PIPELINE.md rename to tutorial/ck_tile/gemm/01_naive_gemm/HOST_LEVEL_PIPELINE.md diff --git a/tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md b/tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md similarity index 100% rename from tutorial/ck_tile/01_naive_gemm/KERNEL_ENTRY_POINT.md rename to tutorial/ck_tile/gemm/01_naive_gemm/KERNEL_ENTRY_POINT.md diff --git a/tutorial/ck_tile/01_naive_gemm/README.md b/tutorial/ck_tile/gemm/01_naive_gemm/README.md similarity index 100% rename from tutorial/ck_tile/01_naive_gemm/README.md rename to tutorial/ck_tile/gemm/01_naive_gemm/README.md diff --git a/tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md b/tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md similarity index 100% rename from tutorial/ck_tile/01_naive_gemm/TILE_DISTRIBUTION.md rename to tutorial/ck_tile/gemm/01_naive_gemm/TILE_DISTRIBUTION.md diff --git a/tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md b/tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md similarity index 100% rename from tutorial/ck_tile/01_naive_gemm/WALKTHROUGH.md rename to tutorial/ck_tile/gemm/01_naive_gemm/WALKTHROUGH.md diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp similarity index 69% rename from tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp index 76c4a58c1d5..d8006b7eb3e 100644 --- a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_pipeline_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -3,34 +3,34 @@ #pragma once +#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp" + #include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" namespace ck_tile { -template -struct PracticeGemmBlockPipelineAGmemBGmemCreg +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg { - using ADataType = typename Problem::ADataType; - using BDataType = typename Problem::BDataType; - using CDataType = typename Problem::CDataType; - using AccDataType = typename Problem::AccDataType; - - using BlockTile = typename Problem::Shape::BlockTile; - using WaveTile = typename Problem::Shape::WaveTile; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; - static constexpr index_t MPerBlock = BlockTile::at(number<0>{}); - static constexpr index_t NPerBlock = BlockTile::at(number<1>{}); - static constexpr index_t KPerBlock = BlockTile::at(number<2>{}); + static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t MPerWave = WaveTile::at(number<0>{}); - static constexpr index_t NPerWave = WaveTile::at(number<1>{}); - static constexpr index_t KPerWave = WaveTile::at(number<2>{}); + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; - using BlockGemm = - remove_cvref_t())>; + using BlockGemm = remove_cvref_t())>; - CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLDSSize() + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() { return integer_divide_ceil( sizeof(ADataType) * @@ -52,9 +52,9 @@ struct PracticeGemmBlockPipelineAGmemBGmemCreg std::is_same_v>, "wrong!"); - static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); // ----------------------------------------------------------------------------------------- @@ -82,38 +82,38 @@ struct PracticeGemmBlockPipelineAGmemBGmemCreg // A DRAM tile window for load auto a_copy_dram_window = make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), a_dram_block_window_tmp.get_window_origin(), Policy::template MakeADramTileDistribution()); // A LDS tile window for store auto a_copy_lds_window = make_tile_window(a_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, a_copy_dram_window.get_tile_distribution()); // B DRAM tile window for load auto b_copy_dram_window = make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), b_dram_block_window_tmp.get_window_origin(), Policy::template MakeBDramTileDistribution()); // B LDS tile window for store auto b_copy_lds_window = make_tile_window(b_lds_block, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, b_copy_dram_window.get_tile_distribution()); // A LDS tile for block GEMM auto a_lds_gemm_window = make_tile_window( - a_lds_block, make_tuple(number{}, number{}), {0, 0}); + a_lds_block, make_tuple(number{}, number{}), {0, 0}); // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}); + b_lds_block, make_tuple(number{}, number{}), {0, 0}); // Block GEMM auto block_gemm = BlockGemm(); @@ -131,28 +131,29 @@ struct PracticeGemmBlockPipelineAGmemBGmemCreg BBlockTile b_block_tile; using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; - constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, KPerBlock); - constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, KPerBlock); + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); // ------------------------------------------------------------------------------------- // Gemm pipeline start // Initialize C tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + // non-prefetch index_t iCounter = num_loop; while(iCounter > 0) { - a_block_tile = load_tile(a_copy_dram_window); // from DRAM to registers - b_block_tile = load_tile(b_copy_dram_window); // from DRAM to registers + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); move_tile_window(a_copy_dram_window, a_dram_tile_window_step); move_tile_window(b_copy_dram_window, b_dram_tile_window_step); - store_tile(a_copy_lds_window, a_block_tile); // from registers to LDS - store_tile(b_copy_lds_window, b_block_tile); // from registers to LDS + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); - block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); // from LDS to registers + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_sync_lds(); iCounter--; diff --git a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp similarity index 58% rename from tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp index a3ed9824886..421a63649fc 100644 --- a/tutorial/ck_tile/01_naive_gemm/block_level/practice_gemm_block_policy_agmem_bgmem_creg.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/block_level/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp @@ -3,41 +3,27 @@ #pragma once -#include "ck_tile/host.hpp" -#include "ck_tile/core.hpp" +#include "../warp_level/block_gemm_asmem_bsmem_creg.hpp" -#include "../warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp" -#include "../warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" namespace ck_tile { -template -struct PracticeGemmBlockPipelineProblem -{ - using ADataType = ADataType_; - using BDataType = BDataType_; - using CDataType = CDataType_; - using AccDataType = AccDataType_; - using Shape = Shape_; -}; - -struct PracticeGemmBlockPolicy +// Default policy for BlockGemmPipelineAGmemBGmemCReg +// Default policy class should not be templated, put template on member functions instead +struct BlockGemmPipelineAGmemBGmemCRegPolicy { - template - CK_TILE_HOST_DEVICE static constexpr auto GetPracticeWaveGemmPipeline() - { - return PracticeGemmWarpPipelineASmemBSmemCreg{}; - } - + // 3d + no padding (NAIVE_IMPLEMENTATION) template CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() { - constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); - constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = 8; constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( @@ -52,14 +38,16 @@ struct PracticeGemmBlockPolicy make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), make_tuple(sequence<0>{}, sequence<1, 2>{}), make_tuple(sequence<0>{}, sequence<1>{})); + return a_lds_block_desc; } + // 3d + no padding (NAIVE_IMPLEMENTATION) template CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); - constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPack = 8; constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( @@ -81,14 +69,12 @@ struct PracticeGemmBlockPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() { - using ADataType = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; - constexpr index_t kMWarp = BlockGemm::MWarp; - constexpr index_t kNWarp = BlockGemm::NWarp; - constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kMPerBlock = Problem::Shape::BlockTile::at(number<0>{}); - constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t K1 = 16 / sizeof(ADataType); constexpr index_t K0 = kKPerBlock / K1; @@ -98,25 +84,23 @@ struct PracticeGemmBlockPolicy constexpr index_t M0 = kMPerBlock / (M2 * M1); return make_static_tile_distribution( - tile_distribution_encoding, // replication - tuple, sequence>, // hierarchy - tuple, sequence<1, 2>>, // parallelism - tuple, sequence<2, 0>>, // paralleism - sequence<1, 2>, // yield - sequence<0, 1>>{}); // yield + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() { - using BDataType = remove_cvref_t; - using BlockGemm = remove_cvref_t())>; - constexpr index_t kMWarp = BlockGemm::MWarp; - constexpr index_t kNWarp = BlockGemm::NWarp; - constexpr index_t kBlockSize = kMWarp * kNWarp * get_warp_size(); + using BDataType = remove_cvref_t; - constexpr index_t kNPerBlock = Problem::Shape::BlockTile::at(number<1>{}); - constexpr index_t kKPerBlock = Problem::Shape::BlockTile::at(number<2>{}); + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t K1 = 16 / sizeof(BDataType); constexpr index_t K0 = kKPerBlock / K1; @@ -133,6 +117,12 @@ struct PracticeGemmBlockPolicy sequence<1, 2>, sequence<0, 1>>{}); } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return BlockGemmASmemBSmemCReg{}; + } }; } // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp new file mode 100644 index 00000000000..559d271af01 --- /dev/null +++ b/tutorial/ck_tile/gemm/01_naive_gemm/host_level/grid_gemm.hpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +template +struct GridGemm +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + using CElementFunction = typename Problem::CElementFunction; + + static constexpr auto kMPerBlock = Policy::kMPerBlock; + static constexpr auto kNPerBlock = Policy::kNPerBlock; + static constexpr auto kKPerBlock = Policy::kKPerBlock; + + template + CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid, + const BGridTensorView& b_grid, + CGridTensorView& c_grid, + const CElementFunction& c_element_func) const + { + const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{}); + const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{}); + const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{}); + + // divide problem + const auto id_block = get_block_id(); + + const auto num_tile_m = integer_divide_ceil(M, kMPerBlock); + const auto num_tile_n = integer_divide_ceil(N, kNPerBlock); + + const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto id_tile = block2tile(id_block); + + const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock); + + // A block window + auto a_block_window = make_tile_window( + a_grid, make_tuple(number{}, number{}), {iM, 0}); + + // B block window + auto b_block_window = make_tile_window( + b_grid, make_tuple(number{}, number{}), {iN, 0}); + + constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; + + const auto acc_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); + + // cast to CDataType and apply CElementFunction + const auto c_block_tile = tile_elementwise_in( + [&](const auto& acc) { return c_element_func(type_convert(acc)); }, + acc_block_tile); + + // store C + auto c_window = make_tile_window( + c_grid, make_tuple(number{}, number{}), {iM, iN}); + + store_tile(c_window, c_block_tile); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp b/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp new file mode 100644 index 00000000000..8c1404bf599 --- /dev/null +++ b/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.cpp @@ -0,0 +1,155 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include "ck_tile/host.hpp" +#include "practice_gemm.hpp" +#include "../reference_gemm.hpp" + +/* + * Naive GEMM implementation (no optimizations) + * A [M, K] + * B [N, K] + * C [M, N] + */ + +// elementwise lambda +struct CElementFunction +{ + template + CK_TILE_HOST_DEVICE auto operator()(const X& x) const + { + return x; + } +}; + +int main(int argc, char* argv[]) +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + + ck_tile::index_t verification = 0; + ck_tile::index_t M = 3328; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + + if(argc == 2) + { + verification = std::stoi(argv[1]); + } + if(argc == 5) + { + verification = std::stoi(argv[1]); + M = std::stoi(argv[2]); + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } + + printf("*** Naive implementation test ***\n"); + + const ck_tile::index_t Lda = K; + const ck_tile::index_t Ldb = K; + const ck_tile::index_t Ldc = N; + + const auto a_lengths = std::array{M, K}; + const auto a_strides = std::array{Lda, 1}; + + const auto b_lengths = std::array{N, K}; + const auto b_strides = std::array{Ldb, 1}; + + const auto c_lengths = std::array{M, N}; + const auto c_strides = std::array{Ldc, 1}; + + // host verify + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.mData.data()); + b_buf.ToDevice(b_host.mData.data()); + + // Alignment + constexpr ck_tile::index_t kAAlignment = 8; + constexpr ck_tile::index_t kBAlignment = 8; + constexpr ck_tile::index_t kCAlignment = 8; + + constexpr ck_tile::index_t kBlockSize = 256; + + constexpr ck_tile::index_t kGemmMPerBlock = 256; + constexpr ck_tile::index_t kGemmKPerBlock = 32; + constexpr ck_tile::index_t kGemmNPerBlock = 128; + + ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + using gemm_kernel = ck_tile::Gemm; + + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 5, 1000}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + c_buf.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp new file mode 100644 index 00000000000..50a49d23fb3 --- /dev/null +++ b/tutorial/ck_tile/gemm/01_naive_gemm/practice_gemm.hpp @@ -0,0 +1,139 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "block_level/block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "host_level/grid_gemm.hpp" + +namespace ck_tile { + +template +struct GridGemmProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + using CElementFunction = CElementFunction_; +}; + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +// C = A * B +template +struct Gemm +{ + static constexpr index_t kBlockSize = kBlockSize_; + + using GridGemmProblem_ = + GridGemmProblem; + + struct GridGemmPolicy + { + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kMPerBlock_; + static constexpr index_t kNPerBlock = kNPerBlock_; + static constexpr index_t kKPerBlock = kKPerBlock_; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() + { + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem>; + return BlockGemmPipelineAGmemBGmemCReg{}; + } + }; + + using GridGemm_ = GridGemm; + + CK_TILE_DEVICE void operator()(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const + { + const auto a_dram = [&] { + return make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(Lda, 1), number{}, number<1>{}); + }(); + + const auto b_dram = [&] { + return make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(Ldb, 1), number{}, number<1>{}); + }(); + + const auto c_dram = [&] { + return make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(Ldc, 1), number{}, number<1>{}); + }(); + + GridGemm_{}(a_dram, b_dram, c_dram, c_element_func); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg.hpp similarity index 61% rename from tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg.hpp index a329357fe8a..2e4b35e44ec 100644 --- a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_pipeline_asmem_bsmem_creg.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg.hpp @@ -4,18 +4,21 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/host.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg_policy.hpp" namespace ck_tile { -template -struct PracticeGemmWarpPipelineASmemBSmemCreg +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCReg { - - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using CDataType = remove_cvref_t; - using WaveGemmShape = remove_cvref_t; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; using WarpGemm = remove_cvref_t< decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; @@ -58,16 +61,14 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; - static_assert(MPerBlock == WaveGemmShape::BlockTile_M && - NPerBlock == WaveGemmShape::BlockTile_N && - KPerBlock == WaveGemmShape::BlockTile_K, + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, "wrong!"); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; -#if !defined(ENABLE_PREFETCH) constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -116,20 +117,17 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); }); }); -#endif // hot loop: static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // Read A warp tensor from A block tensor AWarpTensor a_warp_tensor; - a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { // Read B warp tensor from B block tensor BWarpTensor b_warp_tensor; - b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); // Read C warp tensor from C block tensor @@ -165,13 +163,62 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; - static_assert(MPerBlock == WaveGemmShape::BlockTile_M && - NPerBlock == WaveGemmShape::BlockTile_N && - KPerBlock == WaveGemmShape::BlockTile_K, + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, "wrong!"); constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); static_assert(std::is_same_v, "wrong!"); @@ -191,6 +238,46 @@ struct PracticeGemmWarpPipelineASmemBSmemCreg auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + // Hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + // Warp GEMM + if constexpr(KIterPerWarp == 0) + { + // c = a * b + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + // c += a * b + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + return c_block_tensor; } }; diff --git a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp b/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp similarity index 66% rename from tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp rename to tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp index 4be530cdd35..188e481c654 100644 --- a/tutorial/ck_tile/01_naive_gemm/warp_level/practice_gemm_warp_policy_asmem_bsmem_creg.hpp +++ b/tutorial/ck_tile/gemm/01_naive_gemm/warp_level/block_gemm_asmem_bsmem_creg_policy.hpp @@ -10,14 +10,16 @@ namespace ck_tile { // Default policy for BlockGemmASmemBSmemCReg // Default policy class should not be templated, put template on member functions instead -struct PracticeGemmWarpPolicy +struct BlockGemmASmemBSmemCRegPolicy { template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { + // NAIVE_IMPLEMENTATION uses 4x1 warp configuration constexpr index_t kMWarp = 4; constexpr index_t kNWarp = 1; + // NAIVE_IMPLEMENTATION uses mfma m32 n32 k8 if constexpr(std::is_same_v && std::is_same_v && std::is_same_v) @@ -25,6 +27,13 @@ struct PracticeGemmWarpPolicy return make_tuple( WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, kMWarp, kNWarp); + } else { static_assert(false, "Unsupported data type configuration for GEMM warp execution."); diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt b/tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt new file mode 100644 index 00000000000..ed1f3c7ec6b --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_PADDING_K_FIRST "tile_tutorial_padding_k_first") + +message(DEBUG "adding example ${EXAMPLE_PADDING_K_FIRST}") + +add_executable(${EXAMPLE_PADDING_K_FIRST} EXCLUDE_FROM_ALL gemm.cpp) +target_include_directories(${EXAMPLE_PADDING_K_FIRST} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported) + +target_compile_options(${EXAMPLE_PADDING_K_FIRST} PRIVATE ${EXAMPLE_PADDING_K_FIRST_COMPILE_OPTIONS}) + +add_dependencies(tutorials ${EXAMPLE_PADDING_K_FIRST}) diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg.hpp b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg.hpp new file mode 100644 index 00000000000..2e4b35e44ec --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg.hpp @@ -0,0 +1,285 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + // Hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + // Warp GEMM + if constexpr(KIterPerWarp == 0) + { + // c = a * b + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + // c += a * b + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg_policy.hpp b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg_policy.hpp new file mode 100644 index 00000000000..3cfc1e38f65 --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_asmem_bsmem_creg_policy.hpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Policy for BlockGemmASmemBSmemCReg with MFMA_32x32x16 (8x2) instruction +struct BlockGemmASmemBSmemCRegPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + // KERNEL_A uses 4x1 warp configuration + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; + + // KERNEL_A uses mfma m32 n32 k16 (8x2 variant) + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 00000000000..d8006b7eb3e --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,166 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + using BlockGemm = remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + iCounter--; + } + + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp new file mode 100644 index 00000000000..09cba6c12e1 --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp @@ -0,0 +1,129 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "block_gemm_asmem_bsmem_creg.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +namespace ck_tile { + +// Policy for BlockGemmPipelineAGmemBGmemCReg with PADDING_K_FIRST optimization +struct BlockGemmPipelineAGmemBGmemCRegPolicy +{ + // 3d + PADDING_K_FIRST - adds padding to K dimension to avoid bank conflicts + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + // PADDING_K_FIRST: stride is (kKPerBlock / kKPack + 1) * kKPack instead of kKPerBlock + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + // 3d + no padding for B (PADDING_K_FIRST only pads A in version2) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + // B uses same layout as NAIVE (no padding) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return BlockGemmASmemBSmemCReg{}; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp b/tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp new file mode 100644 index 00000000000..6618eb6825a --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/gemm.cpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm.hpp" +#include "../reference_gemm.hpp" + +/* + * KERNEL_A: GEMM with PADDING_K_FIRST + MFMA_32x32x16 (8x2) + * A [M, K] + * B [N, K] + * C [M, N] + */ + +// elementwise lambda +struct CElementFunction +{ + template + CK_TILE_HOST_DEVICE auto operator()(const X& x) const + { + return x; + } +}; + +int main(int argc, char* argv[]) +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + + ck_tile::index_t verification = 0; + ck_tile::index_t M = 3328; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + + if(argc == 2) + { + verification = std::stoi(argv[1]); + } + if(argc == 5) + { + verification = std::stoi(argv[1]); + M = std::stoi(argv[2]); + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } + + printf("*** Kernel A test ***\n"); + printf(" --> Using PADDING_K_FIRST\n"); + printf(" --> Using mfma_32x32x(8x2)\n"); + + const ck_tile::index_t Lda = K; + const ck_tile::index_t Ldb = K; + const ck_tile::index_t Ldc = N; + + const auto a_lengths = std::array{M, K}; + const auto a_strides = std::array{Lda, 1}; + + const auto b_lengths = std::array{N, K}; + const auto b_strides = std::array{Ldb, 1}; + + const auto c_lengths = std::array{M, N}; + const auto c_strides = std::array{Ldc, 1}; + + // host verify + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.mData.data()); + b_buf.ToDevice(b_host.mData.data()); + + // Alignment + constexpr ck_tile::index_t kAAlignment = 8; + constexpr ck_tile::index_t kBAlignment = 8; + constexpr ck_tile::index_t kCAlignment = 8; + + constexpr ck_tile::index_t kBlockSize = 256; + + constexpr ck_tile::index_t kGemmMPerBlock = 256; + constexpr ck_tile::index_t kGemmKPerBlock = 32; + constexpr ck_tile::index_t kGemmNPerBlock = 128; + + ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + using gemm_kernel = ck_tile::Gemm; + + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 5, 1000}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + c_buf.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp b/tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp new file mode 100644 index 00000000000..2c4137837f1 --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/gemm.hpp @@ -0,0 +1,139 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "grid_gemm.hpp" + +namespace ck_tile { + +template +struct GridGemmProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + using CElementFunction = CElementFunction_; +}; + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +// C = A * B +template +struct Gemm +{ + static constexpr index_t kBlockSize = kBlockSize_; + + using GridGemmProblem_ = + GridGemmProblem; + + struct GridGemmPolicy + { + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kMPerBlock_; + static constexpr index_t kNPerBlock = kNPerBlock_; + static constexpr index_t kKPerBlock = kKPerBlock_; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() + { + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem>; + return BlockGemmPipelineAGmemBGmemCReg{}; + } + }; + + using GridGemm_ = GridGemm; + + CK_TILE_DEVICE void operator()(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const + { + const auto a_dram = [&] { + return make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(Lda, 1), number{}, number<1>{}); + }(); + + const auto b_dram = [&] { + return make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(Ldb, 1), number{}, number<1>{}); + }(); + + const auto c_dram = [&] { + return make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(Ldc, 1), number{}, number<1>{}); + }(); + + GridGemm_{}(a_dram, b_dram, c_dram, c_element_func); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp b/tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp new file mode 100644 index 00000000000..559d271af01 --- /dev/null +++ b/tutorial/ck_tile/gemm/02_padding_k_first/grid_gemm.hpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +template +struct GridGemm +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + using CElementFunction = typename Problem::CElementFunction; + + static constexpr auto kMPerBlock = Policy::kMPerBlock; + static constexpr auto kNPerBlock = Policy::kNPerBlock; + static constexpr auto kKPerBlock = Policy::kKPerBlock; + + template + CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid, + const BGridTensorView& b_grid, + CGridTensorView& c_grid, + const CElementFunction& c_element_func) const + { + const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{}); + const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{}); + const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{}); + + // divide problem + const auto id_block = get_block_id(); + + const auto num_tile_m = integer_divide_ceil(M, kMPerBlock); + const auto num_tile_n = integer_divide_ceil(N, kNPerBlock); + + const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto id_tile = block2tile(id_block); + + const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock); + + // A block window + auto a_block_window = make_tile_window( + a_grid, make_tuple(number{}, number{}), {iM, 0}); + + // B block window + auto b_block_window = make_tile_window( + b_grid, make_tuple(number{}, number{}), {iN, 0}); + + constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; + + const auto acc_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); + + // cast to CDataType and apply CElementFunction + const auto c_block_tile = tile_elementwise_in( + [&](const auto& acc) { return c_element_func(type_convert(acc)); }, + acc_block_tile); + + // store C + auto c_window = make_tile_window( + c_grid, make_tuple(number{}, number{}), {iM, iN}); + + store_tile(c_window, c_block_tile); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt b/tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt new file mode 100644 index 00000000000..4bf5975c9bb --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/CMakeLists.txt @@ -0,0 +1,17 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +set(EXAMPLE_MFMA_16X16X16 "tile_tutorial_mfma_16x16x16") + +message(DEBUG "adding example ${EXAMPLE_MFMA_16X16X16}") + +add_executable(${EXAMPLE_MFMA_16X16X16} EXCLUDE_FROM_ALL gemm.cpp) +target_include_directories(${EXAMPLE_MFMA_16X16X16} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +set(EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal -Wno-ctad-maybe-unsupported) + +target_compile_options(${EXAMPLE_MFMA_16X16X16} PRIVATE ${EXAMPLE_MFMA_16X16X16_COMPILE_OPTIONS}) + +add_dependencies(tutorials ${EXAMPLE_MFMA_16X16X16}) diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg.hpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg.hpp new file mode 100644 index 00000000000..2e4b35e44ec --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg.hpp @@ -0,0 +1,285 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "block_gemm_asmem_bsmem_creg_policy.hpp" + +namespace ck_tile { + +// A is block window on shared memory +// B is block window on shared memory +// C is block distributed tensor +template +struct BlockGemmASmemBSmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + using WarpGemm = remove_cvref_t< + decltype(Policy::template GetWarpGemmMWarpNWarp().template get<0>())>; + static constexpr index_t MWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<1>(); + static constexpr index_t NWarp = + Policy::template GetWarpGemmMWarpNWarp().template get<2>(); + + using AWarpDstr = typename WarpGemm::AWarpDstr; + using BWarpDstr = typename WarpGemm::BWarpDstr; + using CWarpDstr = typename WarpGemm::CWarpDstr; + + using AWarpTensor = typename WarpGemm::AWarpTensor; + using BWarpTensor = typename WarpGemm::BWarpTensor; + using CWarpTensor = typename WarpGemm::CWarpTensor; + + static constexpr auto a_warp_y_lengths = + to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto b_warp_y_lengths = + to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + static constexpr auto c_warp_y_lengths = + to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); + + static constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t{}; + static constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t{}; + + // C += A * B + template + CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, + [[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + // Warp GEMM + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + } + + // C = A * B + template + CK_TILE_DEVICE auto operator()([[maybe_unused]] const ABlockWindowTmp& a_block_window_tmp, + [[maybe_unused]] const BBlockWindowTmp& b_block_window_tmp) const + { + static_assert(std::is_same_v && + std::is_same_v, + "wrong!"); + + constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; + constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; + + static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && + KPerBlock == BlockGemmShape::kK, + "wrong!"); + + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM); + constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK; + + constexpr index_t MPerBlockPerIter = MPerBlock / MIterPerWarp; + constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; + constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; + + const index_t iMWarp = get_warp_id() / NWarp; + const index_t iNWarp = get_warp_id() % NWarp; + + // Construct A-warp-window + auto a_warp_window_tmp = make_tile_window( + a_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {a_block_window_tmp.get_window_origin().at(number<0>{}) + iMWarp * WarpGemm::kM, + a_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::AWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + MIterPerWarp> + a_warp_windows; + + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + a_warp_windows(mIter)(kIter) = a_warp_window_tmp; + move_tile_window(a_warp_windows(mIter)(kIter), + {mIter * MPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + // Construct B-warp-window + auto b_warp_window_tmp = make_tile_window( + b_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + {b_block_window_tmp.get_window_origin().at(number<0>{}) + iNWarp * WarpGemm::kN, + b_block_window_tmp.get_window_origin().at(number<1>{})}, + make_static_tile_distribution(typename WarpGemm::BWarpDstrEncoding{})); + + statically_indexed_array< + statically_indexed_array, + NIterPerWarp> + b_warp_windows; + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + b_warp_windows(nIter)(kIter) = b_warp_window_tmp; + move_tile_window(b_warp_windows(nIter)(kIter), + {nIter * NPerBlockPerIter, kIter * KPerBlockPerIter}); + }); + }); + + static_assert(std::is_same_v, "wrong!"); + + // Construct C-Block-Tensor + constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{}); + + constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); + + auto c_block_tensor = make_static_distributed_tensor(c_block_dstr); + + // Hot loop: + static_for<0, KIterPerWarp, 1>{}([&](auto kIter) { + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { + // Read A warp tensor from A block tensor + AWarpTensor a_warp_tensor; + a_warp_tensor = load_tile(a_warp_windows(mIter)(kIter)); + + static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { + // Read B warp tensor from B block tensor + BWarpTensor b_warp_tensor; + b_warp_tensor = load_tile(b_warp_windows(nIter)(kIter)); + + // Read C warp tensor from C block tensor + CWarpTensor c_warp_tensor; + + // Warp GEMM + if constexpr(KIterPerWarp == 0) + { + // c = a * b + c_warp_tensor = WarpGemm{}(a_warp_tensor, b_warp_tensor); + } + else + { + // c += a * b + c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); + + WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); + } + + // Write C warp tensor into C block tensor + c_block_tensor.set_y_sliced_thread_data( + merge_sequences(sequence{}, c_warp_y_index_zeros), + merge_sequences(sequence<1, 1>{}, c_warp_y_lengths), + c_warp_tensor.get_thread_buffer()); + }); + }); + }); + + return c_block_tensor; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg_policy.hpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg_policy.hpp new file mode 100644 index 00000000000..8ed98bd5302 --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_asmem_bsmem_creg_policy.hpp @@ -0,0 +1,43 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +namespace ck_tile { + +// Policy for BlockGemmASmemBSmemCReg with MFMA_16x16x16 instruction +struct BlockGemmASmemBSmemCRegPolicy +{ + template + CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() + { + // KERNEL_B uses 4x1 warp configuration + constexpr index_t kMWarp = 4; + constexpr index_t kNWarp = 1; + + // KERNEL_B uses mfma m16 n16 k16 + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + return make_tuple( + WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution{}, kMWarp, kNWarp); + } + else + { + static_assert(false, "Unsupported data type configuration for GEMM warp execution."); + } + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg.hpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg.hpp new file mode 100644 index 00000000000..d8006b7eb3e --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg.hpp @@ -0,0 +1,166 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "block_gemm_pipeline_agmem_bgmem_creg_policy.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" +#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp" + +namespace ck_tile { + +// A Tile Window: global memory +// B Tile Window: global memory +// C Distributed tensor: register +template +struct BlockGemmPipelineAGmemBGmemCReg +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = Problem::kBlockSize; + + static constexpr index_t kMPerBlock = BlockGemmShape::kM; + static constexpr index_t kNPerBlock = BlockGemmShape::kN; + static constexpr index_t kKPerBlock = BlockGemmShape::kK; + + using BlockGemm = remove_cvref_t())>; + + CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize() + { + return integer_divide_ceil( + sizeof(ADataType) * + Policy::template MakeALdsBlockDescriptor().get_element_space_size(), + 16) * + 16 + + sizeof(BDataType) * + Policy::template MakeBLdsBlockDescriptor().get_element_space_size(); + } + + template + CK_TILE_HOST_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, + const BDramBlockWindowTmp& b_dram_block_window_tmp, + index_t num_loop, + void* p_smem) const + { + static_assert( + std::is_same_v> && + std::is_same_v>, + "wrong!"); + + static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + "wrong!"); + + // ----------------------------------------------------------------------------------------- + // Definitions of all needed tiles + + // A tile in LDS + ADataType* p_a_lds = static_cast(p_smem); + + constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor(); + + auto a_lds_block = make_tensor_view(p_a_lds, a_lds_block_desc); + + constexpr index_t a_lds_block_space_size_aligned = + integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16) * + 16; + + // B tile in LDS + BDataType* p_b_lds = static_cast( + static_cast(static_cast(p_smem) + a_lds_block_space_size_aligned)); + + constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor(); + + auto b_lds_block = make_tensor_view(p_b_lds, b_lds_block_desc); + + // A DRAM tile window for load + auto a_copy_dram_window = + make_tile_window(a_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + a_dram_block_window_tmp.get_window_origin(), + Policy::template MakeADramTileDistribution()); + + // A LDS tile window for store + auto a_copy_lds_window = + make_tile_window(a_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + a_copy_dram_window.get_tile_distribution()); + + // B DRAM tile window for load + auto b_copy_dram_window = + make_tile_window(b_dram_block_window_tmp.get_bottom_tensor_view(), + make_tuple(number{}, number{}), + b_dram_block_window_tmp.get_window_origin(), + Policy::template MakeBDramTileDistribution()); + + // B LDS tile window for store + auto b_copy_lds_window = + make_tile_window(b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, + b_copy_dram_window.get_tile_distribution()); + + // A LDS tile for block GEMM + auto a_lds_gemm_window = make_tile_window( + a_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // B LDS tile for block GEMM + auto b_lds_gemm_window = make_tile_window( + b_lds_block, make_tuple(number{}, number{}), {0, 0}); + + // Block GEMM + auto block_gemm = BlockGemm(); + + // Acc register tile + auto c_block_tile = decltype(block_gemm(a_lds_gemm_window, b_lds_gemm_window)){}; + + using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); + using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); + + using ABlockTile = decltype(make_static_distributed_tensor(ABlockTileDistr{})); + using BBlockTile = decltype(make_static_distributed_tensor(BBlockTileDistr{})); + + ABlockTile a_block_tile; + BBlockTile b_block_tile; + using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex; + using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex; + constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock); + constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kKPerBlock); + + // ------------------------------------------------------------------------------------- + // Gemm pipeline start + + // Initialize C + tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); + + // non-prefetch + index_t iCounter = num_loop; + + while(iCounter > 0) + { + a_block_tile = load_tile(a_copy_dram_window); + b_block_tile = load_tile(b_copy_dram_window); + move_tile_window(a_copy_dram_window, a_dram_tile_window_step); + move_tile_window(b_copy_dram_window, b_dram_tile_window_step); + store_tile(a_copy_lds_window, a_block_tile); + store_tile(b_copy_lds_window, b_block_tile); + + block_sync_lds(); + block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); + block_sync_lds(); + + iCounter--; + } + + return c_block_tile; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp new file mode 100644 index 00000000000..09cba6c12e1 --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/block_gemm_pipeline_agmem_bgmem_creg_policy.hpp @@ -0,0 +1,129 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "block_gemm_asmem_bsmem_creg.hpp" + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" +#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" + +namespace ck_tile { + +// Policy for BlockGemmPipelineAGmemBGmemCReg with PADDING_K_FIRST optimization +struct BlockGemmPipelineAGmemBGmemCRegPolicy +{ + // 3d + PADDING_K_FIRST - adds padding to K dimension to avoid bank conflicts + template + CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() + { + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + // PADDING_K_FIRST: stride is (kKPerBlock / kKPack + 1) * kKPack instead of kKPerBlock + constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number<(kKPerBlock / kKPack + 1) * kKPack>{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto a_lds_block_desc = transform_tensor_descriptor( + a_lds_block_desc_0, + make_tuple(make_pass_through_transform(kMPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return a_lds_block_desc; + } + + // 3d + no padding for B (PADDING_K_FIRST only pads A in version2) + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() + { + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + constexpr index_t kKPack = 8; + + // B uses same layout as NAIVE (no padding) + constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, number{}, number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); + + constexpr auto b_lds_block_desc = transform_tensor_descriptor( + b_lds_block_desc_0, + make_tuple(make_pass_through_transform(kNPerBlock), + make_merge_transform(make_tuple(kKPerBlock / kKPack, kKPack))), + make_tuple(sequence<0>{}, sequence<1, 2>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return b_lds_block_desc; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() + { + using ADataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(ADataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t M2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t M1 = kBlockSize / get_warp_size(); + constexpr index_t M0 = kMPerBlock / (M2 * M1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() + { + using BDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + + constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; + constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; + + constexpr index_t K1 = 16 / sizeof(BDataType); + constexpr index_t K0 = kKPerBlock / K1; + constexpr index_t N2 = get_warp_size() / K0; + // coalesce reading for each blocks + constexpr index_t N1 = kBlockSize / get_warp_size(); + constexpr index_t N0 = kNPerBlock / (N2 * N1); + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 0>>, + sequence<1, 2>, + sequence<0, 1>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() + { + return BlockGemmASmemBSmemCReg{}; + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp new file mode 100644 index 00000000000..0b2d41ad868 --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.cpp @@ -0,0 +1,158 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include + +#include "ck_tile/host.hpp" +#include "gemm.hpp" +#include "../reference_gemm.hpp" + +/* + * KERNEL_B: GEMM with PADDING_K_FIRST + MFMA_16x16x16 + * A [M, K] + * B [N, K] + * C [M, N] + */ + +// elementwise lambda +struct CElementFunction +{ + template + CK_TILE_HOST_DEVICE auto operator()(const X& x) const + { + return x; + } +}; + +int main(int argc, char* argv[]) +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + + ck_tile::index_t verification = 0; + ck_tile::index_t M = 3328; + ck_tile::index_t N = 4096; + ck_tile::index_t K = 4096; + + if(argc == 2) + { + verification = std::stoi(argv[1]); + } + if(argc == 5) + { + verification = std::stoi(argv[1]); + M = std::stoi(argv[2]); + N = std::stoi(argv[3]); + K = std::stoi(argv[4]); + } + + printf("*** Kernel B test ***\n"); + printf(" --> Using PADDING_K_FIRST\n"); + printf(" --> Using mfma_16x16x16\n"); + + const ck_tile::index_t Lda = K; + const ck_tile::index_t Ldb = K; + const ck_tile::index_t Ldc = N; + + const auto a_lengths = std::array{M, K}; + const auto a_strides = std::array{Lda, 1}; + + const auto b_lengths = std::array{N, K}; + const auto b_strides = std::array{Ldb, 1}; + + const auto c_lengths = std::array{M, N}; + const auto c_strides = std::array{Ldc, 1}; + + // host verify + ck_tile::HostTensor a_host(a_lengths, a_strides); + ck_tile::HostTensor b_host(b_lengths, b_strides); + ck_tile::HostTensor c_host_dev(c_lengths, c_strides); + + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistributionIntegerValue{-5.f, 5.f}(b_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_buf(c_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.mData.data()); + b_buf.ToDevice(b_host.mData.data()); + + // Alignment + constexpr ck_tile::index_t kAAlignment = 8; + constexpr ck_tile::index_t kBAlignment = 8; + constexpr ck_tile::index_t kCAlignment = 8; + + constexpr ck_tile::index_t kBlockSize = 256; + + constexpr ck_tile::index_t kGemmMPerBlock = 256; + constexpr ck_tile::index_t kGemmKPerBlock = 32; + constexpr ck_tile::index_t kGemmNPerBlock = 128; + + ck_tile::index_t kGridSize = (M / kGemmMPerBlock) * (N / kGemmNPerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpSize = 64; // AMD GPU warp size + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / kWarpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + using gemm_kernel = ck_tile::Gemm; + + float ave_time = ck_tile::launch_kernel( + ck_tile::stream_config{nullptr, true, 0, 5, 1000}, + ck_tile::make_kernel(gemm_kernel{}, + kGridSize, + kBlockSize, + 0, + static_cast(a_buf.GetDeviceBuffer()), + static_cast(b_buf.GetDeviceBuffer()), + static_cast(c_buf.GetDeviceBuffer()), + M, + N, + K, + Lda, + Ldb, + Ldc, + CElementFunction{})); + auto pass = true; + + if(verification) + { + // reference gemm + ck_tile::HostTensor c_host_ref(c_lengths, c_strides); + reference_basic_gemm( + a_host, b_host, c_host_ref); + c_buf.FromDevice(c_host_dev.mData.data()); + pass &= ck_tile::check_err(c_host_dev, c_host_ref); + std::cout << "valid:" << (pass ? "y" : "n") << std::endl; + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(ADataType) * M * K + sizeof(BDataType) * K * N + sizeof(CDataType) * M * N; + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s" + << std::endl; + + return !pass; +} diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp new file mode 100644 index 00000000000..2c4137837f1 --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/gemm.hpp @@ -0,0 +1,139 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/core/tensor/tile_distribution.hpp" +#include "ck_tile/ops/common.hpp" +#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" + +#include "block_gemm_pipeline_agmem_bgmem_creg.hpp" +#include "grid_gemm.hpp" + +namespace ck_tile { + +template +struct GridGemmProblem +{ + using ADataType = ADataType_; + using BDataType = BDataType_; + using AccDataType = AccDataType_; + using CDataType = CDataType_; + + using CElementFunction = CElementFunction_; +}; + +template +struct TileGemmShape +{ + static constexpr index_t kM = kMPerTile; + static constexpr index_t kN = kNPerTile; + static constexpr index_t kK = kKPerTile; +}; + +template +struct BlockGemmPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using CDataType = remove_cvref_t; + using BlockGemmShape = remove_cvref_t; + + static constexpr index_t kBlockSize = kBlockSize_; +}; + +// C = A * B +template +struct Gemm +{ + static constexpr index_t kBlockSize = kBlockSize_; + + using GridGemmProblem_ = + GridGemmProblem; + + struct GridGemmPolicy + { + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kMPerBlock = kMPerBlock_; + static constexpr index_t kNPerBlock = kNPerBlock_; + static constexpr index_t kKPerBlock = kKPerBlock_; + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) + { + const auto unmerge = make_merge_transform(make_tuple(N0, M0)); + + return [unmerge](index_t block_id) { + multi_index<2> unmerged; + unmerge.calculate_lower_index(unmerged, make_multi_index(block_id)); + + return make_multi_index(unmerged.at(number<1>{}), unmerged.at(number<0>{})); + }; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemmPipeline() + { + using BlockGemmPipelineProblem_ = + BlockGemmPipelineProblem>; + return BlockGemmPipelineAGmemBGmemCReg{}; + } + }; + + using GridGemm_ = GridGemm; + + CK_TILE_DEVICE void operator()(const ADataType* p_a, + const BDataType* p_b, + CDataType* p_c, + const index_t M, + const index_t N, + const index_t K, + const index_t Lda, + const index_t Ldb, + const index_t Ldc, + const CElementFunction& c_element_func) const + { + const auto a_dram = [&] { + return make_naive_tensor_view( + p_a, make_tuple(M, K), make_tuple(Lda, 1), number{}, number<1>{}); + }(); + + const auto b_dram = [&] { + return make_naive_tensor_view( + p_b, make_tuple(N, K), make_tuple(Ldb, 1), number{}, number<1>{}); + }(); + + const auto c_dram = [&] { + return make_naive_tensor_view( + p_c, make_tuple(M, N), make_tuple(Ldc, 1), number{}, number<1>{}); + }(); + + GridGemm_{}(a_dram, b_dram, c_dram, c_element_func); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp b/tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp new file mode 100644 index 00000000000..559d271af01 --- /dev/null +++ b/tutorial/ck_tile/gemm/03_mfma_16x16x16/grid_gemm.hpp @@ -0,0 +1,72 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +namespace ck_tile { + +template +struct GridGemm +{ + using ADataType = typename Problem::ADataType; + using BDataType = typename Problem::BDataType; + using CDataType = typename Problem::CDataType; + using AccDataType = typename Problem::AccDataType; + using CElementFunction = typename Problem::CElementFunction; + + static constexpr auto kMPerBlock = Policy::kMPerBlock; + static constexpr auto kNPerBlock = Policy::kNPerBlock; + static constexpr auto kKPerBlock = Policy::kKPerBlock; + + template + CK_TILE_DEVICE void operator()(const AGridTensorView& a_grid, + const BGridTensorView& b_grid, + CGridTensorView& c_grid, + const CElementFunction& c_element_func) const + { + const auto M = a_grid.get_tensor_descriptor().get_length(number<0>{}); + const auto N = c_grid.get_tensor_descriptor().get_length(number<1>{}); + const auto K = a_grid.get_tensor_descriptor().get_length(number<1>{}); + + // divide problem + const auto id_block = get_block_id(); + + const auto num_tile_m = integer_divide_ceil(M, kMPerBlock); + const auto num_tile_n = integer_divide_ceil(N, kNPerBlock); + + const auto block2tile = Policy::template MakeBlock2TileMap(num_tile_m, num_tile_n); + + const auto id_tile = block2tile(id_block); + + const auto iM = __builtin_amdgcn_readfirstlane(id_tile.template at<0>() * kMPerBlock); + const auto iN = __builtin_amdgcn_readfirstlane(id_tile.template at<1>() * kNPerBlock); + + // A block window + auto a_block_window = make_tile_window( + a_grid, make_tuple(number{}, number{}), {iM, 0}); + + // B block window + auto b_block_window = make_tile_window( + b_grid, make_tuple(number{}, number{}), {iN, 0}); + + constexpr auto block_gemm_pipeline = Policy::template GetBlockGemmPipeline(); + + __shared__ char p_smem_char[block_gemm_pipeline.GetStaticLdsSize()]; + + const auto acc_block_tile = + block_gemm_pipeline(a_block_window, b_block_window, K / kKPerBlock, p_smem_char); + + // cast to CDataType and apply CElementFunction + const auto c_block_tile = tile_elementwise_in( + [&](const auto& acc) { return c_element_func(type_convert(acc)); }, + acc_block_tile); + + // store C + auto c_window = make_tile_window( + c_grid, make_tuple(number{}, number{}), {iM, iN}); + + store_tile(c_window, c_block_tile); + } +}; + +} // namespace ck_tile diff --git a/tutorial/ck_tile/gemm/CMakeLists.txt b/tutorial/ck_tile/gemm/CMakeLists.txt new file mode 100644 index 00000000000..3f50eac3a0d --- /dev/null +++ b/tutorial/ck_tile/gemm/CMakeLists.txt @@ -0,0 +1,10 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +include_directories(AFTER + ${CMAKE_CURRENT_LIST_DIR} +) + +add_subdirectory(01_naive_gemm) +add_subdirectory(02_padding_k_first) +add_subdirectory(03_mfma_16x16x16) diff --git a/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp b/tutorial/ck_tile/gemm/reference_gemm.hpp similarity index 95% rename from tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp rename to tutorial/ck_tile/gemm/reference_gemm.hpp index 786cf140d56..b93c9451420 100644 --- a/tutorial/ck_tile/01_naive_gemm/reference_gemm.hpp +++ b/tutorial/ck_tile/gemm/reference_gemm.hpp @@ -32,5 +32,6 @@ void reference_basic_gemm(const ck_tile::HostTensor& a_m_k, } }; - ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])(1); + ck_tile::make_ParallelTensorFunctor(f, c_m_n.mDesc.get_lengths()[0])( + std::thread::hardware_concurrency()); }