diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index cb87c9279b575..dbc4dd63e067d 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1693,6 +1693,14 @@ OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) { if (getAxis() != producer.getAxis()) continue; + // If there are multiple uses of this operand concat and they are different + // operations, this means that operand concat will have to happen, so do not + // add its operands to us to avoid repeating data concatenation + const bool allConcatUsersAreThisConcat = llvm::all_of( + producer->getUsers(), [&](Operation *user) { return *this == user; }); + if (!allConcatUsersAreThisConcat) + continue; + // Replace the original operand with all incoming operands foundFoldableConcat = true; concatOperands.pop_back(); diff --git a/mlir/test/Dialect/Tosa/fold_concats.mlir b/mlir/test/Dialect/Tosa/fold_concats.mlir index adc5875d943b0..0e82782df4d96 100644 --- a/mlir/test/Dialect/Tosa/fold_concats.mlir +++ b/mlir/test/Dialect/Tosa/fold_concats.mlir @@ -62,6 +62,43 @@ func.func @nested_fold(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x8x7x7xf32> { // ----- +func.func @concat_multiple_users(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> (tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>) { + %tmp = tensor.empty() : tensor<1x1x7x7xf32> + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = tosa.concat %tmp, %0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x3x7x7xf32> + %2 = tosa.add %0, %0 : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> + return %1, %2 : tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32> +} + +// CHECK-LABEL: func.func @concat_multiple_users +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_1_:%.+]]: tensor<1x1x7x7xf32>) -> (tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = tensor.empty() : tensor<1x1x7x7xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK-NOT: separator of consecutive DAGs +// CHECK-DAG: [[VAR_2_:%.+]] = tosa.concat [[VAR_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x3x7x7xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = tosa.add [[VAR_1_]], [[VAR_1_]] : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: return [[VAR_2_]], [[VAR_3_]] : tensor<1x3x7x7xf32>, tensor<1x2x7x7xf32> +// CHECK: } + +// ----- + +func.func @concat_diamond_shape(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>, %arg2: tensor<1x1x7x7xf32>, %arg3: tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32> { + %tmp = tensor.empty() : tensor<1x1x7x7xf32> + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = tosa.concat %0, %arg2 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x3x7x7xf32> + %2 = tosa.concat %0, %arg3 {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x3x7x7xf32> + %3 = tosa.concat %1, %2 {axis = 1 : i32} : (tensor<1x3x7x7xf32>, tensor<1x3x7x7xf32>) -> tensor<1x6x7x7xf32> + return %3 : tensor<1x6x7x7xf32> +} + +// CHECK-LABEL: func.func @concat_diamond_shape +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_1_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_2_:%.+]]: tensor<1x1x7x7xf32>, [[PARAM_3_:%.+]]: tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_1_]], [[PARAM_2_]], [[PARAM_0_]], [[PARAM_1_]], [[PARAM_3_]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x6x7x7xf32> +// CHECK: return [[VAR_0_]] : tensor<1x6x7x7xf32> +// CHECK: } + +// ----- + func.func @wide_fold(%arg0: tensor<1x1x7x7xf32>, %arg1: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { %0 = tosa.concat %arg0, %arg0 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> %1 = tosa.concat %arg1, %arg1 {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> @@ -91,4 +128,4 @@ func.func @partially_foldable(%arg0: tensor<1x1x8x8xf32>, %arg1: tensor<1x2x4x8x // CHECK: [[VAR_1_:%.+]] = tosa.tile [[PARAM_1_]], [[VAR_0_]] : (tensor<1x2x4x8xf32>, !tosa.shape<4>) -> tensor<1x2x8x8xf32> // CHECK: [[VAR_2_:%.+]] = tosa.concat [[PARAM_0_]], [[PARAM_0_]], [[VAR_1_]] {axis = 1 : i32} : (tensor<1x1x8x8xf32>, tensor<1x1x8x8xf32>, tensor<1x2x8x8xf32>) -> tensor<1x4x8x8xf32> // CHECK: return [[VAR_2_]] : tensor<1x4x8x8xf32> -// CHECK: } \ No newline at end of file +// CHECK: }