diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 901df3a25a46f..7624b5e761e9f 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -112,9 +112,9 @@ class DenseElementsAttr : public Attribute { static DenseElementsAttr get(ShapedType type, ArrayRef values); /// Constructs a dense integer elements attribute from an array of integer - /// or floating-point values. Each value is expected to be the same bitwidth - /// of the element type of 'type'. 'type' must be a vector or tensor with - /// static shape. + /// or floating-point values. Each value is expected to be the same as the + /// storage bitwidth of the element type of 'type'. 'type' must be a vector or + /// tensor with static shape. template ::is_integer || is_valid_cpp_fp_type::value>> diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 112e3f376bd41..b90713e890889 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -1104,13 +1104,15 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type, /// invariants that the templatized 'getValues' method cannot. static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, bool isSigned) { - // Make sure that the data element size is the same as the type element width. - auto denseEltBitWidth = getDenseElementBitWidth(type); - auto dataSize = static_cast(dataEltSize * CHAR_BIT); - if (denseEltBitWidth != dataSize) { - LLVM_DEBUG(llvm::dbgs() << "expected dense element bit width " - << denseEltBitWidth << " to match data size " - << dataSize << " for type " << type << "\n"); + // Make sure that the data element size is the same as the type element + // storage width. + const size_t denseEltStorageBitWidth = getDenseElementStorageWidth(type); + const size_t dataSizeBitWidth = static_cast(dataEltSize * CHAR_BIT); + if (denseEltStorageBitWidth != dataSizeBitWidth) { + LLVM_DEBUG(llvm::dbgs() + << "expected dense element bit width " << denseEltStorageBitWidth + << " to match data size " << dataSizeBitWidth << " for type " + << type << "\n"); return false; } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index 9203248a83baf..03c08ed22a77f 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -570,4 +570,26 @@ TEST(NonSplattedDenseElementAttrTest, GetNonSplatRawDataI16) { .getNonSplatRawData(), expected); } + +TEST(NonSplattedDenseElementAttrTest, GetFromRawI7) { + constexpr std::size_t numberOfElements = 6; + static constexpr std::array rawValues = {1, 2, 3, + 4, 5, 6}; + + mlir::MLIRContext context; + mlir::OpBuilder b(&context); + + auto values = mlir::DenseElementsAttr::get( + mlir::RankedTensorType::get({numberOfElements}, b.getIntegerType(7)), + ArrayRef(rawValues)); + auto fromRaw = mlir::DenseIntOrFPElementsAttr::getFromRawBuffer( + values.getType(), values.getRawData()); + + EXPECT_EQ(values, fromRaw); + EXPECT_EQ(fromRaw.getElementType(), b.getIntegerType(7)); + EXPECT_EQ(fromRaw.getNumElements(), numberOfElements); + for (auto [fr, e] : llvm::zip_equal(fromRaw.getValues(), rawValues)) { + EXPECT_EQ(fr, e); + } +} } // namespace