Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fe8fc5f
fix typos
okkwon Nov 4, 2022
a2da08e
WIP: allow unaligned tensor sizes for tensor core
okkwon Nov 7, 2022
c2793ea
filter out ops with unaligned tensor sizes
okkwon Nov 9, 2022
cf7b50d
add a marker for workgroup specialization
okkwon Nov 9, 2022
dab6480
add findAncestorWithMarker(op, marker)
okkwon Nov 9, 2022
9c74271
Do warp distribution only when an op is aligned
okkwon Nov 9, 2022
3e56226
apply thread-level tiling for unaligned matmul op
okkwon Nov 10, 2022
846d488
Do not fail for LLVMGPUMultiBuffering
okkwon Nov 10, 2022
9395678
move aligned and unaligned op filters to Utils
okkwon Nov 10, 2022
e5d4f83
do tensorcore vectorization only when the candidate op is aligned
okkwon Nov 10, 2022
38b5787
Bail out unaligned K dimension from tensorcore specialization
okkwon Nov 10, 2022
8211f99
update the unit tests
okkwon Nov 11, 2022
157d3f1
Add a filter param to populateContractPromotionPatterns()
okkwon Nov 16, 2022
1309942
Enable unaligned K for tensorcore specialization
okkwon Nov 16, 2022
d5c6cb0
Mark tensorcore and SIMT lowering candidates
okkwon Nov 16, 2022
fd4b182
Use the tensorcore marker to serialize the K loop
okkwon Nov 17, 2022
bccc04c
do not handle genericOp in the tensorcore path
okkwon Nov 17, 2022
74a23ee
Do warp-level tiling selectively using a marker
okkwon Nov 17, 2022
04dc39b
do thread-level tiling for the partial K and unspecialized ops
okkwon Nov 17, 2022
b799615
promote inputs and output always
okkwon Nov 18, 2022
3b67ab8
NFC: remove default value from tileToSerialLoops
okkwon Nov 18, 2022
c8daa2b
Remove onlyReduction from tileToSerialLoops and add peel
okkwon Nov 18, 2022
a4cb7b5
Use `vectorize_for_tensorcore` marker
okkwon Nov 18, 2022
6d7305b
match by default for the non-tensorcore flow
okkwon Nov 18, 2022
d7fe191
add missing genericOp supports for tensorcore lowering
okkwon Nov 18, 2022
23e68db
add more debug print to LLVMGPUTileAndDistribute
okkwon Nov 20, 2022
1f8c8c9
Do not handle the fused op with warp-level tiling
okkwon Nov 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ struct GPUMultiBufferingPass
});
// Apply multi-buffering to all of them.
for (memref::AllocOp alloc : allocs) {
if (failed(memref::multiBuffer(alloc, numBuffers)))
// Stop if any buffer cannot be multi buffered as pipelining will assume
// this happened.
return signalPassFailure();
if (failed(memref::multiBuffer(alloc, numBuffers))) {
// There can be a failing case. Continue processing eligible ones.
continue;
}
}
}

Expand Down
17 changes: 11 additions & 6 deletions compiler/src/iree/compiler/Codegen/Common/GPUPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "iree/compiler/Codegen/Common/GPUPatterns.h"

#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "iree-dialects/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
Expand Down Expand Up @@ -165,8 +166,11 @@ void populateVectorTransferToGPUMMAPreparationPatterns(
patterns.add<FlattenTransferReadOp>(patterns.getContext());
}

using LinalgTransformationFilter = IREE::LinalgExt::LinalgTransformationFilter;

void populateContractPromotionPatterns(RewritePatternSet &patterns,
ArrayRef<int64_t> operandsToPromote) {
ArrayRef<int64_t> operandsToPromote,
LinalgTransformationFilter *filter) {
MLIRContext *context = patterns.getContext();
patterns.insert<LinalgPromotionPattern<linalg::MatmulOp>,
LinalgPromotionPattern<linalg::BatchMatmulOp>,
Expand All @@ -178,11 +182,12 @@ void populateContractPromotionPatterns(RewritePatternSet &patterns,
.setCopyInOutFns(copyToWorkgroupMemory, copyToWorkgroupMemory)
.setOperandsToPromote(operandsToPromote)
.setUseFullTileBuffers({false, false}),
IREE::LinalgExt::LinalgTransformationFilter(
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()))
.setMatchByDefault()
.addFilter(contractOpFilter));
filter ? *filter
: IREE::LinalgExt::LinalgTransformationFilter(
{StringAttr::get(context, getWorkgroupKTiledMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()))
.setMatchByDefault()
.addFilter(contractOpFilter));
}

} // namespace iree_compiler
Expand Down
7 changes: 5 additions & 2 deletions compiler/src/iree/compiler/Codegen/Common/GPUPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#ifndef IREE_COMPILER_CODEGEN_COMMON_GPUPATTERNS_H_
#define IREE_COMPILER_CODEGEN_COMMON_GPUPATTERNS_H_

#include "iree-dialects/Dialect/LinalgExt/Passes/Passes.h"
#include "mlir/IR/PatternMatch.h"

namespace mlir {
namespace iree_compiler {

Expand All @@ -18,8 +20,9 @@ void populateVectorTransferToGPUMMAPreparationPatterns(

/// Adds patterns for promoting Linalg contract op's operands to use GPU shared
/// memory.
void populateContractPromotionPatterns(RewritePatternSet &patterns,
ArrayRef<int64_t> operandsToPromote);
void populateContractPromotionPatterns(
RewritePatternSet &patterns, ArrayRef<int64_t> operandsToPromote,
IREE::LinalgExt::LinalgTransformationFilter *filter = nullptr);

} // namespace iree_compiler
} // namespace mlir
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "iree/compiler/Codegen/Interfaces/PartitionableLoopsInterface.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "iree/compiler/Codegen/Passes.h"
#include "iree/compiler/Codegen/Utils/MarkerUtils.h"
#include "iree/compiler/Codegen/Utils/Utils.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
Expand Down Expand Up @@ -220,6 +221,7 @@ static void specializeDistributionLoops(

// generate scf.if %cond
auto ifOp = builder.create<scf::IfOp>(loc, cond, /*withElseRegion=*/true);
setMarker(ifOp, getWorkgroupSpecializationMarker());

// Transfer the original body to the scf.else body.
auto origBodyBegin = ++Block::iterator(ifOp);
Expand Down
7 changes: 3 additions & 4 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,8 @@ static LogicalResult setContractConfig(func::FuncOp entryPoint,
// Pick the best configuration where the original shape is aligned on the
// tile size.
for (TileWorkgroupSizePair &config : TCtileSizeConfig) {
if (sizeK % config.tileSize[2] == 0 &&
sizeN % config.tileSize[1] == 0 &&
sizeM % config.tileSize[0] == 0) {
if (sizeK >= config.tileSize[2] && sizeN >= config.tileSize[1] &&
sizeM >= config.tileSize[0]) {
return setMatmulConfig(
config.tileSize[0], config.tileSize[1], config.tileSize[2],
config.workgroupSize,
Expand Down Expand Up @@ -410,7 +409,7 @@ static LogicalResult setRootDefaultConfig(func::FuncOp entryPoint,
}

auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
// Pick a vectorSize of 1 for op that we know won't get vectorizedd.
// Pick a vectorSize of 1 for op that we know won't get vectorized.
// Also skip vectorization for linalg on memref (no result) as the pipeline
// relies on tensor level tiling.
// TODO(thomasraoux): This could be improved by checking if the linalg op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct LLVMGPUTensorAllocPass
auto funcOp = getOperation();

// Tile the reduction first to reduce the alloc size.
if (failed(tileToSerialLoops(funcOp))) {
if (failed(tileToSerialLoops(funcOp, /*peel=*/false))) {
return signalPassFailure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ extern llvm::cl::opt<bool> llvmgpuUseMMASync;
//====---------------------------------------------------------------------===//

static void populateVectorizationPatterns(RewritePatternSet &patterns) {
IREE::LinalgExt::LinalgTransformationFilter f(
StringAttr::get(patterns.getContext(), getVectorizeMarker()));
IREE::LinalgExt::LinalgTransformationFilter f(StringAttr::get(
patterns.getContext(), getVectorizeForTensorCoreMarker()));
VectorizationPatterns<linalg::FillOp, linalg::GenericOp>::insert(patterns, f);
patterns.add<LinalgVectorizationPattern>(
patterns.getContext(), f.addOpFilter<linalg::ContractionOpInterface>());
Expand Down
175 changes: 160 additions & 15 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUTileAndDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
Expand Down Expand Up @@ -117,9 +124,10 @@ static void populateTilingToWarpPatterns(
auto getWarpProcInfoFn = [warpPerWorkgroup](
OpBuilder &builder, Location loc,
ArrayRef<Range> parallelLoopRanges) {
return getSubgroupIdsAndCounts(builder, loc, /*warpSize=*/32u,
return getSubgroupIdsAndCounts(builder, loc, kWarpSize,
parallelLoopRanges.size(), warpPerWorkgroup);
};

linalg::LinalgLoopDistributionOptions warpDistributionOptions;
warpDistributionOptions.procInfo = getWarpProcInfoFn;

Expand All @@ -129,17 +137,18 @@ static void populateTilingToWarpPatterns(
.setDistributionOptions(warpDistributionOptions);
MLIRContext *context = patterns.getContext();
IREE::LinalgExt::LinalgTransformationFilter filter(
{StringAttr::get(context, getWorkgroupKTiledMarker()),
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getVectorizeMarker()));
filter.setMatchByDefault();
{StringAttr::get(context, getGPUWarpLevelTilingReqMarker())},
StringAttr::get(context, getVectorizeForTensorCoreMarker()));
TilingPatterns<linalg::MatmulOp, linalg::FillOp, linalg::BatchMatmulOp,
linalg::GenericOp>::insert(patterns, tilingOptions, filter);
}

using FilterFunction = std::function<LogicalResult(Operation *)>;

/// Patterns for thread level tiling.
static void populateTilingToInvocationPatterns(
RewritePatternSet &patterns, SmallVectorImpl<int64_t> &workgroupSize) {
RewritePatternSet &patterns, SmallVectorImpl<int64_t> &workgroupSize,
bool matchByDefault = true) {
linalg::TileSizeComputationFunction getInnerTileSizeFn =
[&](OpBuilder &builder, Operation *operation) {
return calculateDistributedTileSize(workgroupSize, builder, operation);
Expand All @@ -162,17 +171,107 @@ static void populateTilingToInvocationPatterns(
MLIRContext *context = patterns.getContext();
IREE::LinalgExt::LinalgTransformationFilter f(
{StringAttr::get(context, getWorkgroupKTiledMarker()),
StringAttr::get(context, getWorkgroupMemoryMarker())},
StringAttr::get(context, getWorkgroupMemoryMarker()),
StringAttr::get(context, getGPUSimtLoweringReqMarker())},
StringAttr::get(context, getVectorizeMarker()));
f.addFilter([](Operation *op) {
// FFT doesn't support second level of tiling yet.
return success(!isa<IREE::LinalgExt::FftOp>(op));
}).setMatchByDefault();
// FFT doesn't support second level of tiling yet.
return success(!isa<IREE::LinalgExt::FftOp>(op));
});
if (matchByDefault) f.setMatchByDefault();
patterns.insert<IREE::LinalgExt::LinalgTilingPattern,
IREE::LinalgExt::TilingInterfaceTilingPattern>(
context, tilingOptions, f);
}

static void markCandidates(func::FuncOp funcOp) {
funcOp.walk([](linalg::LinalgOp op) {
if (!isa<linalg::BatchMatmulOp, linalg::MatmulOp>(op))
return WalkResult::skip();

if (succeeded(alignedOpFilter(op))) {
setMarker(op, getGPUTensorCoreLoweringReqMarker());
} else {
setMarker(op, getGPUSimtLoweringReqMarker());
}
return WalkResult::advance();
});
}

static LogicalResult tileTensorCoreKDim(func::FuncOp funcOp) {
// mark which linarg op is a tensorcore
markCandidates(funcOp);

auto context = funcOp.getContext();
RewritePatternSet patterns(context);
auto tileSizesFn = [](OpBuilder &builder,
Operation *op) -> SmallVector<Value, 4> {
auto interfaceOp = cast<PartitionableLoopsInterface>(*op);
auto partitionedLoops =
interfaceOp.getPartitionableLoops(kNumMaxParallelDims);
SmallVector<Value, 4> tileSizes = getTileSizes(builder, op, 0);
auto zero = builder.create<arith::ConstantIndexOp>(op->getLoc(), 0);
for (unsigned depth : partitionedLoops) {
if (depth < tileSizes.size()) {
tileSizes[depth] = zero;
}
}
return tileSizes;
};

auto tilingOptions =
linalg::LinalgTilingOptions()
.setLoopType(linalg::LinalgTilingLoopType::Loops)
.setTileSizeComputationFunction(tileSizesFn)
.setPeeledLoops({0}); // peel off the partial iterations

IREE::LinalgExt::LinalgTransformationFilter filter(
ArrayRef<StringAttr>{
StringAttr::get(context, getGPUTensorCoreLoweringReqMarker())},
StringAttr::get(context, getWorkgroupKTiledMarker()));

TilingPatterns<linalg::MatmulOp, linalg::BatchMatmulOp>::insert(
patterns, tilingOptions, filter);

if (failed(applyPatternsAndFoldGreedily(funcOp, std::move(patterns)))) {
return failure();
}

RewritePatternSet wgTilingCanonicalizationPatterns =
linalg::getLinalgTilingCanonicalizationPatterns(funcOp.getContext());
populateAffineMinSCFCanonicalizationPattern(wgTilingCanonicalizationPatterns);
scf::populateSCFForLoopCanonicalizationPatterns(
wgTilingCanonicalizationPatterns);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(wgTilingCanonicalizationPatterns)))) {
return failure();
}

return success();
}

// Get K dimension size. It returns kDynamicSize for unknown cases.
static int64_t getSizeK(linalg::LinalgOp op) {
int64_t sizeK = ShapedType::kDynamicSize;

if (!isa<linalg::BatchMatmulOp, linalg::MatmulOp>(op)) return sizeK;

auto lhsShape =
op.getDpsInputOperand(0)->get().getType().cast<ShapedType>().getShape();
SmallVector<unsigned> exprs;
op.getReductionDims(exprs);
if (exprs.size() == 1) {
for (unsigned i = 0; i < lhsShape.size(); i++) {
if (op.getMatchingIndexingMap(op.getDpsInputOperand(0))
.getDimPosition(i) == exprs[0]) {
sizeK = lhsShape[i];
break;
}
}
}
return sizeK;
}

namespace {
struct LLVMGPUTileAndDistributePass
: public LLVMGPUTileAndDistributeBase<LLVMGPUTileAndDistributePass> {
Expand All @@ -191,7 +290,7 @@ struct LLVMGPUTileAndDistributePass
auto funcOp = getOperation();
if (!isEntryPoint(funcOp)) return;

// Promote C matrix and propagate the potential fill producer into the temp
// Promote C matrix and propagate the potential fill producer into the temp
// allocation. This needs to be done before reduction tiling.
{
RewritePatternSet promotionPatterns(&getContext());
Expand All @@ -200,13 +299,24 @@ struct LLVMGPUTileAndDistributePass
std::move(promotionPatterns)))) {
return signalPassFailure();
}
LLVM_DEBUG({
llvm::dbgs() << "After promote C:\n";
funcOp.dump();
});

propagateSharedMemoryCopy(funcOp);

LLVM_DEBUG({
llvm::dbgs() << "After propagateSharedMemoryCopy():\n";
funcOp.dump();
});
}

// Tile again at the workgroup level since reduction dimension were
// ignored. Dimensions already tiled will be ignore since we tile to the
// same size.
if (failed(tileToSerialLoops(funcOp))) {
// same size. For distributing to warps, peel the partial iterations as
// a separate loop, since the warp distribution is requested for wmma.
if (failed(tileToSerialLoops(funcOp, /*peel=*/distributeToWarp))) {
return signalPassFailure();
}

Expand All @@ -226,7 +336,6 @@ struct LLVMGPUTileAndDistributePass
RewritePatternSet promotionPatterns(&getContext());

populateContractPromotionPatterns(promotionPatterns, {0, 1});

if (failed(applyPatternsAndFoldGreedily(funcOp,
std::move(promotionPatterns)))) {
return signalPassFailure();
Expand All @@ -250,14 +359,50 @@ struct LLVMGPUTileAndDistributePass
});

if (distributeToWarp) {
// Apply last level of tiling and distribute to warps.
// mark candidates for the warp level tiling
funcOp.walk([&](linalg::LinalgOp op) {
if (failed(alignedOpFilter(op))) return WalkResult::skip();
if (!isa<linalg::BatchMatmulOp, linalg::MatmulOp, linalg::FillOp,
linalg::GenericOp>(op))
return WalkResult::skip();

if (isa<linalg::GenericOp>(op) &&
hasMarker(op, getCopyToWorkgroupMemoryMarker())) {
// The GPUDistributeSharedMemoryCopy pass will handle it later.
return WalkResult::skip();
}

// check if K is a multiple of Tile-K.
int64_t sizeK = getSizeK(op);
if (sizeK != ShapedType::kDynamicSize) {
// WG tile sizes
auto wgTileSizes = getTileSizes(op, 0);

if (sizeK % wgTileSizes[wgTileSizes.size() - 1] != 0)
return WalkResult::skip();
}

setMarker(op, getGPUWarpLevelTilingReqMarker());
return WalkResult::advance();
});

// Apply last level of tiling and distribute to warps for aligned ops.
RewritePatternSet warpLevelTilingPatterns(context);
populateTilingToWarpPatterns(warpLevelTilingPatterns, workgroupSize);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(warpLevelTilingPatterns)))) {
return signalPassFailure();
}

// Apply last level of tiling and distribute to threads for unaligned ops.
RewritePatternSet threadLevelTilingPatterns(context);
populateTilingToInvocationPatterns(threadLevelTilingPatterns,
workgroupSize,
/*matchByDefault=*/false);
if (failed(applyPatternsAndFoldGreedily(
funcOp, std::move(threadLevelTilingPatterns)))) {
return signalPassFailure();
}
} else {
// Apply last level of tiling and distribute to threads.
RewritePatternSet threadLevelTilingPatterns(context);
Expand Down
Loading