diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp index ffd728f2593..a8eb7ef9bca 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp @@ -30,8 +30,9 @@ template + bool TransposeC = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlops_pipeline_base { static constexpr auto I0 = Number<0>{}; @@ -386,7 +387,7 @@ struct BlockwiseGemmXdlops_pipeline_base Sequence<1, 1, 1, KPack>, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoadToVgpr ? 1 : A_K1, + ALdsScalarLoadToVgpr ? 1 : A_K1, A_K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, Sequence<0, 1, 2, 3>, 3, - LdsScalarLoadToVgpr ? 1 : B_K1, + BLdsScalarLoadToVgpr ? 1 : B_K1, B_K1>; AThreadCopy a_thread_copy_; diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp index 461ca513f97..2f12cf1d335 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp @@ -33,11 +33,12 @@ template + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> constexpr auto BlockGemmPipeline_Selector() { // Supported for Direct Load and V1 - if constexpr(LdsScalarLoadToVgpr) + if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr) { static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1); } @@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector() MRepeat, NRepeat, KPack, - LdsScalarLoadToVgpr>{}; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>{}; } else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp index ae4504d6baa..91bdbd0a8c1 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp @@ -759,7 +759,8 @@ template + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 { }; @@ -784,7 +785,8 @@ template + bool ALdsScalarLoadToVgpr, + bool BLdsScalarLoadToVgpr> struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1 + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr> : BlockwiseGemmXdlops_pipeline_base + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr> { using Base = BlockwiseGemmXdlops_pipeline_base; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; using Base::I0; using Base::KRepeat; using Base::xdlops_gemm; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp new file mode 100644 index 00000000000..d6ee099930c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp @@ -0,0 +1,1218 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include + +#include + +#include "ck/library/utility/numeric.hpp" +#include "ck/utility/common_header.hpp" +#include "ck/utility/env.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp" +#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/operator_transform/transform_conv_ngchw_to_nhwgc.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/host_utility/io.hpp" + +#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +namespace { + +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_grouped_conv_bwd_data_xdl_cshuffle_v3( + typename GridwiseGemm::Argument karg, + const std::array gemm_kernel_args, + const index_t gemms_count, + const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, + const index_t num_k_per_block) +{ +#if defined(__gfx9__) + // offset base pointer for each work-group + const index_t block_args_id = __builtin_amdgcn_readfirstlane(blockIdx.x); + const index_t g_idx = __builtin_amdgcn_readfirstlane(blockIdx.y); + const index_t k_idx = __builtin_amdgcn_readfirstlane(blockIdx.z * num_k_per_block); + + const long_index_t a_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)); + const long_index_t b_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)); + const long_index_t e_batch_offset = + amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)); + + __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte(get_device_arch())]; + + index_t left = 0; + index_t right = gemms_count; + index_t group_id = index_t((left + right) / 2); + while((!(block_args_id >= gemm_kernel_args[group_id].BlockStart_ && + block_args_id < gemm_kernel_args[group_id].BlockEnd_)) && + left <= right) + { + if(block_args_id < gemm_kernel_args[group_id].BlockStart_) + { + right = group_id; + } + else + { + left = group_id; + } + group_id = index_t((left + right) / 2); + } + + if constexpr(GridwiseGemm::DirectLoadEnabled) + { +#if defined(__gfx950__) + const auto a_grid_desc_ak0_m_ak1_transformed = + GridwiseGemm::template TransformGrid( + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_); + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + a_grid_desc_ak0_m_ak1_transformed, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } +#endif + } + else + { + if(gemm_kernel_args[group_id].HasMainKBlockLoop_) + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + else + { + GridwiseGemm::template Run( + karg.p_a_grid + a_batch_offset, + karg.p_b_grid + b_batch_offset, + karg.p_c_grid + e_batch_offset, + p_shared, + karg, + gemm_kernel_args[group_id].a_grid_desc_ak0_m_ak1_, + gemm_kernel_args[group_id].b_grid_desc_bk0_n_bk1_, + gemm_kernel_args[group_id].e_grid_desc_mblock_mperblock_nblock_nperblock_, + k_idx, + gridDim.z, + blockIdx.x - gemm_kernel_args[group_id].BlockStart_); + } + } +#else + ignore = karg; + ignore = gemm_kernel_args; + ignore = gemms_count; + ignore = compute_ptr_offset_of_batch; + ignore = num_k_per_block; + +#endif // End of if (!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) +} +} // namespace + +// Conv backward data multiple D: +// input : output image A: [G, N, K, Ho, Wo] +// input : weight B: [G, K, C, Y, X], +// input : D0, D1, ... : [G, N, K, Ho, Wo] +// output : input image E: [G, N, C, Hi, Wi] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +template +struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3 + : public DeviceGroupedConvBwdDataMultipleD +{ + // TODO: Extend support for more spatial dimensions. + static_assert(NDimSpatial == 2 || NDimSpatial == 3, + "wrong! only implemented for 2D and 3D now"); + + static_assert(std::is_same_v, "A not NGHWC"); + static_assert(std::is_same_v, "B not GKYXC"); + static_assert(std::is_same_v, "C not NGHWK"); + + // MaxGroupedGemmGroupsNum is used to specify number of gemm args in compile time. With this + // implementation we can avoid copy data to workspace before kernel launch since number of + // groups is runtime parameter. If number of groups is larger than MaxGroupedGemmGroupsNum then + // we run this kernel in the loop. + static constexpr index_t MaxGroupedGemmGroupsNum = + ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0 + ? 1 + : 32; + + using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3; + + static constexpr index_t NumDTensor = DsDataType::Size(); + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static_assert(NumDTensor == 0, "Not supported"); + // static_assert(DirectLoad, "Not supported"); + + static constexpr GemmSpecialization GemmSpec = GemmSpecialization::MNKPadding; + static constexpr bool IsSplitKSupported = false; + + // TODO: Add support for different A and B data types. + using ABDataType = ADataType; + + using ConvToGemmBwdDataTransform = TransformConvBwdDataToGemm_v1; + + // Dummy function just used to create an alias to Grid Descriptors + static auto + GetDummyABDsEGridDescriptor(const ConvToGemmBwdDataTransform& conv_to_gemm_transform) + { + const auto a_grid_desc_ak0_m_ak1 = conv_to_gemm_transform.MakeADescriptor_AK0_M_AK1(); + + const auto b_grid_desc_bk0_n_bk1 = conv_to_gemm_transform.MakeBDescriptor_BK0_N_BK1(); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + conv_to_gemm_transform.MakeCDescriptor_M_N(), 1, 1); + + return make_tuple(a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock); + } + + static constexpr index_t ABlockTransferSrcScalarPerVectorAligned = + ABlockTransferSrcScalarPerVector * sizeof(ADataType) == 8 + ? 4 / sizeof(ADataType) + : ABlockTransferSrcScalarPerVector; + static constexpr index_t BBlockTransferSrcScalarPerVectorAligned = + BBlockTransferSrcScalarPerVector * sizeof(BDataType) == 8 + ? 4 / sizeof(BDataType) + : BBlockTransferSrcScalarPerVector; + + static constexpr bool ALdsScalarLoadToVgpr = false; + static constexpr bool BLdsScalarLoadToVgpr = true; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_xdl_cshuffle_conv_v3< + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::RowMajor, + ADataType, + BDataType, + AccDataType, + EDataType, + EDataType, + AElementwiseOp, + BElementwiseOp, + CDEElementwiseOp, + GemmSpec, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXdl, + NPerXdl, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + DirectLoad ? ABlockTransferSrcScalarPerVectorAligned : ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + DirectLoad ? BBlockTransferSrcScalarPerVectorAligned : BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector, + BlkGemmPipeSched, + BlkGemmPipelineVer, + AComputeType, + BComputeType, + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; + + template + static auto transform_k0_m_k1_to_m_k(const Desc_K0_M_K1& desc_k0_m_k1) + { + const auto grid_desc_m_k = transform_tensor_descriptor( + desc_k0_m_k1, + make_tuple(make_pass_through_transform(desc_k0_m_k1.GetLength(I1)), + make_merge_transform( + make_tuple(desc_k0_m_k1.GetLength(I0), desc_k0_m_k1.GetLength(I2)))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return grid_desc_m_k; + } + + // Note: the dummy function is used just to create the alias + constexpr static ConvToGemmBwdDataTransform dummy_conv_to_gemm_transform; + using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor(dummy_conv_to_gemm_transform)); + + using AGridDesc_AK0_M_AK1 = remove_cvref_t>; + using BGridDesc_BK0_N_BK1 = remove_cvref_t>; + using EGridDesc_MPerBlock_NBlock_NPerBlock = remove_cvref_t>; + + using AGridDesc_M_K = decltype(transform_k0_m_k1_to_m_k(AGridDesc_AK0_M_AK1{})); + using BGridDesc_N_K = decltype(transform_k0_m_k1_to_m_k(BGridDesc_BK0_N_BK1{})); + + struct GemmArgs + { + GemmArgs() = default; + GemmArgs(AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, + EGridDesc_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, + index_t BlockStart, + index_t BlockEnd, + bool HasMainKBlockLoop) + : a_grid_desc_ak0_m_ak1_(a_grid_desc_ak0_m_ak1), + b_grid_desc_bk0_n_bk1_(b_grid_desc_bk0_n_bk1), + e_grid_desc_mblock_mperblock_nblock_nperblock_( + e_grid_desc_mblock_mperblock_nblock_nperblock), + BlockStart_(BlockStart), + BlockEnd_(BlockEnd), + HasMainKBlockLoop_(HasMainKBlockLoop) + + { + } + // tensor descriptors for block/thread-wise copy + AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; + BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; + EGridDesc_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_; + index_t BlockStart_, BlockEnd_; + bool HasMainKBlockLoop_; + }; + // block-to-e-tile map for elementwise kernels + using Block2TileMapInOutElementwise = BlockToCTileMap_M00_N0_M01Adapt; + using Block2TileMapWeiElementwise = BlockToCTileMap_M00_N0_M01Adapt; + static constexpr index_t ClusterLengthMPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1); + static constexpr index_t ClusterLengthNPerBlock = + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3); + + static constexpr index_t ElementwiseBlocksize = ClusterLengthMPerBlock * ClusterLengthNPerBlock; + + // Argument + struct Argument : public BaseArgument + { + Argument(const void* p_a, // output image + const void* p_b, // weight + const std::array&, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, + const std::array& a_g_n_k_wos_strides, + const std::array& b_g_k_c_xs_lengths, + const std::array& b_g_k_c_xs_strides, + const std::array, NumDTensor>&, + const std::array, NumDTensor>&, + const std::array& e_g_n_c_wis_lengths, + const std::array& e_g_n_c_wis_strides, + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + ck::index_t split_k = 1) + : p_a_grid_{static_cast(p_a)}, + p_b_grid_{static_cast(p_b)}, + p_e_grid_{static_cast(p_e)}, + num_group_{a_g_n_k_wos_lengths[0]}, + a_element_op_{a_element_op}, + b_element_op_{b_element_op}, + cde_element_op_{cde_element_op}, + a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths}, + b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths}, + e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths}, + conv_filter_strides_{conv_filter_strides}, + input_left_pads_{input_left_pads}, + input_right_pads_{input_right_pads}, + k_batch_{split_k} + { + bool image_covered_dilation = true; + bool image_covered_strides = true; + for(index_t d = 0; d < NDimSpatial; d++) + { + // If dilation and stride is not equal we will have some empty places + image_covered_dilation &= + conv_filter_dilations[d] == 1 || conv_filter_strides[d] == 1; + // If stride is larger than windows size then we will have some empty places + image_covered_strides &= conv_filter_strides[d] <= b_g_k_c_xs_lengths[d + I3]; + } + bool if_d_is_output_mem = false; + bwd_needs_zero_out = k_batch_ > 1 || !image_covered_dilation || !image_covered_strides; + + // Temporary workaround untill prove/fix above conditions. + bwd_needs_zero_out = !if_d_is_output_mem; + e_space_size_bytes = + ck::accumulate_n( + e_g_n_c_wis_lengths_.begin(), NDimSpatial + I3, 1, std::multiplies<>()) * + sizeof(EDataType); + + static constexpr auto NonSpatialDimsNum = Number<3>{}; + + static constexpr auto DIdx = Number{}; + static constexpr auto HIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto WIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + static constexpr auto ZIdx = Number{}; + static constexpr auto YIdx = + NDimSpatial == 2 ? Number{} : Number{}; + static constexpr auto XIdx = NDimSpatial == 2 ? Number{} + : Number{}; + + // problem definition + const index_t Z = b_g_k_c_xs_lengths[ZIdx]; + const index_t Y = b_g_k_c_xs_lengths[YIdx]; + const index_t X = b_g_k_c_xs_lengths[XIdx]; + + const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum]; + const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum]; + const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum]; + + const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum]; + const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum]; + const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum]; + + const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD); + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1; + const auto YTilde = ConvStrideH / GcdStrideDilationH; + const auto XTilde = ConvStrideW / GcdStrideDilationW; + + index_t grid_size = 0; + // Allocate place for sets of gemms + gemm_kernel_args_.resize( + math::integer_divide_ceil(ZTilde * YTilde * XTilde, MaxGroupedGemmGroupsNum)); + + for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde) + { + for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde) + { + for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde) + { + // check slice is valid + const auto ZDotSlice = + NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1; + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); + + if(YDotSlice * XDotSlice * ZDotSlice <= 0) + { + continue; + } + + std::array tildes; + if constexpr(NDimSpatial == 2) + { + tildes = {i_ytilde, i_xtilde}; + } + else if constexpr(NDimSpatial == 3) + { + tildes = {i_ztilde, i_ytilde, i_xtilde}; + } + else + { + throw std::runtime_error("wrong! only implemented for 2D and 3D now"); + } + + ConvToGemmBwdDataTransform conv_to_gemm_transform_{a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + tildes, + k_batch_}; + + conv_N_per_block_ = conv_to_gemm_transform_.N_; + + const auto a_grid_desc_ak0_m_ak1 = [&]() { + return conv_to_gemm_transform_.MakeADescriptor_AK0_M_AK1(); + }(); + + const auto b_grid_desc_bk0_n_bk1 = [&]() { + return conv_to_gemm_transform_.MakeBDescriptor_BK0_N_BK1(); + }(); + + // desc for problem definition + const auto a_grid_desc_m_k = + transform_k0_m_k1_to_m_k(a_grid_desc_ak0_m_ak1); + const auto b_grid_desc_n_k = + transform_k0_m_k1_to_m_k(b_grid_desc_bk0_n_bk1); + + const auto GemmM = a_grid_desc_m_k.GetLength(I0); + const auto GemmN = b_grid_desc_n_k.GetLength(I0); + + const auto e_grid_desc_mblock_mperblock_nblock_nperblock = + GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( + conv_to_gemm_transform_.MakeCDescriptor_M_N(), + GridwiseGemm::CalculateMBlock(GemmM), + GridwiseGemm::CalculateNBlock(GemmN)); + + a_grid_desc_m_k_container_.push_back(a_grid_desc_m_k); + b_grid_desc_n_k_container_.push_back(b_grid_desc_n_k); + e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back( + e_grid_desc_mblock_mperblock_nblock_nperblock); + + const index_t grid_size_grp = + std::get<0>(GridwiseGemm::CalculateGridSize(GemmM, GemmN, 1, 1)); + const index_t BlockStart = grid_size; + const index_t BlockEnd = grid_size + grid_size_grp; + + grid_size += grid_size_grp; + + // const index_t GemmM = a_grid_desc_m_k.GetLength(I0); + // const index_t GemmN = b_grid_desc_n_k.GetLength(I0); + const index_t GemmK = a_grid_desc_m_k.GetLength(I1); + + // onst auto MBlock = GridwiseGemmCTranspose::CalculateMBlock(GemmM); + // onst auto NBlock = GridwiseGemmCTranspose::CalculateNBlock(GemmN); + + index_t k_grain = split_k * KPerBlock; + index_t K_split = (GemmK + k_grain - 1) / k_grain * KPerBlock; + + const bool HasMainKBlockLoop = + GridwiseGemm::CalculateHasMainKBlockLoop(K_split); + + gemm_kernel_args_[gemms_count_ / MaxGroupedGemmGroupsNum] + [gemms_count_ % MaxGroupedGemmGroupsNum] = + GemmArgs{a_grid_desc_ak0_m_ak1, + b_grid_desc_bk0_n_bk1, + e_grid_desc_mblock_mperblock_nblock_nperblock, + BlockStart, + BlockEnd, + HasMainKBlockLoop}; + gemms_count_++; + if(gemms_count_ % MaxGroupedGemmGroupsNum == 0) + { + gemms_grid_size_.push_back(grid_size); + grid_size = 0; + } + } + } + } + gemm_kernel_args_.resize( + math::integer_divide_ceil(gemms_count_, MaxGroupedGemmGroupsNum)); + gemms_grid_size_.push_back(grid_size); + + // A/B/Ds/E Batch Stride + compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0]; + compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0]; + + num_workgroups_per_Conv_N_ = a_g_n_k_wos_lengths_[I1] / conv_N_per_block_; + } + + std::size_t GetWorkspaceSizeBytes() const { return 0; } + + void Print() const + { + for(std::size_t i = 0; i < a_grid_desc_m_k_container_.size(); i++) + { + std::cout << "a_grid_desc_m_ak_container_" << a_grid_desc_m_k_container_[i] + << std::endl; + + std::cout << "b_grid_desc_n_bk_container_" << b_grid_desc_n_k_container_[i] + << std::endl; + + std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_" + << e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i] + << std::endl; + } + } + + // pointers + const ADataType* p_a_grid_; + const BDataType* p_b_grid_; + EDataType* p_e_grid_; + + // tensor descriptor for problem definition + index_t num_group_; + index_t conv_N_per_block_; + std::vector a_grid_desc_m_k_container_; + std::vector b_grid_desc_n_k_container_; + std::vector + e_grid_desc_mblock_mperblock_nblock_nperblock_container_; + + // tensor descriptor for block-wise copy + std::vector a_grid_desc_ak0_m_ak1_container_; + std::vector b_grid_desc_bk0_n_bk1_container_; + + // for computing batch offset + ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; + + // element-wise op + AElementwiseOp a_element_op_; + BElementwiseOp b_element_op_; + CDEElementwiseOp cde_element_op_; + + std::array a_g_n_k_wos_lengths_; + std::array b_g_k_c_xs_lengths_; + std::array e_g_n_c_wis_lengths_; + std::array conv_filter_strides_; + std::array input_left_pads_; + std::array input_right_pads_; + + const index_t k_batch_; + index_t num_workgroups_per_Conv_N_; + std::vector gemms_grid_size_; + index_t gemms_count_ = 0; + std::vector> gemm_kernel_args_; + + bool bwd_needs_zero_out; + long_index_t e_space_size_bytes; + }; + + // Invoker + struct Invoker : public BaseInvoker + { + using Argument = DeviceOp::Argument; + + template + float RunMultiDGemm(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + const index_t gdy = arg.num_group_; + const index_t gdz = arg.k_batch_; + + const ADataType* p_a_grid = arg.p_a_grid_; + const BDataType* p_b_grid = arg.p_b_grid_; + EDataType* p_e_grid = arg.p_e_grid_; + + for(std::size_t gemm_set_id = 0; gemm_set_id < arg.gemm_kernel_args_.size(); + gemm_set_id++) + { + const index_t GemmM = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmN = arg.b_grid_desc_n_k_container_[gemm_set_id].GetLength(I0); + const index_t GemmK = arg.a_grid_desc_m_k_container_[gemm_set_id].GetLength(I1); + typename GridwiseGemm::Argument gemm_arg{ + p_a_grid, p_b_grid, p_e_grid, GemmM, GemmN, GemmK, I0, I0, I0, arg.k_batch_}; + + const index_t gdx = arg.gemms_grid_size_[gemm_set_id]; + + const index_t gemms_count_for_set = + gemm_set_id == arg.gemm_kernel_args_.size() - 1 + ? arg.gemms_count_ - MaxGroupedGemmGroupsNum * gemm_set_id + : MaxGroupedGemmGroupsNum; + + const std::array& gemm_kernel_args = + arg.gemm_kernel_args_[gemm_set_id]; + + const auto clear_workspace = [&]() { + if(arg.bwd_needs_zero_out && gemm_set_id == 0) + { + hip_check_error(hipMemsetAsync( + p_e_grid, 0, arg.e_space_size_bytes, stream_config.stream_id_)); + } + }; + + bool has_loop_in_all_gemm = true; + bool no_loop_in_all_gemm = true; + for(auto i = 0; i < gemms_count_for_set; i++) + { + has_loop_in_all_gemm &= gemm_kernel_args[i].HasMainKBlockLoop_; + no_loop_in_all_gemm &= !gemm_kernel_args[i].HasMainKBlockLoop_; + } + + auto launch_kernel = [&](auto has_main_k_block_loop_, auto no_main_k_block_loop) { + constexpr bool has_main_loop = has_main_k_block_loop_.value; + constexpr bool no_main_loop = no_main_k_block_loop.value; + const auto kernel = kernel_grouped_conv_bwd_data_xdl_cshuffle_v3< + GridwiseGemm, + DeviceOp::AGridDesc_AK0_M_AK1, + DeviceOp::BGridDesc_BK0_N_BK1, + DeviceOp::EGridDesc_MPerBlock_NBlock_NPerBlock, + MaxGroupedGemmGroupsNum, + GemmArgs, + ComputePtrOffsetOfStridedBatch, + ComputePtrOffsetOfStridedBatch, + ElementOp, + has_main_loop, + no_main_loop>; + + return launch_and_time_kernel_with_preprocess(stream_config, + clear_workspace, + kernel, + dim3(gdx, gdy, gdz), + dim3(BlockSize), + 0, + gemm_arg, + gemm_kernel_args, + gemms_count_for_set, + arg.compute_ptr_offset_of_batch_, + 1); + }; + if(has_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else if(no_loop_in_all_gemm) + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + else + { + ave_time += launch_kernel(integral_constant{}, + integral_constant{}); + } + } + + return ave_time; + } + + float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) + { + float ave_time = 0; + + if(stream_config.log_level_ > 0) + { + arg.Print(); + } + if(arg.k_batch_ > 1) + { + if constexpr(IsSplitKSupported) + { + ave_time += + RunMultiDGemm(arg, stream_config); + } + } + else + { + ave_time += RunMultiDGemm(arg, stream_config); + } + + arg.Print(); + + return ave_time; + } + + float Run(const BaseArgument* p_arg, + const StreamConfig& stream_config = StreamConfig{}) override + { + return Run(*dynamic_cast(p_arg), stream_config); + } + }; + + static bool IsSupportedArgument(const Argument& arg) + { + // check device + if constexpr(DirectLoad) + { + if(get_device_name() != "gfx950") + { + return false; + } + } + + if constexpr(!IsSplitKSupported) + { + if(arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + if(ck::is_gfx11_supported() && arg.k_batch_ > 1) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "SplitK tests are not supported!" << " In " << __FILE__ << ":" + << __LINE__ << ", in function: " << __func__ << std::endl; + } + + return false; + } + + const index_t ConvK = arg.b_g_k_c_xs_lengths_[1]; + const index_t ConvC = arg.b_g_k_c_xs_lengths_[2]; + + // Specialization + if constexpr(ConvBackwardDataSpecialization == + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0) + { + // check if it's 1x1, stride=1 pad = 0 conv + for(int i = 0; i < NDimSpatial; i++) + { + if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 && + arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "ConvBwdDataSpecialization is unsupported!" << " In " + << __FILE__ << ":" << __LINE__ << ", in function: " << __func__ + << std::endl; + } + + return false; + } + } + } + + // vector load for A matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported A Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector load for B matrix from global memory to LDS + if constexpr(is_same_v || + is_same_v) + { + + if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported B Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // vector store for E + if constexpr(is_same_v || + is_same_v) + { + // vector store C matrix into global memory + if(!(ConvC % CShuffleBlockTransferScalarPerVector == 0)) + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + + std::cout << "VectorDim is wrong!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + } + else + { + if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING))) + { + std::cout << "Unsupported E Layout!" << " In " << __FILE__ << ":" << __LINE__ + << ", in function: " << __func__ << std::endl; + } + + return false; + } + + // Check gridwise gemm validity + for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++) + { + const index_t GemmM = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I1); + const index_t GemmN = arg.b_grid_desc_bk0_n_bk1_container_[i].GetLength(I1); + const index_t GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) * + arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2); + // Create gemm arguments with dummy values to check for validity + typename GridwiseGemm::Argument gemm_arg{nullptr, // p_as_grid + nullptr, // p_bs_grid + nullptr, // p_e_grid + GemmM, // M + GemmN, // N + GemmK, // K + I0, // StrideAs + I0, // StrideBs + I0, // StrideE + arg.k_batch_}; + + const auto num_k_loop = gemm_arg.AK0 / (KPerBlock / AK1); + if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1) + { + if(num_k_loop <= GridwiseGemm::BlockwiseGemmPipe::PrefetchStages) + { + return false; + } + } + } + + return true; + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto + MakeArgument(const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) + { + return Argument{p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k}; + } + + static auto MakeInvoker() { return Invoker{}; } + + std::unique_ptr MakeArgumentPointer( + const void* p_a, // output image + const void* p_b, // weight + const std::array& p_ds, // bias + void* p_e, // input image + const std::array& a_g_n_k_wos_lengths, // output image + const std::array& a_g_n_k_wos_strides, // output image + const std::array& b_g_k_c_xs_lengths, // weight + const std::array& b_g_k_c_xs_strides, // weight + const std::array, NumDTensor>& + ds_g_n_c_wis_lengths, // bias + const std::array, NumDTensor>& + ds_g_n_c_wis_strides, // bias + const std::array& e_g_n_c_wis_lengths, // input image + const std::array& e_g_n_c_wis_strides, // input image + const std::array& conv_filter_strides, + const std::array& conv_filter_dilations, + const std::array& input_left_pads, + const std::array& input_right_pads, + const AElementwiseOp& a_element_op, + const BElementwiseOp& b_element_op, + const CDEElementwiseOp& cde_element_op, + const ck::index_t split_k = 1) override + { + return std::make_unique(p_a, + p_b, + p_ds, + p_e, + a_g_n_k_wos_lengths, + a_g_n_k_wos_strides, + b_g_k_c_xs_lengths, + b_g_k_c_xs_strides, + ds_g_n_c_wis_lengths, + ds_g_n_c_wis_strides, + e_g_n_c_wis_lengths, + e_g_n_c_wis_strides, + conv_filter_strides, + conv_filter_dilations, + input_left_pads, + input_right_pads, + a_element_op, + b_element_op, + cde_element_op, + split_k); + } + + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + // clang-format off + str << "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3" + << (DirectLoad ? "_DirectLoad" : "") + << "<" + << BlockSize << ", " + << MPerBlock << ", " + << NPerBlock << ", " + << KPerBlock << ", " + << AK1 << ", " + << BK1 << ", " + << getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", " + << MPerXdl << ", " + << NPerXdl << ", " + << MRepeat << ", " + << NRepeat << ", " + << ABlockTransferSrcScalarPerVector << ", " + << BBlockTransferSrcScalarPerVector << ", " + << CShuffleMRepeatPerShuffle << ", " + << CShuffleNRepeatPerShuffle; + + str << ">"; + + return str.str(); + } + + size_t GetWorkSpaceSize(const BaseArgument* p_arg) const override + { + auto arg = dynamic_cast(p_arg); + if(arg) + { + return arg->GetWorkspaceSizeBytes(); + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } + + void SetWorkSpacePointer(BaseArgument* p_arg, + void* p_workspace, + const StreamConfig& = StreamConfig{}) const override + { + auto p_arg_ = dynamic_cast(p_arg); + if(p_arg_) + { + p_arg_->p_workspace_ = p_workspace; + } + else + throw std::runtime_error( + "The argument pointer is not an object of " + "DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3::Argument structure!"); + } +}; + +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp index ac83cee2510..c1b6740fd5b 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle_v3.hpp @@ -408,10 +408,21 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 ? 4 / sizeof(BDataType) : BBlockTransferSrcScalarPerVector; + static constexpr bool ALdsScalarLoadToVgpr = (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false); + static constexpr bool BLdsScalarLoadToVgpr = (DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ? true : false); + + // Note: Direct load use layout to create proper block and mmtile descriptor + // TODO: Fix and verify RC layout for not direct load (currently it returns wrong results) template using GridwiseGemmBase = GridwiseGemm_xdl_cshuffle_conv_v3< - tensor_layout::gemm::RowMajor, - tensor_layout::gemm::ColumnMajor, + std::conditional_t< + DirectLoad, + tensor_layout::gemm::ColumnMajor, + tensor_layout::gemm::RowMajor>, + std::conditional_t< + DirectLoad, + tensor_layout::gemm::RowMajor, + tensor_layout::gemm::ColumnMajor>, tensor_layout::gemm::RowMajor, ADataType, BDataType, @@ -456,7 +467,9 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB, - DirectLoad>; + DirectLoad, + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>; using GridwiseGemm64 = GridwiseGemmBase; using GridwiseGemm32 = GridwiseGemmBase; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp index 5289e209fbe..c584a3dc12b 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp @@ -63,7 +63,9 @@ template + bool DirectLoad = false, + bool ALdsScalarLoadToVgpr = false, + bool BLdsScalarLoadToVgpr = false> struct GridwiseGemm_xdl_cshuffle_conv_v3 : public GridwiseGemm_xdl_cshuffle_base< ALayout, @@ -246,19 +248,90 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 return math::integer_divide_ceil(N, NPerBlock); } - template + template + __host__ __device__ static auto TransformGrid(const GridDesc_K0_MN_K1_T& desc) + { + + if constexpr(!DirectLoad) + { + return desc; + } + else + { + const index_t K = desc.GetLength(I0) * desc.GetLength(I2); + const index_t MN = desc.GetLength(I1); + + const auto desc_unmerged = transform_tensor_descriptor( + desc, + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, K0Number)), + make_pass_through_transform(MN), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto desc_permuted = transform_tensor_descriptor( + desc_unmerged, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(MN, K0Number)), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + return transform_tensor_descriptor( + desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, K0Number)), + make_pass_through_transform(MN), + make_pass_through_transform(K1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + } + } + + template __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) { - constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); - constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - - return transform_tensor_descriptor( - TileDesc_K0_MN_K1{}, - make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), - make_unmerge_transform(make_tuple( - Number{}, Number{}, Number{}))), - make_tuple(Sequence<0, 2>{}, Sequence<1>{}), - make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + if constexpr(DirectLoad && IsKContinous) + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); + + constexpr auto desc = transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + desc, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } + else + { + constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); + + return transform_tensor_descriptor( + TileDesc_K0_MN_K1{}, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{}))), + make_tuple(Sequence<0, 2>{}, Sequence<1>{}), + make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{})); + } } template @@ -267,7 +340,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl); - return MakeGemmMmaTileDescriptor(ABlockDesc_AK0_M_AK1{}); + return MakeGemmMmaTileDescriptor::value>( + ABlockDesc_AK0_M_AK1{}); } template @@ -276,7 +353,11 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); - return MakeGemmMmaTileDescriptor(BBlockDesc_BK0_N_BK1{}); + return MakeGemmMmaTileDescriptor::value>( + BBlockDesc_BK0_N_BK1{}); } struct Problem @@ -363,9 +444,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { if constexpr(DirectLoad) { - return make_naive_tensor_descriptor( - make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(Number{}, I1, Number{})); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(AK1Number, Number{}, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(AK0Number, Number{}, AK1Number), + make_tuple(Number{}, I1, Number{})); + } } else if constexpr(is_same_v) { @@ -386,9 +476,18 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 { if constexpr(DirectLoad) { - return make_naive_tensor_descriptor( - make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(Number{}, I1, Number{})); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(BK1Number, Number{}, I1)); + } + else + { + return make_naive_tensor_descriptor( + make_tuple(BK0Number, Number{}, BK1Number), + make_tuple(Number{}, I1, Number{})); + } } else if constexpr(is_same_v) { @@ -407,7 +506,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // Disable vector load from lds to vgpr for direct load (backward weight store with continous M // or N dimension) - static constexpr bool LdsScalarLoadToVgpr = DirectLoad; + //static constexpr bool LdsScalarLoadToVgpr = DirectLoad; using BlockwiseGemmPipe = remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, @@ -434,7 +533,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 NXdlPerWave, KPack, DirectLoad, - LdsScalarLoadToVgpr>())>; + ALdsScalarLoadToVgpr, + BLdsScalarLoadToVgpr>())>; template __device__ static constexpr index_t GetSharedMemoryNumberOfByte(DeviceArch) @@ -514,8 +614,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1) + const index_t k_id = 0, + const index_t k_batch = 1, + const index_t block_idx_x = static_cast(blockIdx.x)) { const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; @@ -532,8 +633,8 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; - const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( - make_multi_index(static_cast(blockIdx.x))); + const auto block_work_idx = + block_2_ctile_map.CalculateBottomIndex(make_multi_index(block_idx_x)); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, @@ -567,23 +668,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType, + decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + ABlockTransferSrcScalarPerVector > + (a_grid_desc_ak0_m_ak1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); } else { @@ -623,23 +720,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_b_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + BBlockTransferSrcScalarPerVector > + (b_grid_desc_bk0_n_bk1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); } else { @@ -747,8 +840,9 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock& c_grid_desc_mblock_mperblock_nblock_nperblock, - const index_t k_id = 0, - const index_t k_batch = 1) + const index_t k_id = 0, + const index_t k_batch = 1, + const index_t block_idx_x = static_cast(blockIdx.x)) { const long_index_t a_space_size_divisor = SplitKOffsetHack ? k_batch : 1; const long_index_t b_space_size_divisor = SplitKOffsetHack ? k_batch : 1; @@ -768,7 +862,7 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex( - make_multi_index(static_cast(blockIdx.x))); + make_multi_index(static_cast(block_idx_x))); if(!block_2_ctile_map.ValidCTileIndex( block_work_idx, @@ -802,23 +896,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_a_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - 1, - ABlockTransferSrcScalarPerVector>( - a_grid_desc_ak0_m_ak1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), - a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, ADataType, ADataType, + decltype(a_grid_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + ABlockTransferSrcScalarPerVector > + (a_grid_desc_ak0_m_ak1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, m_block_data_idx_on_grid, 0), + a_block_desc_ak0_m_ak1, + make_multi_index(0, 0, 0)); } else { @@ -858,23 +948,19 @@ struct GridwiseGemm_xdl_cshuffle_conv_v3 auto get_b_blockwise_copy = [&]() { if constexpr(DirectLoad) { - return ThreadGroupTensorSliceTransfer_DirectLoad< - ThisThreadBlock, - Sequence, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - 1, - BBlockTransferSrcScalarPerVector>( - b_grid_desc_bk0_n_bk1, - make_multi_index(SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), - b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0)); + return ThreadGroupTensorSliceTransfer_DirectLoad < ThisThreadBlock, + Sequence, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, BDataType, BDataType, + decltype(b_grid_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, + is_same::value ? 2 : 1, + BBlockTransferSrcScalarPerVector > + (b_grid_desc_bk0_n_bk1, + make_multi_index( + SplitKOffsetHack ? 0 : k_id, n_block_data_idx_on_grid, 0), + b_block_desc_bk0_n_bk1, + make_multi_index(0, 0, 0)); } else { diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp new file mode 100644 index 00000000000..b666ef474dd --- /dev/null +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp @@ -0,0 +1,115 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using BF16 = ck::bhalf_t; +using F16 = ck::half_t; +using F32 = float; +using BF8 = ck::bf8_t; +using F8 = ck::f8_t; + +using Empty_Tuple = ck::Tuple<>; + +template +using S = ck::Sequence; + +using namespace ck::tensor_layout::convolution; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +static constexpr auto ConvBwdDataDefault = ConvolutionBackwardDataSpecialization::Default; + +static constexpr auto ConvBwdDataFilter1x1Stride1Pad0 = + ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0; + +template +using device_grouped_conv_bwd_data_xdl_v3_f16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 64, 8, 8, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F16, F16, true> + // clang-format on + >; + +template +using device_grouped_conv_bwd_data_xdl_v3_bf16_instances = std::tuple< + // clang-format off + // ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CDEBlockTransfer| CDEBlockTransfer| + // ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| MRepeat| NRepeat | _MBlock_MPerBlock| ScalarPerVector| + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| | + // ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 32, 64, 8, 8, 16, 16, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 128, 16, 64, 64, 8, 8, 16, 16, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<2, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 16, 1, 8>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 32, 64, 8, 8, 32, 32, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 64, 8, 8, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 32, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 64, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 4>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 32, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 64, 64, 8, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 16, 64, 8, 8, 16, 16, 4, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 2, 1, 0, 1, 1, S<1, 64, 1, 4>, 4, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 128, 64, 8, 8, 32, 32, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 64, 8, 8, 32, 32, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 256, 256, 32, 8, 8, 32, 32, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 64, 32, 8, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 8, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 128, 32, 8, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 16, 4>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<4, 32, 2>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true>, + DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffleV3< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 256, 128, 32, 64, 8, 8, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 1, 0, S<8, 4, 8>, S<0, 2, 1>,S<0, 2, 1>, 1, 8, 1, 0, 1, 1, S<1, 64, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, BF16, BF16, true> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp index f784b6ea510..3c4efb81c16 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp @@ -108,6 +108,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances( op_ptrs); @@ -148,6 +150,8 @@ struct DeviceOperationInstanceFactory< is_same_v && is_same_v && is_same_v) { + add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances( op_ptrs); add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances( @@ -355,7 +359,7 @@ struct DeviceOperationInstanceFactory< is_same_v) { add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances( - op_ptrs); + op_ptrs); } #endif } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc index 7c61f3ee667..8dae166dd1a 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc +++ b/library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data_xdl.inc @@ -56,6 +56,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances( #endif #ifdef CK_ENABLE_FP16 +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances); + void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances( std::vector>>& instances); + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances); #endif #ifdef CK_ENABLE_FP16 diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt index 19e27cf173b..7f2363affd9 100644 --- a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/CMakeLists.txt @@ -32,6 +32,8 @@ add_instance_library( xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_vec_transpose_instance.cpp xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_vec_transpose_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp + xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp new file mode 100644 index 00000000000..8a536e2e561 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_bf16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_bf16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_bf16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp new file mode 100644 index 00000000000..9d1fb4b93a6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/grouped_conv2d_bwd_data/xdl/device_grouped_conv2d_bwd_data_xdl_v3_nhwgc_gkyxc_nhwgk_f16_instance.cpp @@ -0,0 +1,49 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" +#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_v3_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_grouped_conv2d_bwd_data_xdl_v3_nhwgk_gkyxc_nhwgc_f16_instances( + std::vector>>& instances) +{ + // 1. Default + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataDefault>{}); + // 2. Filter1x1Stride1Pad0 + add_device_operation_instances( + instances, + device_grouped_conv_bwd_data_xdl_v3_f16_instances<2, + NHWGK, + GKYXC, + Empty_Tuple, + NHWGC, + ConvBwdDataFilter1x1Stride1Pad0>{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck