Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ template <index_t NDimSpatial,
// in tuple for MultiAB), unpack if tuple was
// passed
typename BComputeDataType = AComputeDataType,
bool DirectLoad = false>
bool DirectLoad = false,
index_t NumGroupsToMerge = 1>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
ALayout,
Expand Down Expand Up @@ -418,6 +419,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
Wave32MaxMNPerXDL,
NXdlPerWave*(NPerXDL / Wave32MaxMNPerXDL)>();

static_assert(NumGroupsToMerge >= 1);

static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
static constexpr bool isMultiD = DsDataType::Size() > 0;
Expand Down Expand Up @@ -447,7 +450,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
ConvForwardSpecialization,
true /*SplitN*/,
ADataType,
EDataType>;
EDataType,
NumGroupsToMerge>;

using ComputePtrOffset = ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>;

Expand Down Expand Up @@ -784,8 +788,9 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
cde_element_op_{cde_element_op}
{
// A/B/E Batch/N Stride
compute_ptr_offset_of_groups_.BatchStrideA_ = a_g_n_c_wis_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideA_ =
a_g_n_c_wis_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_groups_.BatchStrideB_ = b_g_k_c_xs_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;

// p_as and p_bs are pointers
Expand All @@ -796,7 +801,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
static_for<0, NumDTensor, 1>{}([&](auto i) {
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
// D batch stride
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
compute_ptr_offset_of_groups_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;

Expand All @@ -816,7 +822,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
});

compute_ptr_offset_of_groups_.BatchStrideE_ = e_g_n_k_wos_strides_[0];
compute_ptr_offset_of_groups_.BatchStrideE_ =
e_g_n_k_wos_strides_[0] * NumGroupsToMerge;
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;

if constexpr(is_NGCHW_GKCYX_NGKHW<ALayout, BLayout, ELayout>() ||
Expand Down Expand Up @@ -999,7 +1006,7 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
std::tie(gdx, gdy, gdz) =
GridwiseGemm::CalculateGridSize(GemmM, GemmN, I1 /*arg.KBatch*/);

gdy = arg.num_group_;
gdy = arg.num_group_ / NumGroupsToMerge;
gdz = num_workgroups_per_Conv_N;

index_t K_split = (GemmK + KPerBlock - 1) / KPerBlock * KPerBlock;
Expand Down Expand Up @@ -1499,6 +1506,19 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}

if constexpr(NumGroupsToMerge > 1)
{
if(G % NumGroupsToMerge != 0)
{
if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
{
std::cout << "Unsupported! G % NumGroupsToMerge != 0: G=" << G
<< ", NumGroupsToMerge=" << NumGroupsToMerge << std::endl;
}
return false;
}
}

if(get_device_name() == "gfx908")
{
// FIXME: re-enable fp64 when SWDEV-335738 is fixed
Expand Down Expand Up @@ -1595,6 +1615,22 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
}
}
}
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization::Filter3x3)
{
if(C != 1)
{
return false;
}
for(index_t i = 0; i < NDimSpatial; ++i)
{
const index_t filter_spatial_dim = arg.b_g_k_c_xs_lengths_[i + I3];

if(filter_spatial_dim != I3)
{
return false;
}
}
}

// check vector access of A
// FIXME: layout
Expand Down Expand Up @@ -2125,7 +2161,8 @@ struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer]
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< NumGroupsToMerge
<< ">";
// clang-format on

Expand Down
Loading
Loading