Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
bool TransposeC = false,
bool LdsScalarLoadToVgpr = false>
bool TransposeC = false,
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlops_pipeline_base
{
static constexpr auto I0 = Number<0>{};
Expand Down Expand Up @@ -386,7 +387,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
LdsScalarLoadToVgpr ? 1 : A_K1,
ALdsScalarLoadToVgpr ? 1 : A_K1,
A_K1>;

using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
Expand All @@ -396,7 +397,7 @@ struct BlockwiseGemmXdlops_pipeline_base
Sequence<1, 1, 1, KPack>,
Sequence<0, 1, 2, 3>,
3,
LdsScalarLoadToVgpr ? 1 : B_K1,
BLdsScalarLoadToVgpr ? 1 : B_K1,
B_K1>;

AThreadCopy a_thread_copy_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t NRepeat,
index_t KPack,
bool DirectLoad = false,
bool LdsScalarLoadToVgpr = false>
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
constexpr auto BlockGemmPipeline_Selector()
{
// Supported for Direct Load and V1
if constexpr(LdsScalarLoadToVgpr)
if constexpr(ALdsScalarLoadToVgpr || BLdsScalarLoadToVgpr)
{
static_assert(DirectLoad && BlkGemmPipelineVer == BlockGemmPipelineVersion::v1);
}
Expand Down Expand Up @@ -65,7 +66,8 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
LdsScalarLoadToVgpr>{};
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -759,7 +759,8 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPacks,
bool LdsScalarLoadToVgpr = false>
bool ALdsScalarLoadToVgpr = false,
bool BLdsScalarLoadToVgpr = false>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1
{
};
Expand All @@ -784,7 +785,8 @@ template <index_t BlockSize,
index_t NRepeat,
index_t KPack,
// ,bool TransposeC //disable transposec right now...
bool LdsScalarLoadToVgpr>
bool ALdsScalarLoadToVgpr,
bool BLdsScalarLoadToVgpr>
struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
ADataType,
Expand All @@ -805,7 +807,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
MRepeat,
NRepeat,
KPack,
LdsScalarLoadToVgpr>
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>
: BlockwiseGemmXdlops_pipeline_base<BlockSize,
ADataType,
BDataType,
Expand All @@ -826,7 +829,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NRepeat,
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>

{
using Base = BlockwiseGemmXdlops_pipeline_base<BlockSize,
Expand All @@ -849,7 +853,8 @@ struct BlockwiseGemmXdlopsDirectLoad_pipeline_v1<BlockGemmPipelineScheduler::Int
NRepeat,
KPack,
false /*TransposeC*/,
LdsScalarLoadToVgpr>;
ALdsScalarLoadToVgpr,
BLdsScalarLoadToVgpr>;
using Base::I0;
using Base::KRepeat;
using Base::xdlops_gemm;
Expand Down
Loading
Loading