diff --git a/CHANGELOG.md b/CHANGELOG.md index c99fc1d0657..9d7dd9d97bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added support for fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline * Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index af0f81e832c..d1c06d43780 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template -CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, +CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) { - // TODO: Tile windows should works with unknow number of params - // Load element_wise API works only when the input typle is a tuple-tyupe - return tile_window[number<0>{}].load( - tile_window, elementwise, number{}, bool_constant{}); + // TODO: Tile windows should work with unknown number of params + // Load element_wise API works only when the input type is a tuple-type + return tile_windows[number<0>{}].load( + tile_windows, elementwise, number{}, bool_constant{}); } // Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. @@ -85,12 +85,12 @@ template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, +CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, bool_constant = {}) { - return tile_window.load(dst_tile, number{}, bool_constant{}); + tile_window.load(dst_tile, number{}, bool_constant{}); } /** @@ -131,7 +131,7 @@ template -CK_TILE_DEVICE auto load_tile_raw(T& tile, +CK_TILE_DEVICE void load_tile_raw(T& tile, const tile_window_linear; +// Mixed-precision policy that allows different input and output types +template +struct MixedPrecisionTranspose : public DefaultTranspose +{ + // Inherits quad pattern validation from input type + // but allows output type to differ +}; + template ::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto load_tile_transpose_with_offset( +CK_TILE_DEVICE void load_tile_transpose_with_offset( + DistributedTensor_& out_tensor, const tile_window_with_static_distribution& __restrict__ tile_window, index_t offset) { - using OutTileDstrEncode = typename OutputTileDistributionTraits< - typename TileDistribution_::DstrEncode, - typename BottomTensorView_::DataType>::TransposedDstrEncode; - auto out_tensor = make_static_distributed_tensor( - make_static_tile_distribution(OutTileDstrEncode{})); auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; - constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); @@ -442,8 +448,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( number{}, trans_tensor.get_thread_buffer().template get_as(number{})); }); - - return out_tensor; } /** @@ -455,6 +459,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -462,16 +467,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void +load_tile_transpose(DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_with_offset(out_tensor, tile_window, 0); +} + template < typename BottomTensorView_, typename WindowLengths_, @@ -488,7 +514,133 @@ load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) { - return load_tile_transpose_with_offset(tile_window, 0); + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, 0); + + return out_tensor; +} + +/** + * @brief Mixed-precision transpose load: converts input data type to output data type while + * transposing. + * + * This function enables transposing from one data type (e.g., fp8) to another (e.g., fp16) in a + * single operation. The input tile distribution encoding must be valid for the input data type, + * and the output distribution will be generated based on the output data type. + * + * @tparam DistributedTensor_ The output tensor type with desired output data type. + * @tparam BottomTensorView_ The input tensor view (may have different data type than output). + * @tparam WindowLengths_ The type representing the window lengths. + * @tparam TileDistribution_ The type representing the tile distribution for input. + * @tparam NumCoord The number of coordinates (dimensions). + * @tparam Policy The transpose policy (should validate against input type). + * + * @note + * - Input and output must have compatible element space sizes (total byte count per Y-space). + * - Type conversion is performed element-by-element during the copy. + * - The validation uses the input data type for quad pattern checking. + * - The output distribution is generated based on the output data type. + */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) +{ + using InputDataType = typename BottomTensorView_::DataType; + using OutputDataType = typename DistributedTensor_::DataType; + + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); + constexpr auto input_distr = TileDistribution_{}; + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; + + constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); + constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); + + constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y(); + // constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y(); + + constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths()); + constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths()); + + constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size(); + constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size(); + + // For mixed precision: element space size must be the same (total bytes match) + static_assert(y_in_element_space_size == y_out_element_space_size, + "For mixed precision transpose, input and output element space size must match!"); + + // Allow different vector lengths (e.g., fp8 may vectorize 8 elems, fp16 may vectorize 4). + // Ensure total element counts are consistent and divisible by the input vector length. + constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1]; + constexpr index_t total_elems_in = + reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}); + constexpr index_t total_elems_out = + reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{}); + static_assert(total_elems_in == total_elems_out, + "For mixed precision transpose, input/output element counts must match!"); + static_assert(total_elems_in % vecLoadSize == 0, + "Input vector length must evenly divide total elements."); + + constexpr index_t num_of_access = total_elems_in / vecLoadSize; + + // Read as input type, convert to output type + using InputDataVec = array; + static_for<0, num_of_access, 1>{}([&](auto iAccess) { + auto input_vec = + trans_tensor.get_thread_buffer().template get_as(number{}); + + // Element-wise type conversion + // This will be unrolled by the compiler for each element in the vector + static_for<0, vecLoadSize, 1>{}([&](auto iElem) { + auto output_elem = type_convert(input_vec[iElem]); + out_tensor.get_thread_buffer()[number{}] = output_elem; + }); + }); +} + +/** + * @brief Mixed-precision transpose load with zero offset. + * + * Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0. + */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0); } } // namespace ck_tile diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index d39da82a627..da90675fdd4 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -182,11 +182,11 @@ struct tile_window_with_static_distribution * The same thread, during vectorized reading, accesses the same set of * data from A0, A1, A2, … AN. */ - template - CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + CK_TILE_DEVICE auto load(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -194,7 +194,7 @@ struct tile_window_with_static_distribution constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, - tile_window, + tile_windows, elementwise, number{}, bool_constant{}); @@ -202,12 +202,12 @@ struct tile_window_with_static_distribution } template CK_TILE_DEVICE void load(DistributedTensor& dst_tensor, - const TileWindow_& tile_window, + const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -218,14 +218,14 @@ struct tile_window_with_static_distribution using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = typename Base::TileDstr{}; - constexpr auto sizeOfTuple = TileWindow_::size(); + constexpr auto sizeOfTuple = remove_cvref_t::size(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; @@ -236,7 +236,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const auto idx_vec_value = generate_tuple( [&](auto jj) { - return tile_window[number{}] + return tile_windows[number{}] .get_bottom_tensor_view() .template get_vectorized_elements( bottom_tensor_thread_coord, diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 00234b20cf9..a62bbe981cc 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -7,8 +7,9 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 45fa52e5051..71919b61873 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -5,8 +5,9 @@ #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index b23e45c2331..924db5fb60e 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -10,8 +10,9 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 94243e674f5..0113d8c9a28 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -2,8 +2,9 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp new file mode 100644 index 00000000000..ae11ff13146 --- /dev/null +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -0,0 +1,109 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +// DetermineWarpPrecType is a set of rules to determine the right precision type to use +// for the warp GEMM, given the other precision type. This gives rise to a type conversion: +// type conversions are sometimes needed to obtain a pair of types that are compatible with +// the hardware matrix operations available. A typical use case is mixed precision GEMMs. + +namespace ck_tile { +// For the most general case, default to no conversion. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = BPrecType; +}; + +// For pk_fp4_t x pk_fp4_t, keep pk_fp4_t +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::pk_fp4_t; + using b_prec_type = ck_tile::pk_fp4_t; +}; + +// For pk_int4_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_int4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For pk_fp4_t x B, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For B x pk_fp4_raw_t, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_raw_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For fp8 x bf16, use fp8 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; +}; + +// For bf16 x fp8, use bf16 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::bf16_t; + using b_prec_type = ck_tile::bf16_t; +}; + +// For fp8 x fp16, use fp8 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; +}; + +// For fp16 x fp8, use fp16 +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::half_t; + using b_prec_type = ck_tile::half_t; +}; +}; // namespace ck_tile diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp new file mode 100644 index 00000000000..ee315b9f61a --- /dev/null +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" + +namespace ck_tile { + +template +struct ConverterLoader +{ + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src_window) + { + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto src = load_tile(src_window); + + using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + const element_wise::PassThroughPack8 elementwise_op{}; + + elementwise_op(dst.get_thread_buffer().template get_as()(i), + src.get_thread_buffer().template get_as()[i]); + }); + } + + template + CK_TILE_DEVICE static void load_with_type_convert(WarpTile& dst, const WarpWindow& src_window) + { + if constexpr(LoadTranspose) + { + if constexpr(std::is_same_v) + { + load_tile_transpose(dst, src_window); + } + else + { + load_tile_transpose_convert(dst, src_window); + } + } + else + { + if constexpr(std::is_same_v) + { + load_tile(dst, src_window); + } + else + { + auto tmp = load_tile(src_window); + sweep_tile([&](auto i) { + element_wise::PassThrough elementwise_op{}; + elementwise_op(dst(i), tmp(i)); + }); + } + } + } +}; + +template +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src_window) +{ + if constexpr(std::is_same_v) + { + ConverterLoader::load_interleaved_pk_type( + dst, src_window); + } + else + { + ConverterLoader:: + template load_with_type_convert(dst, src_window); + } +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp deleted file mode 100644 index 10c2a1e4df7..00000000000 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ /dev/null @@ -1,55 +0,0 @@ -// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. -// SPDX-License-Identifier: MIT - -#pragma once - -#include "ck_tile/core/config.hpp" -#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" - -namespace ck_tile { - -template -struct InterleavedPKTypeLoader -{ - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } -}; - -template -CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) -{ - if constexpr(std::is_same_v) - { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); - } - else if constexpr(LoadTranspose) - { - dst = load_tile_transpose(src); - } - else - { - load_tile(dst, src); - } -} - -} // namespace ck_tile diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 5752703ab60..2c0ae4ad093 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -8,8 +8,9 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 433462b22e2..0eb9e59e723 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -10,8 +10,9 @@ #include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4f636b59625..10a64be1731 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -7,7 +7,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include @@ -92,16 +92,8 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = std::conditional_t || - std::is_same_v, - BDataType, - ADataType>; - // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = std::conditional_t || - std::is_same_v || - std::is_same_v, - ADataType, - BDataType>; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 2d3a819e804..2e71957ac77 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -21,8 +21,9 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index eb4aa16d054..2068dfeefe0 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -60,8 +60,9 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9aeabaa8c22..16212c0d130 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR s_acc = gemm_0(q_reg_tensor, k_reg_tensor); dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); } if constexpr(is_epilogue) { @@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); // STAGE 3, P^T@OGrad^T Gemm1 auto pt_reg_tensor = make_static_distributed_tensor( @@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3d21928cedf..37b4ae41a3f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(q_lds_write_window, q_dram_window); async_load_tile(do_lds_write_window, do_dram_window); __builtin_amdgcn_s_waitcnt(0); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); - q_reg_tensor = load_tile(q_lds_read_window); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); - do_reg_tensor = load_tile(do_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); + q_reg_tensor = load_tile(q_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); + do_reg_tensor = load_tile(do_lds_read_window); lse_block_tile = load_tile(lse_dram_window); d_block_tile = load_tile(d_dram_window); @@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); s_waitcnt(); - k_reg_tensor = load_tile(k_lds_read_window); - v_reg_tensor = load_tile(v_lds_read_window); - kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + k_reg_tensor = load_tile(k_lds_read_window); + v_reg_tensor = load_tile(v_lds_read_window); + load_tile_transpose(kt_reg_tensor, kt_lds_read_window); } if constexpr(is_epilogue) { @@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR block_sync_lds(); if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..4cca604ff15 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline }; auto V_lds_load = [&](auto v_lds_read_idx) { - kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx)); }; decltype(m) m_old; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index aab79c52ae9..6bf6d2b5033 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -591,7 +591,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // loop over along the [V]alue Sequence length move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); + load_tile_transpose(v_tile, v_lds_read_window); }); // move back to the origin move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index e6802e82dce..2eb4abd6411 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -14,8 +14,9 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2c3a1611216..abfafb6bb4b 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -76,8 +76,9 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 79030fcd513..847d6c782a8 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -94,12 +94,8 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; - using BTypeToUse = std::conditional_t || - std::is_same_v, - ADataType, - BDataType>; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using WarpGemm = remove_cvref_t; @@ -139,6 +135,7 @@ struct BlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + template CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -158,12 +155,18 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; + using Attr = typename WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessA = + convert ? Attr::AttrNumAccessV * sizeof(ADataType) / sizeof(ComputeDataType) + : Attr::AttrNumAccessV; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + a_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_awarp_dstr_encoding()); return a_block_dstr_encode; } + template CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -183,8 +186,13 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; + using Attr = typename WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessB = + convert ? Attr::AttrNumAccessV * sizeof(BDataType) / sizeof(ComputeDataType) + : Attr::AttrNumAccessV; constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + b_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_bwarp_dstr_encoding()); return b_block_dstr_encode; } @@ -217,10 +225,8 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -289,9 +295,9 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = - make_static_tile_distribution(MakeABlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; static constexpr auto BLdsTileDistr = - make_static_tile_distribution(MakeBBlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); @@ -348,10 +354,8 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile(a_warp_tile_, - a_lds_gemm_window); - load_int4_tile(b_warp_tile_, - b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c9410..358101d1db1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,9 +64,7 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template @@ -74,7 +72,7 @@ struct GemmPipelineAgBgCrImplBase SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_int4_tile(dst_block_tile, dram_tile_window); + load_and_convert_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } @@ -109,7 +107,7 @@ struct GemmPipelineAgBgCrImplBase bool_constant = {}) const { if constexpr(LoadTranspose) - dst_block_tile = load_tile_transpose(lds_tile_window); + load_tile_transpose(dst_block_tile, lds_tile_window); else load_tile(dst_block_tile, lds_tile_window); } @@ -237,12 +235,16 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_load_tile_distr = []() { if constexpr(is_a_load_tr) + { return make_static_tile_distribution( typename InputTileDistributionTraits< typename ALdsLoadTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); + typename ALdsTensorView::DataType>::TransposedDstrEncode{}); + } else + { return ALdsLoadTileDistr{}; + } }(); auto a_lds_gemm_window = @@ -313,19 +315,18 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) + { return make_static_tile_distribution( - typename InputTileDistributionTraits::TransposedDstrEncode{}); - + typename InputTileDistributionTraits< + typename BLdsLoadTileDistr::DstrEncode, + typename BLdsTensorView::DataType>::TransposedDstrEncode{}); + } else + { return BLdsLoadTileDistr{}; + } }(); auto b_lds_gemm_window = diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 8fae7042037..df1b5207b14 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -440,10 +440,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v; + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 777537a83a0..583d94a9f12 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" @@ -22,6 +23,10 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); @@ -33,8 +38,8 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad : WGAttrNumAccessEnum::Invalid; - using WarpGemm = WarpGemmDispatcher; - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 8074994fdd3..f9bdb2e125e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -112,7 +112,6 @@ struct UniversalGemmBasePolicy using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = Derived::template GetSmemPackA(); if constexpr(is_a_load_tr) { @@ -246,6 +245,7 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { + constexpr index_t KPack = Derived::template GetSmemPackA(); constexpr auto DataTypeSize = sizeof(ADataType); constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = @@ -302,15 +302,12 @@ struct UniversalGemmBasePolicy * @tparam Problem Gemm pipeline problem. * @return B tensor LDS block descriptor. */ - template + template > CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - + using BLayout = remove_cvref_t; + using BDataType = OverrideBDataType; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; @@ -895,14 +892,10 @@ struct UniversalGemmPipelineAgBgCrPolicy : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad : WGAttrNumAccessEnum::Invalid; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; - using BTypeToUse = std::conditional_t || - std::is_same_v, - ADataType, - BDataType>; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using WarpGemm = WarpGemmDispatcher( - b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); // Prefill A0 Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); @@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 do { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); Base::GlobalPrefetch( @@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 HotLoopScheduler(); } { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); Base::GlobalPrefetch( @@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 if constexpr(TailNum == TailNumber::Even) { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); block_weight_preshuffle( diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 00512424752..db6eb5d9108 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -17,6 +17,12 @@ namespace ck_tile { using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl< WarpGemmAttributeMfma>>; +template +using WarpGemmMfmaF32F32F32M32N32K16 = WarpGemmImpl, + 8, + AttrNumAccess>>; + template using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 3c7944a4277..d752d2818bf 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -44,12 +44,12 @@ struct WarpGemmAttributeMfma static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - template + template static constexpr auto get_warp_dstr_encoding() { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(NumAccess != 0 && kKPerThread % NumAccess == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(NumAccess == 1) return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -61,14 +61,30 @@ struct WarpGemmAttributeMfma return tile_distribution_encoding< sequence<>, tuple, - sequence>, + sequence>, tuple>, tuple>, sequence<2, 2>, sequence<0, 2>>{}; } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + + template + static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -151,14 +167,17 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - template + template CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { if constexpr(kMNBlock == 1 && kNMBlock == 1) { static_assert(kKPerThread % AttrNumAccessV == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(NumAccess == 1) return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -172,7 +191,7 @@ struct WarpGemmAttributeMfmaIterateK tuple, sequence>, + Impl::kABKPerLane * kKIter / NumAccess>>, tuple>, tuple>, sequence<2, 2>, @@ -180,7 +199,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(kMNBlock == 1 && 1 < kNMBlock) { - static_assert(AttrNumAccessV == 1, + static_assert(NumAccess == 1, "Multiple access is not supported when using multi-block"); // each M/N blocks share the same data return tile_distribution_encoding< @@ -193,7 +212,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(1 < kMNBlock && kNMBlock == 1) { - static_assert(AttrNumAccessV == 1, + static_assert(NumAccess == 1, "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< @@ -207,6 +226,18 @@ struct WarpGemmAttributeMfmaIterateK } } + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) @@ -245,10 +276,12 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -327,10 +360,25 @@ struct WarpGemmAttributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = - typename WarpGemmAttributeMfma::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfma::AWarpDstrEncoding; + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_bwarp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_awarp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -384,6 +432,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -392,6 +441,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution { - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // swap A and B using ADataType = typename Impl::BDataType; @@ -521,10 +573,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution } } - using AWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::AWarpDstrEncoding; + template + using AWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template BWarpDstrEncoding; + template + using BWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template AWarpDstrEncoding; using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -603,6 +657,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -611,6 +666,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence<1>>; + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ef31d06c9c2..4a275848b2b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -67,6 +67,10 @@ template struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // AttrNumAccessV is required for compatibility with the block GEMM, and is currently ignored + // within WarpGemmAttributeWmma + static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // When kTransC is true and A/B types differ, we need an impl with swapped types using TransposedImpl = @@ -99,8 +103,22 @@ struct WarpGemmAttributeWmma // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2 // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4 - using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; - using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; + template + static constexpr auto get_awarp_dstr_encoding() + { + return typename AWarpDstrEncodingTrait::type{}; + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return typename BWarpDstrEncodingTrait::type{}; + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 using CWarpDstrEncoding = diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index a7d71d4fa3c..2689cb8b0ec 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -37,6 +37,8 @@ template<> struct Dispatcher { using Typ template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K16<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K16; }; // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index ca7c32b6af5..5ff0660f49c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -25,8 +25,8 @@ struct WarpGemmImpl using BDataType = typename WarpGemmAttribute::BDataType; using CDataType = typename WarpGemmAttribute::CDataType; - using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; - using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using AWarpDstrEncoding = typename WarpGemmAttribute::template AWarpDstrEncoding<>; + using BWarpDstrEncoding = typename WarpGemmAttribute::template BWarpDstrEncoding<>; using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; using AWarpDstr = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 696de378aaf..df4a7c79778 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -30,8 +30,9 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 3fb80c21ffe..12b7581abb4 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -261,11 +261,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( - a_warp_tile_, a_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 9d19e902e5c..ed5a613d07a 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,10 +248,8 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_int4_tile( - a_warp_tile_, a_block_window); - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -395,10 +393,8 @@ struct AQuantBlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile( - a_warp_tile_, a_lds_gemm_window); - load_int4_tile( - b_warp_tile_, b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B with quantization support diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 03b9dfe34db..7b0f18c44fb 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -258,11 +258,9 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile( - a_warp_tile_, a_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cd70c2ca862..afe68370249 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -197,20 +197,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, const BDramWindow& b_dram_window) { - using DestDataType = typename BBlockTile_::DataType; - using SrcDataType = typename BDramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template ADramWindow& a_dram_window, const DramTileWindowStep& dram_tile_window_step) { - using DestDataType = typename ABlockTile_::DataType; - using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 71e4a744003..3d10bf76fa9 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -174,10 +174,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index be91002cdbd..c92f94e46ec 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -40,10 +40,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 && - std::is_same_v, - ADataType, - BDataType>; + std::conditional_t, ADataType, BDataType>; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; @@ -185,10 +182,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 566f0b6153f..8a0847daf91 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -354,8 +354,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -434,8 +434,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -463,8 +463,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -513,8 +513,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); aq_block_tile_2 = load_tile(aq_copy_dram_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index a2a8c89e0aa..636c7f0a1de 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -344,8 +344,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -430,8 +430,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -467,8 +467,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -525,8 +525,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); bq_block_tile_2 = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 6743e466131..5dafc554203 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -11,8 +11,9 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 1d33ebf39d8..faa165a8b00 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -5,8 +5,9 @@ #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index ebb20aebf47..2266c138729 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -8,8 +8,9 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 469a98c256e..9f572ff5cbb 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -5,8 +5,9 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 88a3d8a137e..c7747a67e70 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -4,8 +4,9 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 3e44122afab..43b24c7f8ca 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -6,8 +6,9 @@ #include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp" #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index 9e31b7bbe26..e680d257453 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -13,8 +13,9 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index ad23a708b79..7ee67334d6b 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -9,8 +9,9 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 13372f32899..ad984c033f0 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -8,8 +8,9 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 9cf3e08319f..b810a57dda0 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -4,8 +4,9 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 090ad0919f5..13d818174e6 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -4,8 +4,9 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 7afce1708b4..b7219511faa 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -6,8 +6,9 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 4bef581254e..70449b5adf8 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -79,41 +79,57 @@ using KernelTypesMemWmma = ::testing::Types< using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 8dc2e884302..a3da751a82d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -299,8 +299,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k); - ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n); + ck_tile::FillUniformDistributionIntegerValue{-0.5, 0.5, 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-0.5, 0.5, 11940}(b_k_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());