diff --git a/cpp/src/arrow/extension/fixed_shape_tensor.cc b/cpp/src/arrow/extension/fixed_shape_tensor.cc index 5be855ffcb1d..544616988746 100644 --- a/cpp/src/arrow/extension/fixed_shape_tensor.cc +++ b/cpp/src/arrow/extension/fixed_shape_tensor.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include @@ -109,8 +110,8 @@ Result> FixedShapeTensorType::Deserialize( return Status::Invalid("Expected FixedSizeList storage type, got ", storage_type->ToString()); } - auto value_type = - internal::checked_pointer_cast(storage_type)->value_type(); + auto fsl_type = internal::checked_pointer_cast(storage_type); + auto value_type = fsl_type->value_type(); rj::Document document; if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() || !document.IsObject() || !document.HasMember("shape") || @@ -119,21 +120,45 @@ Result> FixedShapeTensorType::Deserialize( } std::vector shape; - for (auto& x : document["shape"].GetArray()) { + for (const auto& x : document["shape"].GetArray()) { + if (!x.IsInt64()) { + return Status::Invalid("shape must contain integers, got ", + internal::JsonTypeName(x)); + } shape.emplace_back(x.GetInt64()); } + std::vector permutation; if (document.HasMember("permutation")) { - for (auto& x : document["permutation"].GetArray()) { + const auto& json_permutation = document["permutation"]; + if (!json_permutation.IsArray()) { + return Status::Invalid("permutation must be an array, got ", + internal::JsonTypeName(json_permutation)); + } + for (const auto& x : json_permutation.GetArray()) { + if (!x.IsInt64()) { + return Status::Invalid("permutation must contain integers, got ", + internal::JsonTypeName(x)); + } permutation.emplace_back(x.GetInt64()); } if (shape.size() != permutation.size()) { return Status::Invalid("Invalid permutation"); } + RETURN_NOT_OK(internal::IsPermutationValid(permutation)); } std::vector dim_names; if (document.HasMember("dim_names")) { - for (auto& x : document["dim_names"].GetArray()) { + const auto& json_dim_names = document["dim_names"]; + if (!json_dim_names.IsArray()) { + return Status::Invalid("dim_names must be an array, got ", + internal::JsonTypeName(json_dim_names)); + } + for (const auto& x : json_dim_names.GetArray()) { + if (!x.IsString()) { + return Status::Invalid("dim_names must contain strings, got ", + internal::JsonTypeName(x)); + } dim_names.emplace_back(x.GetString()); } if (shape.size() != dim_names.size()) { @@ -141,7 +166,20 @@ Result> FixedShapeTensorType::Deserialize( } } - return fixed_shape_tensor(value_type, shape, permutation, dim_names); + // Validate product of shape dimensions matches storage type list_size. + // This check is intentionally after field parsing so that metadata-level errors + // (type mismatches, size mismatches) are reported first. + ARROW_ASSIGN_OR_RAISE(auto ext_type, FixedShapeTensorType::Make( + value_type, shape, permutation, dim_names)); + const auto& fst_type = internal::checked_cast(*ext_type); + ARROW_ASSIGN_OR_RAISE(const int64_t expected_size, + internal::ComputeShapeProduct(fst_type.shape())); + if (expected_size != fsl_type->list_size()) { + return Status::Invalid("Product of shape dimensions (", expected_size, + ") does not match FixedSizeList size (", fsl_type->list_size(), + ")"); + } + return ext_type; } std::shared_ptr FixedShapeTensorType::MakeArray( @@ -310,8 +348,7 @@ const Result> FixedShapeTensorArray::ToTensor() const { } std::vector shape = ext_type.shape(); - auto cell_size = std::accumulate(shape.begin(), shape.end(), static_cast(1), - std::multiplies<>()); + ARROW_ASSIGN_OR_RAISE(const int64_t cell_size, internal::ComputeShapeProduct(shape)); shape.insert(shape.begin(), 1, this->length()); internal::Permute(permutation, &shape); @@ -330,6 +367,11 @@ Result> FixedShapeTensorType::Make( const std::shared_ptr& value_type, const std::vector& shape, const std::vector& permutation, const std::vector& dim_names) { const size_t ndim = shape.size(); + for (auto dim : shape) { + if (dim < 0) { + return Status::Invalid("shape must have non-negative values, got ", dim); + } + } if (!permutation.empty() && ndim != permutation.size()) { return Status::Invalid("permutation size must match shape size. Expected: ", ndim, " Got: ", permutation.size()); @@ -342,8 +384,12 @@ Result> FixedShapeTensorType::Make( RETURN_NOT_OK(internal::IsPermutationValid(permutation)); } - const int64_t size = std::accumulate(shape.begin(), shape.end(), - static_cast(1), std::multiplies<>()); + ARROW_ASSIGN_OR_RAISE(const int64_t size, internal::ComputeShapeProduct(shape)); + if (size > std::numeric_limits::max()) { + return Status::Invalid("Product of shape dimensions (", size, + ") exceeds maximum FixedSizeList size (", + std::numeric_limits::max(), ")"); + } return std::make_shared(value_type, static_cast(size), shape, permutation, dim_names); } diff --git a/cpp/src/arrow/extension/tensor_extension_array_test.cc b/cpp/src/arrow/extension/tensor_extension_array_test.cc index 5c6dbe216281..531fc3c01cf5 100644 --- a/cpp/src/arrow/extension/tensor_extension_array_test.cc +++ b/cpp/src/arrow/extension/tensor_extension_array_test.cc @@ -219,6 +219,73 @@ TEST_F(TestFixedShapeTensorType, MetadataSerializationRoundtrip) { CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3],"dim_names":["x","y"]})", "Invalid dim_names"); + + // Validate shape values must be integers. Error message should include the + // JSON type name of the offending value. + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3.5,4]})", + "shape must contain integers, got Number"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":["3","4"]})", + "shape must contain integers, got String"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[null]})", + "shape must contain integers, got Null"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[true]})", + "shape must contain integers, got True"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[false]})", + "shape must contain integers, got False"); + + // Validate shape values must be non-negative + CheckDeserializationRaises(ext_type_, fixed_size_list(int64(), 1), R"({"shape":[-1]})", + "shape must have non-negative values"); + + // Validate product of shape matches storage list_size + CheckDeserializationRaises(ext_type_, storage_type, R"({"shape":[3,3]})", + "Product of shape dimensions"); + + // Validate permutation member must be an array with integer values + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":"invalid"})", + "permutation must be an array, got String"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":{"a":1}})", + "permutation must be an array, got Object"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":[1.5,0.5]})", + "permutation must contain integers, got Number"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":["a","b"]})", + "permutation must contain integers, got String"); + + // Validate permutation values must be unique integers in [0, N-1] + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":[0,0]})", + "Permutation indices"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":[0,5]})", + "Permutation indices"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"permutation":[-1,0]})", + "Permutation indices"); + + // Validate dim_names member must be an array with string values + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"dim_names":"invalid"})", + "dim_names must be an array, got String"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"dim_names":[1,2]})", + "dim_names must contain strings, got Number"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"shape":[3,4],"dim_names":[null,null]})", + "dim_names must contain strings, got Null"); +} + +TEST_F(TestFixedShapeTensorType, MakeValidatesShape) { + // Negative shape values should be rejected + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("shape must have non-negative values"), + FixedShapeTensorType::Make(value_type_, {-1})); + EXPECT_RAISES_WITH_MESSAGE_THAT( + Invalid, testing::HasSubstr("shape must have non-negative values"), + FixedShapeTensorType::Make(value_type_, {3, -1, 4})); } TEST_F(TestFixedShapeTensorType, RoundtripBatch) { @@ -794,6 +861,32 @@ TEST_F(TestVariableShapeTensorType, MetadataSerializationRoundtrip) { "Invalid: permutation"); CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":["x","y"]})", "Invalid: dim_names"); + + // Validate permutation member must be an array with integer values. Error + // message should include the JSON type name of the offending value. + CheckDeserializationRaises(ext_type_, storage_type, R"({"permutation":"invalid"})", + "permutation must be an array, got String"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"permutation":[1.5,0.5,2.5]})", + "permutation must contain integers, got Number"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"permutation":[null,null,null]})", + "permutation must contain integers, got Null"); + + // Validate dim_names member must be an array with string values + CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":"invalid"})", + "dim_names must be an array, got String"); + CheckDeserializationRaises(ext_type_, storage_type, R"({"dim_names":[1,2,3]})", + "dim_names must contain strings, got Number"); + + // Validate uniform_shape member must be an array with integer-or-null values + CheckDeserializationRaises(ext_type_, storage_type, R"({"uniform_shape":"invalid"})", + "uniform_shape must be an array, got String"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"uniform_shape":[1.5,null,null]})", + "uniform_shape must contain integers or nulls, got Number"); + CheckDeserializationRaises(ext_type_, storage_type, + R"({"uniform_shape":["x",null,null]})", + "uniform_shape must contain integers or nulls, got String"); } TEST_F(TestVariableShapeTensorType, RoundtripBatch) { diff --git a/cpp/src/arrow/extension/tensor_internal.cc b/cpp/src/arrow/extension/tensor_internal.cc index 37862b7689f1..e94ea9a1d181 100644 --- a/cpp/src/arrow/extension/tensor_internal.cc +++ b/cpp/src/arrow/extension/tensor_internal.cc @@ -30,6 +30,31 @@ namespace arrow::internal { +namespace { + +// Names indexed by rapidjson::Type enum value: +// kNullType=0, kFalseType=1, kTrueType=2, kObjectType=3, +// kArrayType=4, kStringType=5, kNumberType=6. +constexpr const char* kJsonTypeNames[] = {"Null", "False", "True", "Object", + "Array", "String", "Number"}; + +} // namespace + +const char* JsonTypeName(const ::arrow::rapidjson::Value& v) { + return kJsonTypeNames[v.GetType()]; +} + +Result ComputeShapeProduct(std::span shape) { + int64_t product = 1; + for (const auto dim : shape) { + if (MultiplyWithOverflow(product, dim, &product)) { + return Status::Invalid( + "Product of tensor shape dimensions would not fit in 64-bit integer"); + } + } + return product; +} + bool IsPermutationTrivial(std::span permutation) { for (size_t i = 1; i < permutation.size(); ++i) { if (permutation[i - 1] + 1 != permutation[i]) { @@ -105,12 +130,7 @@ Result> SliceTensorBuffer(const Array& data_array, const DataType& value_type, std::span shape) { const int64_t byte_width = value_type.byte_width(); - int64_t size = 1; - for (const auto dim : shape) { - if (MultiplyWithOverflow(size, dim, &size)) { - return Status::Invalid("Tensor size would not fit in 64-bit integer"); - } - } + ARROW_ASSIGN_OR_RAISE(const int64_t size, ComputeShapeProduct(shape)); if (size != data_array.length()) { return Status::Invalid("Expected data array of length ", size, ", got ", data_array.length()); diff --git a/cpp/src/arrow/extension/tensor_internal.h b/cpp/src/arrow/extension/tensor_internal.h index b5ed5ebe1197..19665bf2cd4c 100644 --- a/cpp/src/arrow/extension/tensor_internal.h +++ b/cpp/src/arrow/extension/tensor_internal.h @@ -21,11 +21,25 @@ #include #include +#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep #include "arrow/result.h" #include "arrow/type_fwd.h" +#include + namespace arrow::internal { +/// \brief Return the name of a RapidJSON value's type (e.g., "Null", "Array", "Number"). +ARROW_EXPORT +const char* JsonTypeName(const ::arrow::rapidjson::Value& v); + +/// \brief Compute the product of the given shape dimensions. +/// +/// Returns Status::Invalid if the product would overflow int64_t. +/// An empty shape returns 1 (the multiplicative identity). +ARROW_EXPORT +Result ComputeShapeProduct(std::span shape); + ARROW_EXPORT bool IsPermutationTrivial(std::span permutation); diff --git a/cpp/src/arrow/extension/variable_shape_tensor.cc b/cpp/src/arrow/extension/variable_shape_tensor.cc index 7e27bbdb749f..b1b12583d7fe 100644 --- a/cpp/src/arrow/extension/variable_shape_tensor.cc +++ b/cpp/src/arrow/extension/variable_shape_tensor.cc @@ -159,26 +159,31 @@ Result> VariableShapeTensorType::Deserialize( if (document.HasMember("permutation")) { const auto& json_permutation = document["permutation"]; if (!json_permutation.IsArray()) { - return Status::Invalid("permutation must be an array"); + return Status::Invalid("permutation must be an array, got ", + internal::JsonTypeName(json_permutation)); } permutation.reserve(ndim); for (const auto& x : json_permutation.GetArray()) { if (!x.IsInt64()) { - return Status::Invalid("permutation must contain integers"); + return Status::Invalid("permutation must contain integers, got ", + internal::JsonTypeName(x)); } permutation.emplace_back(x.GetInt64()); } + RETURN_NOT_OK(internal::IsPermutationValid(permutation)); } std::vector dim_names; if (document.HasMember("dim_names")) { const auto& json_dim_names = document["dim_names"]; if (!json_dim_names.IsArray()) { - return Status::Invalid("dim_names must be an array"); + return Status::Invalid("dim_names must be an array, got ", + internal::JsonTypeName(json_dim_names)); } dim_names.reserve(ndim); for (const auto& x : json_dim_names.GetArray()) { if (!x.IsString()) { - return Status::Invalid("dim_names must contain strings"); + return Status::Invalid("dim_names must contain strings, got ", + internal::JsonTypeName(x)); } dim_names.emplace_back(x.GetString()); } @@ -188,7 +193,8 @@ Result> VariableShapeTensorType::Deserialize( if (document.HasMember("uniform_shape")) { const auto& json_uniform_shape = document["uniform_shape"]; if (!json_uniform_shape.IsArray()) { - return Status::Invalid("uniform_shape must be an array"); + return Status::Invalid("uniform_shape must be an array, got ", + internal::JsonTypeName(json_uniform_shape)); } uniform_shape.reserve(ndim); for (const auto& x : json_uniform_shape.GetArray()) { @@ -197,7 +203,8 @@ Result> VariableShapeTensorType::Deserialize( } else if (x.IsInt64()) { uniform_shape.emplace_back(x.GetInt64()); } else { - return Status::Invalid("uniform_shape must contain integers or nulls"); + return Status::Invalid("uniform_shape must contain integers or nulls, got ", + internal::JsonTypeName(x)); } } }