Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
4d77856
Make some functions return void explicitly instead of auto
SamiAario-AMD Dec 6, 2025
9691569
Use decltype for consistency in Interwave variant of BlockGemmImpl
SamiAario-AMD Nov 21, 2025
bda5a7a
Add braces
SamiAario-AMD Nov 19, 2025
825d17c
Fix a comment
SamiAario-AMD Dec 11, 2025
ca71cd7
Reduce the scope of KPack in MakeALdsBlockDescriptor
SamiAario-AMD Dec 17, 2025
994b8f4
Minor refactoring of load_interleaved_pk_type
SamiAario-AMD Nov 12, 2025
74533b4
Rename load_interleaved_pk_type to load_and_convert_tile
SamiAario-AMD Nov 27, 2025
3a094e2
Include ck_tile/core.hpp in load_interleaved_pk_type.hpp for better I…
SamiAario-AMD Nov 26, 2025
cfa11f2
Rename InterleavedPKTypeLoader to ConverterLoader, and load_int4_tile…
SamiAario-AMD Nov 27, 2025
9559a93
Make explicit that the tile window argument to load_tile_with_element…
SamiAario-AMD Dec 12, 2025
9633d3f
In GetAWindows and GetBWindows, use DataType from LDS tensor view
SamiAario-AMD Dec 17, 2025
9af4498
Remove the defaults for SrcDataType and DstDataType in GemmPipelineAg…
SamiAario-AMD Jan 7, 2026
514035e
In BQuantGemmPipelineAgBgCrCompV3, always convert BDatatype pk_int4_t…
SamiAario-AMD Jan 7, 2026
3d55a1e
No need to specify SrcDataType in load_and_convert_tile as WarpWindow…
SamiAario-AMD Dec 16, 2025
63a4559
No need to specify DstDataType in load_and_convert_tile as WarpTile k…
SamiAario-AMD Dec 16, 2025
8fc4030
Add an instance of load_tile_transpose that takes a reference to the …
SamiAario-AMD Jan 2, 2026
3216110
Remove an unused overload of load_tile_transpose_with_offset
SamiAario-AMD Jan 2, 2026
ca17ac3
When possible, use the overload of load_tile_transpose that does not …
SamiAario-AMD Jan 2, 2026
2edd077
Adjust whitespace with clang-format
SamiAario-AMD Jan 7, 2026
b91efe5
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 7, 2026
0a4388d
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 8, 2026
e62c96f
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 8, 2026
ea4e543
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 14, 2026
c020a42
Fix a build break introduced when merging
SamiAario-AMD Jan 14, 2026
35c620e
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 14, 2026
2ab79eb
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 16, 2026
d71fe5b
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 19, 2026
4b26eac
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 21, 2026
72fa29b
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 27, 2026
d0e9dc5
Merge branch 'develop' into LWPCK-3549-cleanups
SamiAario-AMD Jan 28, 2026
fc1b683
Fix a build break
SamiAario-AMD Jan 28, 2026
9185c25
Rename the parameters of load_interleaved_pk_type and load_and_conver…
SamiAario-AMD Jan 12, 2026
e1b8f6c
Add NumAccess as a template parameter to WarpGemmAttributeMfma::get_w…
SamiAario-AMD Nov 28, 2025
5744562
Introduce DetermineWarpPrecType for determining warp GEMM precision t…
SamiAario-AMD Oct 9, 2025
5a05dbf
Add and use load_with_type_convert
SamiAario-AMD Nov 12, 2025
44fd387
Add MFMA warp gemm for float, float, float, 32, 32, 16
SamiAario-AMD Nov 12, 2025
926546c
Add functionality and tests for bf16 x fp8 and fp8 x bf16
SamiAario-AMD Oct 9, 2025
07b103a
Add functionality and tests for fp16 x fp8 and fp8 x fp16
SamiAario-AMD Nov 12, 2025
f031cc0
Add type conversions to V4 pipeline, WIP!
SamiAario-AMD Oct 10, 2025
34e1913
Refactor type conversions out of MakeBLdsBlockDescriptor, WIP!
SamiAario-AMD Dec 18, 2025
068039a
Add and use load_tile_transpose_convert for mixed precision transpose…
SamiAario-AMD Jan 26, 2026
bc08c31
Restrict the range of FillUniformDistributionIntegerValue for A and B…
SamiAario-AMD Jan 26, 2026
89ab89d
Switch to an implementation of DetermineWarpPrecType that explicitly …
SamiAario-AMD Jan 28, 2026
8b97f9f
Formatting changes
SamiAario-AMD Jan 28, 2026
f35688c
Add a changelog entry
SamiAario-AMD Jan 28, 2026
2848c21
Add include statements added by remod.py
SamiAario-AMD Jan 29, 2026
1546020
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 29, 2026
72c4678
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 29, 2026
3aec759
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 29, 2026
447e41d
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 29, 2026
67b5da4
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 30, 2026
b6b3df4
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 30, 2026
1ec399c
fixup! Add NumAccess as a template parameter to WarpGemmAttributeMfma…
SamiAario-AMD Jan 30, 2026
c1e328a
fixup! Switch to an implementation of DetermineWarpPrecType that expl…
SamiAario-AMD Jan 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
## Composable Kernel 1.2.0 for ROCm 7.2.0

### Added
* Added support for fp16 x fp8, bf16 x fp8, fp8 x fp16, and fp8 x bf16 for the V3 pipeline
* Added tests for f8 x bf8 on CompV3, and f8 x bf8 with K_BlockSize 32 on CompV4
* Added CK-Tile dispatcher - a unified kernel dispatch, code generation and architecture-based kernel filtering system with with C++ and Python frontends starting with GEMM support.
* Added support for bf16 data type to grouped_gemm and grouped_gemm_preshuffle.
Expand Down
18 changes: 9 additions & 9 deletions include/ck_tile/core/tensor/load_tile.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,19 +48,19 @@ CK_TILE_DEVICE auto load_tile(const TileWindow_& tile_window,
* and an elementwise function. For each A = A0, A1… AN, the elementwise function
* is additionally applied during a single read.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile_with_elementwise(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load_tile_with_elementwise(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
// TODO: Tile windows should works with unknow number of params
// Load element_wise API works only when the input typle is a tuple-tyupe
return tile_window[number<0>{}].load(
tile_window, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
// TODO: Tile windows should work with unknown number of params
// Load element_wise API works only when the input type is a tuple-type
return tile_windows[number<0>{}].load(
tile_windows, elementwise, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

// Per-lane read-offset tweaks allow swizzling patterns not representable by tile_distribution.
Expand All @@ -85,12 +85,12 @@ template <typename DistributedTensor_,
typename TileWindow_,
index_t i_access = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load_tile(DistributedTensor_& dst_tile,
CK_TILE_DEVICE void load_tile(DistributedTensor_& dst_tile,
const TileWindow_& tile_window,
number<i_access> = {},
bool_constant<oob_conditional_check> = {})
{
return tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
tile_window.load(dst_tile, number<i_access>{}, bool_constant<oob_conditional_check>{});
}

/**
Expand Down Expand Up @@ -131,7 +131,7 @@ template <typename T,
index_t i_access = -1,
bool oob_conditional_check = true,
bool pre_nop = false>
CK_TILE_DEVICE auto load_tile_raw(T& tile,
CK_TILE_DEVICE void load_tile_raw(T& tile,
const tile_window_linear<BottomTensorView_,
WindowLengths_,
TileDistribution_,
Expand Down
180 changes: 166 additions & 14 deletions include/ck_tile/core/tensor/load_tile_transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ template <typename TileDistributionEncoding_,
using InputTileDistributionTraits =
TransposeTileDistributionTraits<TileDistributionEncoding_, DataType_, Policy, true>;

// Mixed-precision policy that allows different input and output types
template <typename InputDataType, typename OutputDataType>
struct MixedPrecisionTranspose : public DefaultTranspose<InputDataType>
{
// Inherits quad pattern validation from input type
// but allows output type to differ
};

template <typename InnerEncode,
index_t kLeadIterPerWarp,
index_t kSecondIterPerWarp,
Expand Down Expand Up @@ -373,25 +381,27 @@ CK_TILE_HOST_DEVICE constexpr auto InputTileDistributionEncoding()
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* @param offset The offset (in elements) added to the base address before
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
Expand All @@ -401,21 +411,17 @@ template <
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE auto load_tile_transpose_with_offset(
CK_TILE_DEVICE void load_tile_transpose_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));
auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};

constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
Expand All @@ -442,8 +448,6 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
number<iAccess>{},
trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
});

return out_tensor;
}

/**
Expand All @@ -455,23 +459,45 @@ CK_TILE_DEVICE auto load_tile_transpose_with_offset(
* element space size and vector length remain consistent between the input and output
* distributions.
*
* @tparam DistributedTensor_ The type of the tensor containing the transposed tile data.
* @tparam BottomTensorView_ The type of the bottom tensor view.
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy to use (defaults to DefaultTranspose).
* the last is SFINAE to ensure the tile distribution encoding is valid.
*
* @param out_tensor A statically distributed tensor containing the transposed tile
* data.
* @param tile_window The tile window with static distribution to load and transpose.
* indexing.
*
* @return A statically distributed tensor containing the transposed tile data.
*
* @note
* - The function uses compile-time checks to ensure the input and output tile distributions
* are compatible in terms of element space size and vector length.
* - The transpose operation is performed according to the specified Policy.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void
load_tile_transpose(DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
load_tile_transpose_with_offset(out_tensor, tile_window, 0);
}

template <
typename BottomTensorView_,
typename WindowLengths_,
Expand All @@ -488,7 +514,133 @@ load_tile_transpose(const tile_window_with_static_distribution<BottomTensorView_
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
return load_tile_transpose_with_offset(tile_window, 0);
using OutTileDstrEncode = typename OutputTileDistributionTraits<
typename TileDistribution_::DstrEncode,
typename BottomTensorView_::DataType>::TransposedDstrEncode;
auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
make_static_tile_distribution(OutTileDstrEncode{}));

load_tile_transpose_with_offset(out_tensor, tile_window, 0);

return out_tensor;
}

/**
* @brief Mixed-precision transpose load: converts input data type to output data type while
* transposing.
*
* This function enables transposing from one data type (e.g., fp8) to another (e.g., fp16) in a
* single operation. The input tile distribution encoding must be valid for the input data type,
* and the output distribution will be generated based on the output data type.
*
* @tparam DistributedTensor_ The output tensor type with desired output data type.
* @tparam BottomTensorView_ The input tensor view (may have different data type than output).
* @tparam WindowLengths_ The type representing the window lengths.
* @tparam TileDistribution_ The type representing the tile distribution for input.
* @tparam NumCoord The number of coordinates (dimensions).
* @tparam Policy The transpose policy (should validate against input type).
*
* @note
* - Input and output must have compatible element space sizes (total byte count per Y-space).
* - Type conversion is performed element-by-element during the copy.
* - The validation uses the input data type for quad pattern checking.
* - The output distribution is generated based on the output data type.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void load_tile_transpose_convert_with_offset(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window,
index_t offset)
{
using InputDataType = typename BottomTensorView_::DataType;
using OutputDataType = typename DistributedTensor_::DataType;

auto trans_tensor = tile_window.template load_transpose_with_offset<Policy>(offset);
constexpr auto input_distr = TileDistribution_{};
constexpr auto output_distr = typename DistributedTensor_::StaticTileDistribution{};

constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();

constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
// constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();

constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());

constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();

// For mixed precision: element space size must be the same (total bytes match)
static_assert(y_in_element_space_size == y_out_element_space_size,
"For mixed precision transpose, input and output element space size must match!");

// Allow different vector lengths (e.g., fp8 may vectorize 8 elems, fp16 may vectorize 4).
// Ensure total element counts are consistent and divisible by the input vector length.
constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
constexpr index_t total_elems_in =
reduce_on_sequence(y_in_lengths, multiplies<>{}, number<1>{});
constexpr index_t total_elems_out =
reduce_on_sequence(y_out_lengths, multiplies<>{}, number<1>{});
static_assert(total_elems_in == total_elems_out,
"For mixed precision transpose, input/output element counts must match!");
static_assert(total_elems_in % vecLoadSize == 0,
"Input vector length must evenly divide total elements.");

constexpr index_t num_of_access = total_elems_in / vecLoadSize;

// Read as input type, convert to output type
using InputDataVec = array<InputDataType, vecLoadSize>;
static_for<0, num_of_access, 1>{}([&](auto iAccess) {
auto input_vec =
trans_tensor.get_thread_buffer().template get_as<InputDataVec>(number<iAccess>{});

// Element-wise type conversion
// This will be unrolled by the compiler for each element in the vector
static_for<0, vecLoadSize, 1>{}([&](auto iElem) {
auto output_elem = type_convert<OutputDataType>(input_vec[iElem]);
out_tensor.get_thread_buffer()[number<iAccess * vecLoadSize + iElem>{}] = output_elem;
});
});
}

/**
* @brief Mixed-precision transpose load with zero offset.
*
* Convenience wrapper for load_tile_transpose_convert_with_offset with offset=0.
*/
template <
typename DistributedTensor_,
typename BottomTensorView_,
typename WindowLengths_,
typename TileDistribution_,
index_t NumCoord,
typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
typename BottomTensorView_::DataType,
Policy>::distr_encoding_valid,
Policy>>
CK_TILE_DEVICE void load_tile_transpose_convert(
DistributedTensor_& out_tensor,
const tile_window_with_static_distribution<BottomTensorView_,
WindowLengths_,
TileDistribution_,
NumCoord>& __restrict__ tile_window)
{
load_tile_transpose_convert_with_offset(out_tensor, tile_window, 0);
}

} // namespace ck_tile
18 changes: 9 additions & 9 deletions include/ck_tile/core/tensor/tile_window.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,32 +182,32 @@ struct tile_window_with_static_distribution
* The same thread, during vectorized reading, accesses the same set of
* data from A0, A1, A2, … AN.
*/
template <typename TileWindow_,
template <typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
CK_TILE_DEVICE auto load(const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
{
constexpr auto tile_dstr = typename Base::TileDstr{};
auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
load(dst_tensor,
tile_window,
tile_windows,
elementwise,
number<i_access_unsupport_>{},
bool_constant<oob_conditional_check>{});
return dst_tensor;
}

template <typename DistributedTensor,
typename TileWindow_,
typename... TileWindow_,
typename ElementWise_,
index_t i_access_unsupport_ = -1,
bool oob_conditional_check = true>
CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
const TileWindow_& tile_window,
const ck_tile::tuple<TileWindow_...>& tile_windows,
ElementWise_ elementwise,
number<i_access_unsupport_> = {},
bool_constant<oob_conditional_check> = {}) const
Expand All @@ -218,14 +218,14 @@ struct tile_window_with_static_distribution
using SFC_Ys = typename Traits::SFC_Ys;

constexpr auto tile_dstr = typename Base::TileDstr{};
constexpr auto sizeOfTuple = TileWindow_::size();
constexpr auto sizeOfTuple = remove_cvref_t<decltype(tile_windows)>::size();
// loop over thread tensor space [y0, y1, ...]
static_for<0, NumCoord, 1>{}([&](auto iCoord) {
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto window_adaptor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I0];
auto bottom_tensor_thread_coord =
tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
tile_windows[number<0>{}].pre_computed_coords_[iCoord][I1];

static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
Expand All @@ -236,7 +236,7 @@ struct tile_window_with_static_distribution
// read from bottom tensor
const auto idx_vec_value = generate_tuple(
[&](auto jj) {
return tile_window[number<jj>{}]
return tile_windows[number<jj>{}]
.get_bottom_tensor_view()
.template get_vectorized_elements<vector_t>(
bottom_tensor_thread_coord,
Expand Down
Loading