Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2097,6 +2097,120 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
const bool aggressiveReduceConstant;
};

template <typename ElementStorageType>
DenseElementsAttr
concatenateAttrs(const ShapedType outputType, ArrayRef<ElementsAttr> inputAttrs,
const uint32_t concatAxis, PatternRewriter &rewriter,
const Type elementType) {

static_assert(std::is_same<ElementStorageType, APInt>::value ||
std::is_same<ElementStorageType, APFloat>::value,
"ElementStorageType must be either APInt or APFloat");

SmallVector<ElementStorageType> resultValues;
if constexpr (std::is_same<ElementStorageType, APInt>::value) {
resultValues.resize_for_overwrite(outputType.getNumElements());
} else {
resultValues.resize(
outputType.getNumElements(),
APFloat::getZero(cast<FloatType>(elementType).getFloatSemantics()));
}
const auto outputShape = outputType.getShape();

int64_t concatDimOffset = 0;
for (const auto &inputAttr : inputAttrs) {
const auto inputShape = cast<ShapedType>(inputAttr.getType()).getShape();
const auto inputValues = inputAttr.getValues<ElementStorageType>();

for (const auto &[inputLinearIdx, val] : llvm::enumerate(inputValues)) {
// TODO: Could be optimized to work on slices instead of single value
SmallVector<int64_t> multiDimIndex =
offsetToIndex(inputShape, inputLinearIdx);
multiDimIndex[concatAxis] += concatDimOffset;

const int64_t outputLinearIndex =
indexToOffset(outputShape, multiDimIndex);
resultValues[outputLinearIndex] = val;
}
concatDimOffset += inputShape[concatAxis];
}
return DenseElementsAttr::get(outputType, resultValues);
}

struct TosaFoldConstantConcat : public TosaFoldConstantBase<tosa::ConcatOp> {
using TosaFoldConstantBase::TosaFoldConstantBase;

LogicalResult matchAndRewrite(tosa::ConcatOp op,
PatternRewriter &rewriter) const override {
auto inputs = op->getOperands();
const uint32_t concatAxis = op.getAxis();
const auto outputType = cast<ShapedType>(op.getType());
if (!outputType.hasStaticShape()) {
return rewriter.notifyMatchFailure(
op, "Output type must have static shape for concat folding.");
}
if (llvm::any_of(inputs, [](Value v) {
return !cast<ShapedType>(v.getType()).hasStaticShape();
})) {
return rewriter.notifyMatchFailure(
op, "All inputs to ConcatOp must have static shape for folding.");
}

const Type elementType = outputType.getElementType();
if (!elementType.isIntOrIndexOrFloat()) {
// Sanity check, this should always be the case
return rewriter.notifyMatchFailure(
op, "Output element type must be int, index, or float for folding.");
}

SmallVector<ElementsAttr> inputAttrs;
inputAttrs.reserve(inputs.size());

for (Value inputVal : inputs) {
ElementsAttr inputAsAttr;
if (!matchPattern(inputVal, m_Constant(&inputAsAttr))) {
// TODO: This could be extended to handle partial non-const inputs
return rewriter.notifyMatchFailure(
op, "All inputs to ConcatOp must be constant for folding.");
}

if (inputAsAttr.isSplat()) {
const ShapedType inputType = cast<ShapedType>(inputAsAttr.getType());
if (isa<IntegerType>(elementType)) {
inputAsAttr = DenseElementsAttr::get(
inputType, inputAsAttr.getSplatValue<APInt>());
} else {
inputAsAttr = DenseElementsAttr::get(
inputType, inputAsAttr.getSplatValue<APFloat>());
}
}
if (foldSplatOrSingleUseOnly && !inputVal.hasOneUse() &&
!inputAsAttr.isSplat()) {
return rewriter.notifyMatchFailure(
op, "Concat folding heuristic: non-splat constant inputs must have "
"only a single use.");
}
inputAttrs.push_back(inputAsAttr);
}

DenseElementsAttr resultAttr;
if (auto intType = dyn_cast<IntegerType>(elementType)) {
// TODO: This could be optimized to not go to APInt if the int size
// matches c++ native types
resultAttr = concatenateAttrs<APInt>(outputType, inputAttrs, concatAxis,
rewriter, elementType);
} else {
resultAttr = concatenateAttrs<APFloat>(outputType, inputAttrs, concatAxis,
rewriter, elementType);
}

assert(resultAttr && "Result attribute should not be null.");

rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, outputType, resultAttr);
return success();
}
};

} // namespace

void mlir::tosa::populateTosaFoldConstantPatterns(
Expand Down Expand Up @@ -2136,6 +2250,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns(
patterns.add<TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantSlice>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly);
patterns.add<TosaFoldConstantConcat>(ctx, options.foldSplatOrSingleUseOnly);
if (options.enableTileFolding)
patterns.add<TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly);
}
Expand Down
124 changes: 124 additions & 0 deletions mlir/test/Dialect/Tosa/constant-concat.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// RUN: mlir-opt --tosa-layerwise-constant-fold %s | FileCheck %s

// CHECK-LABEL: func.func @concat_i32_axis0
// CHECK-SAME: () -> tensor<4x2xi32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 2], [3, 4], [5, 6], [7, 8]{{.}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32>
// CHECK: return [[VAR_0_]] : tensor<4x2xi32>
func.func @concat_i32_axis0() -> (tensor<4x2xi32>) {
%c0 = "tosa.const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%c1 = "tosa.const"() {value = dense<[[5, 6], [7, 8]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<4x2xi32>
return %0 : tensor<4x2xi32>
}

// CHECK-LABEL: func.func @concat_f32_axis1
// CHECK-SAME: () -> tensor<2x3xf32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1.000000e+00, 2.000000e+00, 3.000000e+00], [4.000000e+00, 5.000000e+00, 6.000000e+00]{{.}}> : tensor<2x3xf32>}> : () -> tensor<2x3xf32>
// CHECK: return [[VAR_0_]] : tensor<2x3xf32>
func.func @concat_f32_axis1() -> (tensor<2x3xf32>) {
%c0 = "tosa.const"() {value = dense<[[1.0, 2.0], [4.0, 5.0]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32>
%c1 = "tosa.const"() {value = dense<[[3.0], [6.0]]> : tensor<2x1xf32>} : () -> tensor<2x1xf32>
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2x3xf32>
return %0 : tensor<2x3xf32>
}

// CHECK-LABEL: func.func @concat_i8_three_inputs_axis1
// CHECK-SAME: () -> tensor<1x5xi8> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 2, 3, 4, 5]{{.}}> : tensor<1x5xi8>}> : () -> tensor<1x5xi8>
// CHECK: return [[VAR_0_]] : tensor<1x5xi8>
func.func @concat_i8_three_inputs_axis1() -> (tensor<1x5xi8>) {
%c0 = "tosa.const"() {value = dense<[[1, 2]]> : tensor<1x2xi8>} : () -> tensor<1x2xi8>
%c1 = "tosa.const"() {value = dense<[[3]]> : tensor<1x1xi8>} : () -> tensor<1x1xi8>
%c2 = "tosa.const"() {value = dense<[[4, 5]]> : tensor<1x2xi8>} : () -> tensor<1x2xi8>
%0 = "tosa.concat"(%c0, %c1, %c2) {axis = 1 : i32} : (tensor<1x2xi8>, tensor<1x1xi8>, tensor<1x2xi8>) -> tensor<1x5xi8>
return %0 : tensor<1x5xi8>
}

// CHECK-LABEL: func.func @concat_i32_with_splat_axis0
// CHECK-SAME: () -> tensor<3x1xi32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[7], [7], [8]{{.}}> : tensor<3x1xi32>}> : () -> tensor<3x1xi32>
// CHECK: return [[VAR_0_]] : tensor<3x1xi32>
func.func @concat_i32_with_splat_axis0() -> (tensor<3x1xi32>) {
%c0 = "tosa.const"() {value = dense<7> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
%c1 = "tosa.const"() {value = dense<[[8]]> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x1xi32>, tensor<1x1xi32>) -> tensor<3x1xi32>
return %0 : tensor<3x1xi32>
}

// CHECK-LABEL: func.func @concat_bool_axis0
// CHECK-SAME: () -> tensor<2x2xi1> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[true, false], [false, true]{{.}}> : tensor<2x2xi1>}> : () -> tensor<2x2xi1>
// CHECK: return [[VAR_0_]] : tensor<2x2xi1>
func.func @concat_bool_axis0() -> (tensor<2x2xi1>) {
%c0 = "tosa.const"() {value = dense<[[true], [false]]> : tensor<2x1xi1>} : () -> tensor<2x1xi1>
%c1 = "tosa.const"() {value = dense<[[false], [true]]> : tensor<2x1xi1>} : () -> tensor<2x1xi1>
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x1xi1>, tensor<2x1xi1>) -> tensor<2x2xi1>
return %0 : tensor<2x2xi1>
}

// CHECK-LABEL: func.func @concat_rank1_i32_axis0
// CHECK-SAME: () -> tensor<5xi32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1, 2, 3, 4, 5]> : tensor<5xi32>}> : () -> tensor<5xi32>
// CHECK: return [[VAR_0_]] : tensor<5xi32>
func.func @concat_rank1_i32_axis0() -> (tensor<5xi32>) {
%c0 = "tosa.const"() {value = dense<[1, 2, 3]> : tensor<3xi32>} : () -> tensor<3xi32>
%c1 = "tosa.const"() {value = dense<[4, 5]> : tensor<2xi32>} : () -> tensor<2xi32>
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<3xi32>, tensor<2xi32>) -> tensor<5xi32>
return %0 : tensor<5xi32>
}

// CHECK-LABEL: func.func @concat_empty_tensor_axis0
// CHECK-SAME: () -> tensor<2x2xi32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 2], [3, 4]{{.}}> : tensor<2x2xi32>}> : () -> tensor<2x2xi32>
// CHECK: return [[VAR_0_]] : tensor<2x2xi32>
func.func @concat_empty_tensor_axis0() -> (tensor<2x2xi32>) {
%c0 = "tosa.const"() {value = dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%c1 = "tosa.const"() {value = dense<> : tensor<0x2xi32>} : () -> tensor<0x2xi32>
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x2xi32>, tensor<0x2xi32>) -> tensor<2x2xi32>
return %0 : tensor<2x2xi32>
}

// CHECK-LABEL: func.func @concat_all_empty_tensors_axis1
// CHECK-SAME: () -> tensor<2x0xi32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<> : tensor<2x0xi32>}> : () -> tensor<2x0xi32>
// CHECK: return [[VAR_0_]] : tensor<2x0xi32>
func.func @concat_all_empty_tensors_axis1() -> (tensor<2x0xi32>) {
%c0 = "tosa.const"() {value = dense<> : tensor<2x0xi32>} : () -> tensor<2x0xi32>
%c1 = "tosa.const"() {value = dense<> : tensor<2x0xi32>} : () -> tensor<2x0xi32>
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x0xi32>, tensor<2x0xi32>) -> tensor<2x0xi32>
return %0 : tensor<2x0xi32>
}

// CHECK-LABEL: func.func @concat_i32_axis1_three_inputs_two_splats
// CHECK-SAME: () -> tensor<2x4xi32> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[1, 10, 11, 2], [1, 12, 13, 2]{{.}}> : tensor<2x4xi32>}> : () -> tensor<2x4xi32>
// CHECK: return [[VAR_0_]] : tensor<2x4xi32>
func.func @concat_i32_axis1_three_inputs_two_splats() -> (tensor<2x4xi32>) {
%c0_splat = "tosa.const"() {value = dense<1> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
%c1_dense = "tosa.const"() {value = dense<[[10, 11], [12, 13]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
%c2_splat = "tosa.const"() {value = dense<2> : tensor<2x1xi32>} : () -> tensor<2x1xi32>
%0 = "tosa.concat"(%c0_splat, %c1_dense, %c2_splat) {axis = 1 : i32} : (tensor<2x1xi32>, tensor<2x2xi32>, tensor<2x1xi32>) -> tensor<2x4xi32>
return %0 : tensor<2x4xi32>
}

// CHECK-LABEL: func.func @concat_ui16_axis0
// CHECK-SAME: () -> tensor<3x2xui16> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}[100, 200], [300, 400], [500, 600]{{.}}> : tensor<3x2xui16>}> : () -> tensor<3x2xui16>
// CHECK: return [[VAR_0_]] : tensor<3x2xui16>
func.func @concat_ui16_axis0() -> (tensor<3x2xui16>) {
%c0 = "tosa.const"() {value = dense<[[100, 200], [300, 400]]> : tensor<2x2xui16>} : () -> tensor<2x2xui16>
%c1 = "tosa.const"() {value = dense<[[500, 600]]> : tensor<1x2xui16>} : () -> tensor<1x2xui16>
%0 = "tosa.concat"(%c0, %c1) {axis = 0 : i32} : (tensor<2x2xui16>, tensor<1x2xui16>) -> tensor<3x2xui16>
return %0 : tensor<3x2xui16>
}

// CHECK-LABEL: func.func @concat_3d_bf16_axis1
// CHECK-SAME: () -> tensor<2x3x2xbf16> {
// CHECK: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<{{.}}{{.}}[1.000000e+00, 2.000000e+00], [5.000000e+00, 6.000000e+00], [7.000000e+00, 8.000000e+00]{{.}}, {{.}}[3.000000e+00, 4.000000e+00], [9.000000e+00, 1.000000e+01], [1.100000e+01, 1.200000e+01]{{.}}{{.}}> : tensor<2x3x2xbf16>}> : () -> tensor<2x3x2xbf16>
// CHECK: return [[VAR_0_]] : tensor<2x3x2xbf16>
func.func @concat_3d_bf16_axis1() -> (tensor<2x3x2xbf16>) {
%c0 = "tosa.const"() {value = dense<[[[1.0, 2.0]], [[3.0, 4.0]]]> : tensor<2x1x2xbf16>} : () -> tensor<2x1x2xbf16>
%c1 = "tosa.const"() {value = dense<[[[5.0, 6.0], [7.0, 8.0]], [[9.0, 10.0], [11.0, 12.0]]]> : tensor<2x2x2xbf16>} : () -> tensor<2x2x2xbf16>
%0 = "tosa.concat"(%c0, %c1) {axis = 1 : i32} : (tensor<2x1x2xbf16>, tensor<2x2x2xbf16>) -> tensor<2x3x2xbf16>
return %0 : tensor<2x3x2xbf16>
}