From 4d77856be54372ac750b81274d37c18ca707463b Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Sat, 6 Dec 2025 17:17:16 +0000 Subject: [PATCH 01/44] Make some functions return void explicitly instead of auto --- include/ck_tile/core/tensor/load_tile.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index af0f81e832c..43cd0ab9c4d 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -85,12 +85,12 @@ template -CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile, +CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile, const TileWindow_& tile_window, number = {}, bool_constant = {}) { - return tile_window.load(dst_tile, number{}, bool_constant{}); + tile_window.load(dst_tile, number{}, bool_constant{}); } /** @@ -131,7 +131,7 @@ template -CK_TILE_DEVICE auto load_tile_raw(T& tile, +CK_TILE_DEVICE void load_tile_raw(T& tile, const tile_window_linear Date: Fri, 21 Nov 2025 10:53:14 +0000 Subject: [PATCH 02/44] Use decltype for consistency in Interwave variant of BlockGemmImpl --- .../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index f6e26ad206d..7882df885d8 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -366,9 +366,9 @@ struct BlockUniversalGemmAsBsCr static constexpr index_t KInnerLoopIter = KPerInnerLoop / WarpGemm::kKPerThread; static constexpr auto ALdsTileDistr = - make_static_tile_distribution(MakeABlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeABlockDistributionEncode())){}; static constexpr auto BLdsTileDistr = - make_static_tile_distribution(MakeBBlockDistributionEncode()); + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using ALdsTile = decltype(make_static_distributed_tensor(ALdsTileDistr)); using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); From bda5a7aa2d0a49cd672ce0ec4290adcfbae2c6bd Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 19 Nov 2025 09:08:27 +0000 Subject: [PATCH 03/44] Add braces --- .../ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 4973d9c9410..dcc11015e75 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -237,12 +237,16 @@ struct GemmPipelineAgBgCrImplBase auto a_lds_load_tile_distr = []() { if constexpr(is_a_load_tr) + { return make_static_tile_distribution( typename InputTileDistributionTraits< typename ALdsLoadTileDistr::DstrEncode, typename Problem::ADataType>::TransposedDstrEncode{}); + } else + { return ALdsLoadTileDistr{}; + } }(); auto a_lds_gemm_window = @@ -320,12 +324,15 @@ struct GemmPipelineAgBgCrImplBase auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) + { return make_static_tile_distribution( typename InputTileDistributionTraits::TransposedDstrEncode{}); - + } else + { return BLdsLoadTileDistr{}; + } }(); auto b_lds_gemm_window = From 825d17c3d772b0533afdccb0aba91a85537234d8 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 11 Dec 2025 12:32:52 +0000 Subject: [PATCH 04/44] Fix a comment --- include/ck_tile/core/tensor/load_tile.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 43cd0ab9c4d..39249cec9c3 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -57,8 +57,8 @@ CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, number = {}, bool_constant = {}) { - // TODO: Tile windows should works with unknow number of params - // Load element_wise API works only when the input typle is a tuple-tyupe + // TODO: Tile windows should work with unknown number of params + // Load element_wise API works only when the input type is a tuple-type return tile_window[number<0>{}].load( tile_window, elementwise, number{}, bool_constant{}); } From ca71cd75fc1486012a1717878201dcbfefdc1d78 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 17 Dec 2025 14:52:24 +0000 Subject: [PATCH 05/44] Reduce the scope of KPack in MakeALdsBlockDescriptor --- .../gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 6199142d986..4bb7829fb64 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -112,7 +112,6 @@ struct UniversalGemmBasePolicy using ADataType = OverrideADataType; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; - constexpr index_t KPack = GetSmemPackA(); if constexpr(is_a_load_tr) { @@ -246,6 +245,7 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { + constexpr index_t KPack = GetSmemPackA(); constexpr auto DataTypeSize = sizeof(ADataType); constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = From 994b8f4c22974a24f4f45652f9250673cc157023 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 12 Nov 2025 08:34:13 +0000 Subject: [PATCH 06/44] Minor refactoring of load_interleaved_pk_type --- .../ck_tile/ops/common/load_interleaved_pk_type.hpp | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp index 10c2a1e4df7..0cc4e43d557 100644 --- a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -12,19 +12,18 @@ template struct InterleavedPKTypeLoader { template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) { - const element_wise::PassThroughPack8 elementwise_op{}; - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); + const auto tmp = load_tile(src); using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); + const element_wise::PassThroughPack8 elementwise_op{}; + + elementwise_op(dst.get_thread_buffer().template get_as()(i), + tmp.get_thread_buffer().template get_as()[i]); }); } }; From 74533b475505499477c2b86c43724b74cf6e9d2a Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 27 Nov 2025 09:28:27 +0000 Subject: [PATCH 07/44] Rename load_interleaved_pk_type to load_and_convert_tile --- include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 2 +- include/ck_tile/ops/batched_contraction.hpp | 2 +- include/ck_tile/ops/batched_transpose.hpp | 2 +- include/ck_tile/ops/common.hpp | 2 +- .../{load_interleaved_pk_type.hpp => load_and_convert_tile.hpp} | 0 include/ck_tile/ops/elementwise.hpp | 2 +- include/ck_tile/ops/epilogue.hpp | 2 +- include/ck_tile/ops/flatmm.hpp | 2 +- include/ck_tile/ops/fmha.hpp | 2 +- include/ck_tile/ops/fused_moe.hpp | 2 +- include/ck_tile/ops/gemm.hpp | 2 +- .../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp | 2 +- .../ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp | 2 +- include/ck_tile/ops/gemm_quant.hpp | 2 +- include/ck_tile/ops/grouped_convolution.hpp | 2 +- include/ck_tile/ops/image_to_column.hpp | 2 +- include/ck_tile/ops/layernorm2d.hpp | 2 +- include/ck_tile/ops/norm_reduce.hpp | 2 +- include/ck_tile/ops/permute.hpp | 2 +- include/ck_tile/ops/pooling.hpp | 2 +- include/ck_tile/ops/reduce.hpp | 2 +- include/ck_tile/ops/rmsnorm2d.hpp | 2 +- include/ck_tile/ops/smoothquant.hpp | 2 +- include/ck_tile/ops/softmax.hpp | 2 +- include/ck_tile/ops/topk.hpp | 2 +- include/ck_tile/ops/topk_softmax.hpp | 2 +- 26 files changed, 25 insertions(+), 25 deletions(-) rename include/ck_tile/ops/common/{load_interleaved_pk_type.hpp => load_and_convert_tile.hpp} (100%) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index 00234b20cf9..aa0f632c216 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -8,7 +8,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 45fa52e5051..9c90db67edd 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index b23e45c2331..9cac035c445 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index 94243e674f5..ad7da5c1833 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -3,7 +3,7 @@ #pragma once #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp similarity index 100% rename from include/ck_tile/ops/common/load_interleaved_pk_type.hpp rename to include/ck_tile/ops/common/load_and_convert_tile.hpp diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index 5752703ab60..bc72f3b0ba1 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index 433462b22e2..d1b38a8bca6 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -11,7 +11,7 @@ #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index 2d3a819e804..e08fac48c7e 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -22,7 +22,7 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index eb4aa16d054..0639fa1b36e 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -61,7 +61,7 @@ #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index e6802e82dce..60f5bd1c4e3 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -15,7 +15,7 @@ #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 2c3a1611216..8dbf111048e 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -77,7 +77,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 7882df885d8..a0fa732d1a4 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index c9499106de7..9939a9586e7 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,7 +4,7 @@ #pragma once #include "ck_tile/core.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 696de378aaf..6aee73cda1d 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -31,7 +31,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index 6743e466131..eeb9b1d8a81 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -12,7 +12,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 1d33ebf39d8..07d99890869 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index ebb20aebf47..8f9ab205ac4 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index 469a98c256e..eae0ea14a33 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -6,7 +6,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 88a3d8a137e..4d37f4fbc12 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index 3e44122afab..faa77d53273 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index 57f3f3c80a8..46bb96af181 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -10,7 +10,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index ad23a708b79..f271be50068 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -10,7 +10,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 13372f32899..4c2fe9bee43 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -9,7 +9,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index 9cf3e08319f..c79ba06abfe 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 090ad0919f5..474ba932270 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -5,7 +5,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 7afce1708b4..066fbf5feea 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -7,7 +7,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" +#include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/utils.hpp" From 3a094e2f8bf954a5fa46ac6fe3379509759f063b Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 26 Nov 2025 11:33:27 +0000 Subject: [PATCH 08/44] Include ck_tile/core.hpp in load_interleaved_pk_type.hpp for better IDE integration --- include/ck_tile/ops/common/load_and_convert_tile.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index 0cc4e43d557..0da1f112296 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -3,7 +3,7 @@ #pragma once -#include "ck_tile/core/config.hpp" +#include "ck_tile/core.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" namespace ck_tile { From cfa11f2d1fad3cf19b0fbae601349bce4525024d Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 27 Nov 2025 08:35:18 +0000 Subject: [PATCH 09/44] Rename InterleavedPKTypeLoader to ConverterLoader, and load_int4_tile to load_and_convert_tile --- .../ops/common/load_and_convert_tile.hpp | 6 ++--- .../block/block_universal_gemm_as_bs_cr.hpp | 24 +++++++++---------- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 3 ++- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 4 ++-- .../block_universal_gemm_as_aquant_bs_cr.hpp | 4 ++-- .../block_universal_gemm_as_bs_bquant_cr.hpp | 4 ++-- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 6 +++-- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 3 ++- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 3 ++- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 3 ++- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 8 +++---- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 8 +++---- 12 files changed, 41 insertions(+), 35 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index 0da1f112296..f2ee23e98bd 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -9,7 +9,7 @@ namespace ck_tile { template -struct InterleavedPKTypeLoader +struct ConverterLoader { template CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) @@ -34,12 +34,12 @@ template -CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) { if constexpr(std::is_same_v) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + ConverterLoader::load_interleaved_pk_type(dst, src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index a0fa732d1a4..040051a5e8d 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -228,10 +228,10 @@ struct BlockUniversalGemmAsBsCr "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile( + a_warp_tile_, a_block_window); + load_and_convert_tile( + b_warp_tile_, b_block_window); // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -294,10 +294,10 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile(a_warp_tile_, - a_block_window); - load_int4_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile( + a_warp_tile_, a_block_window); + load_and_convert_tile( + b_warp_tile_, b_block_window); } // C += A * B @@ -425,10 +425,10 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile(a_warp_tile_, - a_lds_gemm_window); - load_int4_tile(b_warp_tile_, - b_lds_gemm_window); + load_and_convert_tile( + a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile( + b_warp_tile_, b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index dcc11015e75..74632ee5b02 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -74,7 +74,8 @@ struct GemmPipelineAgBgCrImplBase SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_int4_tile(dst_block_tile, dram_tile_window); + load_and_convert_tile(dst_block_tile, + dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index c44d330d139..132b31ed620 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -261,10 +261,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_int4_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 705a992b526..b40168b2afb 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,9 +248,9 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_int4_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); - load_int4_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 313e449c7b5..ece393b40d7 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -258,10 +258,10 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_int4_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_int4_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index cd70c2ca862..b6fd25139ec 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -200,7 +200,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, + a_dram_window); } template @@ -210,7 +211,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, + b_dram_window); } template using DestDataType = typename ABlockTile_::DataType; using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_int4_tile(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, + a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 22dd78e0707..0a4793bb125 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -171,7 +171,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, + a_dram_window); } template (b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, + b_dram_window); } template diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 0f3951ffccc..49064bdb763 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -349,7 +349,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -430,7 +430,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -455,7 +455,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -503,7 +503,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index e4de7e42116..5455944de0c 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -335,7 +335,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -421,7 +421,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -458,7 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -516,7 +516,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_int4_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); From 9559a934329058091aa8967751f51b9b9209593a Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 12 Dec 2025 09:53:26 +0000 Subject: [PATCH 10/44] Make explicit that the tile window argument to load_tile_with_elementwise and the two load methods it uses are tuples --- include/ck_tile/core/tensor/load_tile.hpp | 8 ++++---- include/ck_tile/core/tensor/tile_window.hpp | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index 39249cec9c3..d1c06d43780 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window, * and an elementwise function. For each A = A0, A1… AN, the elementwise function * is additionally applied during a single read. */ -template -CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window, +CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) { // TODO: Tile windows should work with unknown number of params // Load element_wise API works only when the input type is a tuple-type - return tile_window[number<0>{}].load( - tile_window, elementwise, number{}, bool_constant{}); + return tile_windows[number<0>{}].load( + tile_windows, elementwise, number{}, bool_constant{}); } // Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution. diff --git a/include/ck_tile/core/tensor/tile_window.hpp b/include/ck_tile/core/tensor/tile_window.hpp index d39da82a627..da90675fdd4 100644 --- a/include/ck_tile/core/tensor/tile_window.hpp +++ b/include/ck_tile/core/tensor/tile_window.hpp @@ -182,11 +182,11 @@ struct tile_window_with_static_distribution * The same thread, during vectorized reading, accesses the same set of * data from A0, A1, A2, … AN. */ - template - CK_TILE_DEVICE auto load(const TileWindow_& tile_window, + CK_TILE_DEVICE auto load(const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -194,7 +194,7 @@ struct tile_window_with_static_distribution constexpr auto tile_dstr = typename Base::TileDstr{}; auto dst_tensor = make_static_distributed_tensor(tile_dstr); load(dst_tensor, - tile_window, + tile_windows, elementwise, number{}, bool_constant{}); @@ -202,12 +202,12 @@ struct tile_window_with_static_distribution } template CK_TILE_DEVICE void load(DistributedTensor& dst_tensor, - const TileWindow_& tile_window, + const ck_tile::tuple& tile_windows, ElementWise_ elementwise, number = {}, bool_constant = {}) const @@ -218,14 +218,14 @@ struct tile_window_with_static_distribution using SFC_Ys = typename Traits::SFC_Ys; constexpr auto tile_dstr = typename Base::TileDstr{}; - constexpr auto sizeOfTuple = TileWindow_::size(); + constexpr auto sizeOfTuple = remove_cvref_t::size(); // loop over thread tensor space [y0, y1, ...] static_for<0, NumCoord, 1>{}([&](auto iCoord) { /// TODO: use structure binding (to be captured later) if compiled in C++20 auto window_adaptor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I0]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0]; auto bottom_tensor_thread_coord = - tile_window[number<0>{}].pre_computed_coords_[iCoord][I1]; + tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1]; static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) { constexpr auto iAccess = number{}; @@ -236,7 +236,7 @@ struct tile_window_with_static_distribution // read from bottom tensor const auto idx_vec_value = generate_tuple( [&](auto jj) { - return tile_window[number{}] + return tile_windows[number{}] .get_bottom_tensor_view() .template get_vectorized_elements( bottom_tensor_thread_coord, From 9633d3f5bb3b6eea369f678ebfa1187481d6907c Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 17 Dec 2025 14:41:31 +0000 Subject: [PATCH 11/44] In GetAWindows and GetBWindows, use DataType from LDS tensor view --- .../gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 74632ee5b02..3bc1f7c095b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -242,7 +242,7 @@ struct GemmPipelineAgBgCrImplBase return make_static_tile_distribution( typename InputTileDistributionTraits< typename ALdsLoadTileDistr::DstrEncode, - typename Problem::ADataType>::TransposedDstrEncode{}); + typename ALdsTensorView::DataType>::TransposedDstrEncode{}); } else { @@ -318,17 +318,13 @@ struct GemmPipelineAgBgCrImplBase auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}); - using BLdsDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - auto b_lds_load_tile_distr = []() { if constexpr(is_b_load_tr) { return make_static_tile_distribution( - typename InputTileDistributionTraits::TransposedDstrEncode{}); + typename InputTileDistributionTraits< + typename BLdsLoadTileDistr::DstrEncode, + typename BLdsTensorView::DataType>::TransposedDstrEncode{}); } else { From 9af4498194004f865572e53d8bf0ec771eb17151 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 7 Jan 2026 13:48:38 +0000 Subject: [PATCH 12/44] Remove the defaults for SrcDataType and DstDataType in GemmPipelineAgBgCrImplBase::GlobalPrefetch --- .../gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 4 ++-- .../pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 12 ++++++------ .../pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 12 ++++++------ .../pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 12 ++++++------ .../pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 14 +++++++------- .../pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 12 ++++++------ 6 files changed, 33 insertions(+), 33 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 3bc1f7c095b..6959e9e05a6 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,8 +64,8 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template ( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); - Base::GlobalPrefetch( + Base::template GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -436,10 +436,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); - Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); @@ -471,10 +471,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); - Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); block_gemm(c_block_tile, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 2433563ef04..1742dbf6394 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -286,9 +286,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // Global prefetch initialization - DRAM to VGPRs LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch( + Base::template GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); - Base::GlobalPrefetch( + Base::template GlobalPrefetch( aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -321,10 +321,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), + Base::template GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); - Base::GlobalPrefetch(aq_block_tiles.get(number{}), + Base::template GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); }); @@ -381,10 +381,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); - Base::GlobalPrefetch(b_block_tiles.get(number{}), + Base::template GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); - Base::GlobalPrefetch(aq_block_tiles.get(number{}), + Base::template GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 0a4793bb125..114a89d95f4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -277,8 +277,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -309,7 +309,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); @@ -352,8 +352,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); @@ -379,7 +379,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); block_gemm( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 140e6be79c6..34dd7ba6ad7 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -203,7 +203,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); } } @@ -313,10 +313,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::GlobalPrefetch( + Base::template GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -346,7 +346,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); @@ -390,10 +390,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); @@ -419,7 +419,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); block_gemm( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index b63a3124896..e8af4bd8938 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -419,8 +419,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); // BDataType auto b_block_tile = make_static_distributed_tensor( Policy::template MakeBRegTileDistribution()); @@ -480,8 +480,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); @@ -544,8 +544,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch( b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); bq_block_tile = load_tile(bq_copy_dram_window); From 514035e6cf747d096a42b197e051b9758b06f5f6 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 7 Jan 2026 14:33:24 +0000 Subject: [PATCH 13/44] In BQuantGemmPipelineAgBgCrCompV3, always convert BDatatype pk_int4_t to ADataType regardless of BLayout --- .../gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 34dd7ba6ad7..e0e67355e44 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -40,8 +40,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 && - std::is_same_v, + std::conditional_t, ADataType, BDataType>; From 3d55a1e6828e3f57b0e45a1bd4e4bc7ab6fec024 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Tue, 16 Dec 2025 10:24:23 +0000 Subject: [PATCH 14/44] No need to specify SrcDataType in load_and_convert_tile as WarpWindow knows its DataType --- .../ops/common/load_and_convert_tile.hpp | 5 ++-- .../block/block_universal_gemm_as_bs_cr.hpp | 24 +++++++++---------- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 6 ++--- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 8 +++---- ..._universal_gemm_as_aquant_bs_bquant_cr.hpp | 4 ++-- .../block_universal_gemm_as_aquant_bs_cr.hpp | 8 +++---- .../block_universal_gemm_as_bs_bquant_cr.hpp | 8 +++---- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 20 +++++++--------- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 16 ++++++------- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 16 ++++++------- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 18 +++++++------- .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 12 +++++----- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 8 +++---- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 10 ++++---- 14 files changed, 75 insertions(+), 88 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index f2ee23e98bd..b268ef68119 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -28,15 +28,14 @@ struct ConverterLoader } }; -template CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); ConverterLoader::load_interleaved_pk_type(dst, src); diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 040051a5e8d..a22b0dcf652 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -228,10 +228,10 @@ struct BlockUniversalGemmAsBsCr "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - load_and_convert_tile( - a_warp_tile_, a_block_window); - load_and_convert_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, + a_block_window); + load_and_convert_tile(b_warp_tile_, + b_block_window); // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -294,10 +294,10 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_and_convert_tile( - a_warp_tile_, a_block_window); - load_and_convert_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, + a_block_window); + load_and_convert_tile(b_warp_tile_, + b_block_window); } // C += A * B @@ -425,10 +425,10 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_and_convert_tile( - a_warp_tile_, a_lds_gemm_window); - load_and_convert_tile( - b_warp_tile_, b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, + a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, + b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index 6959e9e05a6..e0556f6a6a5 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,8 +64,7 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template (dst_block_tile, - dram_tile_window); + load_and_convert_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 9939a9586e7..3f1e8dfc814 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -627,7 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // // Prefetch A0 Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); // Prefill A0 @@ -652,7 +652,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 do { { - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); Base::GlobalPrefetch( @@ -666,7 +666,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 HotLoopScheduler(); } { - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); Base::GlobalPrefetch( @@ -687,7 +687,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 if constexpr(TailNum == TailNumber::Even) { { - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); block_weight_preshuffle( diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 132b31ed620..5b4056e699f 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -261,10 +261,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_and_convert_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_and_convert_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index b40168b2afb..ea411441ff0 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,10 +248,10 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_and_convert_tile( - a_warp_tile_, a_block_window); - load_and_convert_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(a_warp_tile_, + a_block_window); + load_and_convert_tile(b_warp_tile_, + b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index ece393b40d7..cddc8b0dcd6 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -258,11 +258,11 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_and_convert_tile( - a_warp_tile_, a_block_window); + load_and_convert_tile(a_warp_tile_, + a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_and_convert_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(b_warp_tile_, + b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index bcb38bd3e87..2ff477e5ecd 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -198,10 +198,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, - a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template @@ -209,10 +207,8 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, - b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template ( + Base::template GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -436,10 +432,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); - Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); @@ -471,10 +467,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); - Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); block_gemm(c_block_tile, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index 1742dbf6394..fc1d14a7371 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -175,10 +175,8 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem const DramTileWindowStep& dram_tile_window_step) { using DestDataType = typename ABlockTile_::DataType; - using SrcDataType = typename ADramWindow::Base::TileWindowBase::DataType; constexpr index_t UnaryOpSize = 8; - load_and_convert_tile(a_block_tile, - a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } @@ -286,9 +284,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // Global prefetch initialization - DRAM to VGPRs LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -321,10 +319,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_block_tiles.get(number{}), + Base::template GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(aq_block_tiles.get(number{}), + Base::template GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); }); @@ -381,10 +379,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_block_tiles.get(number{}), + Base::template GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(aq_block_tiles.get(number{}), + Base::template GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 114a89d95f4..e2d9f502993 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -169,10 +169,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, - a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template (b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -309,7 +307,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); @@ -352,8 +350,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); @@ -379,7 +377,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); block_gemm( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index e0e67355e44..7ef79bd32d6 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -183,10 +183,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, - b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template @@ -202,7 +200,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); } } @@ -312,10 +310,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -345,7 +343,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); @@ -389,10 +387,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); @@ -418,7 +416,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(bq_block_tile[(currIdx + 1) % 2], + Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); block_gemm( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index e8af4bd8938..7d090a788a1 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -419,8 +419,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); // BDataType auto b_block_tile = make_static_distributed_tensor( Policy::template MakeBRegTileDistribution()); @@ -480,8 +480,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::template GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); @@ -544,8 +544,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::template GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::template GlobalPrefetch( b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); bq_block_tile = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 49064bdb763..f3886146ee4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -349,7 +349,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -430,7 +430,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -455,7 +455,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -503,7 +503,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 5455944de0c..3ba7064d09b 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -335,8 +335,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -421,7 +421,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -458,7 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -516,7 +516,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); From 63a455952a843c1f02580b8f858c86cd93dce012 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Tue, 16 Dec 2025 11:32:30 +0000 Subject: [PATCH 15/44] No need to specify DstDataType in load_and_convert_tile as WarpTile knows its DataType --- .../ops/common/load_and_convert_tile.hpp | 9 +++------ .../block/block_universal_gemm_as_bs_cr.hpp | 18 ++++++------------ .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 5 ++--- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 9 ++++----- ...k_universal_gemm_as_aquant_bs_bquant_cr.hpp | 4 ++-- .../block_universal_gemm_as_aquant_bs_cr.hpp | 6 ++---- .../block_universal_gemm_as_bs_bquant_cr.hpp | 6 ++---- .../gemm_abquant_pipeline_ag_bg_cr_v3.hpp | 18 ++++++++---------- .../gemm_aquant_pipeline_ag_bg_cr_mem.hpp | 15 +++++++-------- .../gemm_aquant_pipeline_ag_bg_cr_v3.hpp | 15 +++++++-------- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 17 ++++++++--------- .../gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp | 12 ++++++------ .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 8 ++++---- .../gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp | 16 ++++++++-------- 14 files changed, 69 insertions(+), 89 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index b268ef68119..eb22fbb5a27 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -28,17 +28,14 @@ struct ConverterLoader } }; -template +template CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) { if constexpr(std::is_same_v) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - ConverterLoader::load_interleaved_pk_type(dst, src); + ConverterLoader::load_interleaved_pk_type(dst, + src); } else if constexpr(LoadTranspose) { diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index a22b0dcf652..381a5513eab 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -228,10 +228,8 @@ struct BlockUniversalGemmAsBsCr "The ADataType and BDataType as defined in " "traits should be the same as correspoinding block window data type!"); - load_and_convert_tile(a_warp_tile_, - a_block_window); - load_and_convert_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); // hot loop: static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) { static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { @@ -294,10 +292,8 @@ struct BlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_and_convert_tile(a_warp_tile_, - a_block_window); - load_and_convert_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B @@ -425,10 +421,8 @@ struct BlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_and_convert_tile(a_warp_tile_, - a_lds_gemm_window); - load_and_convert_tile(b_warp_tile_, - b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index e0556f6a6a5..ffac45bed41 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -64,8 +64,7 @@ struct GemmPipelineAgBgCrImplBase CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; } - template @@ -73,7 +72,7 @@ struct GemmPipelineAgBgCrImplBase SrcTileWindow& dram_tile_window, const DramTileWindowStep& dram_tile_window_step) const { - load_and_convert_tile(dst_block_tile, dram_tile_window); + load_and_convert_tile(dst_block_tile, dram_tile_window); move_tile_window(dram_tile_window, dram_tile_window_step); } diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 3f1e8dfc814..93999757b07 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -627,8 +627,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 // // Prefetch A0 Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( - b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); // Prefill A0 Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); @@ -652,7 +651,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 do { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); Base::GlobalPrefetch( @@ -666,7 +665,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 HotLoopScheduler(); } { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile); Base::GlobalPrefetch( @@ -687,7 +686,7 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 if constexpr(TailNum == TailNumber::Even) { { - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step); Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile); block_weight_preshuffle( diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 5b4056e699f..12b699a9934 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -261,10 +261,10 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_and_convert_tile( + load_and_convert_tile( a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_and_convert_tile( + load_and_convert_tile( b_warp_tile_, b_block_window); } diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index ea411441ff0..361855e722f 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -248,10 +248,8 @@ struct AQuantBlockUniversalGemmAsBsCr // while ADatatype might not be the same as BDataType at the time of problem // initialization, we can safely use BDataType here because when A would be int4 we will // ensure A is converted to BDataType prior to loading - load_and_convert_tile(a_warp_tile_, - a_block_window); - load_and_convert_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index cddc8b0dcd6..18452e6ffa2 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -258,11 +258,9 @@ struct BQuantBlockUniversalGemmAsBsCr bool_constant = {}, bool_constant = {}) { - load_and_convert_tile(a_warp_tile_, - a_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_and_convert_tile(b_warp_tile_, - b_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp index 2ff477e5ecd..afe68370249 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_abquant_pipeline_ag_bg_cr_v3.hpp @@ -197,18 +197,16 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template CK_TILE_DEVICE static void LoadAndConvertBTile(BBlockTile_& b_block_tile, const BDramWindow& b_dram_window) { - using DestDataType = typename BBlockTile_::DataType; constexpr index_t UnaryOpSize = 8; - load_and_convert_tile(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template ( + Base::GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -432,10 +430,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); - Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); @@ -467,10 +465,10 @@ struct ABQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); - Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); block_gemm(c_block_tile, diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp index fc1d14a7371..7c1d841e4fd 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_mem.hpp @@ -174,9 +174,8 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem ADramWindow& a_dram_window, const DramTileWindowStep& dram_tile_window_step) { - using DestDataType = typename ABlockTile_::DataType; constexpr index_t UnaryOpSize = 8; - load_and_convert_tile(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); move_tile_window(a_dram_window, dram_tile_window_step); } @@ -284,9 +283,9 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem // Global prefetch initialization - DRAM to VGPRs LoadAndConvertATile( a_block_tiles.get(I0{}), a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::GlobalPrefetch( b_block_tiles.get(I0{}), b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::GlobalPrefetch( aq_block_tiles.get(I0{}), aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -319,10 +318,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_block_tiles.get(number{}), + Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(aq_block_tiles.get(number{}), + Base::GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); }); @@ -379,10 +378,10 @@ struct AQuantGemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem LoadAndConvertATile(a_block_tiles.get(number{}), a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_block_tiles.get(number{}), + Base::GlobalPrefetch(b_block_tiles.get(number{}), b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(aq_block_tiles.get(number{}), + Base::GlobalPrefetch(aq_block_tiles.get(number{}), aq_copy_dram_window, aq_dram_tile_window_step); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index e2d9f502993..598472c92e4 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -168,9 +168,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_dram_window); + load_and_convert_tile(a_block_tile, a_dram_window); } template (b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch( aq_block_tile[currIdx], aq_copy_dram_window, aq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -307,7 +306,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); block_sync_lds(); @@ -350,8 +349,8 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); @@ -377,7 +376,7 @@ struct AQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(aq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(aq_block_tile[(currIdx + 1) % 2], aq_copy_dram_window, aq_dram_tile_window_step); block_gemm( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 7ef79bd32d6..a5ad0816cc6 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -182,9 +182,8 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_dram_window); + load_and_convert_tile(b_block_tile, b_dram_window); } template @@ -200,7 +199,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); } } @@ -310,10 +309,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::GlobalPrefetch( bq_block_tile[currIdx], bq_copy_dram_window, bq_dram_tile_window_step); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); @@ -343,7 +342,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); @@ -387,10 +386,10 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); // B tile gets converted to A datatype during loading BGlobalPrefetch(b_block_tile, b_copy_dram_window, b_dram_tile_window_step); - Base::template GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); @@ -416,7 +415,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(bq_block_tile[(currIdx + 1) % 2], + Base::GlobalPrefetch(bq_block_tile[(currIdx + 1) % 2], bq_copy_dram_window, bq_dram_tile_window_step); block_gemm( diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp index 7d090a788a1..b63a3124896 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp @@ -419,8 +419,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); // BDataType auto b_block_tile = make_static_distributed_tensor( Policy::template MakeBRegTileDistribution()); @@ -480,8 +480,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); bq_block_tile = load_tile(bq_copy_dram_window); move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step}); @@ -544,8 +544,8 @@ struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); - Base::template GlobalPrefetch( + Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step); + Base::GlobalPrefetch( b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step); bq_block_tile = load_tile(bq_copy_dram_window); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index f3886146ee4..3f75cdc01de 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -349,7 +349,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -430,7 +430,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -455,7 +455,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -503,7 +503,7 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( + load_and_convert_tile( b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp index 3ba7064d09b..5babb785df2 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp @@ -335,8 +335,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), - b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -421,8 +421,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -458,8 +458,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -516,8 +516,8 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); bq_block_tile_2 = load_tile(bq_copy_dram_window); From 8fc4030a5703624f82e54752eabd598c19215080 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 2 Jan 2026 14:47:32 +0000 Subject: [PATCH 16/44] Add an instance of load_tile_transpose that takes a reference to the output tensor as an input --- .../core/tensor/load_tile_transpose.hpp | 78 ++++++++++++++++--- 1 file changed, 66 insertions(+), 12 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 0ac2ded5f6a..0266fc653f0 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -373,6 +373,7 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding() * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -380,18 +381,19 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding() * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * @param offset The offset (in elements) added to the base address before * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ template < + typename DistributedTensor_, typename BottomTensorView_, typename WindowLengths_, typename TileDistribution_, @@ -401,21 +403,17 @@ template < typename BottomTensorView_::DataType, Policy>::distr_encoding_valid, Policy>> -CK_TILE_DEVICE auto load_tile_transpose_with_offset( +CK_TILE_DEVICE void load_tile_transpose_with_offset( + DistributedTensor_& out_tensor, const tile_window_with_static_distribution& __restrict__ tile_window, index_t offset) { - using OutTileDstrEncode = typename OutputTileDistributionTraits< - typename TileDistribution_::DstrEncode, - typename BottomTensorView_::DataType>::TransposedDstrEncode; - auto out_tensor = make_static_distributed_tensor( - make_static_tile_distribution(OutTileDstrEncode{})); auto trans_tensor = tile_window.template load_transpose_with_offset(offset); constexpr auto input_distr = TileDistribution_{}; - constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{}); + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); @@ -442,6 +440,32 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( number{}, trans_tensor.get_thread_buffer().template get_as(number{})); }); +} + +template < + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE auto load_tile_transpose_with_offset( + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) +{ + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, offset); return out_tensor; } @@ -455,6 +479,7 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * element space size and vector length remain consistent between the input and output * distributions. * + * @tparam DistributedTensor_ The type of the tensor containing the transposed tile data. * @tparam BottomTensorView_ The type of the bottom tensor view. * @tparam WindowLengths_ The type representing the window lengths. * @tparam TileDistribution_ The type representing the tile distribution. @@ -462,16 +487,37 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset( * @tparam Policy The transpose policy to use (defaults to DefaultTranspose). * the last is SFINAE to ensure the tile distribution encoding is valid. * + * @param out_tensor A statically distributed tensor containing the transposed tile + * data. * @param tile_window The tile window with static distribution to load and transpose. * indexing. * - * @return A statically distributed tensor containing the transposed tile data. - * * @note * - The function uses compile-time checks to ensure the input and output tile distributions * are compatible in terms of element space size and vector length. * - The transpose operation is performed according to the specified Policy. */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void +load_tile_transpose(DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_with_offset(out_tensor, tile_window, 0); +} + template < typename BottomTensorView_, typename WindowLengths_, @@ -488,7 +534,15 @@ load_tile_transpose(const tile_window_with_static_distribution& __restrict__ tile_window) { - return load_tile_transpose_with_offset(tile_window, 0); + using OutTileDstrEncode = typename OutputTileDistributionTraits< + typename TileDistribution_::DstrEncode, + typename BottomTensorView_::DataType>::TransposedDstrEncode; + auto out_tensor = make_static_distributed_tensor( + make_static_tile_distribution(OutTileDstrEncode{})); + + load_tile_transpose_with_offset(out_tensor, tile_window, 0); + + return out_tensor; } } // namespace ck_tile From 321611081f9731367193d646111827222d6f065a Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 2 Jan 2026 15:41:54 +0000 Subject: [PATCH 17/44] Remove an unused overload of load_tile_transpose_with_offset --- .../core/tensor/load_tile_transpose.hpp | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 0266fc653f0..9288d740380 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -442,34 +442,6 @@ CK_TILE_DEVICE void load_tile_transpose_with_offset( }); } -template < - typename BottomTensorView_, - typename WindowLengths_, - typename TileDistribution_, - index_t NumCoord, - typename Policy = DefaultTranspose, - typename = std::enable_if_t::distr_encoding_valid, - Policy>> -CK_TILE_DEVICE auto load_tile_transpose_with_offset( - const tile_window_with_static_distribution& __restrict__ tile_window, - index_t offset) -{ - using OutTileDstrEncode = typename OutputTileDistributionTraits< - typename TileDistribution_::DstrEncode, - typename BottomTensorView_::DataType>::TransposedDstrEncode; - auto out_tensor = make_static_distributed_tensor( - make_static_tile_distribution(OutTileDstrEncode{})); - - load_tile_transpose_with_offset(out_tensor, tile_window, offset); - - return out_tensor; -} - /** * @brief transpose loads tile from a tensor and returns the resulting tensor with a new * (transposed) tile distribution. use SFINAE to ensure the tile distribution encoding is valid. From ca17ac3358d75ef09da00c9a17dfd244df58ef5c Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 2 Jan 2026 15:43:35 +0000 Subject: [PATCH 18/44] When possible, use the overload of load_tile_transpose that does not require assignment --- .../ops/common/load_and_convert_tile.hpp | 2 +- ..._bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp | 8 ++++---- ...bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp | 18 +++++++++--------- .../pipeline/block_fmha_fwd_v3_pipeline.hpp | 2 +- ...ock_fmha_pipeline_qr_ks_vs_async_trload.hpp | 2 +- .../pipeline/gemm_pipeline_ag_bg_cr_base.hpp | 2 +- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index eb22fbb5a27..4e05ecc59c8 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -39,7 +39,7 @@ CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) } else if constexpr(LoadTranspose) { - dst = load_tile_transpose(src); + load_tile_transpose(dst, src); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp index 9aeabaa8c22..16212c0d130 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_kr_ktr_vr.hpp @@ -530,7 +530,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR s_acc = gemm_0(q_reg_tensor, k_reg_tensor); dot_lds_read_window.set_bottom_tensor_view_data_ptr(do_lds_ptr_curr); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); } if constexpr(is_epilogue) { @@ -634,7 +634,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); qt_lds_read_window.set_bottom_tensor_view_data_ptr(q_lds_ptr_curr); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); // STAGE 3, P^T@OGrad^T Gemm1 auto pt_reg_tensor = make_static_distributed_tensor( @@ -715,7 +715,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR } if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -728,7 +728,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadKRKTRVR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp index 3d21928cedf..37b4ae41a3f 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_trload_qr_qtr_dor.hpp @@ -455,10 +455,10 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(q_lds_write_window, q_dram_window); async_load_tile(do_lds_write_window, do_dram_window); __builtin_amdgcn_s_waitcnt(0); - qt_reg_tensor = load_tile_transpose(qt_lds_read_window); - q_reg_tensor = load_tile(q_lds_read_window); - dot_reg_tensor = load_tile_transpose(dot_lds_read_window); - do_reg_tensor = load_tile(do_lds_read_window); + load_tile_transpose(qt_reg_tensor, qt_lds_read_window); + q_reg_tensor = load_tile(q_lds_read_window); + load_tile_transpose(dot_reg_tensor, dot_lds_read_window); + do_reg_tensor = load_tile(do_lds_read_window); lse_block_tile = load_tile(lse_dram_window); d_block_tile = load_tile(d_dram_window); @@ -490,9 +490,9 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR async_load_tile(v_lds_write_window, v_dram_window); move_tile_window(v_dram_window, {kN0, 0}); s_waitcnt(); - k_reg_tensor = load_tile(k_lds_read_window); - v_reg_tensor = load_tile(v_lds_read_window); - kt_reg_tensor = load_tile_transpose(kt_lds_read_window); + k_reg_tensor = load_tile(k_lds_read_window); + v_reg_tensor = load_tile(v_lds_read_window); + load_tile_transpose(kt_reg_tensor, kt_lds_read_window); } if constexpr(is_epilogue) { @@ -668,7 +668,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR block_sync_lds(); if constexpr(is_epilogue) { - ds_reg_tensor = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } if constexpr(is_main_body) @@ -680,7 +680,7 @@ struct BlockFmhaBwdDQDKDVPipelineTrLoadQRQTRDOR static_for<0, k4_loops, 1>{}([&](auto i_k4) { if constexpr(i_k4 < k4_loops - 1) { - ds_reg_tensor_next = load_tile_transpose(ds_lds_read_window); + load_tile_transpose(ds_reg_tensor_next, ds_lds_read_window); move_tile_window(ds_lds_read_window, {kK4, 0}); } auto kt_reg_tensor_slice = get_slice_tile( // diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp index c25f57632fa..4cca604ff15 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp @@ -718,7 +718,7 @@ struct BlockFmhaFwdV3Pipeline }; auto V_lds_load = [&](auto v_lds_read_idx) { - kv_tile.v_tile = load_tile_transpose(v_lds_window_load(v_lds_read_idx)); + load_tile_transpose(kv_tile.v_tile, v_lds_window_load(v_lds_read_idx)); }; decltype(m) m_old; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp index 26662dafeb9..3e958ea5317 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp @@ -567,7 +567,7 @@ struct BlockFmhaPipelineQRKSVSAsyncTrload // loop over along the [V]alue Sequence length move_tile_window(v_lds_read_window, {kK1, 0}); - v_tile = load_tile_transpose(v_lds_read_window); + load_tile_transpose(v_tile, v_lds_read_window); }); // move back to the origin move_tile_window(v_lds_read_window, {-kK1 * (k1_loops - 1), 0}); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp index ffac45bed41..358101d1db1 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp @@ -107,7 +107,7 @@ struct GemmPipelineAgBgCrImplBase bool_constant = {}) const { if constexpr(LoadTranspose) - dst_block_tile = load_tile_transpose(lds_tile_window); + load_tile_transpose(dst_block_tile, lds_tile_window); else load_tile(dst_block_tile, lds_tile_window); } From 2edd077b50f552da936f30b4977de0fe6353a405 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 7 Jan 2026 16:19:35 +0000 Subject: [PATCH 19/44] Adjust whitespace with clang-format --- ...ock_universal_gemm_as_aquant_bs_bquant_cr.hpp | 6 ++---- .../gemm_bquant_pipeline_ag_bg_cr_v3.hpp | 4 +--- .../gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp | 16 ++++++++-------- 3 files changed, 11 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp index 12b699a9934..3e16a078e2c 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_bquant_cr.hpp @@ -261,11 +261,9 @@ struct ABQuantBlockUniversalGemmAsBsCr : public BlockGemmQuantBase bool_constant = {}, bool_constant = {}) { - load_and_convert_tile( - a_warp_tile_, a_block_window); + load_and_convert_tile(a_warp_tile_, a_block_window); // If B datatype were pkint4 it would be converted prior to storing in LDS - load_and_convert_tile( - b_warp_tile_, b_block_window); + load_and_convert_tile(b_warp_tile_, b_block_window); } // C += A * B diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index a5ad0816cc6..deebd4d6c63 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -40,9 +40,7 @@ struct BQuantGemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3, - ADataType, - BDataType>; + std::conditional_t, ADataType, BDataType>; static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!"); using I0 = number<0>; diff --git a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp index 3f75cdc01de..e3b82940d64 100644 --- a/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_wp_abquant_pipeline_ag_bg_cr_v2.hpp @@ -349,8 +349,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -430,8 +430,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -455,8 +455,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_ping(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); move_tile_window(b_flat_dram_window, {0, BlockGemmShape::flatKPerBlock}); @@ -503,8 +503,8 @@ struct WPABQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRe move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * flatNPerWarp, kIter * flatKPerWarp}); - load_and_convert_tile( - b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); + load_and_convert_tile(b_warp_tensor_pong(nIter)(kIter), + b_flat_dram_windows(nIter)(kIter)); }); }); aq_block_tile_2 = load_tile(aq_copy_dram_window); From c020a4279708fd0e0c9d4f8b915fe8af33993c8f Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 14 Jan 2026 17:44:48 +0000 Subject: [PATCH 20/44] Fix a build break introduced when merging --- .../gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index 4bb7829fb64..b1d66540c57 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -245,7 +245,7 @@ struct UniversalGemmBasePolicy } else // A is in RowMajor { - constexpr index_t KPack = GetSmemPackA(); + constexpr index_t KPack = Derived::template GetSmemPackA(); constexpr auto DataTypeSize = sizeof(ADataType); constexpr uint64_t MinLdsLayer = 1ULL; constexpr auto MLdsLayer = From fc1b683d18c09976b63fd25f9c3f0b6f2779e10c Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 28 Jan 2026 15:37:13 +0000 Subject: [PATCH 21/44] Fix a build break --- .../block/block_universal_gemm_as_aquant_bs_cr.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 39cd36029dc..ed5a613d07a 100644 --- a/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -393,10 +393,8 @@ struct AQuantBlockUniversalGemmAsBsCr auto b_lds_gemm_window = make_tile_window( b_block_window.get_bottom_tensor_view(), b_lds_shape, b_offset, b_lds_load_distr); - load_int4_tile( - a_warp_tile_, a_lds_gemm_window); - load_int4_tile( - b_warp_tile_, b_lds_gemm_window); + load_and_convert_tile(a_warp_tile_, a_lds_gemm_window); + load_and_convert_tile(b_warp_tile_, b_lds_gemm_window); } // C += A * B with quantization support From 9185c25573b4930c79ffc69c58d2987d53d97548 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Mon, 12 Jan 2026 09:42:03 +0000 Subject: [PATCH 22/44] Rename the parameters of load_interleaved_pk_type and load_and_convert_tile --- .../ck_tile/ops/common/load_and_convert_tile.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index 4e05ecc59c8..82a4f2d9d27 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -12,38 +12,38 @@ template struct ConverterLoader { template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src) + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src_window) { static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto tmp = load_tile(src); + const auto src = load_tile(src_window); using DstVectorType = DstDataType __attribute__((ext_vector_type(UnaryOpSize))); static_for<0, thread_buffer_size, 1>{}([&](auto i) { const element_wise::PassThroughPack8 elementwise_op{}; elementwise_op(dst.get_thread_buffer().template get_as()(i), - tmp.get_thread_buffer().template get_as()[i]); + src.get_thread_buffer().template get_as()[i]); }); } }; template -CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src) +CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src_window) { if constexpr(std::is_same_v) { static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); - ConverterLoader::load_interleaved_pk_type(dst, - src); + ConverterLoader::load_interleaved_pk_type( + dst, src_window); } else if constexpr(LoadTranspose) { - load_tile_transpose(dst, src); + load_tile_transpose(dst, src_window); } else { - load_tile(dst, src); + load_tile(dst, src_window); } } From e1b8f6ca769336e923814c63df52c0012627ee10 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 28 Nov 2025 09:20:19 +0000 Subject: [PATCH 23/44] Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- .../block/block_universal_gemm_as_bs_cr.hpp | 16 +++- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 9 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 86 ++++++++++++++----- .../ck_tile/ops/gemm/warp/warp_gemm_impl.hpp | 4 +- 4 files changed, 86 insertions(+), 29 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 6fb5cf433b1..7f8f7e59977 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -139,6 +139,7 @@ struct BlockUniversalGemmAsBsCr using I0 = number<0>; using I1 = number<1>; + template CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -158,12 +159,18 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; + using Attr = WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessA = + convert ? Attr::AttrNumAccessV * sizeof(ADataType) / sizeof(ComputeDataType) + : Attr::AttrNumAccessV; constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{}); + a_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_awarp_dstr_encoding()); return a_block_dstr_encode; } + template CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { constexpr index_t KPerThread = Traits::KPerThread; @@ -183,8 +190,13 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; + using Attr = WarpGemm::WarpGemmAttribute; + constexpr auto NumAccessB = + convert ? Attr::AttrNumAccessV * sizeof(BDataType) / sizeof(ComputeDataType) + : Attr::AttrNumAccessV; constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( - b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + b_block_outer_dstr_encoding, + WarpGemm::WarpGemmAttribute::template get_bwarp_dstr_encoding()); return b_block_dstr_encode; } diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp index 8fae7042037..71e69621f5b 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp @@ -440,10 +440,11 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3 auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); // Tile distribution for load from lds - constexpr auto a_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode()); - constexpr auto b_lds_load_tile_distr = - make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode()); + constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v; + constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeABlockDistributionEncode()); + constexpr auto b_lds_load_tile_distr = make_static_tile_distribution( + BlockGemm::template MakeBBlockDistributionEncode()); // A DRAM tile window for load // A LDS tile window for store diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 3c7944a4277..0f91e13bf08 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -44,12 +44,12 @@ struct WarpGemmAttributeMfma static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - template + template static constexpr auto get_warp_dstr_encoding() { - static_assert(kKPerThread % AttrNumAccessV == 0, + static_assert(NumAccess != 0 && kKPerThread % NumAccess == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(NumAccess == 1) return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -61,14 +61,30 @@ struct WarpGemmAttributeMfma return tile_distribution_encoding< sequence<>, tuple, - sequence>, + sequence>, tuple>, tuple>, sequence<2, 2>, sequence<0, 2>>{}; } - using AWarpDstrEncoding = decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = decltype(get_warp_dstr_encoding()); + + template + static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -151,14 +167,14 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - template + template CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { if constexpr(kMNBlock == 1 && kNMBlock == 1) { static_assert(kKPerThread % AttrNumAccessV == 0, "kKPerThread must be divisible by NumAccess"); - if constexpr(AttrNumAccessV == 1) + if constexpr(NumAccess == 1) return tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -172,7 +188,7 @@ struct WarpGemmAttributeMfmaIterateK tuple, sequence>, + Impl::kABKPerLane * kKIter / NumAccess>>, tuple>, tuple>, sequence<2, 2>, @@ -180,7 +196,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(kMNBlock == 1 && 1 < kNMBlock) { - static_assert(AttrNumAccessV == 1, + static_assert(NumAccess == 1, "Multiple access is not supported when using multi-block"); // each M/N blocks share the same data return tile_distribution_encoding< @@ -193,7 +209,7 @@ struct WarpGemmAttributeMfmaIterateK } else if constexpr(1 < kMNBlock && kNMBlock == 1) { - static_assert(AttrNumAccessV == 1, + static_assert(NumAccess == 1, "Multiple access is not supported when using multi-block"); // single block to multi-block thread mapping return tile_distribution_encoding< @@ -207,6 +223,18 @@ struct WarpGemmAttributeMfmaIterateK } } + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return get_warp_dstr_encoding(); + } + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() { if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) @@ -245,10 +273,12 @@ struct WarpGemmAttributeMfmaIterateK } } - using AWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); - using BWarpDstrEncoding = - decltype(get_warp_dstr_encoding()); + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -327,10 +357,23 @@ struct WarpGemmAttributeMfmaTransposedCDistribution static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); - using AWarpDstrEncoding = - typename WarpGemmAttributeMfma::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfma::AWarpDstrEncoding; + template + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_bwarp_dstr_encoding(); + } + + template + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + return WarpGemmAttributeMfma::template get_awarp_dstr_encoding(); + } + + template + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + template + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); using CWarpDstrEncoding = tile_distribution_encoding< sequence<>, @@ -459,8 +502,9 @@ template struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution { - using Impl = remove_cvref_t; - static constexpr auto AttrNumAccess = AttrNumAccess_; + using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = AttrNumAccess_; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // swap A and B using ADataType = typename Impl::BDataType; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp index ca7c32b6af5..5ff0660f49c 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp @@ -25,8 +25,8 @@ struct WarpGemmImpl using BDataType = typename WarpGemmAttribute::BDataType; using CDataType = typename WarpGemmAttribute::CDataType; - using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding; - using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding; + using AWarpDstrEncoding = typename WarpGemmAttribute::template AWarpDstrEncoding<>; + using BWarpDstrEncoding = typename WarpGemmAttribute::template BWarpDstrEncoding<>; using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding; using AWarpDstr = remove_cvref_t; From 5744562a5b2e8c53355f5018649bba214e4b2aa7 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 9 Oct 2025 08:07:04 +0000 Subject: [PATCH 24/44] Introduce DetermineWarpPrecType for determining warp GEMM precision types --- include/ck_tile/ops/common.hpp | 1 + .../ops/common/determine_warp_prec_type.hpp | 41 +++++++++++++++++++ .../ops/epilogue/cshuffle_epilogue.hpp | 14 ++----- .../block/block_universal_gemm_as_bs_cr.hpp | 8 +--- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 12 ++---- 5 files changed, 51 insertions(+), 25 deletions(-) create mode 100644 include/ck_tile/ops/common/determine_warp_prec_type.hpp diff --git a/include/ck_tile/ops/common.hpp b/include/ck_tile/ops/common.hpp index ad7da5c1833..0113d8c9a28 100644 --- a/include/ck_tile/ops/common.hpp +++ b/include/ck_tile/ops/common.hpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #pragma once +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp new file mode 100644 index 00000000000..2e28073c9c8 --- /dev/null +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -0,0 +1,41 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" + +// DetermineWarpPrecType is a set of rules to determine the right precision type to use +// for the warp GEMM, given the other precision type. This gives rise to a type conversion: +// type conversions are sometimes needed to obtain a pair of types that are compatible with +// the hardware matrix operations available. A typical use case is mixed precision GEMMs. + +namespace ck_tile { +// For the most general case, we default to no conversion. +template +struct DetermineWarpPrecType +{ + using prec_type = PrecType; +}; + +// For pk_int4_t, we convert to the other precision type. +template +struct DetermineWarpPrecType +{ + using prec_type = OtherPrecType; +}; + +// For pk_fp4_t, we convert to the other precision type. +template +struct DetermineWarpPrecType +{ + using prec_type = OtherPrecType; +}; + +// For pk_fp4_raw_t, we convert to the other precision type. +template +struct DetermineWarpPrecType +{ + using prec_type = OtherPrecType; +}; +}; // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 4f636b59625..36addf625cd 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -7,7 +7,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common/utils.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" #include @@ -92,16 +92,8 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = std::conditional_t || - std::is_same_v, - BDataType, - ADataType>; - // Used for weight-only quantization kernel, B would be dequantized to the same data type as A - using BTypeToUse = std::conditional_t || - std::is_same_v || - std::is_same_v, - ADataType, - BDataType>; + using ATypeToUse = typename DetermineWarpPrecType::prec_type; + using BTypeToUse = typename DetermineWarpPrecType::prec_type; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 7f8f7e59977..6956bfa8c80 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -94,12 +94,8 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; - using BTypeToUse = std::conditional_t || - std::is_same_v, - ADataType, - BDataType>; + using ATypeToUse = typename DetermineWarpPrecType::prec_type; + using BTypeToUse = typename DetermineWarpPrecType::prec_type; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index fae37010492..c900cd6ca59 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -895,14 +895,10 @@ struct UniversalGemmPipelineAgBgCrPolicy : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad : WGAttrNumAccessEnum::Invalid; - using ADataType = remove_cvref_t; - using BDataType = remove_cvref_t; - using ATypeToUse = - std::conditional_t, BDataType, ADataType>; - using BTypeToUse = std::conditional_t || - std::is_same_v, - ADataType, - BDataType>; + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using ATypeToUse = typename DetermineWarpPrecType::prec_type; + using BTypeToUse = typename DetermineWarpPrecType::prec_type; using WarpGemm = WarpGemmDispatcher Date: Wed, 12 Nov 2025 09:04:15 +0000 Subject: [PATCH 25/44] Add and use load_with_type_convert --- .../ops/common/load_and_convert_tile.hpp | 49 ++++++++++++++++--- 1 file changed, 41 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index 82a4f2d9d27..afb3975acc0 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -8,12 +8,13 @@ namespace ck_tile { -template +template struct ConverterLoader { template CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& dst, const WarpWindow& src_window) { + static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; const auto src = load_tile(src_window); @@ -26,6 +27,43 @@ struct ConverterLoader src.get_thread_buffer().template get_as()[i]); }); } + + template + CK_TILE_DEVICE static void load_with_type_convert(WarpTile& dst, const WarpWindow& src_window) + { + if constexpr(LoadTranspose) + { + static_assert(sizeof(SrcDataType) == sizeof(DstDataType), + "SrcDataType and DstDataType must have the same sizes."); + if constexpr(std::is_same_v) + { + dst = load_tile_transpose(src_window); + } + else + { + auto tmp = load_tile_transpose(src_window); + sweep_tile([&](auto i) { + element_wise::PassThrough elementwise_op{}; + elementwise_op(dst(i), tmp(i)); + }); + } + } + else + { + if constexpr(std::is_same_v) + { + load_tile(dst, src_window); + } + else + { + auto tmp = load_tile(src_window); + sweep_tile([&](auto i) { + element_wise::PassThrough elementwise_op{}; + elementwise_op(dst(i), tmp(i)); + }); + } + } + } }; template @@ -33,18 +71,13 @@ CK_TILE_DEVICE void load_and_convert_tile(WarpTile& dst, const WarpWindow& src_w { if constexpr(std::is_same_v) { - static_assert(!LoadTranspose, "LoadTranspose not supported with pk_int4_t"); ConverterLoader::load_interleaved_pk_type( dst, src_window); } - else if constexpr(LoadTranspose) - { - load_tile_transpose(dst, src_window); - } else { - load_tile(dst, src_window); + ConverterLoader:: + template load_with_type_convert(dst, src_window); } } - } // namespace ck_tile From 44fd387896eadcd19d66acef5bf060090996f5c6 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 12 Nov 2025 12:38:04 +0000 Subject: [PATCH 26/44] Add MFMA warp gemm for float, float, float, 32, 32, 16 --- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 6 ++++++ include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp | 2 ++ 2 files changed, 8 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 00512424752..db6eb5d9108 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -17,6 +17,12 @@ namespace ck_tile { using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl< WarpGemmAttributeMfma>>; +template +using WarpGemmMfmaF32F32F32M32N32K16 = WarpGemmImpl, + 8, + AttrNumAccess>>; + template using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl, diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index a7d71d4fa3c..2689cb8b0ec 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -37,6 +37,8 @@ template<> struct Dispatcher { using Typ template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K8<>; }; template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K16<>; }; +template<> struct Dispatcher { using Type = WarpGemmMfmaF32F32F32M32N32K16; }; // fp16 // ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity template<> struct Dispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; From 926546ce3531bf3668ea79cf2563720f843f170f Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 9 Oct 2025 09:04:13 +0000 Subject: [PATCH 27/44] Add functionality and tests for bf16 x fp8 and fp8 x bf16 --- .../ops/common/determine_warp_prec_type.hpp | 14 ++++++++++++++ .../gemm/test_gemm_pipeline_kernel_types.hpp | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp index 2e28073c9c8..8d30f60bdb2 100644 --- a/include/ck_tile/ops/common/determine_warp_prec_type.hpp +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -38,4 +38,18 @@ struct DetermineWarpPrecType { using prec_type = OtherPrecType; }; + +// For fp8 x bf16 or bf16 x fp8, convert fp8 to float +template <> +struct DetermineWarpPrecType +{ + using prec_type = float; +}; + +// For fp8 x bf16 or bf16 x fp8, convert bf16 to float +template <> +struct DetermineWarpPrecType +{ + using prec_type = float; +}; }; // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 4bef581254e..3bfcd9e2c09 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -81,9 +81,11 @@ using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, @@ -91,9 +93,11 @@ using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, @@ -101,9 +105,11 @@ using KernelTypesCompV3 = ::testing::Types< std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, @@ -111,9 +117,11 @@ using KernelTypesCompV3 = ::testing::Types< std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, From 07b103a46cfd2964fcfc03863154cdf0cef44ffc Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 12 Nov 2025 15:09:01 +0000 Subject: [PATCH 28/44] Add functionality and tests for fp16 x fp8 and fp8 x fp16 --- .../ops/common/determine_warp_prec_type.hpp | 14 ++++++++++++++ .../gemm/test_gemm_pipeline_kernel_types.hpp | 8 ++++++++ 2 files changed, 22 insertions(+) diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp index 8d30f60bdb2..13094f591c8 100644 --- a/include/ck_tile/ops/common/determine_warp_prec_type.hpp +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -52,4 +52,18 @@ struct DetermineWarpPrecType { using prec_type = float; }; + +// For fp8 x fp16 or fp16 x fp8, convert fp8 to float +template <> +struct DetermineWarpPrecType +{ + using prec_type = float; +}; + +// For fp8 x fp16 or fp16 x fp8, convert fp16 to float +template <> +struct DetermineWarpPrecType +{ + using prec_type = float; +}; }; // namespace ck_tile diff --git a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp index 3bfcd9e2c09..70449b5adf8 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_kernel_types.hpp @@ -79,48 +79,56 @@ using KernelTypesMemWmma = ::testing::Types< using KernelTypesCompV3 = ::testing::Types< std::tuple< Row, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Row, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Row, Col, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Row, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Row, Row, BF8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, BF16, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, INT8, INT8, INT32, INT32, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, F8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, + std::tuple< Col, Col, Row, F8, F16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF16, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, BF8, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, std::tuple< Col, Col, Row, F8, I4, F32, F16, I256, I256, I64, I32, I32, I16, Intrawave, CompV3>, From f031cc03d018073cf63765b4cbb52a019bf6a975 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 10 Oct 2025 08:39:04 +0000 Subject: [PATCH 29/44] Add type conversions to V4 pipeline, WIP! --- ...emm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 777537a83a0..6b1e9f2871f 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" @@ -22,6 +23,10 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; + using ATypeToUse = typename DetermineWarpPrecType::prec_type; + using BTypeToUse = typename DetermineWarpPrecType::prec_type; constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); @@ -33,8 +38,8 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy : vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad : WGAttrNumAccessEnum::Invalid; - using WarpGemm = WarpGemmDispatcher; - using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy; From 34e191307f56d9a9b597ecbfc82c519e991b6aea Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 18 Dec 2025 09:14:11 +0000 Subject: [PATCH 30/44] Refactor type conversions out of MakeBLdsBlockDescriptor, WIP! --- .../gemm_universal_pipeline_ag_bg_cr_policy.hpp | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index c900cd6ca59..e68e2330ee2 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -302,15 +302,12 @@ struct UniversalGemmBasePolicy * @tparam Problem Gemm pipeline problem. * @return B tensor LDS block descriptor. */ - template + template > CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor() { - using BLayout = remove_cvref_t; - using BDataType = - std::conditional_t, - typename Problem::ADataType, - typename Problem::BDataType>; - + using BLayout = remove_cvref_t; + using BDataType = OverrideBDataType; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; From 068039a24f58ac37c8263b9563c22c356ee45569 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Mon, 26 Jan 2026 09:26:59 +0000 Subject: [PATCH 31/44] Add and use load_tile_transpose_convert for mixed precision transpose loading --- .../core/tensor/load_tile_transpose.hpp | 126 ++++++++++++++++++ .../ops/common/load_and_convert_tile.hpp | 10 +- 2 files changed, 128 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index a0756ad21a3..7ac6793296b 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -343,6 +343,14 @@ template ; +// Mixed-precision policy that allows different input and output types +template +struct MixedPrecisionTranspose : public DefaultTranspose +{ + // Inherits quad pattern validation from input type + // but allows output type to differ +}; + template , + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window, + index_t offset) +{ + using InputDataType = typename BottomTensorView_::DataType; + using OutputDataType = typename DistributedTensor_::DataType; + + auto trans_tensor = tile_window.template load_transpose_with_offset(offset); + constexpr auto input_distr = TileDistribution_{}; + constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{}; + + constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); + constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); + + constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y(); + //constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y(); + + constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths()); + constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths()); + + constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size(); + constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size(); + + // For mixed precision: element space size must be the same (total bytes match) + static_assert(y_in_element_space_size == y_out_element_space_size, + "For mixed precision transpose, input and output element space size must match!"); + + // Allow different vector lengths (e.g., fp8 may vectorize 8 elems, fp16 may vectorize 4). + // Ensure total element counts are consistent and divisible by the input vector length. + constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1]; + constexpr index_t total_elems_in = reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}); + constexpr index_t total_elems_out = reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{}); + static_assert(total_elems_in == total_elems_out, + "For mixed precision transpose, input/output element counts must match!"); + static_assert(total_elems_in % vecLoadSize == 0, + "Input vector length must evenly divide total elements."); + + constexpr index_t num_of_access = total_elems_in / vecLoadSize; + + // Read as input type, convert to output type + using InputDataVec = array; + static_for<0, num_of_access, 1>{}([&](auto iAccess) { + auto input_vec = trans_tensor.get_thread_buffer().template get_as( + number{}); + + // Element-wise type conversion + // This will be unrolled by the compiler for each element in the vector + static_for<0, vecLoadSize, 1>{}([&](auto iElem) { + auto output_elem = type_convert(input_vec[iElem]); + out_tensor.get_thread_buffer()[number{}] = + output_elem; + }); + }); +} + +/** + * @brief Mixed-precision transpose load with zero offset. + * + * Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0. + */ +template < + typename DistributedTensor_, + typename BottomTensorView_, + typename WindowLengths_, + typename TileDistribution_, + index_t NumCoord, + typename Policy = DefaultTranspose, + typename = std::enable_if_t::distr_encoding_valid, + Policy>> +CK_TILE_DEVICE void +load_tile_transpose_convert( + DistributedTensor_& out_tensor, + const tile_window_with_static_distribution& __restrict__ tile_window) +{ + load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0); +} + } // namespace ck_tile diff --git a/include/ck_tile/ops/common/load_and_convert_tile.hpp b/include/ck_tile/ops/common/load_and_convert_tile.hpp index afb3975acc0..ee315b9f61a 100644 --- a/include/ck_tile/ops/common/load_and_convert_tile.hpp +++ b/include/ck_tile/ops/common/load_and_convert_tile.hpp @@ -33,19 +33,13 @@ struct ConverterLoader { if constexpr(LoadTranspose) { - static_assert(sizeof(SrcDataType) == sizeof(DstDataType), - "SrcDataType and DstDataType must have the same sizes."); if constexpr(std::is_same_v) { - dst = load_tile_transpose(src_window); + load_tile_transpose(dst, src_window); } else { - auto tmp = load_tile_transpose(src_window); - sweep_tile([&](auto i) { - element_wise::PassThrough elementwise_op{}; - elementwise_op(dst(i), tmp(i)); - }); + load_tile_transpose_convert(dst, src_window); } } else From bc08c31812b39562f858c641fb9e67d46ed4e968 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Mon, 26 Jan 2026 14:59:28 +0000 Subject: [PATCH 32/44] Restrict the range of FillUniformDistributionIntegerValue for A and B to make tests pass --- test/ck_tile/gemm/test_gemm_pipeline_util.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp index 8dc2e884302..a3da751a82d 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm/test_gemm_pipeline_util.hpp @@ -299,8 +299,8 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::HostTensor c_m_n_dev_result( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11939}(a_m_k); - ck_tile::FillUniformDistributionIntegerValue{-5, 5, 11940}(b_k_n); + ck_tile::FillUniformDistributionIntegerValue{-0.5, 0.5, 11939}(a_m_k); + ck_tile::FillUniformDistributionIntegerValue{-0.5, 0.5, 11940}(b_k_n); ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); From 89ab89dc5c936e7ee83325083d0014e40a823ce9 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 28 Jan 2026 10:05:18 +0000 Subject: [PATCH 33/44] Switch to an implementation of DetermineWarpPrecType that explicitly defines the A and B types - This is for improved clarity and finer control of the datatypes to use --- .../ops/common/determine_warp_prec_type.hpp | 78 +++++++++++++------ .../ops/epilogue/cshuffle_epilogue.hpp | 4 +- .../block/block_universal_gemm_as_bs_cr.hpp | 4 +- ...peline_ag_bg_cr_comp_v4_default_policy.hpp | 6 +- ...emm_universal_pipeline_ag_bg_cr_policy.hpp | 4 +- 5 files changed, 64 insertions(+), 32 deletions(-) diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp index 13094f591c8..866d0635ab9 100644 --- a/include/ck_tile/ops/common/determine_warp_prec_type.hpp +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -11,59 +11,91 @@ // the hardware matrix operations available. A typical use case is mixed precision GEMMs. namespace ck_tile { -// For the most general case, we default to no conversion. -template +// For the most general case, default to no conversion. +template struct DetermineWarpPrecType { - using prec_type = PrecType; + using a_prec_type = APrecType; + using b_prec_type = BPrecType; }; -// For pk_int4_t, we convert to the other precision type. -template -struct DetermineWarpPrecType +// For pk_int4_t x B, use the B type. +template +struct DetermineWarpPrecType { - using prec_type = OtherPrecType; + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; }; -// For pk_fp4_t, we convert to the other precision type. -template -struct DetermineWarpPrecType +// For A x pk_int4_t, use the A type. +template +struct DetermineWarpPrecType { - using prec_type = OtherPrecType; + using a_prec_type = APrecType; + using b_prec_type = APrecType; }; -// For pk_fp4_raw_t, we convert to the other precision type. -template -struct DetermineWarpPrecType +// For pk_fp4_t x B, use the B type. +template +struct DetermineWarpPrecType { - using prec_type = OtherPrecType; + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; }; -// For fp8 x bf16 or bf16 x fp8, convert fp8 to float +// For A x pk_fp4_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For B x pk_fp4_raw_t, use the B type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = BPrecType; + using b_prec_type = BPrecType; +}; + +// For A x pk_fp4_raw_t, use the A type. +template +struct DetermineWarpPrecType +{ + using a_prec_type = APrecType; + using b_prec_type = APrecType; +}; + +// For fp8 x bf16, use fp8 template <> struct DetermineWarpPrecType { - using prec_type = float; + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; }; -// For fp8 x bf16 or bf16 x fp8, convert bf16 to float +// For bf16 x fp8, use bf16 template <> struct DetermineWarpPrecType { - using prec_type = float; + using a_prec_type = ck_tile::bf16_t; + using b_prec_type = ck_tile::bf16_t; }; -// For fp8 x fp16 or fp16 x fp8, convert fp8 to float +// For fp8 x fp16, use fp8 template <> struct DetermineWarpPrecType { - using prec_type = float; + using a_prec_type = ck_tile::fp8_t; + using b_prec_type = ck_tile::fp8_t; }; -// For fp8 x fp16 or fp16 x fp8, convert fp16 to float +// For fp16 x fp8, use fp16 template <> struct DetermineWarpPrecType { - using prec_type = float; + using a_prec_type = ck_tile::half_t; + using b_prec_type = ck_tile::half_t; }; }; // namespace ck_tile diff --git a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp index 36addf625cd..10a64be1731 100644 --- a/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp @@ -92,8 +92,8 @@ struct CShuffleEpilogue using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; - using ATypeToUse = typename DetermineWarpPrecType::prec_type; - using BTypeToUse = typename DetermineWarpPrecType::prec_type; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using ELayout = remove_cvref_t; using CDElementwise = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 6956bfa8c80..7c46bec3dfe 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -94,8 +94,8 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; - using ATypeToUse = typename DetermineWarpPrecType::prec_type; - using BTypeToUse = typename DetermineWarpPrecType::prec_type; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using WarpGemm = remove_cvref_t; diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp index 6b1e9f2871f..583d94a9f12 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp @@ -24,9 +24,9 @@ struct GemmPipelineAgBgCrCompV4DefaultPolicy using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; using ATypeToUse = typename DetermineWarpPrecType::prec_type; - using BTypeToUse = typename DetermineWarpPrecType::prec_type; + typename Problem::BDataType>::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; constexpr index_t vector_size = DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType); diff --git a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp index e68e2330ee2..f9bdb2e125e 100644 --- a/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp @@ -894,8 +894,8 @@ struct UniversalGemmPipelineAgBgCrPolicy using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; - using ATypeToUse = typename DetermineWarpPrecType::prec_type; - using BTypeToUse = typename DetermineWarpPrecType::prec_type; + using ATypeToUse = typename DetermineWarpPrecType::a_prec_type; + using BTypeToUse = typename DetermineWarpPrecType::b_prec_type; using WarpGemm = WarpGemmDispatcher Date: Wed, 28 Jan 2026 14:23:26 +0000 Subject: [PATCH 34/44] Formatting changes --- .../core/tensor/load_tile_transpose.hpp | 22 +++++++++---------- .../gemm_pipeline_ag_bg_cr_comp_v3.hpp | 2 +- .../gemm/warp/warp_gemm_attribute_mfma.hpp | 11 +++++++--- 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/include/ck_tile/core/tensor/load_tile_transpose.hpp b/include/ck_tile/core/tensor/load_tile_transpose.hpp index 7ac6793296b..b49861fc410 100644 --- a/include/ck_tile/core/tensor/load_tile_transpose.hpp +++ b/include/ck_tile/core/tensor/load_tile_transpose.hpp @@ -575,8 +575,8 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor(); constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor(); - constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y(); - //constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y(); + constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y(); + // constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y(); constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths()); constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths()); @@ -590,9 +590,11 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( // Allow different vector lengths (e.g., fp8 may vectorize 8 elems, fp16 may vectorize 4). // Ensure total element counts are consistent and divisible by the input vector length. - constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1]; - constexpr index_t total_elems_in = reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}); - constexpr index_t total_elems_out = reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{}); + constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1]; + constexpr index_t total_elems_in = + reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{}); + constexpr index_t total_elems_out = + reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{}); static_assert(total_elems_in == total_elems_out, "For mixed precision transpose, input/output element counts must match!"); static_assert(total_elems_in % vecLoadSize == 0, @@ -603,15 +605,14 @@ CK_TILE_DEVICE void load_tile_transpose_convert_with_offset( // Read as input type, convert to output type using InputDataVec = array; static_for<0, num_of_access, 1>{}([&](auto iAccess) { - auto input_vec = trans_tensor.get_thread_buffer().template get_as( - number{}); + auto input_vec = + trans_tensor.get_thread_buffer().template get_as(number{}); // Element-wise type conversion // This will be unrolled by the compiler for each element in the vector static_for<0, vecLoadSize, 1>{}([&](auto iElem) { auto output_elem = type_convert(input_vec[iElem]); - out_tensor.get_thread_buffer()[number{}] = - output_elem; + out_tensor.get_thread_buffer()[number{}] = output_elem; }); }); } @@ -632,8 +633,7 @@ template < typename BottomTensorView_::DataType, Policy>::distr_encoding_valid, Policy>> -CK_TILE_DEVICE void -load_tile_transpose_convert( +CK_TILE_DEVICE void load_tile_transpose_convert( DistributedTensor_& out_tensor, const tile_window_with_static_distribution auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem); // Tile distribution for load from lds - constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v; + constexpr bool is_load_tr = is_a_load_tr_v || is_b_load_tr_v; constexpr auto a_lds_load_tile_distr = make_static_tile_distribution( BlockGemm::template MakeABlockDistributionEncode()); constexpr auto b_lds_load_tile_distr = make_static_tile_distribution( diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 0f91e13bf08..c37b09d3e66 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -167,7 +167,10 @@ struct WarpGemmAttributeMfmaIterateK static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, "Multi-block on both M & N directions is not supported"); - template + template CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding() { if constexpr(kMNBlock == 1 && kNMBlock == 1) @@ -360,13 +363,15 @@ struct WarpGemmAttributeMfmaTransposedCDistribution template CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() { - return WarpGemmAttributeMfma::template get_bwarp_dstr_encoding(); + return WarpGemmAttributeMfma::template get_bwarp_dstr_encoding(); } template CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() { - return WarpGemmAttributeMfma::template get_awarp_dstr_encoding(); + return WarpGemmAttributeMfma::template get_awarp_dstr_encoding(); } template From f35688cf2ad9d6c440f5ae4bdb8d69144122d814 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Wed, 28 Jan 2026 14:24:47 +0000 Subject: [PATCH 35/44] Add a changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c99fc1d0657..9d7dd9d97bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.2.0 ### Added +* Added support for fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline * Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4 * Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support. * Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle. From 2848c213a971e3449539a810a597089e7775fe89 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 29 Jan 2026 08:03:57 +0000 Subject: [PATCH 36/44] Add include statements added by remod.py --- include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 1 + include/ck_tile/ops/batched_contraction.hpp | 1 + include/ck_tile/ops/batched_transpose.hpp | 1 + include/ck_tile/ops/elementwise.hpp | 1 + include/ck_tile/ops/epilogue.hpp | 1 + include/ck_tile/ops/flatmm.hpp | 1 + include/ck_tile/ops/fmha.hpp | 1 + include/ck_tile/ops/fused_moe.hpp | 1 + include/ck_tile/ops/gemm.hpp | 1 + include/ck_tile/ops/gemm_quant.hpp | 1 + include/ck_tile/ops/grouped_convolution.hpp | 1 + include/ck_tile/ops/image_to_column.hpp | 1 + include/ck_tile/ops/layernorm2d.hpp | 1 + include/ck_tile/ops/norm_reduce.hpp | 1 + include/ck_tile/ops/permute.hpp | 1 + include/ck_tile/ops/pooling.hpp | 1 + include/ck_tile/ops/reduce.hpp | 1 + include/ck_tile/ops/rmsnorm2d.hpp | 1 + include/ck_tile/ops/smoothquant.hpp | 1 + include/ck_tile/ops/softmax.hpp | 1 + include/ck_tile/ops/topk.hpp | 1 + include/ck_tile/ops/topk_softmax.hpp | 1 + 22 files changed, 22 insertions(+) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index aa0f632c216..a62bbe981cc 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -7,6 +7,7 @@ #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/batched_contraction.hpp b/include/ck_tile/ops/batched_contraction.hpp index 9c90db67edd..71919b61873 100644 --- a/include/ck_tile/ops/batched_contraction.hpp +++ b/include/ck_tile/ops/batched_contraction.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp" #include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp" #include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/batched_transpose.hpp b/include/ck_tile/ops/batched_transpose.hpp index 9cac035c445..924db5fb60e 100644 --- a/include/ck_tile/ops/batched_transpose.hpp +++ b/include/ck_tile/ops/batched_transpose.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp" #include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/elementwise.hpp b/include/ck_tile/ops/elementwise.hpp index bc72f3b0ba1..2c0ae4ad093 100644 --- a/include/ck_tile/ops/elementwise.hpp +++ b/include/ck_tile/ops/elementwise.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp" #include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp" #include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/epilogue.hpp b/include/ck_tile/ops/epilogue.hpp index d1b38a8bca6..0eb9e59e723 100644 --- a/include/ck_tile/ops/epilogue.hpp +++ b/include/ck_tile/ops/epilogue.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/epilogue/default_2d_and_dynamic_quant_epilogue.hpp" #include "ck_tile/ops/epilogue/default_2d_epilogue.hpp" #include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/flatmm.hpp b/include/ck_tile/ops/flatmm.hpp index e08fac48c7e..2e71957ac77 100644 --- a/include/ck_tile/ops/flatmm.hpp +++ b/include/ck_tile/ops/flatmm.hpp @@ -21,6 +21,7 @@ #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/flatmm/pipeline/mx_flatmm_pipeline_agmem_bgmem_creg_v1_policy.hpp" #include "ck_tile/ops/flatmm/pipeline/tile_flatmm_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/fmha.hpp b/include/ck_tile/ops/fmha.hpp index 0639fa1b36e..2068dfeefe0 100644 --- a/include/ck_tile/ops/fmha.hpp +++ b/include/ck_tile/ops/fmha.hpp @@ -60,6 +60,7 @@ #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp" #include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/fused_moe.hpp b/include/ck_tile/ops/fused_moe.hpp index 60f5bd1c4e3..2eb4abd6411 100644 --- a/include/ck_tile/ops/fused_moe.hpp +++ b/include/ck_tile/ops/fused_moe.hpp @@ -14,6 +14,7 @@ #include "ck_tile/ops/fused_moe/pipeline/fused_moegemm_traits.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp" #include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/gemm.hpp b/include/ck_tile/ops/gemm.hpp index 8dbf111048e..abfafb6bb4b 100644 --- a/include/ck_tile/ops/gemm.hpp +++ b/include/ck_tile/ops/gemm.hpp @@ -76,6 +76,7 @@ #include "ck_tile/ops/gemm/warp/warp_gemm_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_smfmac_impl.hpp" #include "ck_tile/ops/gemm/warp/warp_wmma_gemm.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp index 6aee73cda1d..df4a7c79778 100644 --- a/include/ck_tile/ops/gemm_quant.hpp +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -30,6 +30,7 @@ #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp" #include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp" #include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/grouped_convolution.hpp b/include/ck_tile/ops/grouped_convolution.hpp index eeb9b1d8a81..5dafc554203 100644 --- a/include/ck_tile/ops/grouped_convolution.hpp +++ b/include/ck_tile/ops/grouped_convolution.hpp @@ -11,6 +11,7 @@ #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_data_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_bwd_weight_to_gemm.hpp" #include "ck_tile/ops/grouped_convolution/utils/transform_conv_fwd_to_gemm.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/image_to_column.hpp b/include/ck_tile/ops/image_to_column.hpp index 07d99890869..faa165a8b00 100644 --- a/include/ck_tile/ops/image_to_column.hpp +++ b/include/ck_tile/ops/image_to_column.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp" #include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp" #include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/layernorm2d.hpp b/include/ck_tile/ops/layernorm2d.hpp index 8f9ab205ac4..2266c138729 100644 --- a/include/ck_tile/ops/layernorm2d.hpp +++ b/include/ck_tile/ops/layernorm2d.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/norm_reduce.hpp b/include/ck_tile/ops/norm_reduce.hpp index eae0ea14a33..9f572ff5cbb 100644 --- a/include/ck_tile/ops/norm_reduce.hpp +++ b/include/ck_tile/ops/norm_reduce.hpp @@ -5,6 +5,7 @@ #include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp" #include "ck_tile/ops/norm_reduce/block/block_norm_reduce_problem.hpp" #include "ck_tile/ops/norm_reduce/thread/thread_welford.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/permute.hpp b/include/ck_tile/ops/permute.hpp index 4d37f4fbc12..c7747a67e70 100644 --- a/include/ck_tile/ops/permute.hpp +++ b/include/ck_tile/ops/permute.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp" #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/pooling.hpp b/include/ck_tile/ops/pooling.hpp index faa77d53273..43b24c7f8ca 100644 --- a/include/ck_tile/ops/pooling.hpp +++ b/include/ck_tile/ops/pooling.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/pooling/pipeline/pool_default_policy.hpp" #include "ck_tile/ops/pooling/pipeline/pool_problem.hpp" #include "ck_tile/ops/pooling/pipeline/pool_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index b5e53283e48..e680d257453 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -13,6 +13,7 @@ #include "ck_tile/ops/reduce/pipeline/reduce2d_default_policy.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_problem.hpp" #include "ck_tile/ops/reduce/pipeline/reduce2d_shape.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index f271be50068..7ee67334d6b 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -9,6 +9,7 @@ #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp index 4c2fe9bee43..ad984c033f0 100644 --- a/include/ck_tile/ops/smoothquant.hpp +++ b/include/ck_tile/ops/smoothquant.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/softmax.hpp b/include/ck_tile/ops/softmax.hpp index c79ba06abfe..b810a57dda0 100644 --- a/include/ck_tile/ops/softmax.hpp +++ b/include/ck_tile/ops/softmax.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/softmax/block/block_softmax_2d.hpp" #include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/topk.hpp b/include/ck_tile/ops/topk.hpp index 474ba932270..13d818174e6 100644 --- a/include/ck_tile/ops/topk.hpp +++ b/include/ck_tile/ops/topk.hpp @@ -4,6 +4,7 @@ #include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp" #include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" diff --git a/include/ck_tile/ops/topk_softmax.hpp b/include/ck_tile/ops/topk_softmax.hpp index 066fbf5feea..b7219511faa 100644 --- a/include/ck_tile/ops/topk_softmax.hpp +++ b/include/ck_tile/ops/topk_softmax.hpp @@ -6,6 +6,7 @@ #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp" #include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp" +#include "ck_tile/ops/common/determine_warp_prec_type.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_and_convert_tile.hpp" #include "ck_tile/ops/common/streamk_common.hpp" From 154602051e8d223023bf9c4ef07eccc1c1396288 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 29 Jan 2026 11:49:38 +0000 Subject: [PATCH 37/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index c37b09d3e66..2f2c85e9d90 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -432,6 +432,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -440,6 +441,7 @@ struct WarpGemmAttributeMfmaTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple Date: Thu, 29 Jan 2026 12:16:47 +0000 Subject: [PATCH 38/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 2f2c85e9d90..6182490f4ee 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -781,6 +781,7 @@ struct WarpGemmAttributeMfmaIterateK_SwizzleA static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence<1>>; + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, From 3aec759e04dea0ab2459a9c3443ea34efcd903ef Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Thu, 29 Jan 2026 12:21:35 +0000 Subject: [PATCH 39/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- .../ops/gemm/warp/warp_gemm_attribute_mfma.hpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 6182490f4ee..d752d2818bf 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -573,10 +573,12 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution } } - using AWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::BWarpDstrEncoding; - using BWarpDstrEncoding = - typename WarpGemmAttributeMfmaIterateK::AWarpDstrEncoding; + template + using AWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template BWarpDstrEncoding; + template + using BWarpDstrEncoding = typename WarpGemmAttributeMfmaIterateK:: + template AWarpDstrEncoding; using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); // c_vec += a_vec * b_vec @@ -655,6 +657,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1, "Multi-block WarpGemmAttributeMfmaImpl is not supported"); + template using AWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple, sequence>, @@ -663,6 +666,7 @@ struct WarpGemmAttributeMfmaIterateKAndTransposedCDistribution_SwizzleB sequence<2>, sequence<1>>; #if 0 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple>; #else // TODO: more test not only 32x32 + template using BWarpDstrEncoding = tile_distribution_encoding< sequence<>, tuple Date: Thu, 29 Jan 2026 14:55:39 +0000 Subject: [PATCH 40/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- .../ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index 7c46bec3dfe..847d6c782a8 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -155,7 +155,7 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; - using Attr = WarpGemm::WarpGemmAttribute; + using Attr = typename WarpGemm::WarpGemmAttribute; constexpr auto NumAccessA = convert ? Attr::AttrNumAccessV * sizeof(ADataType) / sizeof(ComputeDataType) : Attr::AttrNumAccessV; @@ -186,7 +186,7 @@ struct BlockUniversalGemmAsBsCr tuple>, sequence<1, 2>, sequence<0, 0>>{}; - using Attr = WarpGemm::WarpGemmAttribute; + using Attr = typename WarpGemm::WarpGemmAttribute; constexpr auto NumAccessB = convert ? Attr::AttrNumAccessV * sizeof(BDataType) / sizeof(ComputeDataType) : Attr::AttrNumAccessV; From 67b5da44a19473484c675bb551088167fd9782ab Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 30 Jan 2026 09:14:03 +0000 Subject: [PATCH 41/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- .../ops/gemm/warp/warp_gemm_attribute_wmma.hpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index ef31d06c9c2..6bccc1e816a 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -67,6 +67,8 @@ template struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single; + static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); // When kTransC is true and A/B types differ, we need an impl with swapped types using TransposedImpl = @@ -99,7 +101,21 @@ struct WarpGemmAttributeWmma // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2 // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4 + template + static constexpr auto get_awarp_dstr_encoding() + { + return typename AWarpDstrEncodingTrait::type{}; + } + + template + static constexpr auto get_bwarp_dstr_encoding() + { + return typename BWarpDstrEncodingTrait::type{}; + } + + template using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; + template using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 From b6b3df4c012501d5c39b2ed63f8e8254850a38b6 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 30 Jan 2026 09:17:25 +0000 Subject: [PATCH 42/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 6bccc1e816a..63bd822291f 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -114,9 +114,9 @@ struct WarpGemmAttributeWmma } template - using AWarpDstrEncoding = typename AWarpDstrEncodingTrait::type; + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); template - using BWarpDstrEncoding = typename BWarpDstrEncodingTrait::type; + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16 using CWarpDstrEncoding = From 1ec399c72ac9dbc7ef07c339c3095498c09db6f1 Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 30 Jan 2026 09:21:46 +0000 Subject: [PATCH 43/44] fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_warp_dstr_encoding --- include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp index 63bd822291f..4a275848b2b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp @@ -67,6 +67,8 @@ template struct WarpGemmAttributeWmma { using Impl = remove_cvref_t; + // AttrNumAccessV is required for compatibility with the block GEMM, and is currently ignored + // within WarpGemmAttributeWmma static constexpr auto AttrNumAccess = WGAttrNumAccessEnum::Single; static constexpr auto AttrNumAccessV = static_cast(AttrNumAccess); From c1e328a78d4b52e6e777386b842265478f8b34ab Mon Sep 17 00:00:00 2001 From: Sami Aario Date: Fri, 30 Jan 2026 15:30:01 +0000 Subject: [PATCH 44/44] fixup! Switch to an implementation of DetermineWarpPrecType that explicitly defines the A and B types --- include/ck_tile/ops/common/determine_warp_prec_type.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/include/ck_tile/ops/common/determine_warp_prec_type.hpp b/include/ck_tile/ops/common/determine_warp_prec_type.hpp index 866d0635ab9..ae11ff13146 100644 --- a/include/ck_tile/ops/common/determine_warp_prec_type.hpp +++ b/include/ck_tile/ops/common/determine_warp_prec_type.hpp @@ -19,6 +19,14 @@ struct DetermineWarpPrecType using b_prec_type = BPrecType; }; +// For pk_fp4_t x pk_fp4_t, keep pk_fp4_t +template <> +struct DetermineWarpPrecType +{ + using a_prec_type = ck_tile::pk_fp4_t; + using b_prec_type = ck_tile::pk_fp4_t; +}; + // For pk_int4_t x B, use the B type. template struct DetermineWarpPrecType