diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index 4932ce87d57b7..3cf48e66ec5da 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -2097,6 +2097,44 @@ struct ReduceConstantOptimization : public OpRewritePattern { const bool aggressiveReduceConstant; }; +struct TosaFoldConstantConcat : TosaFoldConstantBase { + using TosaFoldConstantBase::TosaFoldConstantBase; + + LogicalResult matchAndRewrite(tosa::ConcatOp op, + PatternRewriter &rewriter) const override { + auto inputsRange = op.getInput1(); + auto axis = op.getAxis(); + + // TODO: Matching constraints + + // collect all inputvalues + SmallVector inputValuesArr; + SmallVector inputTypesArr; + for (auto input : inputsRange) { + DenseElementsAttr inputValues; + if (!matchPattern(input, m_Constant(&inputValues))) { + return failure(); + } + inputValuesArr.push_back(inputValues); + inputTypesArr.push_back(cast(input.getType())); + } + + // compute result + auto result = this->concat(inputValuesArr, inputTypesArr, axis); + + rewriter.replaceOpWithNewOp(op, result.first, result.second); + } + + std::pair + concat(SmallVector &inputValuesArr, + SmallVector &inputTypesArr, uint32_t axis) const { + auto baseType = inputTypesArr[0].getElementType(); + switch (dyn_cast(baseType).getWidth()) { + // TODO: + } + } +} + } // namespace void mlir::tosa::populateTosaFoldConstantPatterns( @@ -2136,6 +2174,7 @@ void mlir::tosa::populateTosaFoldConstantPatterns( patterns.add(ctx, options.foldSplatOrSingleUseOnly); patterns.add(ctx, options.foldSplatOrSingleUseOnly); patterns.add(ctx, options.foldSplatOrSingleUseOnly); + patterns.add(ctx, options.foldSplatOrSingleUseOnly); if (options.enableTileFolding) patterns.add(ctx, options.foldSplatOrSingleUseOnly); }