From e980fd0867de5022f8761f96f29e848d8656d5ab Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Fri, 5 Jun 2026 14:10:35 -0700 Subject: [PATCH 1/4] GH-45946: [C++][Parquet] Variant decoding --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 3 +- cpp/src/arrow/extension/meson.build | 5 +- cpp/src/arrow/extension/variant_internal.cc | 1020 ++++++++ cpp/src/arrow/extension/variant_internal.h | 347 +++ .../arrow/extension/variant_internal_test.cc | 2128 +++++++++++++++++ cpp/src/arrow/extension/variant_test_util.h | 137 ++ cpp/src/arrow/meson.build | 1 + 8 files changed, 3640 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/extension/variant_internal.cc create mode 100644 cpp/src/arrow/extension/variant_internal.h create mode 100644 cpp/src/arrow/extension/variant_internal_test.cc create mode 100644 cpp/src/arrow/extension/variant_test_util.h diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 45cd7e838121..530d3e5ff3b8 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -391,6 +391,7 @@ set(ARROW_SRCS extension/bool8.cc extension/json.cc extension/parquet_variant.cc + extension/variant_internal.cc extension/uuid.cc pretty_print.cc record_batch.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index ae52bc32a998..582825027c74 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc) +set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc + variant_internal_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/meson.build b/cpp/src/arrow/extension/meson.build index 84dafe4bbe32..6c6d3a7b67a8 100644 --- a/cpp/src/arrow/extension/meson.build +++ b/cpp/src/arrow/extension/meson.build @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc'] +canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', 'variant_internal_test.cc'] if needs_json canonical_extension_tests += [ @@ -40,5 +40,8 @@ install_headers( 'parquet_variant.h', 'uuid.h', 'variable_shape_tensor.h', + # variant_internal.h: public API for variant binary encoding/decoding. + # "internal" refers to the binary encoding internals, not visibility. + 'variant_internal.h', ], ) diff --git a/cpp/src/arrow/extension/variant_internal.cc b/cpp/src/arrow/extension/variant_internal.cc new file mode 100644 index 000000000000..2ee3fd09ba4a --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal.cc @@ -0,0 +1,1020 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_internal.h" + +#include + +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" + +namespace arrow::extension::variant_internal { + +namespace { + +// --------------------------------------------------------------------------- +// Helpers for reading little-endian integers of variable size (1-4 bytes) +// --------------------------------------------------------------------------- + +/// \brief Read an unsigned integer of 1-4 bytes in little-endian order. +/// +/// On big-endian platforms, FromLittleEndian byte-swaps the full 32-bit +/// word after memcpy; the mask then discards any bytes beyond num_bytes. +/// +/// \param[in] data Pointer to the bytes (must have at least num_bytes valid) +/// \param[in] num_bytes Number of bytes to read (1, 2, 3, or 4) +/// \return The decoded unsigned integer value +inline uint32_t ReadUnsignedLE(const uint8_t* data, int32_t num_bytes) { + uint32_t result = 0; + std::memcpy(&result, data, num_bytes); + result = bit_util::FromLittleEndian(result); + if (num_bytes < 4) { + result &= (static_cast(1) << (num_bytes * 8)) - 1; + } + return result; +} + +/// \brief Validate that an offset array is monotonically non-decreasing +/// and within the buffer bounds. +Status ValidateOffsets(const std::vector& offsets, int64_t data_length) { + for (size_t i = 1; i < offsets.size(); ++i) { + if (offsets[i] < offsets[i - 1]) { + return Status::Invalid( + "Variant metadata: string offsets are not monotonically " + "non-decreasing at index ", + i); + } + } + if (!offsets.empty() && offsets.back() > static_cast(data_length)) { + return Status::Invalid("Variant metadata: last string offset ", offsets.back(), + " exceeds data length ", data_length); + } + return Status::OK(); +} + +// --------------------------------------------------------------------------- +// Value decoding helpers +// --------------------------------------------------------------------------- + +/// \brief Decode a single variant value at the given offset and invoke +/// the visitor. Returns the number of bytes consumed. +/// +/// This is the core recursive function. +Status DecodeValueAt(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, VariantVisitor* visitor, int64_t* bytes_consumed, + int32_t depth); + +/// \brief Decode a primitive value at data[offset]. +Status DecodePrimitive(const uint8_t* data, int64_t length, int64_t offset, + uint8_t header, VariantVisitor* visitor, int64_t* bytes_consumed) { + auto primitive_type = GetPrimitiveType(header); + int64_t pos = offset + 1; // skip header byte + + auto check_remaining = [&](int64_t needed) -> Status { + if (pos + needed > length) { + return Status::Invalid("Variant value: truncated primitive at offset ", offset, + ", need ", needed, " bytes but only ", length - pos, + " remaining"); + } + return Status::OK(); + }; + + switch (primitive_type) { + case PrimitiveType::kNull: + ARROW_RETURN_NOT_OK(visitor->Null()); + *bytes_consumed = 1; + return Status::OK(); + + case PrimitiveType::kTrue: + ARROW_RETURN_NOT_OK(visitor->Bool(true)); + *bytes_consumed = 1; + return Status::OK(); + + case PrimitiveType::kFalse: + ARROW_RETURN_NOT_OK(visitor->Bool(false)); + *bytes_consumed = 1; + return Status::OK(); + + case PrimitiveType::kInt8: { + ARROW_RETURN_NOT_OK(check_remaining(1)); + auto value = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Int8(value)); + *bytes_consumed = 2; + return Status::OK(); + } + + case PrimitiveType::kInt16: { + ARROW_RETURN_NOT_OK(check_remaining(2)); + int16_t value; + std::memcpy(&value, data + pos, 2); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int16(value)); + *bytes_consumed = 3; + return Status::OK(); + } + + case PrimitiveType::kInt32: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + int32_t value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int32(value)); + *bytes_consumed = 5; + return Status::OK(); + } + + case PrimitiveType::kInt64: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Int64(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kFloat: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + float value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Float(value)); + *bytes_consumed = 5; + return Status::OK(); + } + + case PrimitiveType::kDouble: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + double value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Double(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kDecimal4: { + // Spec: 1 byte scale in range [0, 38], followed by 4 bytes LE unscaled value. + // Note: scale is not validated during decode to remain lenient with + // forward-compatible data. The encoder validates scale <= 38. + ARROW_RETURN_NOT_OK(check_remaining(5)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal4(data + pos + 1, scale)); + *bytes_consumed = 6; + return Status::OK(); + } + + case PrimitiveType::kDecimal8: { + // Spec: 1 byte scale, followed by 8 bytes LE unscaled value + ARROW_RETURN_NOT_OK(check_remaining(9)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal8(data + pos + 1, scale)); + *bytes_consumed = 10; + return Status::OK(); + } + + case PrimitiveType::kDecimal16: { + // Spec: 1 byte scale, followed by 16 bytes LE unscaled value + ARROW_RETURN_NOT_OK(check_remaining(17)); + auto scale = static_cast(data[pos]); + ARROW_RETURN_NOT_OK(visitor->Decimal16(data + pos + 1, scale)); + *bytes_consumed = 18; + return Status::OK(); + } + + case PrimitiveType::kDate: { + ARROW_RETURN_NOT_OK(check_remaining(4)); + int32_t value; + std::memcpy(&value, data + pos, 4); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->Date(value)); + *bytes_consumed = 5; + return Status::OK(); + } + + case PrimitiveType::kTimestampMicros: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampMicros(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kTimestampMicrosNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampMicrosNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kBinary: { + // 4-byte length prefix + data + ARROW_RETURN_NOT_OK(check_remaining(4)); + uint32_t bin_length; + std::memcpy(&bin_length, data + pos, 4); + bin_length = bit_util::FromLittleEndian(bin_length); + ARROW_RETURN_NOT_OK(check_remaining(4 + static_cast(bin_length))); + auto view = + std::string_view(reinterpret_cast(data + pos + 4), bin_length); + ARROW_RETURN_NOT_OK(visitor->Binary(view)); + *bytes_consumed = 1 + 4 + static_cast(bin_length); + return Status::OK(); + } + + case PrimitiveType::kString: { + // 4-byte length prefix + data + ARROW_RETURN_NOT_OK(check_remaining(4)); + uint32_t str_length; + std::memcpy(&str_length, data + pos, 4); + str_length = bit_util::FromLittleEndian(str_length); + ARROW_RETURN_NOT_OK(check_remaining(4 + static_cast(str_length))); + auto view = + std::string_view(reinterpret_cast(data + pos + 4), str_length); + ARROW_RETURN_NOT_OK(visitor->String(view)); + *bytes_consumed = 1 + 4 + static_cast(str_length); + return Status::OK(); + } + + case PrimitiveType::kTimeNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimeNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kTimestampNanos: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampNanos(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kTimestampNanosNTZ: { + ARROW_RETURN_NOT_OK(check_remaining(8)); + int64_t value; + std::memcpy(&value, data + pos, 8); + value = bit_util::FromLittleEndian(value); + ARROW_RETURN_NOT_OK(visitor->TimestampNanosNTZ(value)); + *bytes_consumed = 9; + return Status::OK(); + } + + case PrimitiveType::kUUID: { + // UUID is 16 bytes in big-endian order + ARROW_RETURN_NOT_OK(check_remaining(16)); + ARROW_RETURN_NOT_OK(visitor->UUID(data + pos)); + *bytes_consumed = 17; + return Status::OK(); + } + + default: + return Status::Invalid("Variant value: unknown primitive type ", + static_cast(primitive_type)); + } +} + +/// \brief Decode a short string (basic_type == 1). The length is encoded +/// in bits 2-7 of the header byte (max 63 bytes). +Status DecodeShortString(const uint8_t* data, int64_t length, int64_t offset, + uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed) { + int32_t str_len = (header >> 2) & 0x3F; + int64_t pos = offset + 1; + if (pos + str_len > length) { + return Status::Invalid("Variant value: truncated short string at offset ", offset, + ", need ", str_len, " bytes but only ", length - pos, + " remaining"); + } + auto view = std::string_view(reinterpret_cast(data + pos), str_len); + ARROW_RETURN_NOT_OK(visitor->String(view)); + *bytes_consumed = 1 + str_len; + return Status::OK(); +} + +/// \brief Decode an object value (basic_type == 2). +/// +/// Object layout per spec: +/// header (1 byte): +/// bits 0-1: basic_type = 2 +/// bits 2-3: field_offset_size_minus_one +/// bits 4-5: field_id_size_minus_one +/// bit 6: is_large (0 → 1-byte num_elements, 1 → 4-byte) +/// num_elements: 1 or 4 bytes (unsigned LE) +/// field_ids: num_elements × field_id_size bytes +/// field_offsets: (num_elements + 1) × field_offset_size bytes +/// field values: concatenated variant values +Status DecodeObject(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed, int32_t depth) { + // Variant Encoding Spec: object value_header layout (bits 2-7 of full byte): + // bits 2-3 (type_info bits 0-1): field_offset_size_minus_one + // bits 4-5 (type_info bits 2-3): field_id_size_minus_one + // bit 6 (type_info bit 4): is_large (0 = 1-byte num_elements, 1 = 4-byte) + // bit 7 (type_info bit 5): unused + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + int64_t pos = offset + 1; // skip header + + // Read num_fields + if (pos + num_fields_size > length) { + return Status::Invalid("Variant value: truncated object num_fields at offset ", + offset); + } + auto num_fields = static_cast(ReadUnsignedLE(data + pos, num_fields_size)); + pos += num_fields_size; + + // Read field IDs + int64_t field_ids_size = static_cast(num_fields) * field_id_size; + if (pos + field_ids_size > length) { + return Status::Invalid("Variant value: truncated object field_ids at offset ", + offset); + } + // TODO: Consider using a stack-allocated small_vector (e.g. arrow::internal:: + // SmallVector) for field_ids and value_offsets to avoid heap allocation for + // the common case of small objects (< 16 fields). Acceptable for a + // correctness-first implementation; optimize if profiling shows pressure. + std::vector field_ids(num_fields); + // NOTE: Per spec, field IDs must be in lexicographic order of corresponding + // key names. We do not validate this ordering here for performance; see + // FindObjectField which relies on this invariant for binary search. + for (int32_t i = 0; i < num_fields; ++i) { + field_ids[i] = ReadUnsignedLE(data + pos, field_id_size); + pos += field_id_size; + } + + // Read value offsets (num_fields + 1 entries) + int64_t offsets_size = (static_cast(num_fields) + 1) * field_offset_size; + if (pos + offsets_size > length) { + return Status::Invalid("Variant value: truncated object offsets at offset ", offset); + } + std::vector value_offsets(num_fields + 1); + for (int32_t i = 0; i <= num_fields; ++i) { + value_offsets[i] = ReadUnsignedLE(data + pos, field_offset_size); + pos += field_offset_size; + } + + // Note: per spec, object field offsets are NOT required to be + // monotonically increasing because field values may be stored + // in a different order than field IDs. + + // The field data starts at pos + int64_t data_start = pos; + int64_t total_data_size = static_cast(value_offsets[num_fields]); + + if (data_start + total_data_size > length) { + return Status::Invalid("Variant value: object data exceeds buffer at offset ", + offset); + } + + // Validate each field offset is within the data region. + // Unlike arrays, object offsets need not be monotonic, but each must + // point within the valid data area. + for (int32_t i = 0; i < num_fields; ++i) { + if (value_offsets[i] > static_cast(total_data_size)) { + return Status::Invalid("Variant value: object field offset ", value_offsets[i], + " at index ", i, " exceeds data size ", total_data_size); + } + } + + ARROW_RETURN_NOT_OK(visitor->StartObject(num_fields)); + + for (int32_t i = 0; i < num_fields; ++i) { + // Resolve field name from metadata dictionary + auto field_id = field_ids[i]; + if (field_id >= metadata.strings.size()) { + return Status::Invalid("Variant value: field_id ", field_id, + " exceeds metadata dictionary size ", + metadata.strings.size()); + } + ARROW_RETURN_NOT_OK(visitor->FieldName(metadata.strings[field_id])); + + // Decode the field value. Pass data_start + total_data_size as the effective + // length to restrict field value decoding within this object's data region. + // NOTE: We do not validate that consumed bytes match the expected field size + // (value_offsets[i+1] - value_offsets[i]) because object offsets are not + // required to be monotonic, making per-field size inference unreliable. + // TODO: Consider optional strict validation for untrusted input. + int64_t field_offset = data_start + value_offsets[i]; + int64_t consumed = 0; + ARROW_RETURN_NOT_OK(DecodeValueAt(metadata, data, data_start + total_data_size, + field_offset, visitor, &consumed, depth)); + } + + ARROW_RETURN_NOT_OK(visitor->EndObject()); + + *bytes_consumed = (data_start - offset) + total_data_size; + return Status::OK(); +} + +/// \brief Decode an array value (basic_type == 3). +/// +/// Array layout per spec: +/// header (1 byte): +/// bits 0-1: basic_type = 3 +/// bits 2-3: field_offset_size_minus_one +/// bit 4: is_large (0 → 1-byte num_elements, 1 → 4-byte) +/// bits 5-7: unused +/// num_elements: 1 or 4 bytes (unsigned LE) +/// field_offsets: (num_elements + 1) × field_offset_size bytes +/// element values: concatenated variant values +Status DecodeArray(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, uint8_t header, VariantVisitor* visitor, + int64_t* bytes_consumed, int32_t depth) { + // Variant Encoding Spec: array value_header layout (bits 2-7 of full byte): + // bits 2-3 (type_info bits 0-1): field_offset_size_minus_one + // bit 4 (type_info bit 2): is_large (0 = 1-byte num_elements, 1 = 4-byte) + // bits 5-7 (type_info bits 3-5): unused + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + + int64_t pos = offset + 1; // skip header + + // Read num_elements + if (pos + num_elements_size > length) { + return Status::Invalid("Variant value: truncated array num_elements at offset ", + offset); + } + auto num_elements = static_cast(ReadUnsignedLE(data + pos, num_elements_size)); + pos += num_elements_size; + + // Read offsets (num_elements + 1 entries) + int64_t offsets_size = (static_cast(num_elements) + 1) * field_offset_size; + if (pos + offsets_size > length) { + return Status::Invalid("Variant value: truncated array offsets at offset ", offset); + } + // TODO: Consider stack-allocated small_vector for the common case of small arrays. + std::vector value_offsets(num_elements + 1); + for (int32_t i = 0; i <= num_elements; ++i) { + value_offsets[i] = ReadUnsignedLE(data + pos, field_offset_size); + pos += field_offset_size; + } + + // Validate value offsets are monotonically non-decreasing + for (int32_t i = 1; i <= num_elements; ++i) { + if (value_offsets[i] < value_offsets[i - 1]) { + return Status::Invalid( + "Variant value: array value offsets are not monotonically " + "non-decreasing at index ", + i); + } + } + + // The element data starts at pos + int64_t data_start = pos; + int64_t total_data_size = static_cast(value_offsets[num_elements]); + + if (data_start + total_data_size > length) { + return Status::Invalid("Variant value: array data exceeds buffer at offset ", offset); + } + + ARROW_RETURN_NOT_OK(visitor->StartArray(num_elements)); + + for (int32_t i = 0; i < num_elements; ++i) { + // Pass data_start + total_data_size as the effective length to restrict + // element value decoding within this array's data region. + // NOTE: consumed bytes are not validated against expected element size + // (value_offsets[i+1] - value_offsets[i]). Monotonicity of offsets is + // already validated above, but we do not check that each element exactly + // fills its allocated slot. TODO: Consider optional strict validation. + int64_t elem_offset = data_start + value_offsets[i]; + int64_t consumed = 0; + ARROW_RETURN_NOT_OK(DecodeValueAt(metadata, data, data_start + total_data_size, + elem_offset, visitor, &consumed, depth)); + } + + ARROW_RETURN_NOT_OK(visitor->EndArray()); + + *bytes_consumed = (data_start - offset) + total_data_size; + return Status::OK(); +} + +Status DecodeValueAt(const VariantMetadata& metadata, const uint8_t* data, int64_t length, + int64_t offset, VariantVisitor* visitor, int64_t* bytes_consumed, + int32_t depth) { + if (offset >= length) { + return Status::Invalid("Variant value: offset ", offset, + " is at or beyond buffer length ", length); + } + if (depth > kMaxNestingDepth) { + return Status::Invalid("Variant value: nesting depth exceeds maximum of ", + kMaxNestingDepth); + } + + uint8_t header = data[offset]; + auto basic_type = GetBasicType(header); + + switch (basic_type) { + case BasicType::kPrimitive: + return DecodePrimitive(data, length, offset, header, visitor, bytes_consumed); + case BasicType::kShortString: + return DecodeShortString(data, length, offset, header, visitor, bytes_consumed); + case BasicType::kObject: + return DecodeObject(metadata, data, length, offset, header, visitor, bytes_consumed, + depth + 1); + case BasicType::kArray: + return DecodeArray(metadata, data, length, offset, header, visitor, bytes_consumed, + depth + 1); + default: + return Status::Invalid("Variant value: unknown basic type ", + static_cast(basic_type)); + } +} + +} // namespace + +// --------------------------------------------------------------------------- +// Public API implementations +// --------------------------------------------------------------------------- + +int32_t PrimitiveValueSize(PrimitiveType primitive_type) { + switch (primitive_type) { + case PrimitiveType::kNull: + case PrimitiveType::kTrue: + case PrimitiveType::kFalse: + return 0; + case PrimitiveType::kInt8: + return 1; + case PrimitiveType::kInt16: + return 2; + case PrimitiveType::kInt32: + case PrimitiveType::kFloat: + case PrimitiveType::kDate: + return 4; + case PrimitiveType::kInt64: + case PrimitiveType::kDouble: + case PrimitiveType::kTimestampMicros: + case PrimitiveType::kTimestampMicrosNTZ: + case PrimitiveType::kTimeNTZ: + case PrimitiveType::kTimestampNanos: + case PrimitiveType::kTimestampNanosNTZ: + return 8; + case PrimitiveType::kDecimal4: + return 5; // 1 byte scale + 4 bytes value + case PrimitiveType::kDecimal8: + return 9; // 1 byte scale + 8 bytes value + case PrimitiveType::kDecimal16: + return 17; // 1 byte scale + 16 bytes value + case PrimitiveType::kUUID: + return 16; + case PrimitiveType::kBinary: + case PrimitiveType::kString: + return -1; // variable length + default: + return -1; + } +} + +Result DecodeMetadata(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant metadata: buffer is null or empty"); + } + + // Variant Encoding Spec §2: Metadata encoding + // Header byte: bits 0-3 = version, bit 4 = sorted, bit 5 = reserved, + // bits 6-7 = offset_size-1 + uint8_t header = data[0]; + uint8_t version = header & 0x0F; + if (version != kVariantVersion) { + return Status::Invalid("Variant metadata: unsupported version ", + static_cast(version), ", expected ", + static_cast(kVariantVersion)); + } + + // Bit 5 is reserved and must be zero in version 1 + if ((header >> 5) & 0x01) { + return Status::Invalid("Variant metadata: reserved bit 5 is set in header"); + } + + bool is_sorted = ((header >> 4) & 0x01) != 0; + int32_t offset_size = ((header >> 6) & 0x03) + 1; + + int64_t pos = 1; + + // Read dictionary size + if (pos + offset_size > length) { + return Status::Invalid("Variant metadata: truncated dictionary size at byte ", pos); + } + auto dict_size = static_cast(ReadUnsignedLE(data + pos, offset_size)); + pos += offset_size; + + // Read string offsets: (dict_size + 1) offsets + int64_t offsets_bytes = static_cast(dict_size + 1) * offset_size; + if (pos + offsets_bytes > length) { + return Status::Invalid("Variant metadata: truncated string offsets, need ", + offsets_bytes, " bytes at position ", pos, + " but buffer length is ", length); + } + + std::vector offsets(dict_size + 1); + for (int32_t i = 0; i <= dict_size; ++i) { + offsets[i] = ReadUnsignedLE(data + pos, offset_size); + pos += offset_size; + } + + // Validate offsets + int64_t string_data_length = length - pos; + ARROW_RETURN_NOT_OK(ValidateOffsets(offsets, string_data_length)); + + // Extract string views + std::vector strings(dict_size); + for (int32_t i = 0; i < dict_size; ++i) { + auto start = static_cast(offsets[i]); + auto end = static_cast(offsets[i + 1]); + strings[i] = + std::string_view(reinterpret_cast(data + pos + start), end - start); + } + + VariantMetadata result; + result.version = version; + result.is_sorted = is_sorted; + result.offset_size = offset_size; + result.strings = std::move(strings); + return result; +} + +Status DecodeVariantValue(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, VariantVisitor* visitor) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + DCHECK_NE(visitor, nullptr); + int64_t bytes_consumed = 0; + return DecodeValueAt(metadata, data, length, 0, visitor, &bytes_consumed, /*depth=*/0); +} + +Result GetValueBasicType(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + return GetBasicType(data[0]); +} + +Result GetObjectFieldCount(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("Variant value: expected object but got basic type ", + static_cast(GetBasicType(header))); + } + // type_info bit 4 = is_large (bit 6 of full byte) + uint8_t type_info = (header >> 2) & 0x3F; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + if (1 + num_fields_size > length) { + return Status::Invalid("Variant value: truncated object header"); + } + return static_cast(ReadUnsignedLE(data + 1, num_fields_size)); +} + +Result GetArrayElementCount(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("Variant value: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kArray) { + return Status::Invalid("Variant value: expected array but got basic type ", + static_cast(GetBasicType(header))); + } + // type_info bit 2 = is_large (bit 4 of full byte) + uint8_t type_info = (header >> 2) & 0x3F; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + if (1 + num_elements_size > length) { + return Status::Invalid("Variant value: truncated array header"); + } + return static_cast(ReadUnsignedLE(data + 1, num_elements_size)); +} + +Result ValueSize(const uint8_t* data, int64_t length) { + if (data == nullptr || length < 1) { + return Status::Invalid("ValueSize: buffer is null or empty"); + } + + uint8_t header = data[0]; + auto basic_type = GetBasicType(header); + uint8_t type_info = (header >> 2) & 0x3F; + + switch (basic_type) { + case BasicType::kShortString: + return 1 + static_cast(type_info); + + case BasicType::kObject: { + // type_info bit 4 = is_large (bit 6 of full byte) + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t sz_bytes = is_large ? 4 : 1; + if (1 + sz_bytes > length) { + return Status::Invalid("ValueSize: truncated object header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, sz_bytes)); + int32_t id_size = ((type_info >> 2) & 0x03) + 1; + int32_t offset_size = (type_info & 0x03) + 1; + int64_t id_start = 1 + sz_bytes; + int64_t offset_start = id_start + num_elements * id_size; + int64_t data_start = offset_start + (num_elements + 1) * offset_size; + // Last offset = total data size + int64_t last_offset_pos = offset_start + num_elements * offset_size; + if (last_offset_pos + offset_size > length) { + return Status::Invalid("ValueSize: truncated object offsets"); + } + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, offset_size)); + return data_start + total_data; + } + + case BasicType::kArray: { + // type_info bit 2 = is_large (bit 4 of full byte) + // Note: Go's valueSize() in arrow-go (prior to fix PR) incorrectly + // used (typeInfo >> 4) for arrays, which reads bit 6 — the object's + // is_large position. The spec places array is_large at bit 4 of the + // full header byte. See: apache/arrow-go#839. + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t sz_bytes = is_large ? 4 : 1; + if (1 + sz_bytes > length) { + return Status::Invalid("ValueSize: truncated array header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, sz_bytes)); + int32_t offset_size = (type_info & 0x03) + 1; + int64_t offset_start = 1 + sz_bytes; + int64_t data_start = offset_start + (num_elements + 1) * offset_size; + // Last offset = total data size + int64_t last_offset_pos = offset_start + num_elements * offset_size; + if (last_offset_pos + offset_size > length) { + return Status::Invalid("ValueSize: truncated array offsets"); + } + auto total_data = + static_cast(ReadUnsignedLE(data + last_offset_pos, offset_size)); + return data_start + total_data; + } + + case BasicType::kPrimitive: { + auto ptype = static_cast(type_info); + int32_t payload_size = PrimitiveValueSize(ptype); + if (payload_size >= 0) { + return 1 + static_cast(payload_size); + } + // Variable-length: Binary or String (4-byte length prefix) + if (1 + 4 > length) { + return Status::Invalid("ValueSize: truncated variable-length header"); + } + uint32_t var_len; + std::memcpy(&var_len, data + 1, 4); + var_len = bit_util::FromLittleEndian(var_len); + return 1 + 4 + static_cast(var_len); + } + + default: + return Status::Invalid("ValueSize: unknown basic type"); + } +} + +Status FindObjectField(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, std::string_view field_name, int64_t* field_offset, + int64_t* field_size) { + *field_offset = -1; + *field_size = 0; + + if (data == nullptr || length < 1) { + return Status::Invalid("FindObjectField: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("FindObjectField: not an object"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + if (1 + num_fields_size > length) { + return Status::Invalid("FindObjectField: truncated header"); + } + auto num_fields = static_cast(ReadUnsignedLE(data + 1, num_fields_size)); + + int64_t id_start = 1 + num_fields_size; + int64_t offset_start = id_start + static_cast(num_fields) * field_id_size; + int64_t data_start = + offset_start + (static_cast(num_fields) + 1) * field_offset_size; + + if (data_start > length) { + return Status::Invalid("FindObjectField: truncated object"); + } + + // Per spec, field IDs are in lexicographic order of corresponding keys. + // Use binary search for large objects, linear scan for small ones. + // NOTE: If the input violates this ordering invariant (malformed data), + // binary search may return incorrect results. We do not validate sorting + // here for performance; callers should use DecodeVariantValue() for full + // validation of untrusted input. + constexpr int32_t kBinarySearchThreshold = 32; + + // Note: get_key_at returns an empty string_view for out-of-range field IDs. + // For the binary search path, this could theoretically misorder comparisons, + // but out-of-range IDs indicate a malformed variant. The function will simply + // not find the requested key and return field_offset=-1 (not found), which is + // a safe degradation for corrupted data. + auto get_key_at = [&](int32_t i) -> std::string_view { + auto id = ReadUnsignedLE(data + id_start + i * field_id_size, field_id_size); + if (id < metadata.strings.size()) { + return metadata.strings[id]; + } + return {}; + }; + + auto get_value_offset = [&](int32_t i) -> int64_t { + return data_start + + static_cast(ReadUnsignedLE( + data + offset_start + i * field_offset_size, field_offset_size)); + }; + + int32_t found_index = -1; + + if (num_fields < kBinarySearchThreshold) { + // Linear scan for small objects + for (int32_t i = 0; i < num_fields; ++i) { + if (get_key_at(i) == field_name) { + found_index = i; + break; + } + } + } else { + // Binary search for large objects (keys are in lex order). + // Note: int32_t is used deliberately for lo/hi to avoid unsigned + // underflow when hi = mid - 1 and mid == 0. The Go implementation + // (ObjectValue.ValueByKey) uses uint32 which wraps to MaxUint32. + int32_t lo = 0, hi = num_fields - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + auto key = get_key_at(mid); + if (key == field_name) { + found_index = mid; + break; + } else if (key < field_name) { + lo = mid + 1; + } else { + hi = mid - 1; + } + } + } + + if (found_index >= 0) { + *field_offset = get_value_offset(found_index); + ARROW_ASSIGN_OR_RAISE(auto size, + ValueSize(data + *field_offset, length - *field_offset)); + *field_size = size; + } + + return Status::OK(); +} + +Status GetArrayElement(const uint8_t* data, int64_t length, int32_t index, + int64_t* element_offset, int64_t* element_size) { + if (data == nullptr || length < 1) { + return Status::Invalid("GetArrayElement: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kArray) { + return Status::Invalid("GetArrayElement: not an array"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int32_t field_offset_size = (type_info & 0x03) + 1; + bool is_large = ((type_info >> 2) & 0x01) != 0; + int32_t num_elements_size = is_large ? 4 : 1; + + if (1 + num_elements_size > length) { + return Status::Invalid("GetArrayElement: truncated header"); + } + auto num_elements = static_cast(ReadUnsignedLE(data + 1, num_elements_size)); + + if (index < 0 || index >= num_elements) { + return Status::Invalid("GetArrayElement: index ", index, " out of range [0, ", + num_elements, ")"); + } + + int64_t offset_start = 1 + num_elements_size; + int64_t data_start = + offset_start + (static_cast(num_elements) + 1) * field_offset_size; + + auto elem_offset = static_cast( + ReadUnsignedLE(data + offset_start + index * field_offset_size, field_offset_size)); + *element_offset = data_start + elem_offset; + ARROW_ASSIGN_OR_RAISE(auto size, + ValueSize(data + *element_offset, length - *element_offset)); + *element_size = size; + return Status::OK(); +} + +Status GetObjectFieldAt(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int32_t index, std::string_view* field_name, + int64_t* field_offset, int64_t* field_size) { + if (data == nullptr || length < 1) { + return Status::Invalid("GetObjectFieldAt: buffer is null or empty"); + } + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kObject) { + return Status::Invalid("GetObjectFieldAt: not an object"); + } + + uint8_t type_info = (header >> 2) & 0x3F; + int32_t obj_offset_size = (type_info & 0x03) + 1; + int32_t field_id_size = ((type_info >> 2) & 0x03) + 1; + bool is_large = ((type_info >> 4) & 0x01) != 0; + int32_t num_fields_size = is_large ? 4 : 1; + + if (1 + num_fields_size > length) { + return Status::Invalid("GetObjectFieldAt: truncated header"); + } + auto num_fields = static_cast(ReadUnsignedLE(data + 1, num_fields_size)); + + if (index < 0 || index >= num_fields) { + return Status::Invalid("GetObjectFieldAt: index ", index, " out of range [0, ", + num_fields, ")"); + } + + int64_t id_start = 1 + num_fields_size; + int64_t offset_start = id_start + static_cast(num_fields) * field_id_size; + int64_t data_start = + offset_start + (static_cast(num_fields) + 1) * obj_offset_size; + + // Get field name from dictionary + auto field_id = ReadUnsignedLE(data + id_start + index * field_id_size, field_id_size); + if (field_id >= metadata.strings.size()) { + return Status::Invalid("GetObjectFieldAt: field_id ", field_id, + " exceeds dictionary size ", metadata.strings.size()); + } + *field_name = metadata.strings[field_id]; + + // Get field value offset + auto value_offset = static_cast( + ReadUnsignedLE(data + offset_start + index * obj_offset_size, obj_offset_size)); + *field_offset = data_start + value_offset; + ARROW_ASSIGN_OR_RAISE(auto size, + ValueSize(data + *field_offset, length - *field_offset)); + *field_size = size; + return Status::OK(); +} + +int32_t FindMetadataKey(const VariantMetadata& metadata, std::string_view key) { + if (metadata.is_sorted) { + // Binary search on sorted dictionary + int32_t lo = 0; + int32_t hi = static_cast(metadata.strings.size()) - 1; + while (lo <= hi) { + int32_t mid = lo + (hi - lo) / 2; + int cmp = metadata.strings[mid].compare(key); + if (cmp == 0) { + return mid; + } else if (cmp < 0) { + lo = mid + 1; + } else { + hi = mid - 1; + } + } + return -1; + } + + // Linear scan for unsorted dictionary + for (int32_t i = 0; i < static_cast(metadata.strings.size()); ++i) { + if (metadata.strings[i] == key) { + return i; + } + } + return -1; +} + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal.h b/cpp/src/arrow/extension/variant_internal.h new file mode 100644 index 000000000000..6f19e9fe6769 --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal.h @@ -0,0 +1,347 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow::extension::variant_internal { + +/// \file variant_internal.h +/// \brief Utilities for Variant binary encoding/decoding. +/// +/// Implements parsing logic per the Variant Encoding Spec: +/// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md +/// +/// The "internal" in the filename refers to the binary encoding internals +/// of the Variant type, not the visibility of this header. This header is +/// installed and provides the public C++ API for working with Variant +/// binary data (independent of the VariantExtensionType in parquet_variant.h). + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Variant encoding spec version 1. +constexpr uint8_t kVariantVersion = 1; + +/// Maximum nesting depth for recursive value decoding. +/// Prevents stack overflow on deeply nested (possibly malicious) input. +constexpr int32_t kMaxNestingDepth = 128; + +// --------------------------------------------------------------------------- +// Enumerations +// --------------------------------------------------------------------------- + +/// \brief Basic type codes from bits 0-1 of the value header byte. +/// +/// Variant Encoding Spec §3: "Value encoding" +enum class BasicType : uint8_t { + kPrimitive = 0, + kShortString = 1, + kObject = 2, + kArray = 3, +}; + +/// \brief Primitive type codes from bits 2-7 when basic_type == kPrimitive. +/// +/// Variant Encoding Spec §3.1: "Primitive types" +enum class PrimitiveType : uint8_t { + kNull = 0, + kTrue = 1, + kFalse = 2, + kInt8 = 3, + kInt16 = 4, + kInt32 = 5, + kInt64 = 6, + kDouble = 7, + kDecimal4 = 8, + kDecimal8 = 9, + kDecimal16 = 10, + kDate = 11, + kTimestampMicros = 12, + kTimestampMicrosNTZ = 13, + kFloat = 14, + kBinary = 15, + kString = 16, + kTimeNTZ = 17, + kTimestampNanos = 18, + kTimestampNanosNTZ = 19, + kUUID = 20, +}; + +// --------------------------------------------------------------------------- +// Metadata +// --------------------------------------------------------------------------- + +/// \brief Parsed variant metadata (string dictionary). +/// +/// The metadata buffer contains a header byte followed by a dictionary of +/// interned strings. String views reference the raw buffer and are valid +/// only as long as the underlying buffer is alive. +struct ARROW_EXPORT VariantMetadata { + /// Spec version (must be kVariantVersion). + uint8_t version = 0; + + /// Whether the dictionary strings are sorted lexicographically. + bool is_sorted = false; + + /// Number of bytes used for each offset (1, 2, 3, or 4). + int32_t offset_size = 0; + + /// Dictionary of interned strings. Views into the raw metadata buffer. + std::vector strings; +}; + +/// \brief Decode a variant metadata buffer. +/// +/// Parses the header byte and string dictionary from the raw metadata +/// buffer. The returned VariantMetadata contains string_views that +/// reference the input buffer directly (zero-copy). +/// +/// \param[in] data Pointer to the metadata buffer (must not be null) +/// \param[in] length Length of the metadata buffer in bytes +/// \return Parsed VariantMetadata on success, Status::Invalid on +/// malformed input +/// +/// \note The input buffer must outlive the returned VariantMetadata. +ARROW_EXPORT Result DecodeMetadata(const uint8_t* data, int64_t length); + +// --------------------------------------------------------------------------- +// Value header utilities +// --------------------------------------------------------------------------- + +/// \brief Extract the basic type from a value header byte. +/// +/// \param[in] header The first byte of a variant value +/// \return The BasicType (bits 0-1) +inline BasicType GetBasicType(uint8_t header) { + return static_cast(header & 0x03); +} + +/// \brief Extract the primitive type from a value header byte. +/// +/// Only valid when GetBasicType(header) == BasicType::kPrimitive. +/// +/// \param[in] header The first byte of a variant value +/// \return The PrimitiveType (bits 2-7) +inline PrimitiveType GetPrimitiveType(uint8_t header) { + return static_cast((header >> 2) & 0x3F); +} + +/// \brief Get the byte size of a primitive value (excluding header). +/// +/// \param[in] primitive_type The primitive type code +/// \return Number of bytes for the value payload, or -1 for +/// variable-length types (Binary, String) +ARROW_EXPORT int32_t PrimitiveValueSize(PrimitiveType primitive_type); + +// --------------------------------------------------------------------------- +// Value decoding +// --------------------------------------------------------------------------- + +/// \brief Visitor interface for variant value decoding. +/// +/// Implement this interface to receive callbacks during variant value +/// traversal. The visitor pattern avoids materializing a tree of objects, +/// which is important when scanning millions of rows. +/// +/// All methods return Status::OK() to continue traversal, or any error +/// Status to abort. +/// +/// \note String values passed to String() and FieldName() are raw bytes from +/// the variant buffer without UTF-8 validation. Per spec, all strings +/// must be valid UTF-8, but validation is the responsibility of a +/// higher-level consumer (e.g., when materializing to Arrow StringArray). +class ARROW_EXPORT VariantVisitor { + public: + virtual ~VariantVisitor() = default; + + /// @name Primitive value callbacks + /// @{ + virtual Status Null() = 0; + virtual Status Bool(bool value) = 0; + virtual Status Int8(int8_t value) = 0; + virtual Status Int16(int16_t value) = 0; + virtual Status Int32(int32_t value) = 0; + virtual Status Int64(int64_t value) = 0; + virtual Status Float(float value) = 0; + virtual Status Double(double value) = 0; + virtual Status Decimal4(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Decimal8(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Decimal16(const uint8_t* bytes, int32_t scale) = 0; + virtual Status Date(int32_t days_since_epoch) = 0; + virtual Status TimestampMicros(int64_t micros_since_epoch) = 0; + virtual Status TimestampMicrosNTZ(int64_t micros_since_epoch) = 0; + virtual Status String(std::string_view value) = 0; + virtual Status Binary(std::string_view value) = 0; + virtual Status TimeNTZ(int64_t micros_since_midnight) = 0; + virtual Status TimestampNanos(int64_t nanos_since_epoch) = 0; + virtual Status TimestampNanosNTZ(int64_t nanos_since_epoch) = 0; + virtual Status UUID(const uint8_t* bytes) = 0; + /// @} + + /// @name Container callbacks + /// @{ + + /// \brief Called at the start of an object with the number of fields. + virtual Status StartObject(int32_t num_fields) = 0; + + /// \brief Called for each object field name, before the field value. + virtual Status FieldName(std::string_view name) = 0; + + /// \brief Called after all fields of an object have been visited. + virtual Status EndObject() = 0; + + /// \brief Called at the start of an array with the number of elements. + virtual Status StartArray(int32_t num_elements) = 0; + + /// \brief Called after all elements of an array have been visited. + virtual Status EndArray() = 0; + /// @} +}; + +/// \brief Decode a variant value buffer using a visitor. +/// +/// Recursively traverses the variant value, calling the appropriate +/// visitor methods for each element. Objects and arrays trigger +/// Start/End pairs with nested visits for their contents. +/// +/// \param[in] metadata Parsed metadata (for resolving string dictionary) +/// \param[in] data Pointer to the value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] visitor Callback interface for decoded values +/// \return Status::OK on success, Status::Invalid on malformed input +/// +/// \note The data buffer must remain valid for the duration of the call. +ARROW_EXPORT Status DecodeVariantValue(const VariantMetadata& metadata, + const uint8_t* data, int64_t length, + VariantVisitor* visitor); + +/// \brief Get the basic type of a variant value without full decoding. +/// +/// \param[in] data Pointer to the value buffer +/// \param[in] length Length of the value buffer in bytes +/// \return The BasicType of the value, or Status::Invalid if the +/// buffer is empty +ARROW_EXPORT Result GetValueBasicType(const uint8_t* data, int64_t length); + +/// \brief Get the number of fields in a variant object. +/// +/// \param[in] data Pointer to the value buffer (must start with an object) +/// \param[in] length Length of the value buffer in bytes +/// \return The number of fields, or Status::Invalid if not an object +ARROW_EXPORT Result GetObjectFieldCount(const uint8_t* data, int64_t length); + +/// \brief Get the number of elements in a variant array. +/// +/// \param[in] data Pointer to the value buffer (must start with an array) +/// \param[in] length Length of the value buffer in bytes +/// \return The number of elements, or Status::Invalid if not an array +ARROW_EXPORT Result GetArrayElementCount(const uint8_t* data, int64_t length); + +// --------------------------------------------------------------------------- +// Value size computation +// --------------------------------------------------------------------------- + +/// \brief Compute the total byte size of a variant value (header + data). +/// +/// Determines how many bytes a variant value occupies by examining +/// its header and (for containers/variable-length types) reading +/// size information. Does NOT recursively validate the contents. +/// +/// \param[in] data Pointer to the start of a variant value +/// \param[in] length Maximum bytes available +/// \return Total byte count of the value, or Status::Invalid if truncated +ARROW_EXPORT Result ValueSize(const uint8_t* data, int64_t length); + +// --------------------------------------------------------------------------- +// Random access utilities +// --------------------------------------------------------------------------- + +/// \brief Find an object field by name and return the offset/size of its value. +/// +/// Searches the field IDs in the object, resolving each against the +/// metadata dictionary. Per spec, field IDs are in lexicographic order +/// of their corresponding key names, enabling binary search for large +/// objects (>=32 fields). For smaller objects, linear scan is used. +/// +/// \param[in] metadata Parsed metadata (for resolving field IDs to names) +/// \param[in] data Pointer to the object value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] field_name The field name to search for +/// \param[out] field_offset Set to the byte offset of the field's value +/// within data, or -1 if not found +/// \param[out] field_size Set to the byte size of the field's value, +/// or 0 if not found +/// \return Status::OK if search completed (field may or may not exist), +/// Status::Invalid if the buffer is malformed +ARROW_EXPORT Status FindObjectField(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, std::string_view field_name, + int64_t* field_offset, int64_t* field_size); + +/// \brief Get the i-th element of a variant array by index (O(1) access). +/// +/// Uses the offset table for random access without traversing preceding +/// elements. +/// +/// \param[in] data Pointer to the array value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] index Zero-based element index +/// \param[out] element_offset Set to the byte offset of the element within data +/// \param[out] element_size Set to the byte size of the element +/// \return Status::OK on success, Status::Invalid if not an array or +/// index is out of range +ARROW_EXPORT Status GetArrayElement(const uint8_t* data, int64_t length, int32_t index, + int64_t* element_offset, int64_t* element_size); + +/// \brief Get the i-th field of a variant object by position. +/// +/// Returns both the field name (resolved from metadata) and a pointer +/// to the field's value. +/// +/// \param[in] metadata Parsed metadata +/// \param[in] data Pointer to the object value buffer +/// \param[in] length Length of the value buffer in bytes +/// \param[in] index Zero-based field index +/// \param[out] field_name Set to the field's key name +/// \param[out] field_offset Set to the byte offset of the field's value +/// \param[out] field_size Set to the byte size of the field's value +/// \return Status::OK on success, Status::Invalid if not an object or +/// index is out of range +ARROW_EXPORT Status GetObjectFieldAt(const VariantMetadata& metadata, const uint8_t* data, + int64_t length, int32_t index, + std::string_view* field_name, int64_t* field_offset, + int64_t* field_size); + +/// \brief Find the dictionary ID for a given key name. +/// +/// Uses binary search if the metadata is sorted, otherwise linear scan. +/// +/// \param[in] metadata Parsed metadata +/// \param[in] key The key to search for +/// \return The dictionary ID if found, or -1 if not present +ARROW_EXPORT int32_t FindMetadataKey(const VariantMetadata& metadata, + std::string_view key); + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal_test.cc b/cpp/src/arrow/extension/variant_internal_test.cc new file mode 100644 index 000000000000..9cfdb665aa82 --- /dev/null +++ b/cpp/src/arrow/extension/variant_internal_test.cc @@ -0,0 +1,2128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_internal.h" +#include "arrow/extension/variant_test_util.h" + +#include +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" + +namespace arrow::extension::variant_internal { + +// =========================================================================== +// Test helpers +// =========================================================================== + +/// \brief Build a metadata buffer from a list of strings. +/// +/// Uses offset_size=1, version=1, sorted flag as specified. +std::vector BuildMetadataBuffer(const std::vector& strings, + bool sorted = false, int32_t offset_size = 1) { + std::vector buffer; + + // Header byte: version=1, sorted flag, offset_size + uint8_t header = kVariantVersion; + if (sorted) { + header |= (1 << 4); + } + header |= static_cast((offset_size - 1) << 6); + buffer.push_back(header); + + // Dictionary size + auto dict_size = static_cast(strings.size()); + for (int32_t b = 0; b < offset_size; ++b) { + buffer.push_back(static_cast((dict_size >> (b * 8)) & 0xFF)); + } + + // Compute string offsets + std::vector offsets(dict_size + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < dict_size; ++i) { + offsets[i + 1] = offsets[i] + static_cast(strings[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= dict_size; ++i) { + for (int32_t b = 0; b < offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write string data + for (const auto& s : strings) { + buffer.insert(buffer.end(), s.begin(), s.end()); + } + + return buffer; +} + +/// \brief Build a primitive value header byte. +uint8_t PrimitiveHeader(PrimitiveType type) { + return static_cast(BasicType::kPrimitive) | (static_cast(type) << 2); +} + +/// \brief Build a short string value buffer. +std::vector BuildShortString(const std::string& s) { + std::vector buffer; + auto len = static_cast(s.size()); + uint8_t header = static_cast(BasicType::kShortString) | (len << 2); + buffer.push_back(header); + buffer.insert(buffer.end(), s.begin(), s.end()); + return buffer; +} + +/// \brief Build an object value buffer. +/// +/// \param field_ids Dictionary indices for each field name +/// \param field_values Serialized variant values for each field +/// \param field_id_size Bytes per field ID (1-4) +/// \param field_offset_size Bytes per offset (1-4) +std::vector BuildObject(const std::vector& field_ids, + const std::vector>& field_values, + int32_t field_id_size = 1, + int32_t field_offset_size = 1) { + auto num_fields = static_cast(field_ids.size()); + bool is_large = (num_fields > 255); + + std::vector buffer; + + // Header per spec: basic_type=2 in bits 0-1, + // bits 2-3: field_offset_size-1 + // bits 4-5: field_id_size-1 + // bit 6: is_large + uint8_t header = static_cast(BasicType::kObject); + header |= static_cast((field_offset_size - 1) << 2); + header |= static_cast((field_id_size - 1) << 4); + if (is_large) { + header |= (1 << 6); + } + buffer.push_back(header); + + // num_fields: 1 byte or 4 bytes depending on is_large + int32_t num_fields_size = is_large ? 4 : 1; + for (int32_t b = 0; b < num_fields_size; ++b) { + buffer.push_back(static_cast((num_fields >> (b * 8)) & 0xFF)); + } + + // field_ids + for (auto fid : field_ids) { + for (int32_t b = 0; b < field_id_size; ++b) { + buffer.push_back(static_cast((fid >> (b * 8)) & 0xFF)); + } + } + + // Compute offsets + std::vector offsets(num_fields + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < num_fields; ++i) { + offsets[i + 1] = offsets[i] + static_cast(field_values[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= num_fields; ++i) { + for (int32_t b = 0; b < field_offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write field value data + for (const auto& fv : field_values) { + buffer.insert(buffer.end(), fv.begin(), fv.end()); + } + + return buffer; +} + +/// \brief Build an array value buffer. +/// +/// \param elements Serialized variant values for each element +/// \param field_offset_size Bytes per offset (1-4) +std::vector BuildArray(const std::vector>& elements, + int32_t field_offset_size = 1) { + auto num_elements = static_cast(elements.size()); + bool is_large = (num_elements > 255); + + std::vector buffer; + + // Header per spec: basic_type=3 in bits 0-1, + // bits 2-3: field_offset_size-1 + // bit 4: is_large + uint8_t header = static_cast(BasicType::kArray); + header |= static_cast((field_offset_size - 1) << 2); + if (is_large) { + header |= (1 << 4); + } + buffer.push_back(header); + + // num_elements: 1 byte or 4 bytes depending on is_large + int32_t num_elements_size = is_large ? 4 : 1; + for (int32_t b = 0; b < num_elements_size; ++b) { + buffer.push_back(static_cast((num_elements >> (b * 8)) & 0xFF)); + } + + // Compute offsets + std::vector offsets(num_elements + 1); + offsets[0] = 0; + for (uint32_t i = 0; i < num_elements; ++i) { + offsets[i + 1] = offsets[i] + static_cast(elements[i].size()); + } + + // Write offsets + for (uint32_t i = 0; i <= num_elements; ++i) { + for (int32_t b = 0; b < field_offset_size; ++b) { + buffer.push_back(static_cast((offsets[i] >> (b * 8)) & 0xFF)); + } + } + + // Write element data + for (const auto& elem : elements) { + buffer.insert(buffer.end(), elem.begin(), elem.end()); + } + + return buffer; +} + +// =========================================================================== +// Metadata decoding tests +// =========================================================================== + +class VariantMetadataTest : public ::testing::Test {}; + +TEST_F(VariantMetadataTest, EmptyDictionary) { + auto buf = BuildMetadataBuffer({}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.version, 1); + ASSERT_FALSE(metadata.is_sorted); + ASSERT_EQ(metadata.offset_size, 1); + ASSERT_EQ(metadata.strings.size(), 0); +} + +TEST_F(VariantMetadataTest, SingleString) { + auto buf = BuildMetadataBuffer({"hello"}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 1); + ASSERT_EQ(metadata.strings[0], "hello"); +} + +TEST_F(VariantMetadataTest, MultipleStrings) { + auto buf = BuildMetadataBuffer({"name", "age", "scores"}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], "name"); + ASSERT_EQ(metadata.strings[1], "age"); + ASSERT_EQ(metadata.strings[2], "scores"); +} + +TEST_F(VariantMetadataTest, SortedFlag) { + auto buf = BuildMetadataBuffer({"age", "name", "score"}, true); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_TRUE(metadata.is_sorted); +} + +TEST_F(VariantMetadataTest, OffsetSize2) { + auto buf = BuildMetadataBuffer({"key1", "key2"}, false, 2); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 2); + ASSERT_EQ(metadata.strings.size(), 2); + ASSERT_EQ(metadata.strings[0], "key1"); + ASSERT_EQ(metadata.strings[1], "key2"); +} + +TEST_F(VariantMetadataTest, OffsetSize4) { + auto buf = BuildMetadataBuffer({"a", "bb", "ccc"}, false, 4); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 4); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], "a"); + ASSERT_EQ(metadata.strings[1], "bb"); + ASSERT_EQ(metadata.strings[2], "ccc"); +} + +TEST_F(VariantMetadataTest, EmptyStrings) { + auto buf = BuildMetadataBuffer({"", "nonempty", ""}); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.strings.size(), 3); + ASSERT_EQ(metadata.strings[0], ""); + ASSERT_EQ(metadata.strings[1], "nonempty"); + ASSERT_EQ(metadata.strings[2], ""); +} + +// Error cases + +TEST_F(VariantMetadataTest, NullBuffer) { + ASSERT_RAISES(Invalid, DecodeMetadata(nullptr, 0)); +} + +TEST_F(VariantMetadataTest, EmptyBuffer) { + uint8_t data = 0; + ASSERT_RAISES(Invalid, DecodeMetadata(&data, 0)); +} + +TEST_F(VariantMetadataTest, UnsupportedVersion) { + // Version 2 (unsupported) + uint8_t data[] = {0x02, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, TruncatedDictionarySize) { + // Header says offset_size=2 (bits 6-7 = 01), but only 1 byte follows + uint8_t data[] = {0x41, 0x00}; // version=1, offset_size=2 + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, TruncatedStringOffsets) { + // Claims dict_size=5 but buffer is too short for offsets + uint8_t data[] = {0x01, 0x05, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, OffsetSize3) { + auto buf = BuildMetadataBuffer({"foo", "bar"}, false, 3); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.offset_size, 3); + ASSERT_EQ(metadata.strings.size(), 2); + ASSERT_EQ(metadata.strings[0], "foo"); + ASSERT_EQ(metadata.strings[1], "bar"); +} + +TEST_F(VariantMetadataTest, ReservedBit5Set) { + // Header with bit 5 set: 0x21 = version=1, bit5=1 + uint8_t data[] = {0x21, 0x00, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantMetadataTest, NonMonotonicStringOffsets) { + // Manually construct metadata where string offsets are NOT monotonically + // non-decreasing. ValidateOffsets should reject this. + // Header: version=1, offset_size=1 + // dict_size=2, offsets=[0, 5, 3] — 3 < 5, non-monotonic + // String data: "helloabc" (8 bytes, but offsets claim 3 as last) + uint8_t data[] = { + 0x01, // header: version=1, offset_size=1 + 0x02, // dict_size = 2 + 0x00, 0x05, 0x03, // offsets: [0, 5, 3] — non-monotonic + 'h', 'e', 'l', 'l', 'o', 'a', 'b', 'c'}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +// =========================================================================== +// Primitive value decoding tests +// =========================================================================== + +class VariantPrimitiveTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantPrimitiveTest, DecodeNull) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Null"); +} + +TEST_F(VariantPrimitiveTest, DecodeTrue) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kTrue)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Bool(true)"); +} + +TEST_F(VariantPrimitiveTest, DecodeFalse) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kFalse)}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Bool(false)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt8) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x2A}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(42)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt8Negative) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0xD6}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(-42)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt16) { + // 300 = 0x012C in little-endian: 0x2C, 0x01 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt16), 0x2C, 0x01}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(300)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt32) { + // 100000 = 0x000186A0 in LE: A0 86 01 00 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0xA0, 0x86, 0x01, 0x00}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(100000)"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt32Max) { + int32_t val = std::numeric_limits::max(); + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kInt32); + std::memcpy(data + 1, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeInt64) { + int64_t val = 1234567890123LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kInt64); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int64(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeFloat) { + float val = 3.14f; + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kFloat); + std::memcpy(data + 1, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + // Float string representation may vary; just check it starts with Float( + ASSERT_TRUE(visitor.events[0].find("Float(") == 0); +} + +TEST_F(VariantPrimitiveTest, DecodeDouble) { + double val = 2.718281828459045; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kDouble); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +TEST_F(VariantPrimitiveTest, DecodeDate) { + // Days since epoch: 19000 (approximately 2022-01-01) + int32_t days = 19000; + uint8_t data[5]; + data[0] = PrimitiveHeader(PrimitiveType::kDate); + std::memcpy(data + 1, &days, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Date(19000)"); +} + +TEST_F(VariantPrimitiveTest, DecodeTimestampMicros) { + int64_t micros = 1654041600000000LL; // some timestamp + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampMicros); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampMicros(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeTimestampMicrosNTZ) { + int64_t micros = 1654041600000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampMicrosNTZ); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampMicrosNTZ(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal4) { + // Spec layout: 1 byte scale, then 4 bytes LE unscaled value + uint8_t data[6]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal4); + data[1] = 2; // scale = 2 + int32_t val = 12345; + std::memcpy(data + 2, &val, 4); // unscaled value + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal4(scale=2)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal4MaxScale) { + // Scale at maximum per spec: 38 + uint8_t data[6]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal4); + data[1] = 38; // scale = 38 (maximum per spec) + int32_t val = 12345; + std::memcpy(data + 2, &val, 4); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal4(scale=38)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal8) { + // Spec layout: 1 byte scale, then 8 bytes LE unscaled value + uint8_t data[10]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal8); + data[1] = 5; // scale = 5 + int64_t val = 123456789012345LL; + std::memcpy(data + 2, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal8(scale=5)"); +} + +TEST_F(VariantPrimitiveTest, DecodeDecimal16) { + // Spec layout: 1 byte scale, then 16 bytes LE unscaled value + uint8_t data[18]; + data[0] = PrimitiveHeader(PrimitiveType::kDecimal16); + data[1] = 10; // scale = 10 + std::memset(data + 2, 0, 16); + data[2] = 0x01; // low byte = 1 + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Decimal16(scale=10)"); +} + +TEST_F(VariantPrimitiveTest, DecodeLongString) { + // Long string: primitive type kString with 4-byte length prefix + std::string test_str = "hello world, this is a long string"; + auto str_len = static_cast(test_str.size()); + + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + // 4-byte little-endian length + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((str_len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), test_str.begin(), test_str.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hello world, this is a long string\")"); +} + +TEST_F(VariantPrimitiveTest, DecodeBinary) { + std::vector bin_bytes = {0x00, 0x01, 0x02, 0x03}; + auto bin_len = static_cast(bin_bytes.size()); + + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((bin_len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), bin_bytes.begin(), bin_bytes.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "Binary(len=4)"); +} + +// Truncation errors + +TEST_F(VariantPrimitiveTest, TruncatedInt32) { + // Only 2 bytes after header, but Int32 needs 4 + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0x00, 0x00}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +TEST_F(VariantPrimitiveTest, EmptyValueBuffer) { + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeVariantValue(empty_metadata_, nullptr, 0, &visitor)); +} + +// =========================================================================== +// Short string tests +// =========================================================================== + +class VariantShortStringTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantShortStringTest, EmptyShortString) { + auto data = BuildShortString(""); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"\")"); +} + +TEST_F(VariantShortStringTest, SimpleShortString) { + auto data = BuildShortString("hi"); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hi\")"); +} + +TEST_F(VariantShortStringTest, MaxLengthShortString) { + // Maximum short string is 63 bytes + std::string max_str(63, 'x'); + auto data = BuildShortString(max_str); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"" + max_str + "\")"); +} + +TEST_F(VariantShortStringTest, TruncatedShortString) { + // Header says length=10 but buffer only has 3 bytes total + uint8_t data[] = {static_cast(BasicType::kShortString) | (10 << 2), 'a', 'b'}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Object decoding tests +// =========================================================================== + +class VariantObjectTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "age", "scores"}; + } +}; + +TEST_F(VariantObjectTest, EmptyObject) { + auto data = BuildObject({}, {}); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 2); + ASSERT_EQ(visitor.events[0], "StartObject(0)"); + ASSERT_EQ(visitor.events[1], "EndObject"); +} + +TEST_F(VariantObjectTest, SingleField) { + // Object with one field: name -> "Alice" (short string) + auto value = BuildShortString("Alice"); + auto data = BuildObject({0}, {value}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 4); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "EndObject"); +} + +TEST_F(VariantObjectTest, MultipleFields) { + // Object: {name: "Bob", age: 30} + auto name_val = BuildShortString("Bob"); + // age: Int32(30) + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + + auto data = BuildObject({0, 1}, {name_val, age_val}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 6); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Bob\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(30)"); + ASSERT_EQ(visitor.events[5], "EndObject"); +} + +TEST_F(VariantObjectTest, InvalidFieldId) { + // field_id=99 exceeds dictionary size of 3 + auto value = BuildShortString("oops"); + auto data = BuildObject({99}, {value}); + + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, DecodeVariantValue(metadata_, data.data(), + static_cast(data.size()), &visitor)); +} + +TEST_F(VariantObjectTest, ThreeByteOffsetSize) { + // Exercises value decoding with 3-byte field_offset_size and field_id_size. + // Object with 2 fields: {name: "test", age: 42} + auto name_val = BuildShortString("test"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto data = BuildObject({0, 1}, {name_val, age_val}, + /*field_id_size=*/3, /*field_offset_size=*/3); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + ASSERT_EQ(visitor.events.size(), 6); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"test\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(42)"); + ASSERT_EQ(visitor.events[5], "EndObject"); +} + +// =========================================================================== +// Array decoding tests +// =========================================================================== + +class VariantArrayTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantArrayTest, EmptyArray) { + auto data = BuildArray({}); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 2); + ASSERT_EQ(visitor.events[0], "StartArray(0)"); + ASSERT_EQ(visitor.events[1], "EndArray"); +} + +TEST_F(VariantArrayTest, SingleElement) { + std::vector elem = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto data = BuildArray({elem}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 3); + ASSERT_EQ(visitor.events[0], "StartArray(1)"); + ASSERT_EQ(visitor.events[1], "Int32(42)"); + ASSERT_EQ(visitor.events[2], "EndArray"); +} + +TEST_F(VariantArrayTest, HeterogeneousElements) { + // Array with mixed types: [42, "hello", true] + std::vector int_elem = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + auto str_elem = BuildShortString("hello"); + std::vector bool_elem = {PrimitiveHeader(PrimitiveType::kTrue)}; + + auto data = BuildArray({int_elem, str_elem, bool_elem}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events.size(), 5); + ASSERT_EQ(visitor.events[0], "StartArray(3)"); + ASSERT_EQ(visitor.events[1], "Int32(42)"); + ASSERT_EQ(visitor.events[2], "String(\"hello\")"); + ASSERT_EQ(visitor.events[3], "Bool(true)"); + ASSERT_EQ(visitor.events[4], "EndArray"); +} + +TEST_F(VariantArrayTest, LargeArrayIsLargeFlag) { + // Build an array with 256 elements to exercise is_large=true (4-byte + // num_elements). Each element is a Null primitive (1 byte each). + // Use field_offset_size=2 since total data (256 bytes) exceeds 1-byte max. + std::vector> elements; + elements.reserve(256); + for (int i = 0; i < 256; ++i) { + elements.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildArray(elements, /*field_offset_size=*/2); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + // StartArray(256) + 256 Nulls + EndArray = 258 events + ASSERT_EQ(visitor.events.size(), 258); + ASSERT_EQ(visitor.events[0], "StartArray(256)"); + ASSERT_EQ(visitor.events[1], "Null"); + ASSERT_EQ(visitor.events[256], "Null"); + ASSERT_EQ(visitor.events[257], "EndArray"); +} + +// =========================================================================== +// Nested structure tests +// =========================================================================== + +class VariantNestedTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "scores", "inner"}; + } +}; + +TEST_F(VariantNestedTest, ObjectWithNestedArray) { + // {name: "Alice", scores: [95, 87]} + auto name_val = BuildShortString("Alice"); + + // scores array: [Int32(95), Int32(87)] + std::vector score1 = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + std::vector score2 = {PrimitiveHeader(PrimitiveType::kInt32), 87, 0, 0, 0}; + auto scores_val = BuildArray({score1, score2}); + + auto data = BuildObject({0, 1}, {name_val, scores_val}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + // Expected events: + // StartObject(2), FieldName("name"), String("Alice"), + // FieldName("scores"), StartArray(2), Int32(95), Int32(87), EndArray, + // EndObject + ASSERT_EQ(visitor.events.size(), 9); + ASSERT_EQ(visitor.events[0], "StartObject(2)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"scores\")"); + ASSERT_EQ(visitor.events[4], "StartArray(2)"); + ASSERT_EQ(visitor.events[5], "Int32(95)"); + ASSERT_EQ(visitor.events[6], "Int32(87)"); + ASSERT_EQ(visitor.events[7], "EndArray"); + ASSERT_EQ(visitor.events[8], "EndObject"); +} + +TEST_F(VariantNestedTest, NestedObjects) { + // {inner: {name: "deep"}} + auto deep_name = BuildShortString("deep"); + auto inner_obj = BuildObject({0}, {deep_name}); + auto data = BuildObject({2}, {inner_obj}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + ASSERT_EQ(visitor.events.size(), 7); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"inner\")"); + ASSERT_EQ(visitor.events[2], "StartObject(1)"); + ASSERT_EQ(visitor.events[3], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[4], "String(\"deep\")"); + ASSERT_EQ(visitor.events[5], "EndObject"); + ASSERT_EQ(visitor.events[6], "EndObject"); +} + +TEST_F(VariantNestedTest, ArrayOfObjects) { + // [{name: "a"}, {name: "b"}] + auto val_a = BuildShortString("a"); + auto obj_a = BuildObject({0}, {val_a}); + + auto val_b = BuildShortString("b"); + auto obj_b = BuildObject({0}, {val_b}); + + auto data = BuildArray({obj_a, obj_b}); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + + ASSERT_EQ(visitor.events.size(), 10); + ASSERT_EQ(visitor.events[0], "StartArray(2)"); + ASSERT_EQ(visitor.events[1], "StartObject(1)"); + ASSERT_EQ(visitor.events[2], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[3], "String(\"a\")"); + ASSERT_EQ(visitor.events[4], "EndObject"); + ASSERT_EQ(visitor.events[5], "StartObject(1)"); + ASSERT_EQ(visitor.events[6], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[7], "String(\"b\")"); + ASSERT_EQ(visitor.events[8], "EndObject"); + ASSERT_EQ(visitor.events[9], "EndArray"); +} + +// =========================================================================== +// Recursion depth limit test +// =========================================================================== + +class VariantDepthTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"x"}; + } +}; + +TEST_F(VariantDepthTest, ExceedsMaxNestingDepth) { + // Build a deeply nested array: [[[[...]]]] + // Each level wraps the inner in a 1-element array with offset_size=2 + // to allow buffers larger than 255 bytes. + std::vector inner = {PrimitiveHeader(PrimitiveType::kNull)}; + + // Wrap 130 times (exceeds kMaxNestingDepth=128) + for (int i = 0; i < 130; ++i) { + inner = BuildArray({inner}, /*field_offset_size=*/2); + } + + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(metadata_, inner.data(), + static_cast(inner.size()), &visitor)); +} + +TEST_F(VariantDepthTest, AtMaxNestingDepthSucceeds) { + // Build 50 levels of nesting — well within kMaxNestingDepth=128 + // and within offset_size=1 limits (each level adds ~4 bytes). + std::vector inner = {PrimitiveHeader(PrimitiveType::kNull)}; + + for (int i = 0; i < 50; ++i) { + inner = BuildArray({inner}); + } + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, inner.data(), + static_cast(inner.size()), &visitor)); +} + +// =========================================================================== +// Utility function tests +// =========================================================================== + +class VariantUtilTest : public ::testing::Test {}; + +TEST_F(VariantUtilTest, GetValueBasicTypePrimitive) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0, 0, 0, 0}; + ASSERT_OK_AND_ASSIGN(auto bt, GetValueBasicType(data, sizeof(data))); + ASSERT_EQ(bt, BasicType::kPrimitive); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeShortString) { + auto data = BuildShortString("test"); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kShortString); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeObject) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"key"}; + auto val = BuildShortString("val"); + auto data = BuildObject({0}, {val}); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kObject); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeArray) { + auto data = BuildArray({}); + ASSERT_OK_AND_ASSIGN(auto bt, + GetValueBasicType(data.data(), static_cast(data.size()))); + ASSERT_EQ(bt, BasicType::kArray); +} + +TEST_F(VariantUtilTest, GetValueBasicTypeEmptyBuffer) { + ASSERT_RAISES(Invalid, GetValueBasicType(nullptr, 0)); +} + +TEST_F(VariantUtilTest, GetObjectFieldCount) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"a", "b", "c"}; + auto v1 = BuildShortString("x"); + auto v2 = BuildShortString("y"); + auto data = BuildObject({0, 1}, {v1, v2}); + ASSERT_OK_AND_ASSIGN( + auto count, GetObjectFieldCount(data.data(), static_cast(data.size()))); + ASSERT_EQ(count, 2); +} + +TEST_F(VariantUtilTest, GetArrayElementCount) { + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + std::vector e3 = {PrimitiveHeader(PrimitiveType::kFalse)}; + auto data = BuildArray({e1, e2, e3}); + ASSERT_OK_AND_ASSIGN( + auto count, GetArrayElementCount(data.data(), static_cast(data.size()))); + ASSERT_EQ(count, 3); +} + +TEST_F(VariantUtilTest, PrimitiveValueSizes) { + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kNull), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTrue), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kFalse), 0); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt8), 1); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt16), 2); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt32), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kInt64), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kFloat), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDouble), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDate), 4); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampMicros), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampMicrosNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimeNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampNanos), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kTimestampNanosNTZ), 8); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kUUID), 16); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal4), 5); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal8), 9); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kDecimal16), 17); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kBinary), -1); + ASSERT_EQ(PrimitiveValueSize(PrimitiveType::kString), -1); +} + +// =========================================================================== +// Integration: Metadata + Value decoding together +// =========================================================================== + +class VariantIntegrationTest : public ::testing::Test {}; + +TEST_F(VariantIntegrationTest, FullRoundTrip) { + // Build a complete variant: {name: "Alice", age: 30, scores: [95, 87]} + auto meta_buf = BuildMetadataBuffer({"name", "age", "scores"}); + + auto name_val = BuildShortString("Alice"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + std::vector s1 = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + std::vector s2 = {PrimitiveHeader(PrimitiveType::kInt32), 87, 0, 0, 0}; + auto scores_val = BuildArray({s1, s2}); + + auto value_buf = BuildObject({0, 1, 2}, {name_val, age_val, scores_val}); + + // Decode metadata + ASSERT_OK_AND_ASSIGN( + auto metadata, + DecodeMetadata(meta_buf.data(), static_cast(meta_buf.size()))); + ASSERT_EQ(metadata.strings.size(), 3); + + // Decode value + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_buf.data(), + static_cast(value_buf.size()), &visitor)); + + // Verify full event sequence + ASSERT_EQ(visitor.events.size(), 11); + ASSERT_EQ(visitor.events[0], "StartObject(3)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"name\")"); + ASSERT_EQ(visitor.events[2], "String(\"Alice\")"); + ASSERT_EQ(visitor.events[3], "FieldName(\"age\")"); + ASSERT_EQ(visitor.events[4], "Int32(30)"); + ASSERT_EQ(visitor.events[5], "FieldName(\"scores\")"); + ASSERT_EQ(visitor.events[6], "StartArray(2)"); + ASSERT_EQ(visitor.events[7], "Int32(95)"); + ASSERT_EQ(visitor.events[8], "Int32(87)"); + ASSERT_EQ(visitor.events[9], "EndArray"); + ASSERT_EQ(visitor.events[10], "EndObject"); +} + +// =========================================================================== +// Visitor early abort test +// =========================================================================== + +/// \brief A visitor that aborts after receiving a specific number of events. +class AbortingVisitor : public VariantVisitor { + public: + int32_t abort_after; + int32_t count = 0; + + explicit AbortingVisitor(int32_t abort_after) : abort_after(abort_after) {} + + Status MaybeAbort() { + ++count; + if (count >= abort_after) { + return Status::Cancelled("Visitor aborted after ", count, " events"); + } + return Status::OK(); + } + + Status Null() override { return MaybeAbort(); } + Status Bool(bool /*value*/) override { return MaybeAbort(); } + Status Int8(int8_t /*value*/) override { return MaybeAbort(); } + Status Int16(int16_t /*value*/) override { return MaybeAbort(); } + Status Int32(int32_t /*value*/) override { return MaybeAbort(); } + Status Int64(int64_t /*value*/) override { return MaybeAbort(); } + Status Float(float /*value*/) override { return MaybeAbort(); } + Status Double(double /*value*/) override { return MaybeAbort(); } + Status Decimal4(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Decimal8(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Decimal16(const uint8_t* /*bytes*/, int32_t /*s*/) override { + return MaybeAbort(); + } + Status Date(int32_t /*days*/) override { return MaybeAbort(); } + Status TimestampMicros(int64_t /*micros*/) override { return MaybeAbort(); } + Status TimestampMicrosNTZ(int64_t /*micros*/) override { return MaybeAbort(); } + Status String(std::string_view /*value*/) override { return MaybeAbort(); } + Status Binary(std::string_view /*value*/) override { return MaybeAbort(); } + Status TimeNTZ(int64_t /*micros*/) override { return MaybeAbort(); } + Status TimestampNanos(int64_t /*nanos*/) override { return MaybeAbort(); } + Status TimestampNanosNTZ(int64_t /*nanos*/) override { return MaybeAbort(); } + Status UUID(const uint8_t* /*bytes*/) override { return MaybeAbort(); } + Status StartObject(int32_t /*num_fields*/) override { return MaybeAbort(); } + Status FieldName(std::string_view /*name*/) override { return MaybeAbort(); } + Status EndObject() override { return MaybeAbort(); } + Status StartArray(int32_t /*num_elements*/) override { return MaybeAbort(); } + Status EndArray() override { return MaybeAbort(); } +}; + +class VariantAbortTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"name", "age"}; + } +}; + +TEST_F(VariantAbortTest, VisitorAbortsEarly) { + // Object: {name: "Alice", age: 30} + auto name_val = BuildShortString("Alice"); + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto data = BuildObject({0, 1}, {name_val, age_val}); + + // Abort after 3 events (StartObject, FieldName, String) + // Should NOT reach the second field + AbortingVisitor visitor(3); + auto status = DecodeVariantValue(metadata_, data.data(), + static_cast(data.size()), &visitor); + ASSERT_TRUE(status.IsCancelled()); + ASSERT_EQ(visitor.count, 3); +} + +TEST_F(VariantAbortTest, VisitorAbortsOnFirstEvent) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + AbortingVisitor visitor(1); + auto status = DecodeVariantValue(metadata_, data, sizeof(data), &visitor); + ASSERT_TRUE(status.IsCancelled()); +} + +// =========================================================================== +// Spec-conformance test with hardcoded byte sequences +// =========================================================================== + +class VariantSpecTest : public ::testing::Test {}; + +TEST_F(VariantSpecTest, HandcraftedNullValue) { + // Variant Encoding Spec: Null is basic_type=0, primitive_type=0 + // Header byte: 0x00 (bits 0-1=00 for primitive, bits 2-7=000000 for null) + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; // v1, 0 strings, offset[0]=0 + uint8_t value_bytes[] = {0x00}; // null + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events.size(), 1); + ASSERT_EQ(visitor.events[0], "Null"); +} + +TEST_F(VariantSpecTest, HandcraftedInt32Value) { + // Int32(42): basic_type=0, primitive_type=5 + // Header: (5 << 2) | 0 = 0x14 + // Value: 42 as LE int32 = 2A 00 00 00 + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[] = {0x14, 0x2A, 0x00, 0x00, 0x00}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(42)"); +} + +TEST_F(VariantSpecTest, HandcraftedShortString) { + // Short string "hello": basic_type=1, length=5 + // Header: (5 << 2) | 1 = 0x15 + // Followed by 5 bytes of UTF-8 "hello" + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[] = {0x15, 'h', 'e', 'l', 'l', 'o'}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"hello\")"); +} + +TEST_F(VariantSpecTest, HandcraftedSimpleObject) { + // Object {"a": 1} with metadata dictionary ["a"] + // + // Metadata: version=1, sorted=false, offset_size=1 + // header=0x01, dict_size=0x01, offsets=[0x00, 0x01], data="a" + uint8_t metadata_bytes[] = {0x01, 0x01, 0x00, 0x01, 'a'}; + // + // Value: object with 1 field + // header: basic_type=2, field_id_size=1(bits2-3=00), + // offset_size=1(bits4-5=00), num_fields_size=1(bits6-7=00) + // = 0x02 + // num_fields: 0x01 + // field_ids: [0x00] (index into metadata for "a") + // offsets: [0x00, 0x05] (field 0 at offset 0, total size 5) + // field value: Int32(1) = header 0x14 + LE bytes 01 00 00 00 + uint8_t value_bytes[] = { + 0x02, // object header + 0x01, // num_fields = 1 + 0x00, // field_id[0] = 0 + 0x00, 0x05, // offsets: [0, 5] + 0x14, 0x01, 0x00, 0x00, 0x00 // Int32(1) + }; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_EQ(visitor.events.size(), 4); + ASSERT_EQ(visitor.events[0], "StartObject(1)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"a\")"); + ASSERT_EQ(visitor.events[2], "Int32(1)"); + ASSERT_EQ(visitor.events[3], "EndObject"); +} + +TEST_F(VariantSpecTest, HandcraftedTrueAndFalse) { + // True: basic_type=0, primitive_type=1 → header = (1<<2)|0 = 0x04 + // False: basic_type=0, primitive_type=2 → header = (2<<2)|0 = 0x08 + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + + uint8_t true_bytes[] = {0x04}; + uint8_t false_bytes[] = {0x08}; + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + + RecordingVisitor v1; + ASSERT_OK(DecodeVariantValue(metadata, true_bytes, sizeof(true_bytes), &v1)); + ASSERT_EQ(v1.events[0], "Bool(true)"); + + RecordingVisitor v2; + ASSERT_OK(DecodeVariantValue(metadata, false_bytes, sizeof(false_bytes), &v2)); + ASSERT_EQ(v2.events[0], "Bool(false)"); +} + +TEST_F(VariantSpecTest, HandcraftedDouble) { + // Double: basic_type=0, primitive_type=7 → header = (7<<2)|0 = 0x1C + // Value: 3.14 as IEEE 754 double LE + uint8_t metadata_bytes[] = {0x01, 0x00, 0x00}; + uint8_t value_bytes[9]; + value_bytes[0] = 0x1C; + double val = 3.14; + std::memcpy(value_bytes + 1, &val, 8); + + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(metadata_bytes, sizeof(metadata_bytes))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, value_bytes, sizeof(value_bytes), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +// =========================================================================== +// ValueSize tests +// =========================================================================== + +class VariantValueSizeTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeTest, NullSize) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 1); +} + +TEST_F(VariantValueSizeTest, Int32Size) { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt32), 0, 0, 0, 0}; + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 5); +} + +TEST_F(VariantValueSizeTest, ShortStringSize) { + auto data = BuildShortString("hello"); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 6); // 1 header + 5 chars +} + +TEST_F(VariantValueSizeTest, ObjectSize) { + VariantMetadata meta; + meta.version = 1; + meta.strings = {"key"}; + auto val = BuildShortString("val"); + auto data = BuildObject({0}, {val}); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeTest, ArraySize) { + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + auto data = BuildArray({e1, e2}); + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeTest, UUIDSize) { + uint8_t data[17]; + data[0] = PrimitiveHeader(PrimitiveType::kUUID); + std::memset(data + 1, 0, 16); + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(data, sizeof(data))); + ASSERT_EQ(size, 17); +} + +// =========================================================================== +// Random access tests +// =========================================================================== + +class VariantRandomAccessTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // Sorted lexicographically for binary search + metadata_.strings = {"age", "name", "score"}; + } +}; + +TEST_F(VariantRandomAccessTest, FindObjectFieldExists) { + // Object: {age: 30, name: "Alice", score: 95} + // field_ids must be in lex order of keys: age=0, name=1, score=2 + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto name_val = BuildShortString("Alice"); + std::vector score_val = {PrimitiveHeader(PrimitiveType::kInt32), 95, 0, 0, 0}; + auto data = BuildObject({0, 1, 2}, {age_val, name_val, score_val}); + + int64_t offset = -1, size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "name", &offset, &size)); + ASSERT_GT(offset, 0); + ASSERT_EQ(size, 6); // short string "Alice" = 1 + 5 + + // Verify we can decode just that field + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data() + offset, size, &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"Alice\")"); +} + +TEST_F(VariantRandomAccessTest, FindObjectFieldNotFound) { + auto val = BuildShortString("x"); + auto data = BuildObject({0}, {val}); + + int64_t offset = -1, size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "nonexistent", &offset, &size)); + ASSERT_EQ(offset, -1); + ASSERT_EQ(size, 0); +} + +TEST_F(VariantRandomAccessTest, GetArrayElementFirst) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0, e1}); + + int64_t offset = 0, size = 0; + ASSERT_OK( + GetArrayElement(data.data(), static_cast(data.size()), 0, &offset, &size)); + ASSERT_EQ(size, 5); // Int32 = 5 bytes + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data() + offset, size, &visitor)); + ASSERT_EQ(visitor.events[0], "Int32(42)"); +} + +TEST_F(VariantRandomAccessTest, GetArrayElementLast) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 42, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0, e1}); + + int64_t offset = 0, size = 0; + ASSERT_OK( + GetArrayElement(data.data(), static_cast(data.size()), 1, &offset, &size)); + ASSERT_EQ(size, 1); // Null = 1 byte +} + +TEST_F(VariantRandomAccessTest, GetArrayElementOutOfRange) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0}); + + int64_t offset = 0, size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + 5, &offset, &size)); +} + +TEST_F(VariantRandomAccessTest, GetObjectFieldAtByIndex) { + std::vector age_val = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto name_val = BuildShortString("Bob"); + auto data = BuildObject({0, 1}, {age_val, name_val}); + + std::string_view name; + int64_t offset = 0, size = 0; + ASSERT_OK(GetObjectFieldAt(metadata_, data.data(), static_cast(data.size()), 1, + &name, &offset, &size)); + ASSERT_EQ(name, "name"); + ASSERT_EQ(size, 4); // short string "Bob" = 1 + 3 +} + +TEST_F(VariantRandomAccessTest, GetObjectFieldAtOutOfRange) { + auto val = BuildShortString("x"); + auto data = BuildObject({0}, {val}); + + std::string_view name; + int64_t offset = 0, size = 0; + ASSERT_RAISES( + Invalid, GetObjectFieldAt(metadata_, data.data(), static_cast(data.size()), + 99, &name, &offset, &size)); +} + +// =========================================================================== +// FindMetadataKey tests +// =========================================================================== + +class VariantFindMetadataKeyTest : public ::testing::Test {}; + +TEST_F(VariantFindMetadataKeyTest, SortedFound) { + VariantMetadata meta; + meta.is_sorted = true; + meta.strings = {"age", "name", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "name"), 1); + ASSERT_EQ(FindMetadataKey(meta, "age"), 0); + ASSERT_EQ(FindMetadataKey(meta, "score"), 2); +} + +TEST_F(VariantFindMetadataKeyTest, SortedNotFound) { + VariantMetadata meta; + meta.is_sorted = true; + meta.strings = {"age", "name", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "missing"), -1); +} + +TEST_F(VariantFindMetadataKeyTest, UnsortedFound) { + VariantMetadata meta; + meta.is_sorted = false; + meta.strings = {"name", "age", "score"}; + ASSERT_EQ(FindMetadataKey(meta, "age"), 1); +} + +TEST_F(VariantFindMetadataKeyTest, UnsortedNotFound) { + VariantMetadata meta; + meta.is_sorted = false; + meta.strings = {"name", "age"}; + ASSERT_EQ(FindMetadataKey(meta, "missing"), -1); +} + +// =========================================================================== +// ValueSize regression tests (Go bug: array is_large bit position) +// =========================================================================== + +class VariantValueSizeRegressionTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeRegressionTest, LargeArrayIsLargeBit) { + // Build a large array with 300 elements (>255) to trigger is_large=true. + // This verifies the is_large bit is read at bit 2 of type_info (bit 4 of + // full byte), NOT bit 4 of type_info (bit 6 of full byte) which was the + // Go bug (apache/arrow-go#839). + std::vector> elements; + elements.reserve(300); + for (int i = 0; i < 300; ++i) { + elements.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildArray(elements, /*field_offset_size=*/2); + + // Verify the header byte is correctly structured + uint8_t header = data[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kArray); + // is_large should be set at bit 4 of the full header byte + ASSERT_TRUE(((header >> 4) & 0x01) != 0); + + // ValueSize must return the total size of the buffer + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeRegressionTest, SmallArrayIsLargeFalse) { + // Array with 3 elements — is_large=false + std::vector e1 = {PrimitiveHeader(PrimitiveType::kNull)}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kTrue)}; + std::vector e3 = {PrimitiveHeader(PrimitiveType::kFalse)}; + auto data = BuildArray({e1, e2, e3}); + + // Verify is_large is NOT set + uint8_t header = data[0]; + ASSERT_FALSE(((header >> 4) & 0x01) != 0); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +TEST_F(VariantValueSizeRegressionTest, LargeObjectIsLargeBit) { + // Object with 300 fields to trigger is_large=true (bit 6 of full byte) + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 300; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = + BuildObject(field_ids, values, /*field_id_size=*/2, /*field_offset_size=*/2); + + // Verify is_large is set at bit 6 of the full header byte + uint8_t header = data[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kObject); + ASSERT_TRUE(((header >> 6) & 0x01) != 0); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, static_cast(data.size())); +} + +// =========================================================================== +// Additional primitive decoding tests +// =========================================================================== + +class VariantPrimitiveExtraTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantPrimitiveExtraTest, DecodeTimeNTZ) { + int64_t micros = 43200000000LL; // 12:00:00 in microseconds + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimeNTZ); + std::memcpy(data + 1, µs, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimeNTZ(" + std::to_string(micros) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeTimestampNanos) { + int64_t nanos = 1654041600000000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampNanos); + std::memcpy(data + 1, &nanos, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampNanos(" + std::to_string(nanos) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeTimestampNanosNTZ) { + int64_t nanos = 1654041600000000000LL; + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kTimestampNanosNTZ); + std::memcpy(data + 1, &nanos, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "TimestampNanosNTZ(" + std::to_string(nanos) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeUUID) { + uint8_t data[17]; + data[0] = PrimitiveHeader(PrimitiveType::kUUID); + // Fill UUID with recognizable pattern (big-endian per spec) + for (int i = 0; i < 16; ++i) { + data[1 + i] = static_cast(i + 1); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "UUID"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt8Boundaries) { + // INT8_MIN = -128 + { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x80}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(-128)"); + } + // INT8_MAX = 127 + { + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kInt8), 0x7F}; + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int8(127)"); + } +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt16Boundaries) { + // INT16_MIN = -32768 + { + int16_t val = std::numeric_limits::min(); + uint8_t data[3]; + data[0] = PrimitiveHeader(PrimitiveType::kInt16); + std::memcpy(data + 1, &val, 2); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(-32768)"); + } + // INT16_MAX = 32767 + { + int16_t val = std::numeric_limits::max(); + uint8_t data[3]; + data[0] = PrimitiveHeader(PrimitiveType::kInt16); + std::memcpy(data + 1, &val, 2); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int16(32767)"); + } +} + +TEST_F(VariantPrimitiveExtraTest, DecodeInt64Min) { + int64_t val = std::numeric_limits::min(); + uint8_t data[9]; + data[0] = PrimitiveHeader(PrimitiveType::kInt64); + std::memcpy(data + 1, &val, 8); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); + ASSERT_EQ(visitor.events[0], "Int64(" + std::to_string(val) + ")"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeEmptyBinary) { + // Binary with zero length + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + uint32_t len = 0; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "Binary(len=0)"); +} + +TEST_F(VariantPrimitiveExtraTest, DecodeEmptyLongString) { + // Long string with zero length + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + uint32_t len = 0; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(empty_metadata_, data.data(), + static_cast(data.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "String(\"\")"); +} + +// =========================================================================== +// Object with non-monotonic offsets (spec-compliant) +// =========================================================================== + +class VariantObjectNonMonotonicTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // Sorted lexicographically + metadata_.strings = {"a", "b", "c"}; + } +}; + +TEST_F(VariantObjectNonMonotonicTest, NonMonotonicObjectOffsets) { + // Per spec: "field IDs and offsets must be listed in the order of the + // corresponding field names, sorted lexicographically" but "the actual + // value entries do not need to be in any particular order" and "the + // field_offset values may not be monotonically increasing." + // + // Construct: {a: 1, b: 2, c: 3} where values are stored as [3, 1, 2] + // in the data area but offsets point to them in key-sorted order. + std::vector val_a = {PrimitiveHeader(PrimitiveType::kInt8), 1}; + std::vector val_b = {PrimitiveHeader(PrimitiveType::kInt8), 2}; + std::vector val_c = {PrimitiveHeader(PrimitiveType::kInt8), 3}; + + // Data area stores: val_c (2 bytes) | val_a (2 bytes) | val_b (2 bytes) + // Offsets: a->2, b->4, c->0, end->6 + uint8_t header = static_cast(BasicType::kObject); // offset_size=1, id_size=1 + std::vector data; + data.push_back(header); + data.push_back(3); // num_fields = 3 + data.push_back(0); // field_id[0] = 0 ("a") + data.push_back(1); // field_id[1] = 1 ("b") + data.push_back(2); // field_id[2] = 2 ("c") + data.push_back(2); // offset[0] = 2 (val_a starts at byte 2) + data.push_back(4); // offset[1] = 4 (val_b starts at byte 4) + data.push_back(0); // offset[2] = 0 (val_c starts at byte 0) + data.push_back(6); // offset[3] = 6 (total data size) + // Data area: val_c, val_a, val_b + data.insert(data.end(), val_c.begin(), val_c.end()); + data.insert(data.end(), val_a.begin(), val_a.end()); + data.insert(data.end(), val_b.begin(), val_b.end()); + + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata_, data.data(), static_cast(data.size()), + &visitor)); + // Field iteration order follows field_ids (sorted by key): a, b, c + ASSERT_EQ(visitor.events.size(), 8); + ASSERT_EQ(visitor.events[0], "StartObject(3)"); + ASSERT_EQ(visitor.events[1], "FieldName(\"a\")"); + ASSERT_EQ(visitor.events[2], "Int8(1)"); + ASSERT_EQ(visitor.events[3], "FieldName(\"b\")"); + ASSERT_EQ(visitor.events[4], "Int8(2)"); + ASSERT_EQ(visitor.events[5], "FieldName(\"c\")"); + ASSERT_EQ(visitor.events[6], "Int8(3)"); + ASSERT_EQ(visitor.events[7], "EndObject"); +} + +TEST_F(VariantObjectNonMonotonicTest, FindFieldWithNonMonotonicOffsets) { + // Same layout as above: values stored out-of-order + uint8_t header = static_cast(BasicType::kObject); + std::vector data; + data.push_back(header); + data.push_back(3); + data.push_back(0); + data.push_back(1); + data.push_back(2); + data.push_back(2); // a -> offset 2 + data.push_back(4); // b -> offset 4 + data.push_back(0); // c -> offset 0 + data.push_back(6); // end = 6 + // Data: [Int8(3), Int8(1), Int8(2)] + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(3); + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(1); + data.push_back(PrimitiveHeader(PrimitiveType::kInt8)); + data.push_back(2); + + // FindObjectField should find "c" at offset 0 of data area + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "c", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_EQ(field_size, 2); // Int8 = 2 bytes + + // Decode the value at that offset and verify it's 3 (val_c) + RecordingVisitor v; + ASSERT_OK(DecodeVariantValue(metadata_, data.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "Int8(3)"); +} + +// =========================================================================== +// ValueSize for variable-length primitives +// =========================================================================== + +class VariantValueSizeVarLenTest : public ::testing::Test {}; + +TEST_F(VariantValueSizeVarLenTest, LongStringSize) { + // Long string "hello" (5 chars): header(1) + length(4) + data(5) = 10 + std::string s = "hello"; + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kString)); + auto len = static_cast(s.size()); + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + data.insert(data.end(), s.begin(), s.end()); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 10); +} + +TEST_F(VariantValueSizeVarLenTest, BinarySize) { + // Binary with 4 bytes: header(1) + length(4) + data(4) = 9 + std::vector data; + data.push_back(PrimitiveHeader(PrimitiveType::kBinary)); + uint32_t len = 4; + for (int b = 0; b < 4; ++b) { + data.push_back(static_cast((len >> (b * 8)) & 0xFF)); + } + data.push_back(0x00); + data.push_back(0x01); + data.push_back(0x02); + data.push_back(0x03); + + ASSERT_OK_AND_ASSIGN(auto size, + ValueSize(data.data(), static_cast(data.size()))); + ASSERT_EQ(size, 9); +} + +TEST_F(VariantValueSizeVarLenTest, TruncatedLongString) { + // Only header byte, no length field + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kString)}; + ASSERT_RAISES(Invalid, ValueSize(data, sizeof(data))); +} + +// =========================================================================== +// Unknown/invalid type tests +// =========================================================================== + +class VariantUnknownTypeTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantUnknownTypeTest, UnknownPrimitiveType) { + // Primitive type ID 25 (beyond kUUID=20) should produce an error. + // Header: (25 << 2) | 0 = 0x64 + uint8_t data[] = {0x64}; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +TEST_F(VariantUnknownTypeTest, UnknownPrimitiveTypeValueSize) { + // ValueSize on an unknown primitive type should still return a value + // (PrimitiveValueSize returns -1, triggering variable-length path). + // With only 1 byte, variable-length path requires 5 bytes → truncated. + uint8_t data[] = {0x64}; + ASSERT_RAISES(Invalid, ValueSize(data, sizeof(data))); +} + +// =========================================================================== +// Array non-monotonic offset rejection test +// =========================================================================== + +class VariantArrayNonMonotonicTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantArrayNonMonotonicTest, RejectsNonMonotonicOffsets) { + // Manually craft an array with 2 elements where offsets go [0, 3, 1] + // (non-monotonic: 1 < 3). This should be rejected. + // header: basic_type=3, offset_size=1, is_large=false → 0x03 + // num_elements: 2 + // offsets: [0, 3, 1] — non-monotonic + // data: 3 bytes of nulls + uint8_t data[] = { + 0x03, // array header: basic_type=3, offset_size=1, is_large=false + 0x02, // num_elements = 2 + 0x00, 0x03, 0x01, // offsets: [0, 3, 1] — non-monotonic! + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + }; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(empty_metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Object field offset out-of-bounds test +// =========================================================================== + +class VariantObjectOffsetBoundsTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = false; + metadata_.offset_size = 1; + metadata_.strings = {"key"}; + } +}; + +TEST_F(VariantObjectOffsetBoundsTest, FieldOffsetExceedsDataSize) { + // Object with 1 field where field_offset[0] = 99 (beyond total_data_size). + // header: basic_type=2, offset_size=1, id_size=1, is_large=false → 0x02 + // num_fields: 1 + // field_ids: [0] + // offsets: [99, 2] — field 0 at offset 99, total=2 + // data: 2 bytes (Null) + uint8_t data[] = { + 0x02, // object header + 0x01, // num_fields = 1 + 0x00, // field_id[0] = 0 + 0x63, 0x02, // offsets: [99, 2] — 99 > total_data_size(2) + PrimitiveHeader(PrimitiveType::kNull), + PrimitiveHeader(PrimitiveType::kNull), + }; + RecordingVisitor visitor; + ASSERT_RAISES(Invalid, + DecodeVariantValue(metadata_, data, sizeof(data), &visitor)); +} + +// =========================================================================== +// Empty metadata with various offset sizes +// =========================================================================== + +class VariantMetadataOffsetSizeTest : public ::testing::Test {}; + +TEST_F(VariantMetadataOffsetSizeTest, EmptyDictionaryOffsetSize4) { + // Valid metadata with 0 strings but offset_size=4. + auto buf = BuildMetadataBuffer({}, false, 4); + ASSERT_OK_AND_ASSIGN(auto metadata, DecodeMetadata(buf.data(), buf.size())); + ASSERT_EQ(metadata.version, 1); + ASSERT_EQ(metadata.offset_size, 4); + ASSERT_EQ(metadata.strings.size(), 0); +} + +// =========================================================================== +// FindObjectField with binary search (large object >= 32 fields) +// =========================================================================== + +class VariantFindFieldBinarySearchTest : public ::testing::Test { + protected: + VariantMetadata metadata_; + // Backing storage for string_views in metadata (must outlive metadata_). + // Do NOT modify key_storage_ after SetUp(); reallocation invalidates + // the string_views stored in metadata_.strings. + std::vector key_storage_; + + void SetUp() override { + metadata_.version = 1; + metadata_.is_sorted = true; + metadata_.offset_size = 1; + // 40 keys in sorted order to trigger binary search path + key_storage_.reserve(40); + for (int i = 0; i < 40; ++i) { + std::string key = "k" + std::string(i < 10 ? "0" : "") + std::to_string(i); + key_storage_.emplace_back(key); + } + for (const auto& k : key_storage_) { + metadata_.strings.push_back(k); + } + } +}; + +TEST_F(VariantFindFieldBinarySearchTest, FindMiddleField) { + // Build object with 40 fields, all null values + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + // Search for "k20" (middle of the sorted range) + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k20", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_EQ(field_size, 1); // Null = 1 byte +} + +TEST_F(VariantFindFieldBinarySearchTest, FindFirstField) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k00", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); +} + +TEST_F(VariantFindFieldBinarySearchTest, FindLastField) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "k39", &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); +} + +TEST_F(VariantFindFieldBinarySearchTest, NotFoundInLargeObject) { + std::vector field_ids; + std::vector> values; + for (int i = 0; i < 40; ++i) { + field_ids.push_back(static_cast(i)); + values.push_back({PrimitiveHeader(PrimitiveType::kNull)}); + } + auto data = BuildObject(field_ids, values); + + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(metadata_, data.data(), static_cast(data.size()), + "zzz", &field_offset, &field_size)); + ASSERT_EQ(field_offset, -1); +} + +// =========================================================================== +// GetArrayElement middle index +// =========================================================================== + +class VariantGetArrayElementExtraTest : public ::testing::Test {}; + +TEST_F(VariantGetArrayElementExtraTest, MiddleElement) { + // Array of [Int32(10), Int32(20), Int32(30)] + std::vector e0 = {PrimitiveHeader(PrimitiveType::kInt32), 10, 0, 0, 0}; + std::vector e1 = {PrimitiveHeader(PrimitiveType::kInt32), 20, 0, 0, 0}; + std::vector e2 = {PrimitiveHeader(PrimitiveType::kInt32), 30, 0, 0, 0}; + auto data = BuildArray({e0, e1, e2}); + + int64_t elem_offset = 0, elem_size = 0; + ASSERT_OK(GetArrayElement(data.data(), static_cast(data.size()), 1, + &elem_offset, &elem_size)); + ASSERT_EQ(elem_size, 5); // Int32 = 5 bytes + + // Decode the middle element + VariantMetadata meta; + meta.version = 1; + RecordingVisitor v; + ASSERT_OK(DecodeVariantValue(meta, data.data() + elem_offset, elem_size, &v)); + ASSERT_EQ(v.events[0], "Int32(20)"); +} + +TEST_F(VariantGetArrayElementExtraTest, EmptyArrayOutOfRange) { + auto data = BuildArray({}); + int64_t elem_offset = 0, elem_size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + 0, &elem_offset, &elem_size)); +} + +// =========================================================================== +// Additional error case tests (missing coverage) +// =========================================================================== + +class VariantErrorCaseTest : public ::testing::Test { + protected: + VariantMetadata empty_metadata_; + + void SetUp() override { + empty_metadata_.version = 1; + empty_metadata_.is_sorted = false; + empty_metadata_.offset_size = 1; + } +}; + +TEST_F(VariantErrorCaseTest, MetadataVersionZero) { + // Version 0 is not supported (only version 1 is valid per spec) + uint8_t data[] = {0x00, 0x00, 0x00}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetObjectFieldCountOnArray) { + // Calling GetObjectFieldCount on an array value should produce an error + auto data = BuildArray({}); + ASSERT_RAISES(Invalid, + GetObjectFieldCount(data.data(), static_cast(data.size()))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementCountOnObject) { + // Calling GetArrayElementCount on an object value should produce an error + auto data = BuildObject({}, {}); + ASSERT_RAISES(Invalid, + GetArrayElementCount(data.data(), static_cast(data.size()))); +} + +TEST_F(VariantErrorCaseTest, GetObjectFieldCountOnPrimitive) { + // Calling GetObjectFieldCount on a primitive should produce an error + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_RAISES(Invalid, GetObjectFieldCount(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementCountOnPrimitive) { + // Calling GetArrayElementCount on a primitive should produce an error + uint8_t data[] = {PrimitiveHeader(PrimitiveType::kNull)}; + ASSERT_RAISES(Invalid, GetArrayElementCount(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, MetadataStringOffsetExceedsBuffer) { + // Metadata where the last string offset claims more data than the buffer + // contains. This exercises the ValidateOffsets check for offsets.back() > + // data_length. + // Header: version=1, offset_size=1 + // dict_size=1, offsets=[0, 100] — but only 3 bytes of string data + uint8_t data[] = { + 0x01, // header: version=1, offset_size=1 + 0x01, // dict_size = 1 + 0x00, 0x64, // offsets: [0, 100] — 100 exceeds available string data + 'a', 'b', 'c'}; + ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); +} + +TEST_F(VariantErrorCaseTest, GetArrayElementNegativeIndex) { + std::vector e0 = {PrimitiveHeader(PrimitiveType::kNull)}; + auto data = BuildArray({e0}); + int64_t elem_offset = 0, elem_size = 0; + ASSERT_RAISES(Invalid, GetArrayElement(data.data(), static_cast(data.size()), + -1, &elem_offset, &elem_size)); +} + +TEST_F(VariantErrorCaseTest, FindObjectFieldOnNonObject) { + // Calling FindObjectField on an array should produce an error + auto data = BuildArray({}); + int64_t field_offset = -1, field_size = 0; + ASSERT_RAISES(Invalid, + FindObjectField(empty_metadata_, data.data(), + static_cast(data.size()), "key", + &field_offset, &field_size)); +} + +// TODO: Add fuzz targets for DecodeMetadata and DecodeVariantValue to exercise +// adversarial/malformed input. Fuzz tests in Arrow are typically registered as +// separate executables under cpp/src/arrow/testing/fuzzing/ — see GH-45948. + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_test_util.h b/cpp/src/arrow/extension/variant_test_util.h new file mode 100644 index 000000000000..9e20947697d7 --- /dev/null +++ b/cpp/src/arrow/extension/variant_test_util.h @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +// This file is for tests only and is not installed as a public header. + +#include +#include + +#include "arrow/extension/variant_internal.h" + +namespace arrow::extension::variant_internal { + +/// \brief A visitor that records all callbacks as a vector of strings +/// for easy assertion in tests. +class RecordingVisitor : public VariantVisitor { + public: + std::vector events; + + Status Null() override { + events.push_back("Null"); + return Status::OK(); + } + Status Bool(bool value) override { + events.push_back(std::string("Bool(") + (value ? "true" : "false") + ")"); + return Status::OK(); + } + Status Int8(int8_t value) override { + events.push_back("Int8(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int16(int16_t value) override { + events.push_back("Int16(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int32(int32_t value) override { + events.push_back("Int32(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Int64(int64_t value) override { + events.push_back("Int64(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Float(float value) override { + events.push_back("Float(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Double(double value) override { + events.push_back("Double(" + std::to_string(value) + ")"); + return Status::OK(); + } + Status Decimal4(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal4(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Decimal8(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal8(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Decimal16(const uint8_t* /*bytes*/, int32_t scale) override { + events.push_back("Decimal16(scale=" + std::to_string(scale) + ")"); + return Status::OK(); + } + Status Date(int32_t days) override { + events.push_back("Date(" + std::to_string(days) + ")"); + return Status::OK(); + } + Status TimestampMicros(int64_t micros) override { + events.push_back("TimestampMicros(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status TimestampMicrosNTZ(int64_t micros) override { + events.push_back("TimestampMicrosNTZ(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status String(std::string_view value) override { + events.push_back("String(\"" + std::string(value) + "\")"); + return Status::OK(); + } + Status Binary(std::string_view value) override { + events.push_back("Binary(len=" + std::to_string(value.size()) + ")"); + return Status::OK(); + } + Status TimeNTZ(int64_t micros) override { + events.push_back("TimeNTZ(" + std::to_string(micros) + ")"); + return Status::OK(); + } + Status TimestampNanos(int64_t nanos) override { + events.push_back("TimestampNanos(" + std::to_string(nanos) + ")"); + return Status::OK(); + } + Status TimestampNanosNTZ(int64_t nanos) override { + events.push_back("TimestampNanosNTZ(" + std::to_string(nanos) + ")"); + return Status::OK(); + } + Status UUID(const uint8_t* /*bytes*/) override { + events.push_back("UUID"); + return Status::OK(); + } + Status StartObject(int32_t num_fields) override { + events.push_back("StartObject(" + std::to_string(num_fields) + ")"); + return Status::OK(); + } + Status FieldName(std::string_view name) override { + events.push_back("FieldName(\"" + std::string(name) + "\")"); + return Status::OK(); + } + Status EndObject() override { + events.push_back("EndObject"); + return Status::OK(); + } + Status StartArray(int32_t num_elements) override { + events.push_back("StartArray(" + std::to_string(num_elements) + ")"); + return Status::OK(); + } + Status EndArray() override { + events.push_back("EndArray"); + return Status::OK(); + } +}; + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index 4b8faebecfd7..dc16985255c2 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -142,6 +142,7 @@ arrow_components = { 'extension/bool8.cc', 'extension/json.cc', 'extension/parquet_variant.cc', + 'extension/variant_internal.cc', 'extension/uuid.cc', 'pretty_print.cc', 'record_batch.cc', From b0c22987b9a6b610489981c8495c73e9d80b0379 Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sat, 20 Jun 2026 12:44:52 -0700 Subject: [PATCH 2/4] Fix spec section references in variant_internal.h comments --- cpp/src/arrow/extension/variant_internal.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/extension/variant_internal.h b/cpp/src/arrow/extension/variant_internal.h index 6f19e9fe6769..87354b50600d 100644 --- a/cpp/src/arrow/extension/variant_internal.h +++ b/cpp/src/arrow/extension/variant_internal.h @@ -55,7 +55,7 @@ constexpr int32_t kMaxNestingDepth = 128; /// \brief Basic type codes from bits 0-1 of the value header byte. /// -/// Variant Encoding Spec §3: "Value encoding" +/// See: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types enum class BasicType : uint8_t { kPrimitive = 0, kShortString = 1, @@ -65,7 +65,7 @@ enum class BasicType : uint8_t { /// \brief Primitive type codes from bits 2-7 when basic_type == kPrimitive. /// -/// Variant Encoding Spec §3.1: "Primitive types" +/// See: https://github.com/apache/parquet-format/blob/master/VariantEncoding.md#encoding-types enum class PrimitiveType : uint8_t { kNull = 0, kTrue = 1, From 8ab28f0a34d93c42a1ca8314c713b9a362672a0f Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Sat, 6 Jun 2026 17:54:30 -0700 Subject: [PATCH 3/4] GH-45947: [C++][Parquet] Variant encoding --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 2 +- cpp/src/arrow/extension/meson.build | 3 +- cpp/src/arrow/extension/variant_builder.cc | 475 +++++++ .../arrow/extension/variant_builder_test.cc | 1180 +++++++++++++++++ cpp/src/arrow/extension/variant_internal.h | 132 ++ cpp/src/arrow/meson.build | 1 + 7 files changed, 1792 insertions(+), 2 deletions(-) create mode 100644 cpp/src/arrow/extension/variant_builder.cc create mode 100644 cpp/src/arrow/extension/variant_builder_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 530d3e5ff3b8..ec076e1321a3 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -391,6 +391,7 @@ set(ARROW_SRCS extension/bool8.cc extension/json.cc extension/parquet_variant.cc + extension/variant_builder.cc extension/variant_internal.cc extension/uuid.cc pretty_print.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index 582825027c74..e66e82c4bcc6 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -16,7 +16,7 @@ # under the License. set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc - variant_internal_test.cc) + variant_internal_test.cc variant_builder_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/meson.build b/cpp/src/arrow/extension/meson.build index 6c6d3a7b67a8..e6362721041c 100644 --- a/cpp/src/arrow/extension/meson.build +++ b/cpp/src/arrow/extension/meson.build @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', 'variant_internal_test.cc'] +canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', + 'variant_internal_test.cc', 'variant_builder_test.cc'] if needs_json canonical_extension_tests += [ diff --git a/cpp/src/arrow/extension/variant_builder.cc b/cpp/src/arrow/extension/variant_builder.cc new file mode 100644 index 000000000000..eb4f6c790909 --- /dev/null +++ b/cpp/src/arrow/extension/variant_builder.cc @@ -0,0 +1,475 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_internal.h" + +#include +#include +#include + +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" + +namespace arrow::extension::variant_internal { + +namespace { + +/// \brief Compute the minimum number of bytes needed to represent a value. +/// \param[in] value Must be non-negative and fit in 4 bytes (represents a size or ID). +int32_t IntSize(int64_t value) { + DCHECK_GE(value, 0); + DCHECK_LE(value, static_cast(std::numeric_limits::max())); + if (value <= 0xFF) return 1; + if (value <= 0xFFFF) return 2; + if (value <= 0xFFFFFF) return 3; + return 4; +} + +/// \brief Write an unsigned integer in little-endian using nbytes bytes. +void WriteUnsignedLE(uint8_t* buf, int64_t value, int32_t nbytes) { + for (int32_t i = 0; i < nbytes; ++i) { + buf[i] = static_cast((value >> (i * 8)) & 0xFF); + } +} + +/// \brief Write a little-endian value into a vector at a given position. +void WriteUnsignedLEAt(std::vector& buf, int64_t pos, int64_t value, + int32_t nbytes) { + for (int32_t i = 0; i < nbytes; ++i) { + buf[pos + i] = static_cast((value >> (i * 8)) & 0xFF); + } +} + +/// \brief Construct a primitive header byte. +uint8_t MakePrimitiveHeader(PrimitiveType type) { + return static_cast(BasicType::kPrimitive) | (static_cast(type) << 2); +} + +/// \brief Write a fixed-size numeric primitive into the buffer. +template +void WritePrimitive(std::vector& buf, PrimitiveType type, T value) { + buf.push_back(MakePrimitiveHeader(type)); + value = bit_util::ToLittleEndian(value); + auto ptr = reinterpret_cast(&value); + buf.insert(buf.end(), ptr, ptr + sizeof(T)); +} + +} // namespace + +// --------------------------------------------------------------------------- +// VariantBuilder implementation +// --------------------------------------------------------------------------- + +VariantBuilder::VariantBuilder() = default; + +VariantBuilder::VariantBuilder(const VariantMetadata& existing_metadata) { + for (int32_t i = 0; i < static_cast(existing_metadata.strings.size()); ++i) { + std::string key(existing_metadata.strings[i]); + dict_[key] = static_cast(i); + dict_keys_.push_back(std::move(key)); + } +} + +uint32_t VariantBuilder::AddKey(std::string_view key) { + // Reuse lookup_buf_ to avoid per-call std::string allocation for the + // hash map lookup. In column-scan workloads (same keys repeated across + // millions of rows), this eliminates the dominant allocation cost. + // The assign() reuses the existing buffer capacity for keys that fit. + lookup_buf_.assign(key.data(), key.size()); + auto it = dict_.find(lookup_buf_); + if (it != dict_.end()) { + return it->second; + } + // Key is new — insert into the dictionary. The map and dict_keys_ each + // hold their own copy of the string (matching original behavior). + auto id = static_cast(dict_keys_.size()); + dict_[lookup_buf_] = id; + dict_keys_.push_back(std::move(lookup_buf_)); + return id; +} + +void VariantBuilder::Reset() { + buffer_.clear(); + dict_.clear(); + dict_keys_.clear(); + lookup_buf_.clear(); +} + +int64_t VariantBuilder::Offset() const { return static_cast(buffer_.size()); } + +int64_t VariantBuilder::NextElement(int64_t start) const { return Offset() - start; } + +VariantBuilder::FieldEntry VariantBuilder::NextField(int64_t start, + std::string_view key) { + auto id = AddKey(key); + return FieldEntry{std::string(key), id, Offset() - start}; +} + +// --- Primitive setters --- + +Status VariantBuilder::Null() { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kNull)); + return Status::OK(); +} + +Status VariantBuilder::Bool(bool value) { + buffer_.push_back( + MakePrimitiveHeader(value ? PrimitiveType::kTrue : PrimitiveType::kFalse)); + return Status::OK(); +} + +Status VariantBuilder::Int(int64_t value) { + if (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()) { + return Int8(static_cast(value)); + } + if (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()) { + return Int16(static_cast(value)); + } + if (value >= std::numeric_limits::min() && + value <= std::numeric_limits::max()) { + return Int32(static_cast(value)); + } + return Int64(value); +} + +Status VariantBuilder::Int8(int8_t value) { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kInt8)); + buffer_.push_back(static_cast(value)); + return Status::OK(); +} + +Status VariantBuilder::Int16(int16_t value) { + WritePrimitive(buffer_, PrimitiveType::kInt16, value); + return Status::OK(); +} + +Status VariantBuilder::Int32(int32_t value) { + WritePrimitive(buffer_, PrimitiveType::kInt32, value); + return Status::OK(); +} + +Status VariantBuilder::Int64(int64_t value) { + WritePrimitive(buffer_, PrimitiveType::kInt64, value); + return Status::OK(); +} + +Status VariantBuilder::Float(float value) { + WritePrimitive(buffer_, PrimitiveType::kFloat, value); + return Status::OK(); +} + +Status VariantBuilder::Double(double value) { + WritePrimitive(buffer_, PrimitiveType::kDouble, value); + return Status::OK(); +} + +Status VariantBuilder::Date(int32_t days_since_epoch) { + WritePrimitive(buffer_, PrimitiveType::kDate, days_since_epoch); + return Status::OK(); +} + +Status VariantBuilder::TimestampMicros(int64_t micros) { + WritePrimitive(buffer_, PrimitiveType::kTimestampMicros, micros); + return Status::OK(); +} + +Status VariantBuilder::TimestampMicrosNTZ(int64_t micros) { + WritePrimitive(buffer_, PrimitiveType::kTimestampMicrosNTZ, micros); + return Status::OK(); +} + +Status VariantBuilder::TimeNTZ(int64_t micros) { + WritePrimitive(buffer_, PrimitiveType::kTimeNTZ, micros); + return Status::OK(); +} + +Status VariantBuilder::TimestampNanos(int64_t nanos) { + WritePrimitive(buffer_, PrimitiveType::kTimestampNanos, nanos); + return Status::OK(); +} + +Status VariantBuilder::TimestampNanosNTZ(int64_t nanos) { + WritePrimitive(buffer_, PrimitiveType::kTimestampNanosNTZ, nanos); + return Status::OK(); +} + +Status VariantBuilder::Decimal4(uint8_t scale, const uint8_t* value_bytes) { + if (scale > 38) { + return Status::Invalid("Variant decimal scale must be in range [0, 38], got ", + static_cast(scale)); + } + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kDecimal4)); + buffer_.push_back(scale); + buffer_.insert(buffer_.end(), value_bytes, value_bytes + 4); + return Status::OK(); +} + +Status VariantBuilder::Decimal8(uint8_t scale, const uint8_t* value_bytes) { + if (scale > 38) { + return Status::Invalid("Variant decimal scale must be in range [0, 38], got ", + static_cast(scale)); + } + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kDecimal8)); + buffer_.push_back(scale); + buffer_.insert(buffer_.end(), value_bytes, value_bytes + 8); + return Status::OK(); +} + +Status VariantBuilder::Decimal16(uint8_t scale, const uint8_t* value_bytes) { + if (scale > 38) { + return Status::Invalid("Variant decimal scale must be in range [0, 38], got ", + static_cast(scale)); + } + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kDecimal16)); + buffer_.push_back(scale); + buffer_.insert(buffer_.end(), value_bytes, value_bytes + 16); + return Status::OK(); +} + +Status VariantBuilder::String(std::string_view value) { + if (value.size() <= 63) { + // Short string: length encoded in header bits 2-7 + uint8_t header = static_cast(BasicType::kShortString) | + (static_cast(value.size()) << 2); + buffer_.push_back(header); + } else { + // Long string: primitive type kString + 4-byte LE length + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kString)); + auto len = static_cast(value.size()); + len = bit_util::ToLittleEndian(len); + auto ptr = reinterpret_cast(&len); + buffer_.insert(buffer_.end(), ptr, ptr + 4); + } + buffer_.insert(buffer_.end(), value.begin(), value.end()); + return Status::OK(); +} + +Status VariantBuilder::Binary(std::string_view value) { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kBinary)); + auto len = static_cast(value.size()); + len = bit_util::ToLittleEndian(len); + auto ptr = reinterpret_cast(&len); + buffer_.insert(buffer_.end(), ptr, ptr + 4); + buffer_.insert(buffer_.end(), value.begin(), value.end()); + return Status::OK(); +} + +Status VariantBuilder::UUID(const uint8_t* bytes) { + buffer_.push_back(MakePrimitiveHeader(PrimitiveType::kUUID)); + buffer_.insert(buffer_.end(), bytes, bytes + 16); + return Status::OK(); +} + +// --- Container construction --- + +Status VariantBuilder::FinishArray(int64_t start, const std::vector& offsets) { + // Note: offset fields are at most 4 bytes, so individual variant values + // cannot exceed ~4GB. This is not validated here; such values are not + // practically expected (Parquet row group sizes are bounded well below this). + auto data_size = Offset() - start; + if (data_size < 0) { + return Status::Invalid("VariantBuilder::FinishArray: invalid start position"); + } + + auto num_elements = static_cast(offsets.size()); + bool is_large = num_elements > 255; + int32_t size_bytes = is_large ? 4 : 1; + int32_t offset_size = IntSize(data_size); + int64_t header_size = 1 + size_bytes + (num_elements + 1) * offset_size; + + // Validate offsets are non-negative (caller-provided) + for (int64_t i = 0; i < num_elements; ++i) { + if (offsets[i] < 0) { + return Status::Invalid("VariantBuilder::FinishArray: negative offset at index ", i); + } + } + + // Shift existing data to make room for the header + buffer_.resize(buffer_.size() + header_size); + std::memmove(buffer_.data() + start + header_size, buffer_.data() + start, data_size); + + // Write header byte + uint8_t header = static_cast(BasicType::kArray) | + (static_cast(offset_size - 1) << 2); + if (is_large) { + header |= (1 << 4); + } + buffer_[start] = header; + + // Write num_elements + WriteUnsignedLEAt(buffer_, start + 1, num_elements, size_bytes); + + // Write offsets + int64_t offset_pos = start + 1 + size_bytes; + for (int64_t i = 0; i < num_elements; ++i) { + WriteUnsignedLEAt(buffer_, offset_pos + i * offset_size, offsets[i], offset_size); + } + // Last offset = total data size + WriteUnsignedLEAt(buffer_, offset_pos + num_elements * offset_size, data_size, + offset_size); + + return Status::OK(); +} + +Status VariantBuilder::FinishObject(int64_t start, std::vector& fields) { + auto data_size = Offset() - start; + if (data_size < 0) { + return Status::Invalid("VariantBuilder::FinishObject: invalid start position"); + } + + // Sort fields by key name lexicographically (spec requirement) + std::sort(fields.begin(), fields.end(), + [](const FieldEntry& a, const FieldEntry& b) { return a.key < b.key; }); + + // Check for duplicate keys + for (size_t i = 1; i < fields.size(); ++i) { + if (fields[i].key == fields[i - 1].key) { + return Status::Invalid("VariantBuilder: duplicate key '", fields[i].key, "'"); + } + } + + auto num_fields = static_cast(fields.size()); + bool is_large = num_fields > 255; + int32_t size_bytes = is_large ? 4 : 1; + + // Compute id_size from max dictionary ID + uint32_t max_id = 0; + for (const auto& f : fields) { + max_id = std::max(max_id, f.id); + } + int32_t id_size = IntSize(static_cast(max_id)); + int32_t offset_size = IntSize(data_size); + + int64_t header_size = + 1 + size_bytes + num_fields * id_size + (num_fields + 1) * offset_size; + + // Shift existing data to make room for the header + buffer_.resize(buffer_.size() + header_size); + std::memmove(buffer_.data() + start + header_size, buffer_.data() + start, data_size); + + // Write header byte: basic_type=2, offset_size in bits 2-3, id_size in bits 4-5, + // is_large in bit 6 + uint8_t header = static_cast(BasicType::kObject) | + (static_cast(offset_size - 1) << 2) | + (static_cast(id_size - 1) << 4); + if (is_large) { + header |= (1 << 6); + } + buffer_[start] = header; + + // Write num_fields + WriteUnsignedLEAt(buffer_, start + 1, num_fields, size_bytes); + + // Write field IDs (sorted by key) + int64_t id_pos = start + 1 + size_bytes; + for (int64_t i = 0; i < num_fields; ++i) { + WriteUnsignedLEAt(buffer_, id_pos + i * id_size, fields[i].id, id_size); + } + + // Write field offsets (sorted by key) + int64_t offset_pos = id_pos + num_fields * id_size; + for (int64_t i = 0; i < num_fields; ++i) { + WriteUnsignedLEAt(buffer_, offset_pos + i * offset_size, fields[i].offset, + offset_size); + } + // Last offset = total data size + WriteUnsignedLEAt(buffer_, offset_pos + num_fields * offset_size, data_size, + offset_size); + + return Status::OK(); +} + +Result VariantBuilder::Finish() { + // Build metadata + auto num_keys = static_cast(dict_keys_.size()); + + // Compute total string data size + int64_t total_string_size = 0; + for (const auto& k : dict_keys_) { + total_string_size += static_cast(k.size()); + } + + // Validate sizes fit within the spec's 4-byte offset limit. + // Note: Go implementation enforces a stricter 128MB limit (metadataMaxSizeLimit). + // We only enforce the spec's 4-byte offset maximum (~4GB), which is the correct + // upper bound per the encoding format. + if (total_string_size > static_cast(std::numeric_limits::max())) { + return Status::Invalid( + "VariantBuilder: total dictionary string data (", total_string_size, + " bytes) exceeds maximum representable by 4-byte offsets"); + } + + // Compute the offset_size: must accommodate both the largest string offset + // (total_string_size) and the dictionary_size field itself, since both use + // offset_size bytes in the metadata encoding. + int32_t offset_size = + IntSize(std::max(total_string_size, static_cast(num_keys))); + + // Check if dictionary is sorted. + // Uniqueness is guaranteed by dict_ (AddKey prevents duplicates), + // so std::is_sorted with default < is sufficient for the "sorted and unique" + // semantics required by the spec. + // TODO: Cache the sorted state incrementally (check only newly-added keys + // against the previous last key) to avoid O(n) rescan on every Finish() call. + bool is_sorted = std::is_sorted(dict_keys_.begin(), dict_keys_.end()); + + // Build metadata buffer + std::vector metadata; + // Header byte + uint8_t meta_header = kVariantVersion; + if (is_sorted) { + meta_header |= (1 << 4); + } + meta_header |= static_cast((offset_size - 1) << 6); + metadata.push_back(meta_header); + + // Dictionary size + metadata.resize(metadata.size() + offset_size); + WriteUnsignedLE(metadata.data() + 1, num_keys, offset_size); + + // String offsets + int64_t cur_offset = 0; + for (int32_t i = 0; i <= num_keys; ++i) { + size_t pos = metadata.size(); + metadata.resize(pos + offset_size); + WriteUnsignedLE(metadata.data() + pos, cur_offset, offset_size); + if (i < num_keys) { + cur_offset += static_cast(dict_keys_[i].size()); + } + } + + // String data + for (const auto& k : dict_keys_) { + metadata.insert(metadata.end(), k.begin(), k.end()); + } + + EncodedVariant result; + result.metadata = std::move(metadata); + result.value = std::move(buffer_); + + // Note: dict_ and dict_keys_ are intentionally NOT cleared here. + // The dictionary is preserved so the builder can encode multiple values + // sharing the same key schema without re-adding keys. Call Reset() + // explicitly to clear everything. + buffer_.clear(); + + return result; +} + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_builder_test.cc b/cpp/src/arrow/extension/variant_builder_test.cc new file mode 100644 index 000000000000..8831add77cc9 --- /dev/null +++ b/cpp/src/arrow/extension/variant_builder_test.cc @@ -0,0 +1,1180 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_internal.h" +#include "arrow/extension/variant_test_util.h" + +#include +#include +#include +#include + +#include "arrow/testing/gtest_util.h" + +namespace arrow::extension::variant_internal { + +// =========================================================================== +// Helper: decode an EncodedVariant and return visitor events +// =========================================================================== + +/// Encode with builder, decode, return events. +/// Note: Uses .ValueOrDie() because ASSERT_OK_AND_ASSIGN cannot be used +/// in a non-void function. Test-only; will crash with a descriptive message +/// on failure rather than producing a clean test failure. +std::vector RoundTrip(VariantBuilder& builder) { + auto result = builder.Finish().ValueOrDie(); + auto metadata = + DecodeMetadata(result.metadata.data(), static_cast(result.metadata.size())) + .ValueOrDie(); + RecordingVisitor visitor; + DecodeVariantValue(metadata, result.value.data(), + static_cast(result.value.size()), &visitor) + .ok(); + return visitor.events; +} + +// =========================================================================== +// Primitive round-trip tests +// =========================================================================== + +class VariantBuilderPrimitiveTest : public ::testing::Test {}; + +TEST_F(VariantBuilderPrimitiveTest, Null) { + VariantBuilder b; + ASSERT_OK(b.Null()); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Null"); +} + +TEST_F(VariantBuilderPrimitiveTest, BoolTrue) { + VariantBuilder b; + ASSERT_OK(b.Bool(true)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Bool(true)"); +} + +TEST_F(VariantBuilderPrimitiveTest, BoolFalse) { + VariantBuilder b; + ASSERT_OK(b.Bool(false)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Bool(false)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt8) { + VariantBuilder b; + ASSERT_OK(b.Int(42)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(42)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(300)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(300)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt32) { + VariantBuilder b; + ASSERT_OK(b.Int(100000)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(100000)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntAutoSizesInt64) { + VariantBuilder b; + ASSERT_OK(b.Int(5000000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int64(5000000000)"); +} + +TEST_F(VariantBuilderPrimitiveTest, IntNegative) { + VariantBuilder b; + ASSERT_OK(b.Int(-42)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(-42)"); +} + +TEST_F(VariantBuilderPrimitiveTest, ShortString) { + VariantBuilder b; + ASSERT_OK(b.String("hello")); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "String(\"hello\")"); +} + +TEST_F(VariantBuilderPrimitiveTest, LongString) { + std::string long_str(100, 'x'); + VariantBuilder b; + ASSERT_OK(b.String(long_str)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "String(\"" + long_str + "\")"); +} + +TEST_F(VariantBuilderPrimitiveTest, ShortStringBoundary63) { + std::string str63(63, 'a'); + VariantBuilder b; + ASSERT_OK(b.String(str63)); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + // Should use short string encoding: 1 byte header + 63 bytes + ASSERT_EQ(result.value.size(), 64); +} + +TEST_F(VariantBuilderPrimitiveTest, LongStringBoundary64) { + std::string str64(64, 'a'); + VariantBuilder b; + ASSERT_OK(b.String(str64)); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + // Should use long string encoding: 1 byte header + 4 byte length + 64 bytes + ASSERT_EQ(result.value.size(), 69); +} + +TEST_F(VariantBuilderPrimitiveTest, Date) { + VariantBuilder b; + ASSERT_OK(b.Date(19000)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Date(19000)"); +} + +TEST_F(VariantBuilderPrimitiveTest, Double) { + VariantBuilder b; + ASSERT_OK(b.Double(3.14)); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Double(") == 0); +} + +// =========================================================================== +// Array round-trip tests +// =========================================================================== + +class VariantBuilderArrayTest : public ::testing::Test {}; + +TEST_F(VariantBuilderArrayTest, EmptyArray) { + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + ASSERT_OK(b.FinishArray(start, offsets)); + auto events = RoundTrip(b); + ASSERT_EQ(events.size(), 2); + ASSERT_EQ(events[0], "StartArray(0)"); + ASSERT_EQ(events[1], "EndArray"); +} + +TEST_F(VariantBuilderArrayTest, SimpleArray) { + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(1)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(2)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(3)); + ASSERT_OK(b.FinishArray(start, offsets)); + + auto events = RoundTrip(b); + ASSERT_EQ(events.size(), 5); + ASSERT_EQ(events[0], "StartArray(3)"); + ASSERT_EQ(events[1], "Int8(1)"); + ASSERT_EQ(events[2], "Int8(2)"); + ASSERT_EQ(events[3], "Int8(3)"); + ASSERT_EQ(events[4], "EndArray"); +} + +TEST_F(VariantBuilderArrayTest, NestedArray) { + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + + // First element: nested array [10, 20] + offsets.push_back(b.NextElement(start)); + auto inner_start = b.Offset(); + std::vector inner_offsets; + inner_offsets.push_back(b.NextElement(inner_start)); + ASSERT_OK(b.Int(10)); + inner_offsets.push_back(b.NextElement(inner_start)); + ASSERT_OK(b.Int(20)); + ASSERT_OK(b.FinishArray(inner_start, inner_offsets)); + + // Second element: 30 + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(30)); + + ASSERT_OK(b.FinishArray(start, offsets)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "StartArray(2)"); + ASSERT_EQ(events[1], "StartArray(2)"); + ASSERT_EQ(events[2], "Int8(10)"); + ASSERT_EQ(events[3], "Int8(20)"); + ASSERT_EQ(events[4], "EndArray"); + ASSERT_EQ(events[5], "Int8(30)"); + ASSERT_EQ(events[6], "EndArray"); +} + +// =========================================================================== +// Object round-trip tests +// =========================================================================== + +class VariantBuilderObjectTest : public ::testing::Test {}; + +TEST_F(VariantBuilderObjectTest, EmptyObject) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + ASSERT_OK(b.FinishObject(start, fields)); + auto events = RoundTrip(b); + ASSERT_EQ(events.size(), 2); + ASSERT_EQ(events[0], "StartObject(0)"); + ASSERT_EQ(events[1], "EndObject"); +} + +TEST_F(VariantBuilderObjectTest, SimpleObject) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "name")); + ASSERT_OK(b.String("Alice")); + fields.push_back(b.NextField(start, "age")); + ASSERT_OK(b.Int(30)); + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + // Fields sorted by key: "age" before "name" + ASSERT_EQ(events[0], "StartObject(2)"); + ASSERT_EQ(events[1], "FieldName(\"age\")"); + ASSERT_EQ(events[2], "Int8(30)"); + ASSERT_EQ(events[3], "FieldName(\"name\")"); + ASSERT_EQ(events[4], "String(\"Alice\")"); + ASSERT_EQ(events[5], "EndObject"); +} + +TEST_F(VariantBuilderObjectTest, NestedObject) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "inner")); + { + auto inner_start = b.Offset(); + std::vector inner_fields; + inner_fields.push_back(b.NextField(inner_start, "key")); + ASSERT_OK(b.String("value")); + ASSERT_OK(b.FinishObject(inner_start, inner_fields)); + } + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "StartObject(1)"); + ASSERT_EQ(events[1], "FieldName(\"inner\")"); + ASSERT_EQ(events[2], "StartObject(1)"); + ASSERT_EQ(events[3], "FieldName(\"key\")"); + ASSERT_EQ(events[4], "String(\"value\")"); + ASSERT_EQ(events[5], "EndObject"); + ASSERT_EQ(events[6], "EndObject"); +} + +TEST_F(VariantBuilderObjectTest, DuplicateKeyError) { + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.Int(2)); + ASSERT_RAISES(Invalid, b.FinishObject(start, fields)); +} + +TEST_F(VariantBuilderObjectTest, FieldsSortedByKey) { + // Insert fields in reverse order; verify they come out sorted + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "z_last")); + ASSERT_OK(b.Int(3)); + fields.push_back(b.NextField(start, "a_first")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "m_middle")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[1], "FieldName(\"a_first\")"); + ASSERT_EQ(events[2], "Int8(1)"); + ASSERT_EQ(events[3], "FieldName(\"m_middle\")"); + ASSERT_EQ(events[4], "Int8(2)"); + ASSERT_EQ(events[5], "FieldName(\"z_last\")"); + ASSERT_EQ(events[6], "Int8(3)"); +} + +// =========================================================================== +// Builder features +// =========================================================================== + +class VariantBuilderFeatureTest : public ::testing::Test {}; + +TEST_F(VariantBuilderFeatureTest, Reset) { + VariantBuilder b; + ASSERT_OK(b.Int(42)); + auto events1 = RoundTrip(b); + ASSERT_EQ(events1[0], "Int8(42)"); + + b.Reset(); + ASSERT_OK(b.String("hello")); + auto events2 = RoundTrip(b); + ASSERT_EQ(events2[0], "String(\"hello\")"); +} + +TEST_F(VariantBuilderFeatureTest, BuilderFromExistingMetadata) { + // First, build a variant to get metadata + VariantBuilder b1; + auto start = b1.Offset(); + std::vector fields; + fields.push_back(b1.NextField(start, "name")); + ASSERT_OK(b1.String("Alice")); + ASSERT_OK(b1.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded1, b1.Finish()); + + // Decode the metadata + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); + + // Build a new variant reusing the same metadata + VariantBuilder b2(meta); + auto start2 = b2.Offset(); + std::vector fields2; + fields2.push_back(b2.NextField(start2, "name")); + ASSERT_OK(b2.String("Bob")); + ASSERT_OK(b2.FinishObject(start2, fields2)); + + auto events = RoundTrip(b2); + ASSERT_EQ(events[1], "FieldName(\"name\")"); + ASSERT_EQ(events[2], "String(\"Bob\")"); +} + +TEST_F(VariantBuilderFeatureTest, MetadataSortedFlag) { + // If keys are inserted in sorted order, metadata should have sorted flag + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "alpha")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "beta")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_TRUE(meta.is_sorted); +} + +TEST_F(VariantBuilderFeatureTest, MetadataUnsortedFlag) { + // If keys are inserted out of order, sorted flag should be false + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "beta")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "alpha")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_FALSE(meta.is_sorted); +} + +// =========================================================================== +// Integration: full round-trip of complex structure +// =========================================================================== + +class VariantBuilderIntegrationTest : public ::testing::Test {}; + +TEST_F(VariantBuilderIntegrationTest, ComplexObject) { + // {"name": "Alice", "scores": [95, 87, 92], "active": true} + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + + fields.push_back(b.NextField(start, "name")); + ASSERT_OK(b.String("Alice")); + + fields.push_back(b.NextField(start, "scores")); + { + auto arr_start = b.Offset(); + std::vector arr_offsets; + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(95)); + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(87)); + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(92)); + ASSERT_OK(b.FinishArray(arr_start, arr_offsets)); + } + + fields.push_back(b.NextField(start, "active")); + ASSERT_OK(b.Bool(true)); + + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + // Fields sorted: "active", "name", "scores" + ASSERT_EQ(events[0], "StartObject(3)"); + ASSERT_EQ(events[1], "FieldName(\"active\")"); + ASSERT_EQ(events[2], "Bool(true)"); + ASSERT_EQ(events[3], "FieldName(\"name\")"); + ASSERT_EQ(events[4], "String(\"Alice\")"); + ASSERT_EQ(events[5], "FieldName(\"scores\")"); + ASSERT_EQ(events[6], "StartArray(3)"); + ASSERT_EQ(events[7], "Int8(95)"); + ASSERT_EQ(events[8], "Int8(87)"); + ASSERT_EQ(events[9], "Int8(92)"); + ASSERT_EQ(events[10], "EndArray"); + ASSERT_EQ(events[11], "EndObject"); +} + +TEST_F(VariantBuilderIntegrationTest, LargeMetadataOffsetSize) { + // Build an object with enough unique keys to trigger 2-byte metadata offsets. + // 300 keys of ~4 chars each = ~1200 bytes total string data > 255. + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + for (int i = 0; i < 300; ++i) { + std::string key = "k" + std::to_string(i); + fields.push_back(b.NextField(start, key)); + ASSERT_OK(b.Int(i)); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // Verify metadata can be decoded + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_EQ(static_cast(meta.strings.size()), 300); + // offset_size should be >= 2 (total string data > 255 bytes) + ASSERT_GE(meta.offset_size, 2); + + // Verify value can be decoded + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + // StartObject(300) + 300*(FieldName + Int8) + EndObject = 602 events + ASSERT_EQ(visitor.events.size(), 602); + ASSERT_EQ(visitor.events[0], "StartObject(300)"); + ASSERT_EQ(visitor.events[601], "EndObject"); +} + +TEST_F(VariantBuilderIntegrationTest, MetadataOffsetSizeFromKeyCount) { + // Verify that offset_size is computed from max(total_string_size, num_keys). + // Use 260 single-character keys: total_string_size=260 (>255, needs 2 bytes) + // but num_keys=260 also exceeds 255. This ensures the formula handles both. + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + // Generate 260 unique 1-char keys using characters + numeric suffixes + for (int i = 0; i < 260; ++i) { + // Use 2-char keys to guarantee uniqueness: "a0" through "z9", then "A0"... + char c1 = (i < 260) ? static_cast('a' + (i / 10) % 26) : 'A'; + char c2 = static_cast('0' + (i % 10)); + std::string key = {c1, c2}; + fields.push_back(b.NextField(start, key)); + ASSERT_OK(b.Null()); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + // 260 keys of 2 chars = 520 bytes total string data > 255, needs 2-byte offsets + ASSERT_GE(meta.offset_size, 2); + // Also verify num_keys is correctly stored + ASSERT_EQ(static_cast(meta.strings.size()), 260); + + // Verify round-trip + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + ASSERT_EQ(visitor.events[0], "StartObject(260)"); +} + +TEST_F(VariantBuilderIntegrationTest, InvalidStartPosition) { + VariantBuilder b; + ASSERT_OK(b.Int(42)); + // start=999 is beyond the buffer — should fail + std::vector offsets; + ASSERT_RAISES(Invalid, b.FinishArray(999, offsets)); + + std::vector fields; + ASSERT_RAISES(Invalid, b.FinishObject(999, fields)); +} + +TEST_F(VariantBuilderIntegrationTest, NegativeArrayOffsetRejected) { + VariantBuilder b; + auto start = b.Offset(); + ASSERT_OK(b.Int(1)); + std::vector offsets = {-1}; + ASSERT_RAISES(Invalid, b.FinishArray(start, offsets)); +} + +// =========================================================================== +// Additional primitive round-trip tests (coverage gaps) +// =========================================================================== + +class VariantBuilderPrimitiveExtraTest : public ::testing::Test {}; + +TEST_F(VariantBuilderPrimitiveExtraTest, FloatRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.Float(2.5f)); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, BinaryRoundTrip) { + std::string_view bin_data("\x00\x01\x02\x03", 4); + VariantBuilder b; + ASSERT_OK(b.Binary(bin_data)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Binary(len=4)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, EmptyBinaryRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.Binary("")); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Binary(len=0)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, UUIDRoundTrip) { + uint8_t uuid_bytes[16]; + for (int i = 0; i < 16; ++i) uuid_bytes[i] = static_cast(i + 1); + VariantBuilder b; + ASSERT_OK(b.UUID(uuid_bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "UUID"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampMicrosRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampMicros(1654041600000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampMicros(1654041600000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampMicrosNTZRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampMicrosNTZ(1654041600000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampMicrosNTZ(1654041600000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampNanosRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampNanos(1654041600000000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampNanos(1654041600000000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimestampNanosNTZRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimestampNanosNTZ(1654041600000000000LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimestampNanosNTZ(1654041600000000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, TimeNTZRoundTrip) { + VariantBuilder b; + ASSERT_OK(b.TimeNTZ(43200000000LL)); // 12:00:00 in microseconds + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "TimeNTZ(43200000000)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, Decimal4RoundTrip) { + int32_t val = 12345; + uint8_t bytes[4]; + std::memcpy(bytes, &val, 4); + VariantBuilder b; + ASSERT_OK(b.Decimal4(2, bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Decimal4(scale=2)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, Decimal8RoundTrip) { + int64_t val = 123456789012345LL; + uint8_t bytes[8]; + std::memcpy(bytes, &val, 8); + VariantBuilder b; + ASSERT_OK(b.Decimal8(5, bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Decimal8(scale=5)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, Decimal16RoundTrip) { + uint8_t bytes[16] = {}; + bytes[0] = 0x01; // value = 1 in low byte + VariantBuilder b; + ASSERT_OK(b.Decimal16(10, bytes)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Decimal16(scale=10)"); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, DecimalScaleValidation) { + uint8_t bytes[16] = {}; + VariantBuilder b; + // Scale 39 exceeds spec maximum of 38 + ASSERT_RAISES(Invalid, b.Decimal4(39, bytes)); + ASSERT_RAISES(Invalid, b.Decimal8(39, bytes)); + ASSERT_RAISES(Invalid, b.Decimal16(39, bytes)); + // Scale 38 is valid + ASSERT_OK(b.Decimal4(38, bytes)); +} + +TEST_F(VariantBuilderPrimitiveExtraTest, EmptyString) { + VariantBuilder b; + ASSERT_OK(b.String("")); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "String(\"\")"); +} + +// =========================================================================== +// Special float/double values: NaN, ±Inf +// =========================================================================== + +class VariantBuilderSpecialFloatTest : public ::testing::Test {}; + +TEST_F(VariantBuilderSpecialFloatTest, FloatNaN) { + VariantBuilder b; + ASSERT_OK(b.Float(std::numeric_limits::quiet_NaN())); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + // Verify it round-trips (NaN != NaN, so just check we get a Float event) + ASSERT_OK_AND_ASSIGN( + auto metadata, + DecodeMetadata(result.metadata.data(), static_cast(result.metadata.size()))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, result.value.data(), + static_cast(result.value.size()), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, FloatPositiveInf) { + VariantBuilder b; + ASSERT_OK(b.Float(std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, FloatNegativeInf) { + VariantBuilder b; + ASSERT_OK(b.Float(-std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Float(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, DoubleNaN) { + VariantBuilder b; + ASSERT_OK(b.Double(std::numeric_limits::quiet_NaN())); + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + ASSERT_OK_AND_ASSIGN( + auto metadata, + DecodeMetadata(result.metadata.data(), static_cast(result.metadata.size()))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(metadata, result.value.data(), + static_cast(result.value.size()), &visitor)); + ASSERT_TRUE(visitor.events[0].find("Double(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, DoublePositiveInf) { + VariantBuilder b; + ASSERT_OK(b.Double(std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Double(") == 0); +} + +TEST_F(VariantBuilderSpecialFloatTest, DoubleNegativeInf) { + VariantBuilder b; + ASSERT_OK(b.Double(-std::numeric_limits::infinity())); + auto events = RoundTrip(b); + ASSERT_TRUE(events[0].find("Double(") == 0); +} + +// =========================================================================== +// Int auto-sizing boundary tests +// =========================================================================== + +class VariantBuilderIntBoundaryTest : public ::testing::Test {}; + +TEST_F(VariantBuilderIntBoundaryTest, Int8MaxBecomesInt8) { + VariantBuilder b; + ASSERT_OK(b.Int(127)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(127)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int8MaxPlusOneBecomesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(128)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(128)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int8MinBecomesInt8) { + VariantBuilder b; + ASSERT_OK(b.Int(-128)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(-128)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int8MinMinusOneBecomesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(-129)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(-129)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int16MaxBecomesInt16) { + VariantBuilder b; + ASSERT_OK(b.Int(32767)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(32767)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int16MaxPlusOneBecomesInt32) { + VariantBuilder b; + ASSERT_OK(b.Int(32768)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(32768)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int32MaxBecomesInt32) { + VariantBuilder b; + ASSERT_OK(b.Int(2147483647LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(2147483647)"); +} + +TEST_F(VariantBuilderIntBoundaryTest, Int32MaxPlusOneBecomesInt64) { + VariantBuilder b; + ASSERT_OK(b.Int(2147483648LL)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int64(2147483648)"); +} + +// =========================================================================== +// Large array round-trip (is_large flag) +// =========================================================================== + +class VariantBuilderLargeContainerTest : public ::testing::Test {}; + +TEST_F(VariantBuilderLargeContainerTest, LargeArrayIsLarge) { + // Build an array with 300 elements (>255) to trigger is_large=true. + // This exercises the same code path as the Go bug (apache/arrow-go#839). + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + for (int i = 0; i < 300; ++i) { + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Null()); + } + ASSERT_OK(b.FinishArray(start, offsets)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // Verify the header byte has is_large set correctly + ASSERT_FALSE(encoded.value.empty()); + uint8_t header = encoded.value[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kArray); + // is_large at bit 4 of full byte + ASSERT_TRUE(((header >> 4) & 0x01) != 0); + + // Verify round-trip: decode and check element count + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + // StartArray(300) + 300 Nulls + EndArray = 302 events + ASSERT_EQ(visitor.events.size(), 302); + ASSERT_EQ(visitor.events[0], "StartArray(300)"); + ASSERT_EQ(visitor.events[301], "EndArray"); + + // Also verify ValueSize works correctly on this large array + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(size, static_cast(encoded.value.size())); +} + +TEST_F(VariantBuilderLargeContainerTest, LargeObjectIsLarge) { + // Build an object with 300 fields (>255) to trigger is_large=true. + // Verifies that the encoder correctly sets is_large at bit 6 of the + // full header byte (bit 4 of the 6-bit type_info / value_header). + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + for (int i = 0; i < 300; ++i) { + std::string key = "field_" + std::to_string(i); + fields.push_back(b.NextField(start, key)); + ASSERT_OK(b.Null()); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // Verify the header byte has is_large set correctly at bit 6 + ASSERT_FALSE(encoded.value.empty()); + uint8_t header = encoded.value[0]; + ASSERT_EQ(GetBasicType(header), BasicType::kObject); + // Object is_large at bit 6 of full byte (bit 4 of type_info) + ASSERT_TRUE(((header >> 6) & 0x01) != 0); + + // Verify round-trip: decode and check field count + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto field_count, + GetObjectFieldCount(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(field_count, 300); + + // Verify full decode + RecordingVisitor visitor; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data(), + static_cast(encoded.value.size()), &visitor)); + // StartObject(300) + 300*(FieldName + Null) + EndObject = 602 events + ASSERT_EQ(visitor.events.size(), 602); + ASSERT_EQ(visitor.events[0], "StartObject(300)"); + ASSERT_EQ(visitor.events[601], "EndObject"); + + // Verify ValueSize matches buffer size + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(size, static_cast(encoded.value.size())); +} + +// =========================================================================== +// Decoder utility round-trips through builder output +// =========================================================================== + +class VariantBuilderDecoderUtilTest : public ::testing::Test {}; + +TEST_F(VariantBuilderDecoderUtilTest, FindObjectFieldOnBuilderOutput) { + // Build {alpha: 1, beta: "two", gamma: true} and verify FindObjectField works + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "alpha")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "beta")); + ASSERT_OK(b.String("two")); + fields.push_back(b.NextField(start, "gamma")); + ASSERT_OK(b.Bool(true)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + + // Find "beta" + int64_t field_offset = -1, field_size = 0; + ASSERT_OK(FindObjectField(meta, encoded.value.data(), + static_cast(encoded.value.size()), "beta", + &field_offset, &field_size)); + ASSERT_GT(field_offset, 0); + ASSERT_GT(field_size, 0); + + // Decode the field value + RecordingVisitor v; + ASSERT_OK( + DecodeVariantValue(meta, encoded.value.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "String(\"two\")"); + + // Find non-existent key + int64_t nf_offset = -1, nf_size = 0; + ASSERT_OK(FindObjectField(meta, encoded.value.data(), + static_cast(encoded.value.size()), "missing", + &nf_offset, &nf_size)); + ASSERT_EQ(nf_offset, -1); +} + +TEST_F(VariantBuilderDecoderUtilTest, GetArrayElementOnBuilderOutput) { + // Build [10, 20, 30] and verify GetArrayElement works + VariantBuilder b; + auto start = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(10)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(20)); + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Int(30)); + ASSERT_OK(b.FinishArray(start, offsets)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + + // Access element at index 2 + int64_t elem_offset = 0, elem_size = 0; + ASSERT_OK(GetArrayElement(encoded.value.data(), + static_cast(encoded.value.size()), 2, &elem_offset, + &elem_size)); + ASSERT_GT(elem_offset, 0); + ASSERT_EQ(elem_size, 2); // Int8(30) = 2 bytes + + RecordingVisitor v; + ASSERT_OK(DecodeVariantValue(meta, encoded.value.data() + elem_offset, elem_size, &v)); + ASSERT_EQ(v.events[0], "Int8(30)"); +} + +TEST_F(VariantBuilderDecoderUtilTest, GetObjectFieldAtOnBuilderOutput) { + // Build {x: 100, y: 200} and access by positional index + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "x")); + ASSERT_OK(b.Int(100)); + fields.push_back(b.NextField(start, "y")); + ASSERT_OK(b.Int(200)); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(encoded.metadata.data(), + static_cast(encoded.metadata.size()))); + + // Fields are sorted by key: "x" at index 0, "y" at index 1 + std::string_view field_name; + int64_t field_offset = 0, field_size = 0; + ASSERT_OK(GetObjectFieldAt(meta, encoded.value.data(), + static_cast(encoded.value.size()), 0, &field_name, + &field_offset, &field_size)); + ASSERT_EQ(field_name, "x"); + + RecordingVisitor v; + ASSERT_OK( + DecodeVariantValue(meta, encoded.value.data() + field_offset, field_size, &v)); + ASSERT_EQ(v.events[0], "Int8(100)"); +} + +TEST_F(VariantBuilderDecoderUtilTest, ValueSizeOnBuilderOutput) { + // Build a nested structure and verify ValueSize matches buffer size + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "data")); + { + auto arr_start = b.Offset(); + std::vector arr_offsets; + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.String("hello")); + arr_offsets.push_back(b.NextElement(arr_start)); + ASSERT_OK(b.Int(42)); + ASSERT_OK(b.FinishArray(arr_start, arr_offsets)); + } + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto encoded, b.Finish()); + + // ValueSize of the top-level value should equal the total buffer size + ASSERT_OK_AND_ASSIGN(auto size, ValueSize(encoded.value.data(), + static_cast(encoded.value.size()))); + ASSERT_EQ(size, static_cast(encoded.value.size())); +} + +// =========================================================================== +// Direct integer type method tests (verify explicit types not auto-sized) +// =========================================================================== + +class VariantBuilderDirectIntTest : public ::testing::Test {}; + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt8) { + VariantBuilder b; + ASSERT_OK(b.Int8(42)); + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int8(42)"); +} + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt16) { + VariantBuilder b; + ASSERT_OK(b.Int16(42)); // Would be Int8 if auto-sized + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int16(42)"); +} + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt32) { + VariantBuilder b; + ASSERT_OK(b.Int32(42)); // Would be Int8 if auto-sized + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int32(42)"); +} + +TEST_F(VariantBuilderDirectIntTest, ExplicitInt64) { + VariantBuilder b; + ASSERT_OK(b.Int64(42)); // Would be Int8 if auto-sized + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "Int64(42)"); +} + +// =========================================================================== +// Builder reuse: multiple Finish() calls with preserved dictionary +// =========================================================================== + +class VariantBuilderReuseTest : public ::testing::Test {}; + +TEST_F(VariantBuilderReuseTest, MultipleFinishPreservesDictionary) { + VariantBuilder b; + + // Build first value: {name: "Alice"} + auto start1 = b.Offset(); + std::vector fields1; + fields1.push_back(b.NextField(start1, "name")); + ASSERT_OK(b.String("Alice")); + ASSERT_OK(b.FinishObject(start1, fields1)); + ASSERT_OK_AND_ASSIGN(auto encoded1, b.Finish()); + + // Build second value: {name: "Bob"} — reuses dictionary from first build + auto start2 = b.Offset(); + std::vector fields2; + fields2.push_back(b.NextField(start2, "name")); + ASSERT_OK(b.String("Bob")); + ASSERT_OK(b.FinishObject(start2, fields2)); + ASSERT_OK_AND_ASSIGN(auto encoded2, b.Finish()); + + // Verify first value decodes correctly + ASSERT_OK_AND_ASSIGN( + auto meta1, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); + RecordingVisitor v1; + ASSERT_OK(DecodeVariantValue(meta1, encoded1.value.data(), + static_cast(encoded1.value.size()), &v1)); + ASSERT_EQ(v1.events[1], "FieldName(\"name\")"); + ASSERT_EQ(v1.events[2], "String(\"Alice\")"); + + // Verify second value decodes correctly + ASSERT_OK_AND_ASSIGN( + auto meta2, + DecodeMetadata(encoded2.metadata.data(), + static_cast(encoded2.metadata.size()))); + RecordingVisitor v2; + ASSERT_OK(DecodeVariantValue(meta2, encoded2.value.data(), + static_cast(encoded2.value.size()), &v2)); + ASSERT_EQ(v2.events[1], "FieldName(\"name\")"); + ASSERT_EQ(v2.events[2], "String(\"Bob\")"); + + // Both should have the same dictionary content (same metadata structure) + ASSERT_EQ(meta1.strings.size(), meta2.strings.size()); + ASSERT_EQ(meta1.strings[0], "name"); + ASSERT_EQ(meta2.strings[0], "name"); +} + +TEST_F(VariantBuilderReuseTest, DictionaryGrowsAcrossFinishCalls) { + VariantBuilder b; + + // Build first value with key "x" + auto start1 = b.Offset(); + std::vector fields1; + fields1.push_back(b.NextField(start1, "x")); + ASSERT_OK(b.Int(1)); + ASSERT_OK(b.FinishObject(start1, fields1)); + ASSERT_OK_AND_ASSIGN(auto encoded1, b.Finish()); + + // Build second value with keys "x" and "y" — dictionary should grow + auto start2 = b.Offset(); + std::vector fields2; + fields2.push_back(b.NextField(start2, "x")); + ASSERT_OK(b.Int(2)); + fields2.push_back(b.NextField(start2, "y")); + ASSERT_OK(b.Int(3)); + ASSERT_OK(b.FinishObject(start2, fields2)); + ASSERT_OK_AND_ASSIGN(auto encoded2, b.Finish()); + + // First metadata has 1 key + ASSERT_OK_AND_ASSIGN( + auto meta1, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); + ASSERT_EQ(meta1.strings.size(), 1); + + // Second metadata has 2 keys (dictionary grew) + ASSERT_OK_AND_ASSIGN( + auto meta2, + DecodeMetadata(encoded2.metadata.data(), + static_cast(encoded2.metadata.size()))); + ASSERT_EQ(meta2.strings.size(), 2); + + // Verify second value decodes correctly + RecordingVisitor v2; + ASSERT_OK(DecodeVariantValue(meta2, encoded2.value.data(), + static_cast(encoded2.value.size()), &v2)); + ASSERT_EQ(v2.events[0], "StartObject(2)"); + // Fields sorted: "x" before "y" + ASSERT_EQ(v2.events[1], "FieldName(\"x\")"); + ASSERT_EQ(v2.events[2], "Int8(2)"); + ASSERT_EQ(v2.events[3], "FieldName(\"y\")"); + ASSERT_EQ(v2.events[4], "Int8(3)"); +} + +// =========================================================================== +// Edge case: FinishObject/FinishArray with pre-existing buffer content +// =========================================================================== + +class VariantBuilderPreExistingBufferTest : public ::testing::Test {}; + +TEST_F(VariantBuilderPreExistingBufferTest, ObjectAfterPrimitive) { + // Write a primitive value first, then build an object. This exercises + // the case where start > 0 (data_size = buffer.size() - start). + // The builder is designed for single top-level values, but this tests + // the internal arithmetic correctness. + VariantBuilder b; + // Write a "prefix" value that occupies buffer space before our object + ASSERT_OK(b.Int(99)); + int64_t prefix_size = b.Offset(); // should be 2 (Int8 header + 1 byte) + ASSERT_EQ(prefix_size, 2); + + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.String("val")); + ASSERT_OK(b.FinishObject(start, fields)); + + // The buffer now contains [Int8(99)] + [Object{key: "val"}]. + // We can't call Finish() meaningfully for a two-value buffer, + // but verify no crash or corruption occurred and the object portion + // is correctly sized. + ASSERT_GT(b.Offset(), prefix_size); +} + +TEST_F(VariantBuilderPreExistingBufferTest, ArrayAfterPrimitive) { + // Same as above but for arrays. + VariantBuilder b; + ASSERT_OK(b.Int(99)); + int64_t prefix_size = b.Offset(); + + auto start = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(start)); + ASSERT_OK(b.Null()); + ASSERT_OK(b.FinishArray(start, offsets)); + + ASSERT_GT(b.Offset(), prefix_size); +} + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal.h b/cpp/src/arrow/extension/variant_internal.h index 87354b50600d..6a61085069e5 100644 --- a/cpp/src/arrow/extension/variant_internal.h +++ b/cpp/src/arrow/extension/variant_internal.h @@ -18,7 +18,9 @@ #pragma once #include +#include #include +#include #include #include "arrow/result.h" @@ -344,4 +346,134 @@ ARROW_EXPORT Status GetObjectFieldAt(const VariantMetadata& metadata, const uint ARROW_EXPORT int32_t FindMetadataKey(const VariantMetadata& metadata, std::string_view key); +// --------------------------------------------------------------------------- +// Variant Builder (Encoder) +// --------------------------------------------------------------------------- + +/// \brief Builder for constructing Variant binary values. +/// +/// Mirrors the Go implementation's Builder pattern. Values are written +/// into an internal buffer; containers (objects/arrays) use a start-offset +/// + finish pattern that shifts data to insert headers. +/// +/// Usage: +/// VariantBuilder builder; +/// auto start = builder.Offset(); +/// std::vector fields; +/// fields.push_back(builder.NextField(start, "name")); +/// builder.String("Alice"); +/// fields.push_back(builder.NextField(start, "age")); +/// builder.Int(30); +/// builder.FinishObject(start, fields); +/// ARROW_ASSIGN_OR_RAISE(auto result, builder.Finish()); +class ARROW_EXPORT VariantBuilder { + public: + VariantBuilder(); + explicit VariantBuilder(const VariantMetadata& existing_metadata); + ~VariantBuilder() = default; + + VariantBuilder(VariantBuilder&&) noexcept = default; + VariantBuilder& operator=(VariantBuilder&&) noexcept = default; + + // Non-copyable (owns dictionary state) + VariantBuilder(const VariantBuilder&) = delete; + VariantBuilder& operator=(const VariantBuilder&) = delete; + + /// @name Primitive value setters + /// @{ + Status Null(); + Status Bool(bool value); + Status Int(int64_t value); ///< Auto-selects smallest int type + Status Int8(int8_t value); + Status Int16(int16_t value); + Status Int32(int32_t value); + Status Int64(int64_t value); + Status Float(float value); + Status Double(double value); + Status Decimal4(uint8_t scale, const uint8_t* value_bytes); + Status Decimal8(uint8_t scale, const uint8_t* value_bytes); + Status Decimal16(uint8_t scale, const uint8_t* value_bytes); + Status Date(int32_t days_since_epoch); + Status TimestampMicros(int64_t micros); + Status TimestampMicrosNTZ(int64_t micros); + Status TimeNTZ(int64_t micros); + Status TimestampNanos(int64_t nanos); + Status TimestampNanosNTZ(int64_t nanos); + Status String(std::string_view value); ///< Auto short-string for <=63 bytes + Status Binary(std::string_view value); + Status UUID(const uint8_t* bytes); + /// @} + + /// @name Container construction + /// @{ + + /// \brief Current buffer offset. Use as the start of a container. + int64_t Offset() const; + + /// \brief Compute the next element offset relative to start (for arrays). + int64_t NextElement(int64_t start) const; + + /// \brief A field entry for object construction. + struct FieldEntry { + std::string key; + uint32_t id; + int64_t offset; + }; + + /// \brief Record the next field for an object (adds key to dictionary). + FieldEntry NextField(int64_t start, std::string_view key); + + /// \brief Finalize an array value in the buffer. + Status FinishArray(int64_t start, const std::vector& offsets); + + /// \brief Finalize an object value. Sorts fields in-place by key, + /// then rejects duplicate keys. + /// + /// The fields vector is modified: entries are reordered to + /// lexicographic key order per the spec requirement that field IDs + /// and offsets are listed in lexicographic order of keys. + Status FinishObject(int64_t start, std::vector& fields); + /// @} + + /// @name Output + /// @{ + + /// \brief Encoded output produced by Finish(). + struct EncodedVariant { + std::vector metadata; + std::vector value; + }; + + /// \brief Finalize and produce the encoded metadata + value buffers. + /// \note The internal dictionary is preserved after Finish() so subsequent + /// values can share the same key schema. Call Reset() to clear all state. + Result Finish(); + + /// \brief Reset the builder for reuse. + void Reset(); + /// @} + + // TODO GH-45948: Add BuildWithoutMeta() to return raw value bytes without + // metadata, needed for shredded variant encoding. + + // TODO GH-45948: Add UnsafeAppendEncoded(const uint8_t* data, int64_t size) + // to append pre-encoded variant value bytes for composition/shredding. + + // TODO GH-45948: Add SetAllowDuplicates(bool) for duplicate key tolerance + // with last-value-wins semantics (uses ValueSize for compaction). + + private: + /// \brief Add a key to the dictionary, returning its ID. + uint32_t AddKey(std::string_view key); + + std::vector buffer_; + std::unordered_map dict_; + std::vector dict_keys_; + // Reusable buffer for dictionary lookups. Avoids allocating a new std::string + // on every AddKey() call when the key already exists (the common case in + // column-scan workloads where the same field names repeat across rows). + // Only used inside AddKey(); not part of the builder's logical state. + std::string lookup_buf_; +}; + } // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index dc16985255c2..d8c81b868fe4 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -142,6 +142,7 @@ arrow_components = { 'extension/bool8.cc', 'extension/json.cc', 'extension/parquet_variant.cc', + 'extension/variant_builder.cc', 'extension/variant_internal.cc', 'extension/uuid.cc', 'pretty_print.cc', From c92cb110b08a2677786642e40be84c78a35514de Mon Sep 17 00:00:00 2001 From: Jared Yu Date: Tue, 9 Jun 2026 20:00:51 -0700 Subject: [PATCH 4/4] GH-45948: [C++][Parquet] Variant shredding Implements variant shredding infrastructure per the VariantShredding spec, achieving functional parity with arrow-rs parquet-variant-compute for the primitive shredding path. == VariantBuilder extensions == - BuildWithoutMeta(): raw value bytes without metadata (shared metadata) - UnsafeAppendEncoded(data, size): zero-copy append of pre-encoded bytes - SetAllowDuplicates(bool): last-value-wins duplicate key compaction - FinishObject() updated with deterministic duplicate resolution == VariantExtensionType evolution == - Supports shredded storage: struct{metadata, value?, typed_value?} - IsSupportedStorageType() accepts both unshredded (2 required fields) and shredded (required metadata + optional value + optional typed_value) - Added typed_value() accessor and is_shredded() query - Constructor finds fields by name (not position) for robustness == Shredding schema + type compatibility == - VariantShreddingSchema: Primitive/Object/Array tree schema - ToArrowType(): converts to Arrow DataType with proper struct wrapping - IsVariantCompatibleWithType(): 21-type compatibility matrix with widening rules (Int8->Int64, Float->Double, etc.) == Shredding kernel (ShredVariantColumn) == - Per-row decode + type-match + routing to typed_value or residual value - Compatible values routed to typed_value column (value set null) - Incompatible values kept in value column (typed_value set null) - Handles Variant Null as compatible with any target type == Reconstruction kernel (ReconstructVariantColumn) == - Merges typed_value and residual value back into complete variant binary - Handles all 4 states: (null,null)=missing, (v,null)=unshredded, (null,t)=shredded, (v,t)=partial object (errors for primitives) == Tests == - 10 builder extension tests (BuildWithoutMeta, UnsafeAppendEncoded, SetAllowDuplicates with various scenarios) - 20 type compatibility tests covering all primitive types - 6 schema construction tests - 4 shred/reconstruct round-trip tests proving identity: Reconstruct(Shred(v)) == v for matching, mixed, and null values == Remaining work (object/array shredding, Parquet bridge) == - Object shredding: field-level routing + residual object encoding - Array shredding: element-wise shredding with list builders - Parquet bridge: VariantToNode/NodeToArrow in schema.cc - Native typed_value extraction (currently stores variant bytes; full native-type extraction deferred to variant_get kernel) --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/extension/CMakeLists.txt | 3 +- cpp/src/arrow/extension/meson.build | 4 +- cpp/src/arrow/extension/parquet_variant.cc | 102 +- cpp/src/arrow/extension/parquet_variant.h | 17 +- cpp/src/arrow/extension/variant_builder.cc | 72 +- .../arrow/extension/variant_builder_test.cc | 230 +- cpp/src/arrow/extension/variant_internal.h | 43 +- .../arrow/extension/variant_internal_test.cc | 35 +- cpp/src/arrow/extension/variant_shredding.cc | 2116 ++++++++++++++++ cpp/src/arrow/extension/variant_shredding.h | 199 ++ .../arrow/extension/variant_shredding_test.cc | 2118 +++++++++++++++++ cpp/src/arrow/meson.build | 1 + 13 files changed, 4836 insertions(+), 105 deletions(-) create mode 100644 cpp/src/arrow/extension/variant_shredding.cc create mode 100644 cpp/src/arrow/extension/variant_shredding.h create mode 100644 cpp/src/arrow/extension/variant_shredding_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index ec076e1321a3..2143ba31b032 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -393,6 +393,7 @@ set(ARROW_SRCS extension/parquet_variant.cc extension/variant_builder.cc extension/variant_internal.cc + extension/variant_shredding.cc extension/uuid.cc pretty_print.cc record_batch.cc diff --git a/cpp/src/arrow/extension/CMakeLists.txt b/cpp/src/arrow/extension/CMakeLists.txt index e66e82c4bcc6..14986b19c163 100644 --- a/cpp/src/arrow/extension/CMakeLists.txt +++ b/cpp/src/arrow/extension/CMakeLists.txt @@ -16,7 +16,8 @@ # under the License. set(CANONICAL_EXTENSION_TESTS bool8_test.cc json_test.cc uuid_test.cc - variant_internal_test.cc variant_builder_test.cc) + variant_internal_test.cc variant_builder_test.cc + variant_shredding_test.cc) if(ARROW_JSON) list(APPEND CANONICAL_EXTENSION_TESTS tensor_extension_array_test.cc opaque_test.cc) diff --git a/cpp/src/arrow/extension/meson.build b/cpp/src/arrow/extension/meson.build index e6362721041c..a3a33886258d 100644 --- a/cpp/src/arrow/extension/meson.build +++ b/cpp/src/arrow/extension/meson.build @@ -16,7 +16,8 @@ # under the License. canonical_extension_tests = ['bool8_test.cc', 'json_test.cc', 'uuid_test.cc', - 'variant_internal_test.cc', 'variant_builder_test.cc'] + 'variant_internal_test.cc', 'variant_builder_test.cc', + 'variant_shredding_test.cc'] if needs_json canonical_extension_tests += [ @@ -44,5 +45,6 @@ install_headers( # variant_internal.h: public API for variant binary encoding/decoding. # "internal" refers to the binary encoding internals, not visibility. 'variant_internal.h', + 'variant_shredding.h', ], ) diff --git a/cpp/src/arrow/extension/parquet_variant.cc b/cpp/src/arrow/extension/parquet_variant.cc index 95aa5a0eb68e..3c09e7d1141a 100644 --- a/cpp/src/arrow/extension/parquet_variant.cc +++ b/cpp/src/arrow/extension/parquet_variant.cc @@ -28,18 +28,20 @@ namespace arrow::extension { VariantExtensionType::VariantExtensionType(const std::shared_ptr& storage_type) : ExtensionType(storage_type) { - // GH-45948: Shredded variants will need to handle an optional shredded_value as - // well as value_ becoming optional. - - // IsSupportedStorageType should have been called already, asserting that both - // metadata and value are present. - if (storage_type->field(0)->name() == "metadata") { - metadata_ = storage_type->field(0); - value_ = storage_type->field(1); - } else { - value_ = storage_type->field(0); - metadata_ = storage_type->field(1); + // Find fields by name (ordering does not matter per spec). + for (int i = 0; i < storage_type->num_fields(); ++i) { + const auto& f = storage_type->field(i); + if (f->name() == "metadata") { + metadata_ = f; + } else if (f->name() == "value") { + value_ = f; + } else if (f->name() == "typed_value") { + typed_value_ = f; + } } + // IsSupportedStorageType() should have been called before construction. + DCHECK_NE(metadata_, nullptr); + DCHECK_NE(value_, nullptr); } bool VariantExtensionType::ExtensionEquals(const ExtensionType& other) const { @@ -71,35 +73,52 @@ bool IsBinaryField(const std::shared_ptr field) { bool VariantExtensionType::IsSupportedStorageType( const std::shared_ptr& storage_type) { - // For now we only supported unshredded variants. Unshredded variant storage - // type should be a struct with a binary metadata and binary value. - // - // GH-45948: In shredded variants, the binary value field can be replaced - // with one or more of the following: object, array, typed_value, and - // variant_value. - if (storage_type->id() == Type::STRUCT) { - if (storage_type->num_fields() == 2) { - // Ordering of metadata and value fields does not matter, as we will assign - // these to the VariantExtensionType's member shared_ptrs in the constructor. - // Here we just need to check that they are both present. - - const auto& field0 = storage_type->field(0); - const auto& field1 = storage_type->field(1); - - bool metadata_and_value_present = - (field0->name() == "metadata" && field1->name() == "value") || - (field1->name() == "metadata" && field0->name() == "value"); - - if (metadata_and_value_present) { - // Both metadata and value must be non-nullable binary types for unshredded - // variants. This will change in GH-46948, when we will require a Visitor - // to traverse the structure of the variant. - return IsBinaryField(field0) && IsBinaryField(field1) && !field0->nullable() && - !field1->nullable(); - } + if (storage_type->id() != Type::STRUCT) { + return false; + } + + // Find fields by name + std::shared_ptr metadata_field; + std::shared_ptr value_field; + std::shared_ptr typed_value_field; + + for (int i = 0; i < storage_type->num_fields(); ++i) { + const auto& f = storage_type->field(i); + if (f->name() == "metadata") { + metadata_field = f; + } else if (f->name() == "value") { + value_field = f; + } else if (f->name() == "typed_value") { + typed_value_field = f; } } + // metadata is always required and must be binary-like + if (!metadata_field || !IsBinaryField(metadata_field)) { + return false; + } + + // Unshredded: required metadata + required value (both binary) + if (value_field && !typed_value_field) { + return IsBinaryField(value_field) && !metadata_field->nullable() && + !value_field->nullable(); + } + + // Shredded: required metadata + optional value + optional typed_value + if (value_field && typed_value_field) { + // metadata must be non-nullable, value must be nullable binary, + // typed_value must be nullable (any type) + return !metadata_field->nullable() && IsBinaryField(value_field) && + value_field->nullable() && typed_value_field->nullable(); + } + + // NOTE: The shredding spec allows leaf schemas where `value` is absent + // (typed_value only, for fully-shredded columns with no residual). We + // reject this case for now because the current shredding implementation + // always produces a `value` column. Supporting value-absent schemas + // requires changes to ShredVariantColumn/ReconstructVariantColumn to + // handle the missing residual path. This can be added in a follow-up + // when Parquet reader integration requires it. return false; } @@ -113,9 +132,12 @@ Result> VariantExtensionType::Make( return std::make_shared(std::move(storage_type)); } -/// NOTE: this is still experimental. GH-45948 will add shredding support, at which point -/// we need to separate this into unshredded_variant and shredded_variant helper -/// functions. +/// \brief Return a VariantExtensionType instance. +/// +/// Supports both unshredded and shredded storage types: +/// - Unshredded: struct{required binary metadata, required binary value} +/// - Shredded: struct{required binary metadata, optional binary value, +/// optional typed_value} std::shared_ptr variant(std::shared_ptr storage_type) { return VariantExtensionType::Make(std::move(storage_type)).ValueOrDie(); } diff --git a/cpp/src/arrow/extension/parquet_variant.h b/cpp/src/arrow/extension/parquet_variant.h index be90923f14e6..efa535a73d49 100644 --- a/cpp/src/arrow/extension/parquet_variant.h +++ b/cpp/src/arrow/extension/parquet_variant.h @@ -40,6 +40,13 @@ class ARROW_EXPORT VariantArray : public ExtensionArray { /// required binary value; /// } /// +/// Shredded variant representation: +/// optional group variant_name (VARIANT) { +/// required binary metadata; +/// optional binary value; +/// optional typed_value; +/// } +/// /// To read more about variant encoding, see the variant encoding spec at /// https://github.com/apache/parquet-format/blob/master/VariantEncoding.md /// @@ -69,10 +76,16 @@ class ARROW_EXPORT VariantExtensionType : public ExtensionType { std::shared_ptr value() const { return value_; } + /// \brief The typed_value field, or nullptr if unshredded. + std::shared_ptr typed_value() const { return typed_value_; } + + /// \brief Whether this variant has a shredded typed_value column. + bool is_shredded() const { return typed_value_ != nullptr; } + private: - // TODO GH-45948 added shredded_value std::shared_ptr metadata_; - std::shared_ptr value_; + std::shared_ptr value_; // nullable when shredded + std::shared_ptr typed_value_; // nullptr if unshredded }; /// \brief Return a VariantExtensionType instance. diff --git a/cpp/src/arrow/extension/variant_builder.cc b/cpp/src/arrow/extension/variant_builder.cc index eb4f6c790909..7fe7abf5b4e6 100644 --- a/cpp/src/arrow/extension/variant_builder.cc +++ b/cpp/src/arrow/extension/variant_builder.cc @@ -333,14 +333,34 @@ Status VariantBuilder::FinishObject(int64_t start, std::vector& fiel return Status::Invalid("VariantBuilder::FinishObject: invalid start position"); } - // Sort fields by key name lexicographically (spec requirement) - std::sort(fields.begin(), fields.end(), - [](const FieldEntry& a, const FieldEntry& b) { return a.key < b.key; }); - - // Check for duplicate keys - for (size_t i = 1; i < fields.size(); ++i) { - if (fields[i].key == fields[i - 1].key) { - return Status::Invalid("VariantBuilder: duplicate key '", fields[i].key, "'"); + // Sort fields by key name lexicographically (spec requirement). + // For allow_duplicates_ mode, use a secondary comparison on offset + // so that among duplicates, the last-inserted value (highest offset) is last. + std::sort(fields.begin(), fields.end(), [](const FieldEntry& a, const FieldEntry& b) { + if (a.key != b.key) return a.key < b.key; + return a.offset < b.offset; // stable ordering for duplicates + }); + + // Handle duplicate keys + if (allow_duplicates_) { + // Last-value-wins compaction: among sorted duplicates (ascending offset), + // the last one has the highest offset = most recently inserted value. + // Reverse-iterate to collect unique keys keeping the last occurrence. + std::vector compacted; + compacted.reserve(fields.size()); + for (auto it = fields.rbegin(); it != fields.rend(); ++it) { + if (compacted.empty() || compacted.back().key != it->key) { + compacted.push_back(std::move(*it)); + } + } + std::reverse(compacted.begin(), compacted.end()); + fields = std::move(compacted); + } else { + // Strict mode: reject duplicates + for (size_t i = 1; i < fields.size(); ++i) { + if (fields[i].key == fields[i - 1].key) { + return Status::Invalid("VariantBuilder: duplicate key '", fields[i].key, "'"); + } } } @@ -405,14 +425,12 @@ Result VariantBuilder::Finish() { total_string_size += static_cast(k.size()); } - // Validate sizes fit within the spec's 4-byte offset limit. - // Note: Go implementation enforces a stricter 128MB limit (metadataMaxSizeLimit). - // We only enforce the spec's 4-byte offset maximum (~4GB), which is the correct - // upper bound per the encoding format. + // Validate sizes fit within the spec's 4-byte offset limit (~4GB). + // This is the correct upper bound per the encoding format. if (total_string_size > static_cast(std::numeric_limits::max())) { - return Status::Invalid( - "VariantBuilder: total dictionary string data (", total_string_size, - " bytes) exceeds maximum representable by 4-byte offsets"); + return Status::Invalid("VariantBuilder: total dictionary string data (", + total_string_size, + " bytes) exceeds maximum representable by 4-byte offsets"); } // Compute the offset_size: must accommodate both the largest string offset @@ -472,4 +490,28 @@ Result VariantBuilder::Finish() { return result; } +Result> VariantBuilder::BuildWithoutMeta() { + if (buffer_.empty()) { + return Status::Invalid("VariantBuilder::BuildWithoutMeta: no value has been written"); + } + std::vector result = std::move(buffer_); + // After std::move, buffer_ is in a valid-but-unspecified state. + // Explicit clear() ensures a deterministic empty state for reuse. + buffer_.clear(); + return result; +} + +void VariantBuilder::UnsafeAppendEncoded(const uint8_t* data, int64_t size) { + // Callers must provide valid, non-empty variant bytes. The DCHECK guards + // catch programming errors in debug builds; in release builds, a zero-size + // append is a no-op (defensive against malformed input flowing through + // reconstruction paths). + DCHECK_NE(data, nullptr); + DCHECK_GT(size, 0); + if (size <= 0) return; + buffer_.insert(buffer_.end(), data, data + size); +} + +void VariantBuilder::SetAllowDuplicates(bool allow) { allow_duplicates_ = allow; } + } // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_builder_test.cc b/cpp/src/arrow/extension/variant_builder_test.cc index 8831add77cc9..f8100e0142f5 100644 --- a/cpp/src/arrow/extension/variant_builder_test.cc +++ b/cpp/src/arrow/extension/variant_builder_test.cc @@ -665,9 +665,9 @@ TEST_F(VariantBuilderSpecialFloatTest, FloatNaN) { ASSERT_OK(b.Float(std::numeric_limits::quiet_NaN())); ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); // Verify it round-trips (NaN != NaN, so just check we get a Float event) - ASSERT_OK_AND_ASSIGN( - auto metadata, - DecodeMetadata(result.metadata.data(), static_cast(result.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(result.metadata.data(), + static_cast(result.metadata.size()))); RecordingVisitor visitor; ASSERT_OK(DecodeVariantValue(metadata, result.value.data(), static_cast(result.value.size()), &visitor)); @@ -692,9 +692,9 @@ TEST_F(VariantBuilderSpecialFloatTest, DoubleNaN) { VariantBuilder b; ASSERT_OK(b.Double(std::numeric_limits::quiet_NaN())); ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); - ASSERT_OK_AND_ASSIGN( - auto metadata, - DecodeMetadata(result.metadata.data(), static_cast(result.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto metadata, + DecodeMetadata(result.metadata.data(), + static_cast(result.metadata.size()))); RecordingVisitor visitor; ASSERT_OK(DecodeVariantValue(metadata, result.value.data(), static_cast(result.value.size()), &visitor)); @@ -1058,10 +1058,9 @@ TEST_F(VariantBuilderReuseTest, MultipleFinishPreservesDictionary) { ASSERT_OK_AND_ASSIGN(auto encoded2, b.Finish()); // Verify first value decodes correctly - ASSERT_OK_AND_ASSIGN( - auto meta1, - DecodeMetadata(encoded1.metadata.data(), - static_cast(encoded1.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto meta1, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); RecordingVisitor v1; ASSERT_OK(DecodeVariantValue(meta1, encoded1.value.data(), static_cast(encoded1.value.size()), &v1)); @@ -1069,10 +1068,9 @@ TEST_F(VariantBuilderReuseTest, MultipleFinishPreservesDictionary) { ASSERT_EQ(v1.events[2], "String(\"Alice\")"); // Verify second value decodes correctly - ASSERT_OK_AND_ASSIGN( - auto meta2, - DecodeMetadata(encoded2.metadata.data(), - static_cast(encoded2.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto meta2, + DecodeMetadata(encoded2.metadata.data(), + static_cast(encoded2.metadata.size()))); RecordingVisitor v2; ASSERT_OK(DecodeVariantValue(meta2, encoded2.value.data(), static_cast(encoded2.value.size()), &v2)); @@ -1107,17 +1105,15 @@ TEST_F(VariantBuilderReuseTest, DictionaryGrowsAcrossFinishCalls) { ASSERT_OK_AND_ASSIGN(auto encoded2, b.Finish()); // First metadata has 1 key - ASSERT_OK_AND_ASSIGN( - auto meta1, - DecodeMetadata(encoded1.metadata.data(), - static_cast(encoded1.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto meta1, + DecodeMetadata(encoded1.metadata.data(), + static_cast(encoded1.metadata.size()))); ASSERT_EQ(meta1.strings.size(), 1); // Second metadata has 2 keys (dictionary grew) - ASSERT_OK_AND_ASSIGN( - auto meta2, - DecodeMetadata(encoded2.metadata.data(), - static_cast(encoded2.metadata.size()))); + ASSERT_OK_AND_ASSIGN(auto meta2, + DecodeMetadata(encoded2.metadata.data(), + static_cast(encoded2.metadata.size()))); ASSERT_EQ(meta2.strings.size(), 2); // Verify second value decodes correctly @@ -1177,4 +1173,194 @@ TEST_F(VariantBuilderPreExistingBufferTest, ArrayAfterPrimitive) { ASSERT_GT(b.Offset(), prefix_size); } +// =========================================================================== +// GH-45948: BuildWithoutMeta, UnsafeAppendEncoded, SetAllowDuplicates +// =========================================================================== + +class VariantBuilderShreddingTest : public ::testing::Test {}; + +TEST_F(VariantBuilderShreddingTest, BuildWithoutMetaBasic) { + // BuildWithoutMeta should return only the value bytes + VariantBuilder b; + ASSERT_OK(b.Int(42)); + ASSERT_OK_AND_ASSIGN(auto value_only, b.BuildWithoutMeta()); + + // Value should be 2 bytes: header(Int8) + 1-byte value + ASSERT_EQ(value_only.size(), 2); + // The header byte for Int8: basic_type=0, primitive_type=3 shifted left by 2 + uint8_t expected_header = + static_cast(BasicType::kPrimitive) | (static_cast(3) << 2); + ASSERT_EQ(value_only[0], expected_header); + ASSERT_EQ(value_only[1], 42); +} + +TEST_F(VariantBuilderShreddingTest, BuildWithoutMetaPreservesDict) { + // After BuildWithoutMeta, the dictionary should remain available for + // subsequent builds (same behavior as Finish()) + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.String("val")); + ASSERT_OK(b.FinishObject(start, fields)); + ASSERT_OK_AND_ASSIGN(auto value1, b.BuildWithoutMeta()); + + // Build another value using the same key — should reuse dictionary ID + auto start2 = b.Offset(); + std::vector fields2; + fields2.push_back(b.NextField(start2, "key")); + ASSERT_OK(b.String("val2")); + ASSERT_OK(b.FinishObject(start2, fields2)); + + // Now use Finish() to get metadata — should have 1 key (not 2) + ASSERT_OK_AND_ASSIGN(auto result, b.Finish()); + ASSERT_OK_AND_ASSIGN(auto meta, + DecodeMetadata(result.metadata.data(), + static_cast(result.metadata.size()))); + ASSERT_EQ(meta.strings.size(), 1); + ASSERT_EQ(meta.strings[0], "key"); +} + +TEST_F(VariantBuilderShreddingTest, BuildWithoutMetaEmptyBufferError) { + // BuildWithoutMeta on an empty builder should error + VariantBuilder b; + ASSERT_RAISES(Invalid, b.BuildWithoutMeta()); +} + +TEST_F(VariantBuilderShreddingTest, UnsafeAppendEncodedSimple) { + // Build a value, extract its bytes, then append to another builder + VariantBuilder b1; + ASSERT_OK(b1.String("hello")); + ASSERT_OK_AND_ASSIGN(auto encoded1, b1.Finish()); + + // Create a new builder and append the raw value bytes + VariantBuilder b2; + b2.UnsafeAppendEncoded(encoded1.value.data(), + static_cast(encoded1.value.size())); + + // The buffer should now contain the same bytes + ASSERT_OK_AND_ASSIGN(auto result2, b2.Finish()); + ASSERT_EQ(result2.value, encoded1.value); +} + +TEST_F(VariantBuilderShreddingTest, UnsafeAppendEncodedInObject) { + // Use UnsafeAppendEncoded to compose an object with a pre-encoded field value + VariantBuilder value_builder; + ASSERT_OK(value_builder.Int(999)); + ASSERT_OK_AND_ASSIGN(auto pre_encoded, value_builder.Finish()); + + // Now build an object where one field uses the pre-encoded value + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "pre_field")); + b.UnsafeAppendEncoded(pre_encoded.value.data(), + static_cast(pre_encoded.value.size())); + fields.push_back(b.NextField(start, "normal_field")); + ASSERT_OK(b.String("hi")); + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + // Fields sorted: "normal_field" before "pre_field" + ASSERT_EQ(events[0], "StartObject(2)"); + ASSERT_EQ(events[1], "FieldName(\"normal_field\")"); + ASSERT_EQ(events[2], "String(\"hi\")"); + ASSERT_EQ(events[3], "FieldName(\"pre_field\")"); + ASSERT_EQ(events[4], "Int16(999)"); + ASSERT_EQ(events[5], "EndObject"); +} + +TEST_F(VariantBuilderShreddingTest, SetAllowDuplicatesLastValueWins) { + // With allow_duplicates = true, duplicate keys should keep last value + VariantBuilder b; + b.SetAllowDuplicates(true); + + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.Int(1)); // first value + fields.push_back(b.NextField(start, "key")); + ASSERT_OK(b.Int(2)); // second value (should win) + fields.push_back(b.NextField(start, "other")); + ASSERT_OK(b.Int(3)); + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "StartObject(2)"); + // Fields sorted: "key" before "other" + ASSERT_EQ(events[1], "FieldName(\"key\")"); + ASSERT_EQ(events[2], "Int8(2)"); // last value wins + ASSERT_EQ(events[3], "FieldName(\"other\")"); + ASSERT_EQ(events[4], "Int8(3)"); + ASSERT_EQ(events[5], "EndObject"); +} + +TEST_F(VariantBuilderShreddingTest, SetAllowDuplicatesDefaultRejects) { + // Default behavior (allow_duplicates = false) still rejects + VariantBuilder b; + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "dup")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "dup")); + ASSERT_OK(b.Int(2)); + ASSERT_RAISES(Invalid, b.FinishObject(start, fields)); +} + +TEST_F(VariantBuilderShreddingTest, SetAllowDuplicatesMultipleDups) { + // Multiple duplicates of the same key — last one wins + VariantBuilder b; + b.SetAllowDuplicates(true); + + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "x")); + ASSERT_OK(b.Int(1)); + fields.push_back(b.NextField(start, "x")); + ASSERT_OK(b.Int(2)); + fields.push_back(b.NextField(start, "x")); + ASSERT_OK(b.Int(3)); // this should win + ASSERT_OK(b.FinishObject(start, fields)); + + auto events = RoundTrip(b); + ASSERT_EQ(events[0], "StartObject(1)"); + ASSERT_EQ(events[1], "FieldName(\"x\")"); + ASSERT_EQ(events[2], "Int8(3)"); + ASSERT_EQ(events[3], "EndObject"); +} + +TEST_F(VariantBuilderShreddingTest, SetAllowDuplicatesToggle) { + // Can toggle between strict and lenient mode + VariantBuilder b; + + // First object: strict (default) + auto start = b.Offset(); + std::vector fields; + fields.push_back(b.NextField(start, "a")); + ASSERT_OK(b.Int(1)); + ASSERT_OK(b.FinishObject(start, fields)); + auto events1 = RoundTrip(b); + ASSERT_EQ(events1[0], "StartObject(1)"); + + // Switch to lenient + b.SetAllowDuplicates(true); + auto start2 = b.Offset(); + std::vector fields2; + fields2.push_back(b.NextField(start2, "a")); + ASSERT_OK(b.Int(1)); + fields2.push_back(b.NextField(start2, "a")); + ASSERT_OK(b.Int(2)); + ASSERT_OK(b.FinishObject(start2, fields2)); // should not error + + // Switch back to strict + b.SetAllowDuplicates(false); + auto start3 = b.Offset(); + std::vector fields3; + fields3.push_back(b.NextField(start3, "b")); + ASSERT_OK(b.Int(1)); + fields3.push_back(b.NextField(start3, "b")); + ASSERT_OK(b.Int(2)); + ASSERT_RAISES(Invalid, b.FinishObject(start3, fields3)); // should error again +} + } // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal.h b/cpp/src/arrow/extension/variant_internal.h index 6a61085069e5..816c0d7cc9a3 100644 --- a/cpp/src/arrow/extension/variant_internal.h +++ b/cpp/src/arrow/extension/variant_internal.h @@ -453,14 +453,44 @@ class ARROW_EXPORT VariantBuilder { void Reset(); /// @} - // TODO GH-45948: Add BuildWithoutMeta() to return raw value bytes without - // metadata, needed for shredded variant encoding. + /// @name Shredding support (GH-45948) + /// @{ - // TODO GH-45948: Add UnsafeAppendEncoded(const uint8_t* data, int64_t size) - // to append pre-encoded variant value bytes for composition/shredding. + /// \brief Produce only the value bytes, without metadata. + /// + /// Used during variant shredding when the metadata is shared at the + /// top-level Variant group. The caller must ensure the metadata + /// dictionary used during construction matches the shared metadata column. + /// + /// \return The raw value buffer (metadata must be obtained separately + /// via Finish() on a separate call, or from the shared column). + Result> BuildWithoutMeta(); + + /// \brief Append pre-encoded variant value bytes directly. + /// + /// Copies raw variant-encoded bytes into the builder's value buffer + /// without any validation or re-encoding. This is used during + /// reconstruction to efficiently copy residual fields from the + /// value column. + /// + /// SAFETY: The caller MUST guarantee that: + /// 1. data points to a valid variant value encoding + /// 2. All field IDs in the encoded data reference keys already + /// present in this builder's metadata dictionary + /// + /// \param[in] data Pointer to pre-encoded variant value bytes + /// \param[in] size Number of bytes to append + void UnsafeAppendEncoded(const uint8_t* data, int64_t size); - // TODO GH-45948: Add SetAllowDuplicates(bool) for duplicate key tolerance - // with last-value-wins semantics (uses ValueSize for compaction). + /// \brief Configure duplicate key handling in FinishObject(). + /// + /// When enabled, duplicate keys are resolved by keeping the last + /// value inserted (last-value-wins / compaction semantics). + /// When disabled (default), duplicate keys produce Status::Invalid. + /// + /// \param[in] allow If true, tolerate duplicate keys + void SetAllowDuplicates(bool allow); + /// @} private: /// \brief Add a key to the dictionary, returning its ID. @@ -474,6 +504,7 @@ class ARROW_EXPORT VariantBuilder { // column-scan workloads where the same field names repeat across rows). // Only used inside AddKey(); not part of the builder's logical state. std::string lookup_buf_; + bool allow_duplicates_ = false; }; } // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_internal_test.cc b/cpp/src/arrow/extension/variant_internal_test.cc index 9cfdb665aa82..33c63021c894 100644 --- a/cpp/src/arrow/extension/variant_internal_test.cc +++ b/cpp/src/arrow/extension/variant_internal_test.cc @@ -316,11 +316,10 @@ TEST_F(VariantMetadataTest, NonMonotonicStringOffsets) { // Header: version=1, offset_size=1 // dict_size=2, offsets=[0, 5, 3] — 3 < 5, non-monotonic // String data: "helloabc" (8 bytes, but offsets claim 3 as last) - uint8_t data[] = { - 0x01, // header: version=1, offset_size=1 - 0x02, // dict_size = 2 - 0x00, 0x05, 0x03, // offsets: [0, 5, 3] — non-monotonic - 'h', 'e', 'l', 'l', 'o', 'a', 'b', 'c'}; + uint8_t data[] = {0x01, // header: version=1, offset_size=1 + 0x02, // dict_size = 2 + 0x00, 0x05, 0x03, // offsets: [0, 5, 3] — non-monotonic + 'h', 'e', 'l', 'l', 'o', 'a', 'b', 'c'}; ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); } @@ -1857,7 +1856,9 @@ TEST_F(VariantArrayNonMonotonicTest, RejectsNonMonotonicOffsets) { uint8_t data[] = { 0x03, // array header: basic_type=3, offset_size=1, is_large=false 0x02, // num_elements = 2 - 0x00, 0x03, 0x01, // offsets: [0, 3, 1] — non-monotonic! + 0x00, + 0x03, + 0x01, // offsets: [0, 3, 1] — non-monotonic! PrimitiveHeader(PrimitiveType::kNull), PrimitiveHeader(PrimitiveType::kNull), PrimitiveHeader(PrimitiveType::kNull), @@ -1894,13 +1895,13 @@ TEST_F(VariantObjectOffsetBoundsTest, FieldOffsetExceedsDataSize) { 0x02, // object header 0x01, // num_fields = 1 0x00, // field_id[0] = 0 - 0x63, 0x02, // offsets: [99, 2] — 99 > total_data_size(2) + 0x63, + 0x02, // offsets: [99, 2] — 99 > total_data_size(2) PrimitiveHeader(PrimitiveType::kNull), PrimitiveHeader(PrimitiveType::kNull), }; RecordingVisitor visitor; - ASSERT_RAISES(Invalid, - DecodeVariantValue(metadata_, data, sizeof(data), &visitor)); + ASSERT_RAISES(Invalid, DecodeVariantValue(metadata_, data, sizeof(data), &visitor)); } // =========================================================================== @@ -2095,11 +2096,10 @@ TEST_F(VariantErrorCaseTest, MetadataStringOffsetExceedsBuffer) { // data_length. // Header: version=1, offset_size=1 // dict_size=1, offsets=[0, 100] — but only 3 bytes of string data - uint8_t data[] = { - 0x01, // header: version=1, offset_size=1 - 0x01, // dict_size = 1 - 0x00, 0x64, // offsets: [0, 100] — 100 exceeds available string data - 'a', 'b', 'c'}; + uint8_t data[] = {0x01, // header: version=1, offset_size=1 + 0x01, // dict_size = 1 + 0x00, 0x64, // offsets: [0, 100] — 100 exceeds available string data + 'a', 'b', 'c'}; ASSERT_RAISES(Invalid, DecodeMetadata(data, sizeof(data))); } @@ -2115,10 +2115,9 @@ TEST_F(VariantErrorCaseTest, FindObjectFieldOnNonObject) { // Calling FindObjectField on an array should produce an error auto data = BuildArray({}); int64_t field_offset = -1, field_size = 0; - ASSERT_RAISES(Invalid, - FindObjectField(empty_metadata_, data.data(), - static_cast(data.size()), "key", - &field_offset, &field_size)); + ASSERT_RAISES(Invalid, FindObjectField(empty_metadata_, data.data(), + static_cast(data.size()), "key", + &field_offset, &field_size)); } // TODO: Add fuzz targets for DecodeMetadata and DecodeVariantValue to exercise diff --git a/cpp/src/arrow/extension/variant_shredding.cc b/cpp/src/arrow/extension/variant_shredding.cc new file mode 100644 index 000000000000..f6476e964c3f --- /dev/null +++ b/cpp/src/arrow/extension/variant_shredding.cc @@ -0,0 +1,2116 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_shredding.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_decimal.h" +#include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/buffer.h" +#include "arrow/extension/variant_internal.h" +#include "arrow/memory_pool.h" +#include "arrow/type.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/decimal.h" +#include "arrow/util/endian.h" +#include "arrow/util/logging_internal.h" + +namespace arrow::extension::variant_internal { + +// Forward declaration +static Result> ShredVariantColumnObject( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, const VariantShreddingSchema& schema); + +static Result> ShredVariantColumnArray( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, const VariantShreddingSchema& schema); + +static Result> ReconstructVariantColumnObject( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_value_array, + const VariantShreddingSchema& schema); + +static Result> ReconstructVariantColumnArray( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_value_array, + const VariantShreddingSchema& schema); + +// --------------------------------------------------------------------------- +// VariantShreddingSchema implementation +// --------------------------------------------------------------------------- + +VariantShreddingSchema VariantShreddingSchema::Primitive(std::shared_ptr type) { + VariantShreddingSchema schema; + schema.kind_ = Kind::kPrimitive; + schema.type_ = std::move(type); + return schema; +} + +VariantShreddingSchema VariantShreddingSchema::Object( + std::vector> fields) { + VariantShreddingSchema schema; + schema.kind_ = Kind::kObject; + schema.fields_ = std::move(fields); + return schema; +} + +VariantShreddingSchema VariantShreddingSchema::Array( + VariantShreddingSchema element_schema) { + VariantShreddingSchema schema; + schema.kind_ = Kind::kArray; + schema.element_schema_ = + std::make_shared(std::move(element_schema)); + return schema; +} + +std::shared_ptr VariantShreddingSchema::ToArrowType() const { + switch (kind_) { + case Kind::kPrimitive: + return type_; + + case Kind::kObject: { + std::vector> arrow_fields; + arrow_fields.reserve(fields_.size()); + for (const auto& [name, sub_schema] : fields_) { + auto typed_value_type = sub_schema.ToArrowType(); + auto field_struct = struct_({ + field("value", binary(), /*nullable=*/true), + field("typed_value", typed_value_type, /*nullable=*/true), + }); + arrow_fields.push_back(field(name, field_struct, /*nullable=*/false)); + } + return struct_(std::move(arrow_fields)); + } + + case Kind::kArray: { + auto elem_typed_type = element_schema_->ToArrowType(); + auto element_struct = struct_({ + field("value", binary(), /*nullable=*/true), + field("typed_value", elem_typed_type, /*nullable=*/true), + }); + return list(field("element", element_struct, /*nullable=*/false)); + } + } + // All enum values are handled above; this is unreachable. + DCHECK(false) << "Unknown VariantShreddingSchema kind"; + return nullptr; +} + +// --------------------------------------------------------------------------- +// Type compatibility check +// --------------------------------------------------------------------------- + +bool IsVariantCompatibleWithType(const uint8_t* variant_data, int64_t variant_length, + const DataType& target_type) { + if (variant_length < 1) return false; + + uint8_t header = variant_data[0]; + auto basic_type = GetBasicType(header); + + if (basic_type == BasicType::kShortString) { + // Short strings are semantically UTF-8 text (same as kString), not binary data. + // They should NOT match BINARY/LARGE_BINARY targets — only string targets. + // Rust's shredding makes the same distinction: strings → Utf8/LargeUtf8/Utf8View, + // binary → Binary/LargeBinary/BinaryView. + return target_type.id() == Type::STRING || target_type.id() == Type::LARGE_STRING || + target_type.id() == Type::STRING_VIEW; + } + + if (basic_type == BasicType::kObject || basic_type == BasicType::kArray) { + return false; + } + + if (basic_type == BasicType::kPrimitive) { + auto prim_type = static_cast((header >> 2) & 0x3F); + switch (prim_type) { + case PrimitiveType::kNull: + // Per Rust/spec semantics: Variant::Null is NOT shredded into typed + // columns. It is stored as-is in the value column. This distinguishes + // "variant-typed null" (value = 0x00 byte) from "SQL NULL / missing" + // (both value and typed_value are null). + return false; + case PrimitiveType::kTrue: + case PrimitiveType::kFalse: + return target_type.id() == Type::BOOL; + case PrimitiveType::kInt8: + return target_type.id() == Type::INT8 || target_type.id() == Type::INT16 || + target_type.id() == Type::INT32 || target_type.id() == Type::INT64; + case PrimitiveType::kInt16: + return target_type.id() == Type::INT16 || target_type.id() == Type::INT32 || + target_type.id() == Type::INT64; + case PrimitiveType::kInt32: + return target_type.id() == Type::INT32 || target_type.id() == Type::INT64; + case PrimitiveType::kInt64: + return target_type.id() == Type::INT64; + case PrimitiveType::kFloat: + // Note: Float→Double widening means shred(Float)→reconstruct produces Double. + // This is a lossy round-trip for the type tag (value precision is preserved). + return target_type.id() == Type::FLOAT || target_type.id() == Type::DOUBLE; + case PrimitiveType::kDouble: + return target_type.id() == Type::DOUBLE; + case PrimitiveType::kDecimal4: + case PrimitiveType::kDecimal8: + case PrimitiveType::kDecimal16: { + if (target_type.id() != Type::DECIMAL128 && + target_type.id() != Type::DECIMAL256) { + return false; + } + // Verify scale matches: read scale byte from variant data (byte after header). + int32_t min_size = (prim_type == PrimitiveType::kDecimal4) ? 6 + : (prim_type == PrimitiveType::kDecimal8) ? 10 + : (prim_type == PrimitiveType::kDecimal16) ? 18 + : 0; + if (variant_length < min_size) return false; + auto variant_scale = static_cast(variant_data[1]); + if (target_type.id() == Type::DECIMAL128) { + return variant_scale == static_cast(target_type).scale(); + } + return true; // DECIMAL256 — accept any scale for now + // TODO: This is asymmetric with DECIMAL128 which validates scale. + // Consider adding scale matching for DECIMAL256 once usage is + // established. Currently pragmatic since DECIMAL256 is rarely used. + } + case PrimitiveType::kDate: + return target_type.id() == Type::DATE32; + case PrimitiveType::kTimestampMicros: { + if (target_type.id() != Type::TIMESTAMP) return false; + auto& ts = static_cast(target_type); + return ts.unit() == TimeUnit::MICRO && !ts.timezone().empty(); + } + case PrimitiveType::kTimestampMicrosNTZ: { + if (target_type.id() != Type::TIMESTAMP) return false; + auto& ts = static_cast(target_type); + return ts.unit() == TimeUnit::MICRO && ts.timezone().empty(); + } + case PrimitiveType::kTimestampNanos: { + if (target_type.id() != Type::TIMESTAMP) return false; + auto& ts = static_cast(target_type); + return ts.unit() == TimeUnit::NANO && !ts.timezone().empty(); + } + case PrimitiveType::kTimestampNanosNTZ: { + if (target_type.id() != Type::TIMESTAMP) return false; + auto& ts = static_cast(target_type); + return ts.unit() == TimeUnit::NANO && ts.timezone().empty(); + } + case PrimitiveType::kTimeNTZ: { + if (target_type.id() != Type::TIME64) return false; + // The variant spec's kTimeNTZ stores microseconds since midnight. + // Only accept time64(MICRO) targets — a time64(NANO) target would + // cause misinterpretation of the shredded values in typed_value. + auto& t = static_cast(target_type); + return t.unit() == TimeUnit::MICRO; + } + case PrimitiveType::kString: + return target_type.id() == Type::STRING || + target_type.id() == Type::LARGE_STRING || + target_type.id() == Type::STRING_VIEW; + case PrimitiveType::kBinary: + return target_type.id() == Type::BINARY || + target_type.id() == Type::LARGE_BINARY || + target_type.id() == Type::BINARY_VIEW; + case PrimitiveType::kUUID: + return target_type.id() == Type::FIXED_SIZE_BINARY && + static_cast(target_type).byte_width() == 16; + default: + // Unknown or future primitive types are not compatible with any target. + return false; + } + } + + return false; +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +namespace { + +// Supports BINARY, LARGE_BINARY, BINARY_VIEW, and STRING_VIEW arrays. +// Rust supports all of these via GenericByteViewArray / GenericByteArray. +std::string_view GetBinaryValue(const Array& array, int64_t i) { + if (array.type_id() == Type::BINARY) { + auto& bin = static_cast(array); + auto view = bin.GetView(i); + return {reinterpret_cast(view.data()), static_cast(view.size())}; + } + if (array.type_id() == Type::LARGE_BINARY) { + auto& bin = static_cast(array); + auto view = bin.GetView(i); + return {reinterpret_cast(view.data()), static_cast(view.size())}; + } + if (array.type_id() == Type::BINARY_VIEW || array.type_id() == Type::STRING_VIEW) { + auto& bin = static_cast(array); + return bin.GetView(i); + } + // Callers validate input types at public entry points (ShredVariantColumn, + // ReconstructVariantColumn). Reaching here indicates a programming error. + DCHECK(false) << "GetBinaryValue: unsupported array type " << array.type()->ToString(); + return {}; +} + +/// Read a little-endian unsigned integer of nbytes (1-8) from buf. +/// The result is zero-extended to int64_t. Callers that need signed values +/// narrow-cast the result (e.g., static_cast), which reinterprets +/// the low bits as a signed two's-complement value — correct because the +/// shift-based reconstruction isolates exactly the source width. +/// +/// Note: This differs from variant_internal.cc's ReadUnsignedLE() which +/// returns uint32_t (max 4 bytes). This version returns int64_t to support +/// 8-byte reads for timestamp/int64 extraction. +/// +/// This implementation is endian-safe: it reconstructs the value byte-by-byte +/// using shifts, which produces correct results on both little-endian and +/// big-endian architectures (Arrow CI includes s390x big-endian targets). +int64_t ReadLE(const uint8_t* buf, int32_t nbytes) { + uint64_t result = 0; + for (int32_t i = 0; i < nbytes; ++i) { + result |= static_cast(buf[i]) << (i * 8); + } + return static_cast(result); +} + +/// Extract a native int64 value from variant-encoded bytes (handles all int sizes). +/// Returns true if extraction succeeded, false if type is not an integer. +bool ExtractInt64(const uint8_t* data, int64_t length, int64_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + switch (prim) { + case PrimitiveType::kInt8: + if (length < 2) return false; + *out = static_cast(static_cast(data[1])); + return true; + case PrimitiveType::kInt16: + if (length < 3) return false; + // ReadLE returns zero-extended int64_t; narrow to int16_t for sign-extension, + // then widen back to int64_t explicitly. + *out = static_cast(static_cast(ReadLE(data + 1, 2))); + return true; + case PrimitiveType::kInt32: + if (length < 5) return false; + // ReadLE returns zero-extended int64_t; narrow to int32_t for sign-extension, + // then widen back to int64_t explicitly. + *out = static_cast(static_cast(ReadLE(data + 1, 4))); + return true; + case PrimitiveType::kInt64: + if (length < 9) return false; + *out = ReadLE(data + 1, 8); + return true; + default: + return false; + } +} + +/// Extract a boolean from variant-encoded bytes. +bool ExtractBool(const uint8_t* data, int64_t length, bool* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kTrue) { + *out = true; + return true; + } + if (prim == PrimitiveType::kFalse) { + *out = false; + return true; + } + return false; +} + +/// Extract a double from variant-encoded bytes (handles float→double widening and +/// native double). Named "ExtractDoubleOrFloat" to clarify that it accepts both +/// kFloat (widened to double) and kDouble variant types. +bool ExtractDoubleOrFloat(const uint8_t* data, int64_t length, double* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kDouble && length >= 9) { + uint64_t bits; + std::memcpy(&bits, data + 1, 8); + bits = bit_util::FromLittleEndian(bits); + std::memcpy(out, &bits, 8); + return true; + } + if (prim == PrimitiveType::kFloat && length >= 5) { + uint32_t bits; + std::memcpy(&bits, data + 1, 4); + bits = bit_util::FromLittleEndian(bits); + float f; + std::memcpy(&f, &bits, 4); + *out = static_cast(f); + return true; + } + return false; +} + +/// Extract a string_view from variant-encoded bytes (handles short and long strings). +bool ExtractString(const uint8_t* data, int64_t length, std::string_view* out) { + if (length < 1) return false; + uint8_t header = data[0]; + auto basic_type = GetBasicType(header); + if (basic_type == BasicType::kShortString) { + int32_t str_len = (header >> 2) & 0x3F; + if (length < 1 + str_len) return false; + *out = std::string_view(reinterpret_cast(data + 1), str_len); + return true; + } + if (basic_type == BasicType::kPrimitive) { + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kString && length >= 5) { + uint32_t str_len; + std::memcpy(&str_len, data + 1, 4); + str_len = bit_util::FromLittleEndian(str_len); + if (length < 5 + static_cast(str_len)) return false; + *out = std::string_view(reinterpret_cast(data + 5), str_len); + return true; + } + } + return false; +} + +/// Check if the variant value is Variant Null. +bool IsVariantNull(const uint8_t* data, int64_t length) { + if (length < 1) return false; + uint8_t header = data[0]; + return GetBasicType(header) == BasicType::kPrimitive && + static_cast((header >> 2) & 0x3F) == PrimitiveType::kNull; +} + +/// Extract a float from variant-encoded bytes. +bool ExtractFloat(const uint8_t* data, int64_t length, float* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kFloat && length >= 5) { + uint32_t bits; + std::memcpy(&bits, data + 1, 4); + bits = bit_util::FromLittleEndian(bits); + std::memcpy(out, &bits, 4); + return true; + } + return false; +} + +/// Extract a date32 (days since epoch) from variant-encoded bytes. +bool ExtractDate32(const uint8_t* data, int64_t length, int32_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kDate && length >= 5) { + *out = static_cast(ReadLE(data + 1, 4)); + return true; + } + return false; +} + +/// Extract a timestamp (micros or nanos) from variant-encoded bytes. +bool ExtractTimestamp(const uint8_t* data, int64_t length, int64_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if ((prim == PrimitiveType::kTimestampMicros || + prim == PrimitiveType::kTimestampMicrosNTZ || + prim == PrimitiveType::kTimestampNanos || + prim == PrimitiveType::kTimestampNanosNTZ) && + length >= 9) { + *out = static_cast(ReadLE(data + 1, 8)); + return true; + } + return false; +} + +/// Extract binary data from variant-encoded bytes. +bool ExtractBinary(const uint8_t* data, int64_t length, const uint8_t** out_data, + int32_t* out_size) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kBinary && length >= 5) { + uint32_t bin_len; + std::memcpy(&bin_len, data + 1, 4); + bin_len = bit_util::FromLittleEndian(bin_len); + if (length < 5 + static_cast(bin_len)) return false; + *out_data = data + 5; + *out_size = static_cast(bin_len); + return true; + } + return false; +} + +/// Extract an int32 from variant-encoded bytes (handles int8, int16, int32). +bool ExtractInt32(const uint8_t* data, int64_t length, int32_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + switch (prim) { + case PrimitiveType::kInt8: + if (length < 2) return false; + *out = static_cast(static_cast(data[1])); + return true; + case PrimitiveType::kInt16: + if (length < 3) return false; + *out = static_cast(static_cast(ReadLE(data + 1, 2))); + return true; + case PrimitiveType::kInt32: + if (length < 5) return false; + *out = static_cast(ReadLE(data + 1, 4)); + return true; + default: + return false; + } +} + +/// Extract an int8 from variant-encoded bytes (handles only int8). +bool ExtractInt8(const uint8_t* data, int64_t length, int8_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kInt8 && length >= 2) { + *out = static_cast(data[1]); + return true; + } + return false; +} + +/// Extract an int16 from variant-encoded bytes (handles int8, int16). +bool ExtractInt16(const uint8_t* data, int64_t length, int16_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + switch (prim) { + case PrimitiveType::kInt8: + if (length < 2) return false; + *out = static_cast(static_cast(data[1])); + return true; + case PrimitiveType::kInt16: + if (length < 3) return false; + *out = static_cast(ReadLE(data + 1, 2)); + return true; + default: + return false; + } +} + +/// Extract time64 (microseconds since midnight) from variant-encoded bytes. +bool ExtractTime64(const uint8_t* data, int64_t length, int64_t* out) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kTimeNTZ && length >= 9) { + *out = static_cast(ReadLE(data + 1, 8)); + return true; + } + return false; +} + +/// Extract UUID (16 big-endian bytes) from variant-encoded bytes. +bool ExtractUUID(const uint8_t* data, int64_t length, uint8_t* out16) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + if (prim == PrimitiveType::kUUID && length >= 17) { + std::memcpy(out16, data + 1, 16); + return true; + } + return false; +} + +/// Extract a Decimal128 (scale + unscaled value) from variant-encoded bytes. +/// Handles decimal4 (4-byte), decimal8 (8-byte), and decimal16 (16-byte). +/// The output is a 128-bit unscaled integer value suitable for Arrow's Decimal128Type. +bool ExtractDecimal128(const uint8_t* data, int64_t length, int64_t* out_low, + int64_t* out_high, uint8_t* out_scale) { + if (length < 1) return false; + uint8_t header = data[0]; + if (GetBasicType(header) != BasicType::kPrimitive) return false; + auto prim = static_cast((header >> 2) & 0x3F); + switch (prim) { + case PrimitiveType::kDecimal4: + if (length < 6) return false; // 1 header + 1 scale + 4 value + *out_scale = data[1]; + { + int32_t val; + // memcpy + FromLittleEndian is safe here: we copy the full variable width + // (4 bytes into int32), then swap. This differs from ReadLE's byte-shift + // approach which handles partial-width reads. + std::memcpy(&val, data + 2, 4); + val = bit_util::FromLittleEndian(val); + *out_low = static_cast(val); + *out_high = (val < 0) ? -1 : 0; // sign-extend + } + return true; + case PrimitiveType::kDecimal8: + if (length < 10) return false; // 1 header + 1 scale + 8 value + *out_scale = data[1]; + { + int64_t val; + // memcpy + FromLittleEndian is safe here: we copy the full variable width + // (8 bytes into int64), then swap. This differs from ReadLE's byte-shift + // approach which handles partial-width reads. + std::memcpy(&val, data + 2, 8); + val = bit_util::FromLittleEndian(val); + *out_low = val; + *out_high = (val < 0) ? -1 : 0; // sign-extend + } + return true; + case PrimitiveType::kDecimal16: + if (length < 18) return false; // 1 header + 1 scale + 16 value + *out_scale = data[1]; + { + std::memcpy(out_low, data + 2, 8); + std::memcpy(out_high, data + 10, 8); + *out_low = bit_util::FromLittleEndian(*out_low); + *out_high = bit_util::FromLittleEndian(*out_high); + } + return true; + default: + return false; + } +} + +} // namespace + +// --------------------------------------------------------------------------- +// ShredVariantColumnObject — object field-level shredding +// --------------------------------------------------------------------------- + +namespace { + +/// Shred a single object variant value, routing named fields to sub-builders. +/// Produces: for each schema field, the field value (or null if missing). +/// Residual = object with remaining fields not in schema. +/// +/// For fields with Primitive sub-schemas, the output construction phase +/// (in ShredVariantColumnObject) performs recursive native extraction by +/// calling ShredVariantColumn on the per-field BinaryArray. This matches +/// Rust's VariantToShreddedObjectVariantRowBuilder behavior: compatible +/// field values go to typed_value, incompatible remain in value. +/// +/// For fields with Object or Array sub-schemas, field values are stored +/// as variant binary in the "value" sub-column (recursive nested shredding +/// is a potential follow-up optimization). +// TODO GH-45948 follow-up: Recursive shredding for nested Object/Array +// sub-schemas (not just Primitive). +struct ObjectFieldShredder { + ObjectFieldShredder(const VariantShreddingSchema& schema, int64_t num_rows) + : schema(schema), num_rows(num_rows) {} + + const VariantShreddingSchema& schema; + int64_t num_rows; + + // One BinaryBuilder per schema field (stores variant bytes of matching fields). + // For Primitive sub-schemas, the output construction phase re-shreds these + // into {value, typed_value} using ShredVariantColumn. For other sub-schemas, + // these bytes go directly to the "value" sub-column. + std::vector field_value_builders; + + // Residual: object bytes for fields not in schema + BinaryBuilder residual_builder; + + Status Init() { + field_value_builders.resize(schema.fields().size()); + for (auto& b : field_value_builders) { + ARROW_RETURN_NOT_OK(b.Reserve(num_rows)); + } + ARROW_RETURN_NOT_OK(residual_builder.Reserve(num_rows)); + return Status::OK(); + } + + Status AppendNull() { + // Object is missing/null → all field builders get null, residual gets null + for (auto& b : field_value_builders) { + ARROW_RETURN_NOT_OK(b.AppendNull()); + } + ARROW_RETURN_NOT_OK(residual_builder.AppendNull()); + return Status::OK(); + } + + Status AppendNonObject(std::string_view variant_bytes) { + // Value is not an object → residual gets the whole value, fields get null + for (auto& b : field_value_builders) { + ARROW_RETURN_NOT_OK(b.AppendNull()); + } + ARROW_RETURN_NOT_OK(residual_builder.Append( + reinterpret_cast(variant_bytes.data()), variant_bytes.size())); + return Status::OK(); + } + + Status AppendObject(const VariantMetadata& meta, const uint8_t* obj_data, + int64_t obj_length) { + // Determine which fields from the schema are present in this object + const auto& schema_fields = schema.fields(); + + // Get field count + ARROW_ASSIGN_OR_RAISE(auto field_count, GetObjectFieldCount(obj_data, obj_length)); + + // Single pass over object fields: build a name→(index, offset, size) map. + // This eliminates the previous O(s × k) inner marking loop by allowing + // O(1) positional index lookup when marking shredded fields. + // PERF TODO: This allocates an unordered_map per row. For column-scan + // workloads with millions of rows, consider lifting the map to the + // ObjectFieldShredder struct and clearing/reusing it across rows. + // Duplicate keys (spec-invalid): last occurrence wins in map, earlier + // occurrences remain in residual. Matches last-value-wins semantics. + struct FieldInfo { + int32_t index; + int64_t offset; + int64_t size; + }; + std::unordered_map object_field_map; + object_field_map.reserve(field_count); + for (int32_t fi = 0; fi < field_count; ++fi) { + std::string_view fname; + int64_t foff = 0, fsz = 0; + ARROW_RETURN_NOT_OK( + GetObjectFieldAt(meta, obj_data, obj_length, fi, &fname, &foff, &fsz)); + object_field_map[fname] = FieldInfo{fi, foff, fsz}; + } + + // Track which object fields are shredded vs residual + std::vector is_shredded(field_count, false); + + // For each schema field, look it up in the pre-built map. + // Total complexity: O(k) for map construction + O(s) for lookups = O(s + k). + for (size_t sf = 0; sf < schema_fields.size(); ++sf) { + auto it = object_field_map.find(schema_fields[sf].first); + if (it != object_field_map.end()) { + const auto& info = it->second; + ARROW_RETURN_NOT_OK(field_value_builders[sf].Append( + obj_data + info.offset, static_cast(info.size))); + is_shredded[info.index] = true; + } else { + // Not found — null for this field + ARROW_RETURN_NOT_OK(field_value_builders[sf].AppendNull()); + } + } + + // Build residual object with non-shredded fields + bool has_residual = false; + for (int32_t fi = 0; fi < field_count; ++fi) { + if (!is_shredded[fi]) { + has_residual = true; + break; + } + } + + if (has_residual) { + // Build a residual variant object with non-shredded fields. + // PERF: GetObjectFieldAt is called again here for non-shredded fields even + // though object_field_map already stores their offset/size. This is O(1) per + // field (just header arithmetic) so acceptable, but could be avoided by + // iterating object_field_map entries where !is_shredded[info.index]. + VariantBuilder vb(meta); + auto start = vb.Offset(); + std::vector residual_fields; + for (int32_t fi = 0; fi < field_count; ++fi) { + if (!is_shredded[fi]) { + std::string_view fname; + int64_t foff = 0, fsz = 0; + ARROW_RETURN_NOT_OK( + GetObjectFieldAt(meta, obj_data, obj_length, fi, &fname, &foff, &fsz)); + residual_fields.push_back(vb.NextField(start, fname)); + vb.UnsafeAppendEncoded(obj_data + foff, fsz); + } + } + ARROW_RETURN_NOT_OK(vb.FinishObject(start, residual_fields)); + ARROW_ASSIGN_OR_RAISE(auto residual_bytes, vb.BuildWithoutMeta()); + ARROW_RETURN_NOT_OK(residual_builder.Append( + residual_bytes.data(), static_cast(residual_bytes.size()))); + } else { + // All fields were shredded → residual is null + ARROW_RETURN_NOT_OK(residual_builder.AppendNull()); + } + + return Status::OK(); + } +}; + +} // namespace + +Result> ShredVariantColumnObject( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, const VariantShreddingSchema& schema) { + const int64_t num_rows = value_array->length(); + const auto& schema_fields = schema.fields(); + + ObjectFieldShredder shredder(schema, num_rows); + ARROW_RETURN_NOT_OK(shredder.Init()); + + for (int64_t i = 0; i < num_rows; ++i) { + if (value_array->IsNull(i)) { + ARROW_RETURN_NOT_OK(shredder.AppendNull()); + continue; + } + + auto bytes = GetBinaryValue(*value_array, i); + auto* data = reinterpret_cast(bytes.data()); + auto len = static_cast(bytes.size()); + + if (len < 1) { + ARROW_RETURN_NOT_OK(shredder.AppendNull()); + continue; + } + + auto basic_type = GetBasicType(data[0]); + if (basic_type != BasicType::kObject) { + // Not an object → goes entirely to residual value + ARROW_RETURN_NOT_OK(shredder.AppendNonObject(bytes)); + continue; + } + + // Decode metadata for this row + auto meta_bytes = GetBinaryValue(*metadata_array, i); + ARROW_ASSIGN_OR_RAISE( + auto meta, DecodeMetadata(reinterpret_cast(meta_bytes.data()), + static_cast(meta_bytes.size()))); + + ARROW_RETURN_NOT_OK(shredder.AppendObject(meta, data, len)); + } + + // Build output: typed_value is a struct with one field per schema field. + // Each field is a struct {value: binary(nullable), typed_value: (nullable)}. + // For fields with Primitive sub-schemas, we perform native extraction: + // compatible values go to typed_value, incompatible remain in value. + // This matches Rust's recursive shredding behavior. + std::vector> typed_value_columns; + std::vector> typed_value_fields; + + for (size_t sf = 0; sf < schema_fields.size(); ++sf) { + std::shared_ptr field_arr; + ARROW_RETURN_NOT_OK(shredder.field_value_builders[sf].Finish(&field_arr)); + + const auto& sub_schema = schema_fields[sf].second; + + if (sub_schema.kind() == VariantShreddingSchema::Kind::kPrimitive) { + // Recursive native extraction: shred the field values through the + // primitive path. This produces a struct {metadata, value, typed_value} + // from which we extract just {value, typed_value} for the sub-field. + // We use the top-level metadata_array since field values reference + // the same metadata dictionary. + ARROW_ASSIGN_OR_RAISE(auto field_shredded, + ShredVariantColumn(metadata_array, field_arr, sub_schema)); + auto field_value_col = field_shredded->field(1); // "value" (nullable) + auto field_typed_col = field_shredded->field(2); // "typed_value" (nullable) + + // Determine typed_value field type — for TIMESTAMP/TIME64, the builder + // produces Int64Array, so the field declares int64(). + auto typed_field_type = sub_schema.type(); + if (typed_field_type->id() == Type::TIMESTAMP || + typed_field_type->id() == Type::TIME64) { + typed_field_type = int64(); + } + + auto inner_fields = std::vector>{ + field("value", binary(), true), + field("typed_value", typed_field_type, true), + }; + ARROW_ASSIGN_OR_RAISE( + auto field_struct, + StructArray::Make({field_value_col, field_typed_col}, inner_fields)); + typed_value_columns.push_back(field_struct); + typed_value_fields.push_back( + field(schema_fields[sf].first, field_struct->type(), false)); + } else { + // Non-primitive sub-schemas (Object, Array): store field values as + // variant binary in the "value" sub-column. Recursive shredding of + // nested objects/arrays is a potential follow-up optimization. + // + // NOTE: We use NullArray here for the typed_value sub-column. The field + // metadata declares the logical type (from sub_schema.ToArrowType()), + // but the actual array is type null(). This is semantically acceptable + // because the field is always null (no typed extraction for non-primitive + // sub-schemas), so no consumer will attempt to access typed data. The + // declared field type serves only as schema documentation for readers + // inspecting the shredded output structure. + auto null_arr = std::make_shared(num_rows); + auto inner_fields = std::vector>{ + field("value", binary(), true), + field("typed_value", sub_schema.ToArrowType(), true), + }; + ARROW_ASSIGN_OR_RAISE(auto field_struct, + StructArray::Make({field_arr, null_arr}, inner_fields)); + typed_value_columns.push_back(field_struct); + typed_value_fields.push_back( + field(schema_fields[sf].first, field_struct->type(), false)); + } + } + + std::shared_ptr typed_value_struct; + if (!typed_value_fields.empty()) { + ARROW_ASSIGN_OR_RAISE(typed_value_struct, + StructArray::Make(typed_value_columns, typed_value_fields)); + } else { + typed_value_struct = std::make_shared(num_rows); + } + + std::shared_ptr residual_result; + ARROW_RETURN_NOT_OK(shredder.residual_builder.Finish(&residual_result)); + + auto output_fields = std::vector>{ + field("metadata", metadata_array->type(), false), + field("value", binary(), true), + field("typed_value", typed_value_struct->type(), true), + }; + + return StructArray::Make({metadata_array, residual_result, typed_value_struct}, + output_fields); +} + +// --------------------------------------------------------------------------- +// ShredVariantColumnArray — array element-wise shredding +// --------------------------------------------------------------------------- + +Result> ShredVariantColumnArray( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, const VariantShreddingSchema& schema) { + const int64_t num_rows = value_array->length(); + + // For array shredding, typed_value is a ListArray where each element + // is a struct {value: binary, typed_value: }. + // Per the spec: if the variant is an array → typed_value is the list, value is null. + // if the variant is NOT an array → typed_value is null, value has the + // bytes. + // + // Rust parity: elements are recursively shredded according to the element schema, + // producing native typed columns for compatible elements. Incompatible elements + // remain in the per-element value column as variant binary. This enables + // statistics-based predicate pushdown on array element values. + + BinaryBuilder residual_value_builder; + ARROW_RETURN_NOT_OK(residual_value_builder.Reserve(num_rows)); + + // Phase 1: Extract array element bytes into a flat BinaryArray. + // Track list offsets and validity manually (rather than using ListBuilder) + // to avoid double-finish issues with the internal value builder. + BinaryBuilder elem_value_builder; + BinaryBuilder elem_metadata_builder; + std::vector list_offsets; + list_offsets.reserve(num_rows + 1); + list_offsets.push_back(0); + std::vector list_validity(num_rows, false); + + for (int64_t i = 0; i < num_rows; ++i) { + if (value_array->IsNull(i)) { + ARROW_RETURN_NOT_OK(residual_value_builder.AppendNull()); + list_offsets.push_back(list_offsets.back()); + // list_validity[i] remains false (null) + continue; + } + + auto bytes = GetBinaryValue(*value_array, i); + auto* data = reinterpret_cast(bytes.data()); + auto len = static_cast(bytes.size()); + + if (len < 1 || GetBasicType(data[0]) != BasicType::kArray) { + // Not an array → residual value, typed_value null + ARROW_RETURN_NOT_OK(residual_value_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + list_offsets.push_back(list_offsets.back()); + // list_validity[i] remains false (null) + continue; + } + + // Is an array → extract elements + ARROW_RETURN_NOT_OK(residual_value_builder.AppendNull()); + list_validity[i] = true; + + // Get the row's metadata for element replication + auto meta_bytes = GetBinaryValue(*metadata_array, i); + + ARROW_ASSIGN_OR_RAISE(auto elem_count, GetArrayElementCount(data, len)); + for (int32_t ei = 0; ei < elem_count; ++ei) { + int64_t elem_offset = 0, elem_size = 0; + ARROW_RETURN_NOT_OK(GetArrayElement(data, len, ei, &elem_offset, &elem_size)); + ARROW_RETURN_NOT_OK( + elem_value_builder.Append(data + elem_offset, static_cast(elem_size))); + ARROW_RETURN_NOT_OK(elem_metadata_builder.Append( + reinterpret_cast(meta_bytes.data()), meta_bytes.size())); + } + list_offsets.push_back(static_cast(list_offsets.back() + elem_count)); + } + + // Phase 2: Recursively shred the flattened element array through the element schema. + std::shared_ptr elem_values_arr; + ARROW_RETURN_NOT_OK(elem_value_builder.Finish(&elem_values_arr)); + + std::shared_ptr elem_metadata_arr; + ARROW_RETURN_NOT_OK(elem_metadata_builder.Finish(&elem_metadata_arr)); + + ARROW_ASSIGN_OR_RAISE( + auto elem_shredded, + ShredVariantColumn(elem_metadata_arr, elem_values_arr, schema.element_schema())); + + // Extract the {value, typed_value} columns from the recursively shredded result. + auto elem_value_col = elem_shredded->field(1); // per-element residual value + auto elem_typed_col = elem_shredded->field(2); // per-element typed_value + + // Determine typed_value field type for the element struct. + auto elem_typed_field_type = schema.element_schema().type(); + if (schema.element_schema().kind() == VariantShreddingSchema::Kind::kPrimitive) { + if (elem_typed_field_type && (elem_typed_field_type->id() == Type::TIMESTAMP || + elem_typed_field_type->id() == Type::TIME64)) { + elem_typed_field_type = int64(); + } + } else { + // For Object/Array element schemas, use the shredded output's actual type + elem_typed_field_type = elem_typed_col->type(); + } + + // Build the element struct array: {value: binary, typed_value: } + auto elem_struct_fields = std::vector>{ + field("value", binary(), true), + field("typed_value", elem_typed_field_type, true), + }; + ARROW_ASSIGN_OR_RAISE( + auto elem_struct_arr, + StructArray::Make({elem_value_col, elem_typed_col}, elem_struct_fields)); + + // Phase 3: Build the ListArray from manually tracked offsets and the element struct. + auto offsets_buf = Buffer::FromVector(std::move(list_offsets)); + + // Build null bitmap for the list + int64_t null_count = 0; + std::shared_ptr null_bitmap; + for (int64_t i = 0; i < num_rows; ++i) { + if (!list_validity[i]) ++null_count; + } + if (null_count > 0) { + ARROW_ASSIGN_OR_RAISE(null_bitmap, AllocateBitmap(num_rows)); + for (int64_t i = 0; i < num_rows; ++i) { + if (list_validity[i]) { + bit_util::SetBit(null_bitmap->mutable_data(), i); + } else { + bit_util::ClearBit(null_bitmap->mutable_data(), i); + } + } + } + + auto typed_value_list = std::make_shared( + list(field("element", elem_struct_arr->type(), false)), num_rows, offsets_buf, + elem_struct_arr, null_bitmap, null_count); + + std::shared_ptr residual_result; + ARROW_RETURN_NOT_OK(residual_value_builder.Finish(&residual_result)); + + auto output_fields = std::vector>{ + field("metadata", metadata_array->type(), false), + field("value", binary(), true), + field("typed_value", typed_value_list->type(), true), + }; + + return StructArray::Make({metadata_array, residual_result, + std::static_pointer_cast(typed_value_list)}, + output_fields); +} + +// --------------------------------------------------------------------------- +// Template helpers for primitive shredding +// --------------------------------------------------------------------------- + +namespace { + +/// Generic primitive shredding loop. For each row: +/// - If input is null → both builders get null +/// - If Variant::Null → value column gets bytes, typed gets null +/// - If extraction succeeds → typed gets value, residual gets null +/// - Otherwise → value column gets bytes, typed gets null +/// +/// This eliminates ~360 lines of per-type copy-paste that previously existed +/// as individual switch cases with identical structure. +template +Status ShredPrimitiveLoop(const Array& value_array, int64_t num_rows, + BinaryBuilder& residual, std::shared_ptr* out, + ExtractFn&& extract) { + BuilderT typed_builder; + ARROW_RETURN_NOT_OK(typed_builder.Reserve(num_rows)); + for (int64_t i = 0; i < num_rows; ++i) { + if (value_array.IsNull(i)) { + ARROW_RETURN_NOT_OK(residual.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + continue; + } + auto bytes = GetBinaryValue(value_array, i); + auto* data = reinterpret_cast(bytes.data()); + auto len = static_cast(bytes.size()); + // Default-initialized; only read when extract() returns true. + NativeT native_val{}; + if (IsVariantNull(data, len)) { + // Variant::Null goes to value column — distinguishes from SQL NULL + ARROW_RETURN_NOT_OK( + residual.Append(reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } else if (extract(data, len, &native_val)) { + ARROW_RETURN_NOT_OK(residual.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.Append(native_val)); + } else { + ARROW_RETURN_NOT_OK( + residual.Append(reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } + } + return typed_builder.Finish(out); +} + +/// Specialization for BINARY/LARGE_BINARY extraction (uses pointer+size output). +template +Status ShredBinaryLoop(const Array& value_array, int64_t num_rows, + BinaryBuilder& residual, std::shared_ptr* out) { + BuilderT typed_builder; + ARROW_RETURN_NOT_OK(typed_builder.Reserve(num_rows)); + for (int64_t i = 0; i < num_rows; ++i) { + if (value_array.IsNull(i)) { + ARROW_RETURN_NOT_OK(residual.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + continue; + } + auto bytes = GetBinaryValue(value_array, i); + auto* data = reinterpret_cast(bytes.data()); + auto len = static_cast(bytes.size()); + const uint8_t* bin_data; + int32_t bin_size; + if (IsVariantNull(data, len)) { + ARROW_RETURN_NOT_OK( + residual.Append(reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } else if (ExtractBinary(data, len, &bin_data, &bin_size)) { + ARROW_RETURN_NOT_OK(residual.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.Append(bin_data, bin_size)); + } else { + ARROW_RETURN_NOT_OK( + residual.Append(reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } + } + return typed_builder.Finish(out); +} + +} // namespace + +// --------------------------------------------------------------------------- +// ShredVariantColumn — dispatch by schema kind +// --------------------------------------------------------------------------- + +Result> ShredVariantColumn( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, const VariantShreddingSchema& schema) { + // Validate input array types — GetBinaryValue silently returns empty for + // unsupported types, which would cause subtle data corruption. + if (metadata_array->type_id() != Type::BINARY && + metadata_array->type_id() != Type::LARGE_BINARY && + metadata_array->type_id() != Type::BINARY_VIEW) { + return Status::Invalid( + "ShredVariantColumn: metadata_array must be BINARY, LARGE_BINARY, or " + "BINARY_VIEW, got ", + metadata_array->type()->ToString()); + } + if (value_array->type_id() != Type::BINARY && + value_array->type_id() != Type::LARGE_BINARY && + value_array->type_id() != Type::BINARY_VIEW) { + return Status::Invalid( + "ShredVariantColumn: value_array must be BINARY, LARGE_BINARY, or BINARY_VIEW, " + "got ", + value_array->type()->ToString()); + } + if (metadata_array->length() != value_array->length()) { + return Status::Invalid( + "ShredVariantColumn: metadata_array and value_array length mismatch (", + metadata_array->length(), " vs ", value_array->length(), ")"); + } + + if (schema.kind() == VariantShreddingSchema::Kind::kArray) { + return ShredVariantColumnArray(metadata_array, value_array, schema); + } + + if (schema.kind() == VariantShreddingSchema::Kind::kObject) { + return ShredVariantColumnObject(metadata_array, value_array, schema); + } + + // Primitive shredding + const int64_t num_rows = value_array->length(); + const auto& target_type = *schema.type(); + + // Residual value builder (variant bytes for non-matching rows) + BinaryBuilder residual_value_builder; + ARROW_RETURN_NOT_OK(residual_value_builder.Reserve(num_rows)); + + // NOTE (Rust divergence): Rust's shredding uses arrow::compute::cast() which + // allows cross-type conversions (e.g., Int32→Float64, Float32→Int32). C++ + // only shreds values whose variant type matches the target column type + // directly (with safe widening within the same numeric family, e.g., + // Int8→Int64). This is spec-compliant but less aggressive for predicate + // pushdown. A future cast-based approach could be added as a separate mode. + // + // NOTE (Rust divergence — additional types): Rust additionally supports + // Uint8/16/32/64, Float16, Decimal32, Decimal64, Decimal256, and + // TimestampSecond/Millisecond as shredding targets. These require the + // cast-based mode (variant spec only encodes signed ints, float32/64, + // and timestamp micros/nanos natively). Adding them is straightforward + // once the CastOptions mode is implemented. + std::shared_ptr typed_result; + + switch (target_type.id()) { + case Type::INT64: { + // Note: No IsVariantCompatibleWithType() gatekeeper here because + // ExtractInt64() already accepts only Int8/Int16/Int32/Int64 variants, + // and all int→int64 widening is unconditionally valid per the spec. + // This differs from TIMESTAMP/DECIMAL which need explicit type matching. + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int64_t* out) { + return ExtractInt64(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::DOUBLE: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, double* out) { + return ExtractDoubleOrFloat(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::BOOL: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, bool* out) { + return ExtractBool(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::STRING: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, std::string_view* out) { + return ExtractString(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::LARGE_STRING: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, std::string_view* out) { + return ExtractString(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::FLOAT: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, float* out) { + return ExtractFloat(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::DATE32: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int32_t* out) { + return ExtractDate32(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::TIMESTAMP: { + // Use Int64 storage for timestamps (Arrow stores timestamps as int64). + // IsVariantCompatibleWithType enforces TimeUnit and timezone matching, + // so only correctly-matching timestamp variants reach the typed column. + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int64_t* out) { + return IsVariantCompatibleWithType(data, len, target_type) && + ExtractTimestamp(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::INT8: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int8_t* out) { + return ExtractInt8(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::INT16: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int16_t* out) { + return ExtractInt16(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::INT32: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int32_t* out) { + return ExtractInt32(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::BINARY: { + auto st = ShredBinaryLoop(*value_array, num_rows, + residual_value_builder, &typed_result); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::LARGE_BINARY: { + auto st = ShredBinaryLoop( + *value_array, num_rows, residual_value_builder, &typed_result); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::STRING_VIEW: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, std::string_view* out) { + return ExtractString(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::BINARY_VIEW: { + // Note: ShredBinaryLoop calls typed_builder.Append(bin_data, bin_size) where + // bin_size is int32_t. BinaryViewBuilder::Append accepts int64_t — the implicit + // widening from int32_t is safe and produces correct results. + auto st = ShredBinaryLoop(*value_array, num_rows, + residual_value_builder, &typed_result); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::TIME64: { + auto st = ShredPrimitiveLoop( + *value_array, num_rows, residual_value_builder, &typed_result, + [&](const uint8_t* data, int64_t len, int64_t* out) { + return ExtractTime64(data, len, out); + }); + ARROW_RETURN_NOT_OK(st); + break; + } + + case Type::FIXED_SIZE_BINARY: { + // UUID = FixedSizeBinary(16) + auto& fsb_type = static_cast(target_type); + if (fsb_type.byte_width() != 16) { + return Status::NotImplemented( + "ShredVariantColumn: only FixedSizeBinary(16) for UUID is supported"); + } + FixedSizeBinaryBuilder typed_builder(fixed_size_binary(16)); + ARROW_RETURN_NOT_OK(typed_builder.Reserve(num_rows)); + for (int64_t i = 0; i < num_rows; ++i) { + if (value_array->IsNull(i)) { + ARROW_RETURN_NOT_OK(residual_value_builder.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + continue; + } + auto bytes = GetBinaryValue(*value_array, i); + auto* data = reinterpret_cast(bytes.data()); + auto len = static_cast(bytes.size()); + uint8_t uuid_bytes[16]; + if (IsVariantNull(data, len)) { + ARROW_RETURN_NOT_OK(residual_value_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } else if (ExtractUUID(data, len, uuid_bytes)) { + ARROW_RETURN_NOT_OK(residual_value_builder.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.Append(uuid_bytes)); + } else { + ARROW_RETURN_NOT_OK(residual_value_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } + } + ARROW_RETURN_NOT_OK(typed_builder.Finish(&typed_result)); + break; + } + + case Type::DECIMAL128: { + // Decimal128 shredding — extracts scale from variant, stores unscaled value. + // IsVariantCompatibleWithType checks both type and scale compatibility. + // Note: Rust also supports Decimal32/Decimal64 as separate shredding targets + // (via VariantDecimal4/VariantDecimal8 types). C++ Arrow only has Decimal128/256, + // so we consolidate all decimal widths into Decimal128 storage. + Decimal128Builder typed_builder(schema.type()); + ARROW_RETURN_NOT_OK(typed_builder.Reserve(num_rows)); + for (int64_t i = 0; i < num_rows; ++i) { + if (value_array->IsNull(i)) { + ARROW_RETURN_NOT_OK(residual_value_builder.AppendNull()); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + continue; + } + auto bytes = GetBinaryValue(*value_array, i); + auto* data = reinterpret_cast(bytes.data()); + auto len = static_cast(bytes.size()); + int64_t low = 0, high = 0; + uint8_t scale = 0; + if (IsVariantNull(data, len)) { + ARROW_RETURN_NOT_OK(residual_value_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } else if (IsVariantCompatibleWithType(data, len, target_type) && + ExtractDecimal128(data, len, &low, &high, &scale)) { + Decimal128 dec_val(high, static_cast(low)); + ARROW_RETURN_NOT_OK(typed_builder.Append(dec_val)); + ARROW_RETURN_NOT_OK(residual_value_builder.AppendNull()); + } else { + ARROW_RETURN_NOT_OK(residual_value_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + ARROW_RETURN_NOT_OK(typed_builder.AppendNull()); + } + } + ARROW_RETURN_NOT_OK(typed_builder.Finish(&typed_result)); + break; + } + + // TODO GH-45948 follow-up: Add FixedSizeList and ListView as array + // shredding targets (Rust supports all list-like types via GenericListArray). + // LargeList is supported in reconstruction but not as a distinct shredding + // output (shredding always produces ListArray with 32-bit offsets). + // TODO GH-45948 follow-up (Rust parity — cast mode): Add support for: + // - Uint8/16/32/64 (variant spec only encodes signed ints; requires cast) + // - Float16 (variant spec encodes Float32/64; requires cast/truncation) + // - Decimal32, Decimal64 (Rust has VariantDecimal4/8 dedicated types) + // - TimestampSecond, TimestampMillisecond (variant spec only has micros/nanos; + // requires unit conversion similar to Rust's CastOptions approach) + // These all require a CastOptions-based extraction mode analogous to Rust's + // `shred_variant_with_options()` which uses arrow::compute::cast(). + + default: + return Status::NotImplemented("ShredVariantColumn: unsupported target type ", + target_type.ToString()); + } + + std::shared_ptr residual_result; + ARROW_RETURN_NOT_OK(residual_value_builder.Finish(&residual_result)); + + // Determine the output field type for typed_value. For most types this is + // schema.type() directly. For TIMESTAMP and TIME64, the builder produces + // Int64Array (physical storage), so declare the field as int64() to match. + // The reconstruction path uses schema.type()->id() to re-encode correctly. + auto typed_field_type = schema.type(); + if (typed_field_type->id() == Type::TIMESTAMP || + typed_field_type->id() == Type::TIME64) { + typed_field_type = int64(); + } + + auto output_fields = std::vector>{ + field("metadata", metadata_array->type(), /*nullable=*/false), + field("value", binary(), /*nullable=*/true), + field("typed_value", typed_field_type, /*nullable=*/true), + }; + + return StructArray::Make({metadata_array, residual_result, typed_result}, + output_fields); +} + +// --------------------------------------------------------------------------- +// ReconstructVariantColumnArray +// --------------------------------------------------------------------------- + +static Result> ReconstructVariantColumnArray( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_value_array, + const VariantShreddingSchema& schema) { + const int64_t num_rows = metadata_array->length(); + + // Validate typed_value_array is a list-like array with binary elements. + // Accept LIST, LARGE_LIST, FIXED_SIZE_LIST, LIST_VIEW, and LARGE_LIST_VIEW + // to support all Parquet list-like representations (Rust parity). + if (typed_value_array->type_id() != Type::LIST && + typed_value_array->type_id() != Type::LARGE_LIST && + typed_value_array->type_id() != Type::FIXED_SIZE_LIST && + typed_value_array->type_id() != Type::LIST_VIEW && + typed_value_array->type_id() != Type::LARGE_LIST_VIEW) { + return Status::Invalid( + "ReconstructVariantColumnArray: typed_value_array must be LIST, LARGE_LIST, " + "FIXED_SIZE_LIST, LIST_VIEW, or LARGE_LIST_VIEW, got ", + typed_value_array->type()->ToString()); + } + + BinaryBuilder output_builder; + ARROW_RETURN_NOT_OK(output_builder.Reserve(num_rows)); + + // Generic lambda that handles any list-like type with value_offset(i) + values(). + // Works for LIST, LARGE_LIST, FIXED_SIZE_LIST, LIST_VIEW, and LARGE_LIST_VIEW. + // Elements can be either raw BINARY (legacy format) or STRUCT{value, typed_value} + // (recursively shredded format). We detect which format at runtime. + auto reconstruct_rows_offset_based = [&](auto* list_arr) -> Status { + auto value_type_id = list_arr->value_type()->id(); + + if (value_type_id == Type::BINARY) { + // Legacy format: elements are raw binary variant bytes. + DCHECK_NE(list_arr->values(), nullptr); + const auto* elem_arr = static_cast(list_arr->values().get()); + + for (int64_t i = 0; i < num_rows; ++i) { + bool value_present = value_array && value_array->IsValid(i); + bool typed_present = typed_value_array && typed_value_array->IsValid(i); + + if (!value_present && !typed_present) { + uint8_t null_byte = 0x00; + ARROW_RETURN_NOT_OK(output_builder.Append(&null_byte, 1)); + } else if (value_present && !typed_present) { + auto bytes = GetBinaryValue(*value_array, i); + ARROW_RETURN_NOT_OK(output_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + } else if (!value_present && typed_present) { + VariantBuilder vb; + auto start = vb.Offset(); + std::vector offsets; + + auto list_start = list_arr->value_offset(i); + auto list_length = list_arr->value_length(i); + for (decltype(list_length) ei = 0; ei < list_length; ++ei) { + offsets.push_back(vb.NextElement(start)); + auto elem_view = elem_arr->GetView(static_cast(list_start + ei)); + vb.UnsafeAppendEncoded(reinterpret_cast(elem_view.data()), + static_cast(elem_view.size())); + } + + ARROW_RETURN_NOT_OK(vb.FinishArray(start, offsets)); + ARROW_ASSIGN_OR_RAISE(auto arr_bytes, vb.BuildWithoutMeta()); + ARROW_RETURN_NOT_OK(output_builder.Append( + arr_bytes.data(), static_cast(arr_bytes.size()))); + } else { + return Status::Invalid( + "ReconstructVariantColumn: both value and typed_value non-null " + "at row ", + i, " (not valid for array schemas)"); + } + } + } else if (value_type_id == Type::STRUCT) { + // Recursively shredded format: elements are struct{value, typed_value}. + // Reconstruct elements first at the column level, then use the + // reconstructed per-element binary to build variant arrays. + DCHECK_NE(list_arr->values(), nullptr); + const auto* elem_struct = static_cast(list_arr->values().get()); + + if (elem_struct->num_fields() != 2) { + return Status::Invalid( + "ReconstructVariantColumnArray: element struct must have 2 fields " + "(value, typed_value), got ", + elem_struct->num_fields()); + } + auto elem_value_col = elem_struct->field(0); + auto elem_typed_col = elem_struct->field(1); + + // Build a metadata array for elements by replicating row metadata. + int64_t total_elements = elem_struct->length(); + BinaryBuilder elem_meta_builder; + ARROW_RETURN_NOT_OK(elem_meta_builder.Reserve(total_elements)); + for (int64_t i = 0; i < num_rows; ++i) { + if (!typed_value_array->IsValid(i)) continue; + auto meta_bytes = GetBinaryValue(*metadata_array, i); + auto list_length = list_arr->value_length(i); + for (decltype(list_length) ei = 0; ei < list_length; ++ei) { + ARROW_RETURN_NOT_OK(elem_meta_builder.Append( + reinterpret_cast(meta_bytes.data()), meta_bytes.size())); + } + } + std::shared_ptr elem_meta_arr; + ARROW_RETURN_NOT_OK(elem_meta_builder.Finish(&elem_meta_arr)); + + if (elem_meta_arr->length() != total_elements) { + return Status::Invalid( + "ReconstructVariantColumnArray: element metadata count mismatch (", + elem_meta_arr->length(), " vs ", total_elements, " elements)"); + } + + // Recursively reconstruct all elements at once (column-level operation) + ARROW_ASSIGN_OR_RAISE( + auto reconstructed_elements, + ReconstructVariantColumn(elem_meta_arr, elem_value_col, elem_typed_col, + schema.element_schema())); + + // Build variant arrays from the reconstructed element bytes + for (int64_t i = 0; i < num_rows; ++i) { + bool value_present = value_array && value_array->IsValid(i); + bool typed_present = typed_value_array && typed_value_array->IsValid(i); + + if (!value_present && !typed_present) { + uint8_t null_byte = 0x00; + ARROW_RETURN_NOT_OK(output_builder.Append(&null_byte, 1)); + } else if (value_present && !typed_present) { + auto bytes = GetBinaryValue(*value_array, i); + ARROW_RETURN_NOT_OK(output_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + } else if (!value_present && typed_present) { + VariantBuilder vb; + auto start = vb.Offset(); + std::vector offsets; + + auto list_start = list_arr->value_offset(i); + auto list_length = list_arr->value_length(i); + for (decltype(list_length) ei = 0; ei < list_length; ++ei) { + offsets.push_back(vb.NextElement(start)); + auto elem_bytes = GetBinaryValue(*reconstructed_elements, + static_cast(list_start + ei)); + vb.UnsafeAppendEncoded(reinterpret_cast(elem_bytes.data()), + static_cast(elem_bytes.size())); + } + + ARROW_RETURN_NOT_OK(vb.FinishArray(start, offsets)); + ARROW_ASSIGN_OR_RAISE(auto arr_bytes, vb.BuildWithoutMeta()); + ARROW_RETURN_NOT_OK(output_builder.Append( + arr_bytes.data(), static_cast(arr_bytes.size()))); + } else { + return Status::Invalid( + "ReconstructVariantColumn: both value and typed_value non-null " + "at row ", + i, " (not valid for array schemas)"); + } + } + } else { + return Status::Invalid( + "ReconstructVariantColumnArray: list elements must be BINARY or " + "STRUCT{value, typed_value}, got ", + list_arr->value_type()->ToString()); + } + return Status::OK(); + }; + + switch (typed_value_array->type_id()) { + case Type::LIST: + ARROW_RETURN_NOT_OK(reconstruct_rows_offset_based( + static_cast(typed_value_array.get()))); + break; + case Type::LARGE_LIST: + ARROW_RETURN_NOT_OK(reconstruct_rows_offset_based( + static_cast(typed_value_array.get()))); + break; + case Type::FIXED_SIZE_LIST: + ARROW_RETURN_NOT_OK(reconstruct_rows_offset_based( + static_cast(typed_value_array.get()))); + break; + case Type::LIST_VIEW: + ARROW_RETURN_NOT_OK(reconstruct_rows_offset_based( + static_cast(typed_value_array.get()))); + break; + case Type::LARGE_LIST_VIEW: + ARROW_RETURN_NOT_OK(reconstruct_rows_offset_based( + static_cast(typed_value_array.get()))); + break; + default: + return Status::Invalid( + "ReconstructVariantColumnArray: unexpected typed_value type ", + typed_value_array->type()->ToString()); + } + + std::shared_ptr result; + ARROW_RETURN_NOT_OK(output_builder.Finish(&result)); + return result; +} + +// --------------------------------------------------------------------------- +// ReconstructVariantColumnObject +// --------------------------------------------------------------------------- + +static Result> ReconstructVariantColumnObject( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_value_array, + const VariantShreddingSchema& schema) { + const int64_t num_rows = metadata_array->length(); + const auto& schema_fields = schema.fields(); + + // Validate typed_value_array is the expected StructArray. + if (typed_value_array->type_id() != Type::STRUCT) { + return Status::Invalid( + "ReconstructVariantColumnObject: typed_value_array must be STRUCT, got ", + typed_value_array->type()->ToString()); + } + + BinaryBuilder output_builder; + ARROW_RETURN_NOT_OK(output_builder.Reserve(num_rows)); + + // typed_value should be a StructArray with one field per schema field + const auto* typed_struct = static_cast(typed_value_array.get()); + + // Validate that the typed_value struct has the expected number of fields. + if (typed_struct->num_fields() != static_cast(schema_fields.size())) { + return Status::Invalid("ReconstructVariantColumnObject: typed_value struct has ", + typed_struct->num_fields(), " fields but schema expects ", + schema_fields.size()); + } + + // Pre-compute per-field reconstructed variant arrays for fields with + // primitive sub-schemas. This avoids O(n²) by calling ReconstructVariantColumn + // once per field (column-level) rather than once per row. + std::vector> field_reconstructed_arrays(schema_fields.size()); + for (size_t sf = 0; sf < schema_fields.size(); ++sf) { + const auto& sub_schema = schema_fields[sf].second; + if (sub_schema.kind() == VariantShreddingSchema::Kind::kPrimitive) { + auto field_struct_col = typed_struct->field(static_cast(sf)); + auto* field_struct = static_cast(field_struct_col.get()); + auto field_value_col = field_struct->field(0); // "value" sub-column + auto field_typed_col = field_struct->field(1); // "typed_value" sub-column + + ARROW_ASSIGN_OR_RAISE(field_reconstructed_arrays[sf], + ReconstructVariantColumn(metadata_array, field_value_col, + field_typed_col, sub_schema)); + } + } + + // Cache for metadata reuse: in columnar data, all rows typically share the + // same metadata dictionary. By comparing metadata bytes across rows, we avoid + // redundant DecodeMetadata calls and VariantBuilder dictionary copies. + // Lifetime safety: cached_meta_bytes is a string_view into metadata_array's + // buffer, which is kept alive by the shared_ptr parameter for this function's + // entire duration. Arrow arrays never relocate their backing buffers. + std::string_view cached_meta_bytes; + VariantMetadata cached_meta; + // Cached builder: reused when consecutive rows share metadata. After + // BuildWithoutMeta(), the builder's buffer is cleared but its dictionary + // (dict_, dict_keys_) is preserved — exactly what we need for the next row. + std::unique_ptr cached_builder; + + for (int64_t i = 0; i < num_rows; ++i) { + bool value_present = value_array && value_array->IsValid(i); + bool typed_present = typed_value_array && typed_value_array->IsValid(i); + + if (!value_present && !typed_present) { + // Both null → Variant null (see comment in primitive reconstruction + // about SQL NULL vs variant-null ambiguity) + uint8_t null_byte = 0x00; + ARROW_RETURN_NOT_OK(output_builder.Append(&null_byte, 1)); + continue; + } + + if (value_present && !typed_present) { + // Non-object value stored in residual + auto bytes = GetBinaryValue(*value_array, i); + ARROW_RETURN_NOT_OK(output_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + continue; + } + + // typed_present: reconstruct object from shredded fields + residual + auto meta_bytes = GetBinaryValue(*metadata_array, i); + + // Reuse cached metadata and builder if bytes are identical (common case: + // all rows share the same metadata dictionary). This avoids O(n × k) + // dictionary copies from VariantBuilder construction per row. + if (meta_bytes != cached_meta_bytes) { + ARROW_ASSIGN_OR_RAISE( + cached_meta, DecodeMetadata(reinterpret_cast(meta_bytes.data()), + static_cast(meta_bytes.size()))); + cached_meta_bytes = meta_bytes; + cached_builder = std::make_unique(cached_meta); + cached_builder->SetAllowDuplicates(true); + } + + // Reuse the cached builder: BuildWithoutMeta() clears the buffer but + // preserves the dictionary, so NextField() can resolve keys without + // redundant hash map insertions. + VariantBuilder& vb = *cached_builder; + auto start = vb.Offset(); + std::vector fields; + + // Add shredded fields from typed_value struct + for (size_t sf = 0; sf < schema_fields.size(); ++sf) { + const auto& sub_schema = schema_fields[sf].second; + + if (sub_schema.kind() == VariantShreddingSchema::Kind::kPrimitive && + field_reconstructed_arrays[sf]) { + // Primitive sub-schema: use the pre-computed reconstructed array. + // Check the original sub-field struct to determine if the field was + // present (value or typed_value non-null) or absent (both null). + auto field_struct_col = typed_struct->field(static_cast(sf)); + auto* field_struct = static_cast(field_struct_col.get()); + auto field_value_col = field_struct->field(0); + auto field_typed_col = field_struct->field(1); + bool field_present = field_value_col->IsValid(i) || field_typed_col->IsValid(i); + + if (field_present) { + auto& recon_arr = field_reconstructed_arrays[sf]; + auto recon_bytes = GetBinaryValue(*recon_arr, i); + if (!recon_bytes.empty()) { + fields.push_back(vb.NextField(start, schema_fields[sf].first)); + vb.UnsafeAppendEncoded(reinterpret_cast(recon_bytes.data()), + static_cast(recon_bytes.size())); + } + } + } else { + // Non-primitive or no reconstruction available: read from value sub-column + auto field_struct_col = typed_struct->field(static_cast(sf)); + auto* field_struct = static_cast(field_struct_col.get()); + auto field_value_col = field_struct->field(0); // "value" sub-column + + if (field_value_col->IsValid(i)) { + auto field_bytes = GetBinaryValue(*field_value_col, i); + fields.push_back(vb.NextField(start, schema_fields[sf].first)); + vb.UnsafeAppendEncoded(reinterpret_cast(field_bytes.data()), + static_cast(field_bytes.size())); + } + } + // Both null → field is missing from this object, skip it + } + + // Add residual fields (if value is present alongside typed_value) + if (value_present) { + auto residual_bytes = GetBinaryValue(*value_array, i); + auto* residual_data = reinterpret_cast(residual_bytes.data()); + auto residual_len = static_cast(residual_bytes.size()); + + if (residual_len > 0 && GetBasicType(residual_data[0]) == BasicType::kObject) { + ARROW_ASSIGN_OR_RAISE(auto residual_field_count, + GetObjectFieldCount(residual_data, residual_len)); + for (int32_t fi = 0; fi < residual_field_count; ++fi) { + std::string_view fname; + int64_t foff = 0, fsz = 0; + ARROW_RETURN_NOT_OK(GetObjectFieldAt(cached_meta, residual_data, residual_len, + fi, &fname, &foff, &fsz)); + fields.push_back(vb.NextField(start, fname)); + vb.UnsafeAppendEncoded(residual_data + foff, fsz); + } + } + } + + ARROW_RETURN_NOT_OK(vb.FinishObject(start, fields)); + ARROW_ASSIGN_OR_RAISE(auto obj_bytes, vb.BuildWithoutMeta()); + ARROW_RETURN_NOT_OK( + output_builder.Append(obj_bytes.data(), static_cast(obj_bytes.size()))); + } + + std::shared_ptr result; + ARROW_RETURN_NOT_OK(output_builder.Finish(&result)); + return result; +} + +// --------------------------------------------------------------------------- +// ReconstructVariantColumn — dispatch by schema kind +// --------------------------------------------------------------------------- + +Result> ReconstructVariantColumn( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_value_array, + const VariantShreddingSchema& schema) { + // Validate array lengths are consistent. + if (metadata_array->length() != value_array->length() || + metadata_array->length() != typed_value_array->length()) { + return Status::Invalid("ReconstructVariantColumn: array length mismatch (metadata=", + metadata_array->length(), ", value=", value_array->length(), + ", typed_value=", typed_value_array->length(), ")"); + } + // Validate input array types — GetBinaryValue silently returns empty for + // unsupported types, which would cause subtle data corruption. + if (metadata_array->type_id() != Type::BINARY && + metadata_array->type_id() != Type::LARGE_BINARY && + metadata_array->type_id() != Type::BINARY_VIEW) { + return Status::Invalid( + "ReconstructVariantColumn: metadata_array must be BINARY, LARGE_BINARY, or " + "BINARY_VIEW, got ", + metadata_array->type()->ToString()); + } + if (value_array->type_id() != Type::BINARY && + value_array->type_id() != Type::LARGE_BINARY && + value_array->type_id() != Type::BINARY_VIEW) { + return Status::Invalid( + "ReconstructVariantColumn: value_array must be BINARY, LARGE_BINARY, or " + "BINARY_VIEW, got ", + value_array->type()->ToString()); + } + + if (schema.kind() == VariantShreddingSchema::Kind::kArray) { + return ReconstructVariantColumnArray(metadata_array, value_array, typed_value_array, + schema); + } + + if (schema.kind() == VariantShreddingSchema::Kind::kObject) { + return ReconstructVariantColumnObject(metadata_array, value_array, typed_value_array, + schema); + } + + // Primitive reconstruction + const int64_t num_rows = metadata_array->length(); + BinaryBuilder output_builder; + ARROW_RETURN_NOT_OK(output_builder.Reserve(num_rows)); + + // We need to re-encode native typed_value back to variant bytes. + // Reused across all rows — safe because primitive encoding never adds + // dictionary keys (only objects do). The dictionary stays empty. + VariantBuilder vb; + + for (int64_t i = 0; i < num_rows; ++i) { + bool value_present = value_array && value_array->IsValid(i); + bool typed_present = typed_value_array && typed_value_array->IsValid(i); + + if (!value_present && !typed_present) { + // Both null → Variant null (single byte: 0x00) + // NOTE: This is ambiguous — it could represent either: + // 1. "SQL NULL / missing" (the row itself is absent), or + // 2. "variant-typed null" (a Variant::Null that was stored in value) + // The Rust implementation disambiguates via a separate NullBuffer on the + // VariantArray struct. Since we don't receive a validity bitmap here, + // we conservatively emit Variant::Null (0x00). Callers that need to + // distinguish SQL NULL from variant-null should check the struct-level + // validity bitmap before calling ReconstructVariantColumn. + // TODO GH-45948 follow-up (Rust parity — NullBuffer): Accept an optional + // validity bitmap to propagate struct-level nulls into the output, matching + // Rust's `unshred_variant()` which returns a separate NullBuffer. + uint8_t null_byte = 0x00; + ARROW_RETURN_NOT_OK(output_builder.Append(&null_byte, 1)); + } else if (value_present && !typed_present) { + // Value present, typed null → use residual value as-is + auto bytes = GetBinaryValue(*value_array, i); + ARROW_RETURN_NOT_OK(output_builder.Append( + reinterpret_cast(bytes.data()), bytes.size())); + } else if (!value_present && typed_present) { + // Typed present, value null → re-encode native value as variant. + // Dispatch on the *schema* type, not the physical array type, because + // some types (e.g., TIMESTAMP, TIME64) are stored as Int64 arrays + // but need to be re-encoded with their specific variant type. + switch (schema.type()->id()) { + case Type::INT64: { + auto& arr = static_cast(*typed_value_array); + // Note: vb.Int() auto-sizes to smallest representation. This means + // Shred(Int64(42))→Reconstruct() produces Int8(42). The *value* is + // preserved but the encoding width may narrow. This matches Rust. + ARROW_RETURN_NOT_OK(vb.Int(arr.Value(i))); + break; + } + case Type::DOUBLE: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Double(arr.Value(i))); + break; + } + case Type::FLOAT: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Float(arr.Value(i))); + break; + } + case Type::BOOL: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Bool(arr.Value(i))); + break; + } + case Type::STRING: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.String(arr.GetView(i))); + break; + } + case Type::LARGE_STRING: { + auto& arr = static_cast(*typed_value_array); + auto view = arr.GetView(i); + ARROW_RETURN_NOT_OK(vb.String(std::string_view(view.data(), view.size()))); + break; + } + case Type::DATE32: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Date(arr.Value(i))); + break; + } + case Type::TIMESTAMP: { + auto& arr = static_cast(*typed_value_array); + // Determine the correct timestamp variant from the schema type. + // The schema carries TimeUnit and timezone, which tells us which + // variant encoding method to use for faithful round-trip. + auto& ts_type = static_cast(*schema.type()); + bool has_tz = !ts_type.timezone().empty(); + if (ts_type.unit() == TimeUnit::NANO) { + if (has_tz) { + ARROW_RETURN_NOT_OK(vb.TimestampNanos(arr.Value(i))); + } else { + ARROW_RETURN_NOT_OK(vb.TimestampNanosNTZ(arr.Value(i))); + } + } else { + // MICRO unit (guaranteed by IsVariantCompatibleWithType which rejects + // SECOND/MILLI/other units during shredding) + if (has_tz) { + ARROW_RETURN_NOT_OK(vb.TimestampMicros(arr.Value(i))); + } else { + ARROW_RETURN_NOT_OK(vb.TimestampMicrosNTZ(arr.Value(i))); + } + } + break; + } + case Type::BINARY: { + auto& arr = static_cast(*typed_value_array); + auto view = arr.GetView(i); + ARROW_RETURN_NOT_OK(vb.Binary( + std::string_view(reinterpret_cast(view.data()), view.size()))); + break; + } + case Type::LARGE_BINARY: { + auto& arr = static_cast(*typed_value_array); + auto view = arr.GetView(i); + ARROW_RETURN_NOT_OK(vb.Binary( + std::string_view(reinterpret_cast(view.data()), view.size()))); + break; + } + case Type::STRING_VIEW: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.String(arr.GetView(i))); + break; + } + case Type::BINARY_VIEW: { + auto& arr = static_cast(*typed_value_array); + auto view = arr.GetView(i); + ARROW_RETURN_NOT_OK(vb.Binary(view)); + break; + } + case Type::INT8: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Int8(arr.Value(i))); + break; + } + case Type::INT16: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Int16(arr.Value(i))); + break; + } + case Type::INT32: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.Int32(arr.Value(i))); + break; + } + case Type::TIME64: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.TimeNTZ(arr.Value(i))); + break; + } + case Type::FIXED_SIZE_BINARY: { + auto& arr = static_cast(*typed_value_array); + ARROW_RETURN_NOT_OK(vb.UUID(arr.GetValue(i))); + break; + } + case Type::DECIMAL128: { + auto& arr = static_cast(*typed_value_array); + auto& dec_type = static_cast(*typed_value_array->type()); + Decimal128 val(arr.GetValue(i)); + uint8_t scale = static_cast(dec_type.scale()); + // Preserve the smallest encoding width that can represent the value. + // This ensures round-trip byte identity: a Decimal4 variant stays Decimal4 + // after shred→reconstruct rather than being widened to Decimal16. + // + // Use high_bits()/low_bits() accessors which are endian-safe (they + // return numeric values, not raw bytes). This avoids the ToBytes()+ + // FromLittleEndian() pattern which is incorrect on big-endian. + int64_t high_word = val.high_bits(); + int64_t low_word = static_cast(val.low_bits()); + // Check if the value fits in 4 bytes (int32 range) + int32_t as_int32 = static_cast(low_word); + if (high_word == (as_int32 < 0 ? -1 : 0) && + low_word == static_cast(as_int32)) { + // Fits in Decimal4 (4-byte unscaled value) + uint8_t d4_bytes[4]; + int32_t val32 = bit_util::ToLittleEndian(as_int32); + std::memcpy(d4_bytes, &val32, 4); + ARROW_RETURN_NOT_OK(vb.Decimal4(scale, d4_bytes)); + } else if (high_word == (low_word < 0 ? -1 : 0)) { + // Fits in Decimal8 (8-byte unscaled value) + uint8_t d8_bytes[8]; + int64_t val64 = bit_util::ToLittleEndian(low_word); + std::memcpy(d8_bytes, &val64, 8); + ARROW_RETURN_NOT_OK(vb.Decimal8(scale, d8_bytes)); + } else { + // Full Decimal16 — write low and high words in little-endian + uint8_t d16_bytes[16]; + int64_t le_low = bit_util::ToLittleEndian(low_word); + int64_t le_high = bit_util::ToLittleEndian(high_word); + std::memcpy(d16_bytes, &le_low, 8); + std::memcpy(d16_bytes + 8, &le_high, 8); + ARROW_RETURN_NOT_OK(vb.Decimal16(scale, d16_bytes)); + } + break; + } + // TODO GH-45948 follow-up: Add FixedSizeList and ListView + // reconstruction (LargeList is already supported above via the + // array reconstruction path). + // TODO GH-45948 follow-up (Rust parity — cast mode): Add Uint8/16/32/64, + // Float16, Decimal32/64, TimestampSecond/Millisecond reconstruction. + // See shredding switch for the full list. + default: + return Status::NotImplemented( + "ReconstructVariantColumn: unsupported typed_value type ", + typed_value_array->type()->ToString()); + } + // Use BuildWithoutMeta() to avoid reconstructing the metadata dictionary + // on every row — primitives don't reference dictionary keys, so the + // metadata is irrelevant here. This avoids O(n) allocations. + ARROW_ASSIGN_OR_RAISE(auto value_bytes, vb.BuildWithoutMeta()); + ARROW_RETURN_NOT_OK(output_builder.Append( + value_bytes.data(), static_cast(value_bytes.size()))); + } else { + // Both present → partial object (not supported for primitive schemas) + return Status::Invalid( + "ReconstructVariantColumn: both value and typed_value are non-null " + "at row ", + i, " (partial objects not supported for primitive schemas)"); + } + } + + std::shared_ptr result; + ARROW_RETURN_NOT_OK(output_builder.Finish(&result)); + return result; +} + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_shredding.h b/cpp/src/arrow/extension/variant_shredding.h new file mode 100644 index 000000000000..13e1271d2cfc --- /dev/null +++ b/cpp/src/arrow/extension/variant_shredding.h @@ -0,0 +1,199 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include + +#include "arrow/array/array_nested.h" +#include "arrow/result.h" +#include "arrow/status.h" +#include "arrow/type_fwd.h" +#include "arrow/util/visibility.h" + +namespace arrow::extension::variant_internal { + +// --------------------------------------------------------------------------- +// Shredding Schema +// --------------------------------------------------------------------------- + +/// \brief Defines which fields/types should be extracted from a Variant +/// into typed Parquet columns. +/// +/// A shredding schema is a tree that mirrors the structure of the Variant +/// values to be shredded: +/// - Primitive: shred the entire value as a specific Arrow type +/// - Object: shred named fields, each with its own sub-schema +/// - Array: shred array elements with a uniform element schema +/// +/// This is the C++ equivalent of Rust's ShreddedSchemaBuilder. +class ARROW_EXPORT VariantShreddingSchema { + public: + enum class Kind { kPrimitive, kObject, kArray }; + + /// \brief Create a primitive shredding schema. + /// + /// When applied, if the variant value's type is compatible with + /// the target DataType, it is cast and placed in typed_value. + /// Otherwise it remains in the value column as binary variant. + /// + /// \param[in] type The target Arrow DataType for the typed_value column + static VariantShreddingSchema Primitive(std::shared_ptr type); + + /// \brief Create an object shredding schema. + /// + /// When applied to an object variant, named fields are extracted + /// according to their individual sub-schemas. Remaining fields + /// go into the residual value column. + /// + /// \param[in] fields Vector of (field_name, sub_schema) pairs + static VariantShreddingSchema Object( + std::vector> fields); + + /// \brief Create an array shredding schema. + /// + /// When applied to an array variant, each element is shredded + /// according to the element_schema. + /// + /// \param[in] element_schema Schema for array elements + static VariantShreddingSchema Array(VariantShreddingSchema element_schema); + + /// \brief Get the kind of this schema node. + Kind kind() const { return kind_; } + + /// \brief Get the target type (only valid for Primitive kind). + const std::shared_ptr& type() const { return type_; } + + /// \brief Get the object fields (only valid for Object kind). + const std::vector>& fields() const { + return fields_; + } + + /// \brief Get the element schema (only valid for Array kind). + const VariantShreddingSchema& element_schema() const { return *element_schema_; } + + /// \brief Convert this schema to the Arrow DataType for the typed_value column. + /// + /// - Primitive → the target DataType directly + /// - Object → struct type with fields named per the schema + /// - Array → list type with element struct {value?, typed_value?} + /// + /// NOTE: This returns the *logical* type. The actual shredded output for + /// TIMESTAMP and TIME64 schemas uses int64() as the physical field type + /// (since Arrow builds timestamps via Int64Builder). Callers comparing + /// ToArrowType() output against shredded array field types should account + /// for this TIMESTAMP/TIME64 → int64() mapping. + std::shared_ptr ToArrowType() const; + + private: + Kind kind_ = Kind::kPrimitive; + std::shared_ptr type_; // Primitive + std::vector> fields_; // Object + std::shared_ptr element_schema_; // Array +}; + +// --------------------------------------------------------------------------- +// Shredding Operations +// --------------------------------------------------------------------------- + +/// \brief Determine if a Variant primitive type is compatible with a target +/// Arrow DataType for shredding purposes. +/// +/// Compatibility means the variant value can be losslessly represented +/// in the target typed column without falling back to the binary value column. +/// +/// \param[in] variant_data Pointer to the variant value buffer +/// \param[in] variant_length Length of the variant value buffer +/// \param[in] target_type The target Arrow DataType +/// \return true if the variant value type is compatible with target_type +ARROW_EXPORT bool IsVariantCompatibleWithType(const uint8_t* variant_data, + int64_t variant_length, + const DataType& target_type); + +// --------------------------------------------------------------------------- +// Column-level Shredding / Reconstruction +// --------------------------------------------------------------------------- + +/// \brief Shred a column of variant values according to a shredding schema. +/// +/// Takes an unshredded VariantArray (a StructArray with {metadata, value}) +/// and produces a shredded StructArray with {metadata, value?, typed_value?} +/// where matching values are routed to typed_value and non-matching values +/// remain in value. +/// +/// Type compatibility is strict: a variant value is only shredded if its +/// encoded type matches the target type directly (with safe widening within +/// the same family, e.g., Int8→Int64). Non-matching values always fall to +/// the value column safely — no errors are produced for type mismatches. +/// +/// This is the C++ equivalent of Rust's `shred_variant()`. +/// +/// Known Rust parity gaps (planned follow-ups): +/// - Recursive object/array sub-field shredding: Rust recursively shreds +/// nested object and array sub-fields. C++ handles primitive sub-fields +/// natively and recursively shreds array elements. Object/Array sub-schemas +/// in object fields store values as variant binary (recursive nested +/// shredding for those is a potential follow-up). +/// - CastOptions: Rust supports cross-type coercion (e.g., Int32->Float64 via +/// arrow::compute::cast); this function uses strict matching only. +/// - Additional targets: Rust supports FixedSizeList and ListView as +/// shredding output targets. Reconstruction accepts all list-like types. +/// +/// \param[in] metadata_array The shared metadata column (binary array) +/// \param[in] value_array The unshredded value column (binary array) +/// \param[in] schema The shredding schema defining the typed_value target +/// \return A struct with three fields: {metadata, value(nullable), typed_value(nullable)} +ARROW_EXPORT Result> ShredVariantColumn( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, const VariantShreddingSchema& schema); + +/// \brief Reconstruct unshredded variant values from shredded columns. +/// +/// Takes a shredded representation {metadata, value?, typed_value?} and +/// produces a fully-materialized value column where all typed_value entries +/// have been re-encoded as variant binary. +/// +/// This is the C++ equivalent of Rust's `unshred_variant()`. +/// +/// \param[in] metadata_array The shared metadata column (binary array) +/// \param[in] value_array The residual value column (nullable binary array) +/// \param[in] typed_value_array The shredded typed column (nullable) +/// \param[in] schema The shredding schema (needed to interpret typed_value) +/// \return A non-nullable binary array of fully-reconstructed variant values. +/// When both value and typed_value are null for a row, the output +/// contains a Variant::Null byte (0x00). Callers that need to distinguish +/// SQL NULL (missing row) from variant-null must check the original +/// struct-level validity bitmap before calling this function. +/// +/// TODO GH-45948 follow-up (Rust parity — NullBuffer): Accept an optional +/// validity bitmap parameter (or return a separate NullBuffer alongside the +/// data array) to propagate struct-level nulls into the reconstructed output. +/// Rust's `unshred_variant()` returns `(VariantArray, Option)` which +/// cleanly disambiguates SQL NULL from variant-null without requiring callers +/// to retain and cross-reference the original struct validity. This would make +/// the C++ API fully equivalent to Rust's semantics. +ARROW_EXPORT Result> ReconstructVariantColumn( + const std::shared_ptr& metadata_array, + const std::shared_ptr& value_array, + const std::shared_ptr& typed_value_array, + const VariantShreddingSchema& schema); + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/extension/variant_shredding_test.cc b/cpp/src/arrow/extension/variant_shredding_test.cc new file mode 100644 index 000000000000..1c93aa07342a --- /dev/null +++ b/cpp/src/arrow/extension/variant_shredding_test.cc @@ -0,0 +1,2118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/extension/variant_shredding.h" + +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_decimal.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/buffer.h" +#include "arrow/extension/variant_internal.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type.h" +#include "arrow/util/decimal.h" + +namespace arrow::extension::variant_internal { + +namespace { + +// Test-local helper to extract variant bytes from a BinaryArray at index i. +// Named to avoid confusion with BinaryViewArray::GetView(). +std::string_view GetBinaryView(const Array& array, int64_t i) { + if (array.type_id() == Type::BINARY) { + auto& bin = static_cast(array); + auto view = bin.GetView(i); + return {reinterpret_cast(view.data()), static_cast(view.size())}; + } + return {}; +} + +} // namespace + +// =========================================================================== +// VariantShreddingSchema tests +// =========================================================================== + +class VariantShreddingSchemaTest : public ::testing::Test {}; + +TEST_F(VariantShreddingSchemaTest, PrimitiveSchema) { + auto schema = VariantShreddingSchema::Primitive(int64()); + ASSERT_EQ(schema.kind(), VariantShreddingSchema::Kind::kPrimitive); + ASSERT_EQ(schema.type()->id(), Type::INT64); +} + +TEST_F(VariantShreddingSchemaTest, PrimitiveToArrowType) { + auto schema = VariantShreddingSchema::Primitive(utf8()); + auto arrow_type = schema.ToArrowType(); + ASSERT_EQ(arrow_type->id(), Type::STRING); +} + +TEST_F(VariantShreddingSchemaTest, ObjectSchema) { + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + {"age", VariantShreddingSchema::Primitive(int64())}, + }); + ASSERT_EQ(schema.kind(), VariantShreddingSchema::Kind::kObject); + ASSERT_EQ(schema.fields().size(), 2); + ASSERT_EQ(schema.fields()[0].first, "name"); + ASSERT_EQ(schema.fields()[1].first, "age"); +} + +TEST_F(VariantShreddingSchemaTest, ObjectToArrowType) { + auto schema = VariantShreddingSchema::Object({ + {"event_type", VariantShreddingSchema::Primitive(utf8())}, + {"event_ts", VariantShreddingSchema::Primitive(timestamp(TimeUnit::MICRO, "UTC"))}, + }); + auto arrow_type = schema.ToArrowType(); + ASSERT_EQ(arrow_type->id(), Type::STRUCT); + auto struct_type = std::static_pointer_cast(arrow_type); + ASSERT_EQ(struct_type->num_fields(), 2); + + // Each field should be a struct with {value, typed_value} + auto event_type_field = struct_type->field(0); + ASSERT_EQ(event_type_field->name(), "event_type"); + ASSERT_EQ(event_type_field->type()->id(), Type::STRUCT); + auto inner = std::static_pointer_cast(event_type_field->type()); + ASSERT_EQ(inner->num_fields(), 2); + ASSERT_EQ(inner->field(0)->name(), "value"); + ASSERT_EQ(inner->field(0)->type()->id(), Type::BINARY); + ASSERT_TRUE(inner->field(0)->nullable()); + ASSERT_EQ(inner->field(1)->name(), "typed_value"); + ASSERT_EQ(inner->field(1)->type()->id(), Type::STRING); + ASSERT_TRUE(inner->field(1)->nullable()); +} + +TEST_F(VariantShreddingSchemaTest, ArraySchema) { + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int32())); + ASSERT_EQ(schema.kind(), VariantShreddingSchema::Kind::kArray); + ASSERT_EQ(schema.element_schema().kind(), VariantShreddingSchema::Kind::kPrimitive); + ASSERT_EQ(schema.element_schema().type()->id(), Type::INT32); +} + +TEST_F(VariantShreddingSchemaTest, ArrayToArrowType) { + auto schema = + VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(float64())); + auto arrow_type = schema.ToArrowType(); + ASSERT_EQ(arrow_type->id(), Type::LIST); + auto list_type = std::static_pointer_cast(arrow_type); + auto elem_type = list_type->value_type(); + ASSERT_EQ(elem_type->id(), Type::STRUCT); +} + +TEST_F(VariantShreddingSchemaTest, NestedObjectSchema) { + // Schema for shredding: {"location": {"lat": double, "lon": double}} + auto schema = VariantShreddingSchema::Object({ + {"location", VariantShreddingSchema::Object({ + {"lat", VariantShreddingSchema::Primitive(float64())}, + {"lon", VariantShreddingSchema::Primitive(float64())}, + })}, + }); + ASSERT_EQ(schema.kind(), VariantShreddingSchema::Kind::kObject); + ASSERT_EQ(schema.fields().size(), 1); + ASSERT_EQ(schema.fields()[0].first, "location"); + ASSERT_EQ(schema.fields()[0].second.kind(), VariantShreddingSchema::Kind::kObject); +} + +// =========================================================================== +// Type compatibility tests +// =========================================================================== + +class VariantTypeCompatibilityTest : public ::testing::Test { + protected: + // Helper to build a variant value and check compatibility + bool CheckCompatibility(std::function build_fn, + const DataType& target) { + VariantBuilder b; + build_fn(b).ok(); + auto result = b.Finish().ValueOrDie(); + return IsVariantCompatibleWithType(result.value.data(), + static_cast(result.value.size()), target); + } +}; + +TEST_F(VariantTypeCompatibilityTest, Int8CompatibleWithInt8) { + ASSERT_TRUE(CheckCompatibility([](VariantBuilder& b) { return b.Int8(42); }, *int8())); +} + +TEST_F(VariantTypeCompatibilityTest, Int8CompatibleWithInt64) { + // Int8 can be widened to Int64 + ASSERT_TRUE(CheckCompatibility([](VariantBuilder& b) { return b.Int8(42); }, *int64())); +} + +TEST_F(VariantTypeCompatibilityTest, Int64NotCompatibleWithInt32) { + // Int64 cannot be narrowed to Int32 + ASSERT_FALSE(CheckCompatibility([](VariantBuilder& b) { return b.Int64(5000000000LL); }, + *int32())); +} + +TEST_F(VariantTypeCompatibilityTest, StringCompatibleWithUtf8) { + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.String("hello"); }, *utf8())); +} + +TEST_F(VariantTypeCompatibilityTest, ShortStringCompatibleWithUtf8) { + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.String("hi"); }, *utf8())); +} + +TEST_F(VariantTypeCompatibilityTest, BoolCompatibleWithBool) { + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.Bool(true); }, *boolean())); +} + +TEST_F(VariantTypeCompatibilityTest, BoolNotCompatibleWithInt32) { + ASSERT_FALSE( + CheckCompatibility([](VariantBuilder& b) { return b.Bool(true); }, *int32())); +} + +TEST_F(VariantTypeCompatibilityTest, DoubleCompatibleWithFloat64) { + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.Double(3.14); }, *float64())); +} + +TEST_F(VariantTypeCompatibilityTest, FloatCompatibleWithFloat64ViaWidening) { + // Float can be widened to Double — value precision preserved, type tag changes + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.Float(3.14f); }, *float64())); +} + +TEST_F(VariantTypeCompatibilityTest, FloatCompatibleWithFloat32) { + // Float is directly compatible with its own type + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.Float(3.14f); }, *float32())); +} + +TEST_F(VariantTypeCompatibilityTest, DoubleNotCompatibleWithFloat32) { + // Double cannot be narrowed to Float + ASSERT_FALSE( + CheckCompatibility([](VariantBuilder& b) { return b.Double(3.14); }, *float32())); +} + +TEST_F(VariantTypeCompatibilityTest, DateCompatibleWithDate32) { + ASSERT_TRUE( + CheckCompatibility([](VariantBuilder& b) { return b.Date(19000); }, *date32())); +} + +TEST_F(VariantTypeCompatibilityTest, TimestampMicrosCompatibleWithTimestamp) { + ASSERT_TRUE(CheckCompatibility( + [](VariantBuilder& b) { return b.TimestampMicros(1654041600000000LL); }, + *timestamp(TimeUnit::MICRO, "UTC"))); +} + +TEST_F(VariantTypeCompatibilityTest, TimestampMicrosNotCompatibleWithNanos) { + // TimestampMicros should NOT be compatible with NANO resolution target + ASSERT_FALSE(CheckCompatibility( + [](VariantBuilder& b) { return b.TimestampMicros(1654041600000000LL); }, + *timestamp(TimeUnit::NANO, "UTC"))); +} + +TEST_F(VariantTypeCompatibilityTest, TimestampMicrosNotCompatibleWithNTZ) { + // TimestampMicros (with timezone) should NOT be compatible with NTZ target + ASSERT_FALSE(CheckCompatibility( + [](VariantBuilder& b) { return b.TimestampMicros(1654041600000000LL); }, + *timestamp(TimeUnit::MICRO))); +} + +TEST_F(VariantTypeCompatibilityTest, TimestampNanosNTZCompatibleWithNanosNTZ) { + ASSERT_TRUE(CheckCompatibility( + [](VariantBuilder& b) { return b.TimestampNanosNTZ(1654041600000000000LL); }, + *timestamp(TimeUnit::NANO))); +} + +TEST_F(VariantTypeCompatibilityTest, DecimalScaleMismatchNotCompatible) { + // Decimal4 with scale=3 should NOT be compatible with Decimal128(10,2) + ASSERT_FALSE(CheckCompatibility( + [](VariantBuilder& b) { + uint8_t bytes[4] = {0x39, 0x30, 0x00, 0x00}; + return b.Decimal4(3, bytes); + }, + *decimal128(10, 2))); +} + +TEST_F(VariantTypeCompatibilityTest, DecimalScaleMatchCompatible) { + // Decimal4 with scale=2 should be compatible with Decimal128(10,2) + ASSERT_TRUE(CheckCompatibility( + [](VariantBuilder& b) { + uint8_t bytes[4] = {0x39, 0x30, 0x00, 0x00}; + return b.Decimal4(2, bytes); + }, + *decimal128(10, 2))); +} + +TEST_F(VariantTypeCompatibilityTest, NullNotCompatibleWithTypedColumns) { + // Per Rust/spec: Variant::Null is NOT compatible with any typed column. + // It remains in the value column to distinguish "variant null" from "missing". + ASSERT_FALSE(CheckCompatibility([](VariantBuilder& b) { return b.Null(); }, *int64())); + ASSERT_FALSE(CheckCompatibility([](VariantBuilder& b) { return b.Null(); }, *utf8())); + ASSERT_FALSE( + CheckCompatibility([](VariantBuilder& b) { return b.Null(); }, *boolean())); +} + +TEST_F(VariantTypeCompatibilityTest, StringNotCompatibleWithInt64) { + ASSERT_FALSE( + CheckCompatibility([](VariantBuilder& b) { return b.String("hello"); }, *int64())); +} + +TEST_F(VariantTypeCompatibilityTest, UUIDCompatibleWithFixedSizeBinary16) { + uint8_t uuid[16] = {}; + ASSERT_TRUE(CheckCompatibility([&uuid](VariantBuilder& b) { return b.UUID(uuid); }, + *fixed_size_binary(16))); +} + +TEST_F(VariantTypeCompatibilityTest, UUIDNotCompatibleWithFixedSizeBinary32) { + uint8_t uuid[16] = {}; + ASSERT_FALSE(CheckCompatibility([&uuid](VariantBuilder& b) { return b.UUID(uuid); }, + *fixed_size_binary(32))); +} + +TEST_F(VariantTypeCompatibilityTest, Time64MicroCompatibleWithTime64Micro) { + ASSERT_TRUE(CheckCompatibility([](VariantBuilder& b) { return b.TimeNTZ(1234567); }, + *time64(TimeUnit::MICRO))); +} + +TEST_F(VariantTypeCompatibilityTest, Time64NanoNotCompatibleWithTime64Micro) { + // The variant spec's kTimeNTZ stores microseconds. A time64(NANO) target + // would cause misinterpretation of the values in the typed_value column. + ASSERT_FALSE(CheckCompatibility([](VariantBuilder& b) { return b.TimeNTZ(1234567); }, + *time64(TimeUnit::NANO))); +} + +// =========================================================================== +// ShredVariantColumn / ReconstructVariantColumn round-trip tests +// =========================================================================== + +class VariantShredRoundTripTest : public ::testing::Test { + protected: + // Helper: build a binary array of encoded variant values. + // Note: Uses .ok()/.ValueOrDie() because ASSERT_OK_AND_ASSIGN cannot be used + // in non-void functions. Test-only; will crash with a descriptive message + // on failure rather than producing a clean test failure. + std::shared_ptr BuildVariantColumn( + const std::vector>& builders) { + // Use a single shared builder to produce consistent metadata + BinaryBuilder array_builder; + for (const auto& build_fn : builders) { + VariantBuilder vb; + build_fn(vb).ok(); + auto encoded = vb.Finish().ValueOrDie(); + array_builder + .Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr result; + array_builder.Finish(&result).ok(); + return result; + } + + // Helper: build a metadata array (all same metadata) + std::shared_ptr BuildMetadataColumn(int64_t num_rows) { + // Empty metadata (no dictionary keys needed for primitives) + VariantBuilder vb; + vb.Null().ok(); + auto encoded = vb.Finish().ValueOrDie(); + + BinaryBuilder builder; + for (int64_t i = 0; i < num_rows; ++i) { + builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + } + std::shared_ptr result; + builder.Finish(&result).ok(); + return result; + } +}; + +TEST_F(VariantShredRoundTripTest, PrimitiveInt64AllMatch) { + // All values are integers — should all go to typed_value + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.Int(42); }, + [](VariantBuilder& b) { return b.Int(100); }, + [](VariantBuilder& b) { return b.Int(-7); }, + }); + auto metadata = BuildMetadataColumn(3); + auto schema = VariantShreddingSchema::Primitive(int64()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + // All values should be in typed_value (value column all null) + auto value_col = shredded->field(1); // "value" + ASSERT_EQ(value_col->null_count(), 3); // all null + + auto typed_col = shredded->field(2); // "typed_value" + ASSERT_EQ(typed_col->type_id(), Type::INT64); + ASSERT_EQ(typed_col->null_count(), 0); // all present + + // Verify native values + auto& int_arr = static_cast(*typed_col); + ASSERT_EQ(int_arr.Value(0), 42); + ASSERT_EQ(int_arr.Value(1), 100); + ASSERT_EQ(int_arr.Value(2), -7); + + // Reconstruct and verify round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 3); + // Verify the reconstructed variant bytes are valid + for (int64_t i = 0; i < 3; ++i) { + ASSERT_TRUE(reconstructed->IsValid(i)); + } +} + +TEST_F(VariantShredRoundTripTest, PrimitiveMixedTypes) { + // Mix of matching and non-matching types + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.Int(42); }, // matches int64 + [](VariantBuilder& b) { return b.String("hi"); }, // doesn't match + [](VariantBuilder& b) { return b.Int(99); }, // matches + [](VariantBuilder& b) { return b.Bool(true); }, // doesn't match + }); + auto metadata = BuildMetadataColumn(4); + auto schema = VariantShreddingSchema::Primitive(int64()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Rows 0,2 → typed_value; Rows 1,3 → value + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(value_col->IsValid(1)); + ASSERT_TRUE(value_col->IsNull(2)); + ASSERT_TRUE(value_col->IsValid(3)); + + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(typed_col->IsNull(1)); + ASSERT_TRUE(typed_col->IsValid(2)); + ASSERT_TRUE(typed_col->IsNull(3)); + + // Verify native int values + auto& int_arr = static_cast(*typed_col); + ASSERT_EQ(int_arr.Value(0), 42); + ASSERT_EQ(int_arr.Value(2), 99); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 4); +} + +TEST_F(VariantShredRoundTripTest, PrimitiveAllMismatch) { + // No values match the target type — all stay in value column + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.String("a"); }, + [](VariantBuilder& b) { return b.String("b"); }, + }); + auto metadata = BuildMetadataColumn(2); + auto schema = VariantShreddingSchema::Primitive(int64()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // All in value, none in typed_value + ASSERT_EQ(value_col->null_count(), 0); + ASSERT_EQ(typed_col->null_count(), 2); + + // Round-trip: since typed_value is all null, everything comes from value + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + for (int64_t i = 0; i < 2; ++i) { + ASSERT_EQ(GetBinaryView(*values, i), GetBinaryView(*reconstructed, i)); + } +} + +TEST_F(VariantShredRoundTripTest, NullVariantIsRouted) { + // Per Rust/spec: Variant::Null is NOT shredded into typed columns. + // It goes to the value column as raw bytes (0x00). This distinguishes + // "variant-typed null" from "SQL NULL / missing" (both columns null). + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.Null(); }, + [](VariantBuilder& b) { return b.Int(5); }, + }); + auto metadata = BuildMetadataColumn(2); + auto schema = VariantShreddingSchema::Primitive(int64()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + // Variant Null → value column has content (0x00 byte), typed_value is null + ASSERT_TRUE(value_col->IsValid(0)); + ASSERT_TRUE(typed_col->IsNull(0)); + // Int(5) → typed_value is present, value is null + ASSERT_TRUE(value_col->IsNull(1)); + ASSERT_TRUE(typed_col->IsValid(1)); +} + +TEST_F(VariantShredRoundTripTest, ZeroRowInput) { + // Empty arrays (zero rows) should be handled gracefully without errors. + auto values = BuildVariantColumn({}); + auto metadata = BuildMetadataColumn(0); + + // Primitive schema with zero rows + auto prim_schema = VariantShreddingSchema::Primitive(int64()); + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, prim_schema)); + ASSERT_EQ(shredded->length(), 0); + ASSERT_EQ(shredded->num_fields(), 3); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + ASSERT_EQ(value_col->length(), 0); + ASSERT_EQ(typed_col->length(), 0); + + // Round-trip on empty arrays + ASSERT_OK_AND_ASSIGN( + auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, prim_schema)); + ASSERT_EQ(reconstructed->length(), 0); + + // Object schema with zero rows + auto obj_schema = VariantShreddingSchema::Object({ + {"a", VariantShreddingSchema::Primitive(int64())}, + }); + ASSERT_OK_AND_ASSIGN(auto obj_shredded, + ShredVariantColumn(metadata, values, obj_schema)); + ASSERT_EQ(obj_shredded->length(), 0); + + // Array schema with zero rows + auto arr_schema = + VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + ASSERT_OK_AND_ASSIGN(auto arr_shredded, + ShredVariantColumn(metadata, values, arr_schema)); + ASSERT_EQ(arr_shredded->length(), 0); +} + +// =========================================================================== +// Object shredding round-trip tests +// =========================================================================== + +class VariantShredObjectTest : public ::testing::Test { + protected: + // Helper: build a variant column with object values + struct ObjectRow { + std::vector>> fields; + }; + + std::shared_ptr BuildObjectColumn(const std::vector& rows) { + BinaryBuilder array_builder; + for (const auto& row : rows) { + VariantBuilder vb; + auto start = vb.Offset(); + std::vector fields; + for (const auto& [key, value_fn] : row.fields) { + fields.push_back(vb.NextField(start, key)); + value_fn(vb).ok(); + } + vb.FinishObject(start, fields).ok(); + auto encoded = vb.Finish().ValueOrDie(); + array_builder + .Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr result; + array_builder.Finish(&result).ok(); + return result; + } + + std::shared_ptr BuildMetadataForObjects(const std::vector& rows) { + BinaryBuilder builder; + for (const auto& row : rows) { + VariantBuilder vb; + auto start = vb.Offset(); + std::vector fields; + for (const auto& [key, value_fn] : row.fields) { + fields.push_back(vb.NextField(start, key)); + value_fn(vb).ok(); + } + vb.FinishObject(start, fields).ok(); + auto encoded = vb.Finish().ValueOrDie(); + builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + } + std::shared_ptr result; + builder.Finish(&result).ok(); + return result; + } +}; + +TEST_F(VariantShredObjectTest, FullyShredded) { + // Object {"name": "Alice", "age": 30} shredded with schema {name, age} + std::vector rows = { + {{{"name", [](VariantBuilder& b) { return b.String("Alice"); }}, + {"age", [](VariantBuilder& b) { return b.Int(30); }}}}, + {{{"name", [](VariantBuilder& b) { return b.String("Bob"); }}, + {"age", [](VariantBuilder& b) { return b.Int(25); }}}}, + }; + + auto values = BuildObjectColumn(rows); + auto metadata = BuildMetadataForObjects(rows); + + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + {"age", VariantShreddingSchema::Primitive(int64())}, + }); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + // When fully shredded, residual value should be all null + auto value_col = shredded->field(1); + ASSERT_EQ(value_col->null_count(), 2); + + // typed_value is a struct with "name" and "age" fields + auto typed_col = shredded->field(2); + ASSERT_EQ(typed_col->type_id(), Type::STRUCT); + + // Verify native extraction: each field's struct should have typed_value populated + auto& typed_struct = static_cast(*typed_col); + + // "name" field: sub-struct {value: binary, typed_value: string} + auto name_field_struct = static_cast(typed_struct.field(0).get()); + auto name_value_col = name_field_struct->field(0); // fallback value + auto name_typed_col = name_field_struct->field(1); // native typed_value + + // Both names should be in typed_value (string compatible with utf8 target) + ASSERT_EQ(name_value_col->null_count(), 2); // no fallback needed + ASSERT_EQ(name_typed_col->null_count(), 0); // both present + auto& name_arr = static_cast(*name_typed_col); + ASSERT_EQ(name_arr.GetView(0), "Alice"); + ASSERT_EQ(name_arr.GetView(1), "Bob"); + + // "age" field: sub-struct {value: binary, typed_value: int64} + auto age_field_struct = static_cast(typed_struct.field(1).get()); + auto age_value_col = age_field_struct->field(0); + auto age_typed_col = age_field_struct->field(1); + + ASSERT_EQ(age_value_col->null_count(), 2); // no fallback needed + ASSERT_EQ(age_typed_col->null_count(), 0); // both present + auto& age_arr = static_cast(*age_typed_col); + ASSERT_EQ(age_arr.Value(0), 30); + ASSERT_EQ(age_arr.Value(1), 25); + + // Reconstruct and verify round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + // Verify the reconstructed values decode correctly + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); +} + +TEST_F(VariantShredObjectTest, MissingFieldNativeExtraction) { + // Object with missing field — verifies native extraction handles absent fields. + // Row 0: {name: "Alice", age: 30} — both fields present + // Row 1: {name: "Bob"} — "age" field missing + // Row 2: {age: 42} — "name" field missing + std::vector rows = { + {{{"name", [](VariantBuilder& b) { return b.String("Alice"); }}, + {"age", [](VariantBuilder& b) { return b.Int(30); }}}}, + {{{"name", [](VariantBuilder& b) { return b.String("Bob"); }}}}, + {{{"age", [](VariantBuilder& b) { return b.Int(42); }}}}, + }; + + auto values = BuildObjectColumn(rows); + auto metadata = BuildMetadataForObjects(rows); + + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + {"age", VariantShreddingSchema::Primitive(int64())}, + }); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + auto& typed_struct = static_cast(*typed_col); + + // "name" field: rows 0,1 present, row 2 absent + auto name_struct = static_cast(typed_struct.field(0).get()); + auto name_typed = name_struct->field(1); + ASSERT_TRUE(name_typed->IsValid(0)); // "Alice" + ASSERT_TRUE(name_typed->IsValid(1)); // "Bob" + ASSERT_TRUE(name_typed->IsNull(2)); // missing + auto& name_arr = static_cast(*name_typed); + ASSERT_EQ(name_arr.GetView(0), "Alice"); + ASSERT_EQ(name_arr.GetView(1), "Bob"); + + // "age" field: rows 0,2 present, row 1 absent + auto age_struct = static_cast(typed_struct.field(1).get()); + auto age_typed = age_struct->field(1); + ASSERT_TRUE(age_typed->IsValid(0)); // 30 + ASSERT_TRUE(age_typed->IsNull(1)); // missing + ASSERT_TRUE(age_typed->IsValid(2)); // 42 + auto& age_arr = static_cast(*age_typed); + ASSERT_EQ(age_arr.Value(0), 30); + ASSERT_EQ(age_arr.Value(2), 42); + + // Reconstruct and verify round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 3); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); + ASSERT_TRUE(reconstructed->IsValid(2)); +} + +TEST_F(VariantShredObjectTest, PartiallyShredded) { + // Object {"name": "Alice", "age": 30, "score": 95.5} + // Schema only shreds "name" — "age" and "score" go to residual + std::vector rows = { + {{{"name", [](VariantBuilder& b) { return b.String("Alice"); }}, + {"age", [](VariantBuilder& b) { return b.Int(30); }}, + {"score", [](VariantBuilder& b) { return b.Double(95.5); }}}}, + }; + + auto values = BuildObjectColumn(rows); + auto metadata = BuildMetadataForObjects(rows); + + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + }); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + // Residual should have content (the "age" and "score" fields) + auto value_col = shredded->field(1); + ASSERT_TRUE(value_col->IsValid(0)); // has residual + + // Reconstruct + auto typed_col = shredded->field(2); + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); +} + +TEST_F(VariantShredObjectTest, NonObjectFallback) { + // A string value (not an object) with an object schema → goes to residual + BinaryBuilder array_builder; + { + VariantBuilder vb; + vb.String("not an object").ok(); + auto encoded = vb.Finish().ValueOrDie(); + array_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + array_builder.Finish(&values).ok(); + + BinaryBuilder meta_builder; + { + VariantBuilder vb; + vb.String("x").ok(); + auto encoded = vb.Finish().ValueOrDie(); + meta_builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + } + std::shared_ptr metadata; + meta_builder.Finish(&metadata).ok(); + + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + }); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + // Non-object → residual has the value, typed fields are null + auto value_col = shredded->field(1); + ASSERT_TRUE(value_col->IsValid(0)); +} + +// =========================================================================== +// Array shredding round-trip tests +// =========================================================================== + +class VariantShredArrayTest : public ::testing::Test { + protected: + std::shared_ptr BuildMetadata(int64_t num_rows) { + VariantBuilder vb; + vb.Null().ok(); + auto encoded = vb.Finish().ValueOrDie(); + BinaryBuilder builder; + for (int64_t i = 0; i < num_rows; ++i) { + builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + } + std::shared_ptr result; + builder.Finish(&result).ok(); + return result; + } +}; + +TEST_F(VariantShredArrayTest, SimpleArrayShred) { + // Build a variant array value: [1, 2, 3] + BinaryBuilder value_builder; + { + VariantBuilder vb; + auto start = vb.Offset(); + std::vector offsets; + offsets.push_back(vb.NextElement(start)); + vb.Int(1).ok(); + offsets.push_back(vb.NextElement(start)); + vb.Int(2).ok(); + offsets.push_back(vb.NextElement(start)); + vb.Int(3).ok(); + vb.FinishArray(start, offsets).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadata(1); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + // Array goes to typed_value (a list), value is null + auto value_col = shredded->field(1); + ASSERT_TRUE(value_col->IsNull(0)); + + auto typed_col = shredded->field(2); + ASSERT_EQ(typed_col->type_id(), Type::LIST); + auto& list = static_cast(*typed_col); + ASSERT_EQ(list.value_length(0), 3); // 3 elements + + // With recursive element shredding, elements are struct{value, typed_value} + // where compatible int values go to typed_value and value is null. + auto elem_type = list.value_type(); + ASSERT_EQ(elem_type->id(), Type::STRUCT); + auto* elem_struct_type = static_cast(elem_type.get()); + ASSERT_EQ(elem_struct_type->num_fields(), 2); + ASSERT_EQ(elem_struct_type->field(0)->name(), "value"); + ASSERT_EQ(elem_struct_type->field(1)->name(), "typed_value"); + + // Verify native int64 values in the element typed_value column + auto* elem_struct = static_cast(list.values().get()); + auto elem_value_col = elem_struct->field(0); // per-element residual + auto elem_typed_col = elem_struct->field(1); // per-element typed_value + // All 3 ints should be in typed_value + ASSERT_EQ(elem_value_col->null_count(), 3); + ASSERT_EQ(elem_typed_col->null_count(), 0); + auto& elem_int_arr = static_cast(*elem_typed_col); + ASSERT_EQ(elem_int_arr.Value(0), 1); + ASSERT_EQ(elem_int_arr.Value(1), 2); + ASSERT_EQ(elem_int_arr.Value(2), 3); + + // Round-trip reconstruction + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); +} + +TEST_F(VariantShredArrayTest, ArrayShredMixedElements) { + // Build a variant array value: [1, "hello", 3] — mixed types with int64 schema. + // Int elements should go to typed_value, string element to per-element value. + BinaryBuilder value_builder; + { + VariantBuilder vb; + auto start = vb.Offset(); + std::vector offsets; + offsets.push_back(vb.NextElement(start)); + vb.Int(1).ok(); + offsets.push_back(vb.NextElement(start)); + vb.String("hello").ok(); + offsets.push_back(vb.NextElement(start)); + vb.Int(3).ok(); + vb.FinishArray(start, offsets).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadata(1); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // The array is shredded — residual value is null, typed_value has the list + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_EQ(typed_col->type_id(), Type::LIST); + + auto& list = static_cast(*typed_col); + auto* elem_struct = static_cast(list.values().get()); + auto elem_value_col = elem_struct->field(0); + auto elem_typed_col = elem_struct->field(1); + + // Element 0 (Int(1)) → typed_value=1, value=null + ASSERT_TRUE(elem_value_col->IsNull(0)); + ASSERT_TRUE(elem_typed_col->IsValid(0)); + auto& elem_ints = static_cast(*elem_typed_col); + ASSERT_EQ(elem_ints.Value(0), 1); + + // Element 1 (String("hello")) → typed_value=null, value=variant bytes + ASSERT_TRUE(elem_value_col->IsValid(1)); + ASSERT_TRUE(elem_typed_col->IsNull(1)); + + // Element 2 (Int(3)) → typed_value=3, value=null + ASSERT_TRUE(elem_value_col->IsNull(2)); + ASSERT_TRUE(elem_typed_col->IsValid(2)); + ASSERT_EQ(elem_ints.Value(2), 3); + + // Round-trip reconstruction + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify the reconstructed value decodes to an array with the expected elements + auto recon_bytes = GetBinaryView(*reconstructed, 0); + auto* data = reinterpret_cast(recon_bytes.data()); + auto len = static_cast(recon_bytes.size()); + ASSERT_GE(len, 1); + ASSERT_EQ(GetBasicType(data[0]), BasicType::kArray); + ASSERT_OK_AND_ASSIGN(auto elem_count, GetArrayElementCount(data, len)); + ASSERT_EQ(elem_count, 3); +} + +TEST_F(VariantShredArrayTest, NonArrayFallback) { + // A string value with an array schema → goes to residual + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.String("not an array").ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadata(1); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + // Non-array → residual has the value, typed_value is null + auto value_col = shredded->field(1); + ASSERT_TRUE(value_col->IsValid(0)); + + auto typed_col = shredded->field(2); + ASSERT_TRUE(typed_col->IsNull(0)); +} + +// =========================================================================== +// Additional round-trip tests for specific types (Decimal128, UUID, Timestamps) +// =========================================================================== + +class VariantShredTypedRoundTripTest : public ::testing::Test { + protected: + std::shared_ptr BuildMetadataColumn(int64_t num_rows) { + VariantBuilder vb; + vb.Null().ok(); + auto encoded = vb.Finish().ValueOrDie(); + BinaryBuilder builder; + for (int64_t i = 0; i < num_rows; ++i) { + builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + } + std::shared_ptr result; + builder.Finish(&result).ok(); + return result; + } +}; + +TEST_F(VariantShredTypedRoundTripTest, Decimal128RoundTrip) { + // Build variant column with decimal values (scale=2) + BinaryBuilder value_builder; + { + VariantBuilder vb; + // Encode 123.45 as decimal4 with scale 2 → unscaled 12345 + uint8_t scale = 2; + int32_t unscaled = 12345; + uint8_t bytes[4]; + std::memcpy(bytes, &unscaled, 4); + vb.Decimal4(scale, bytes).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + VariantBuilder vb; + // Encode -999.99 as decimal4 with scale 2 → unscaled -99999 + uint8_t scale = 2; + int32_t unscaled = -99999; + uint8_t bytes[4]; + std::memcpy(bytes, &unscaled, 4); + vb.Decimal4(scale, bytes).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(2); + + // Schema: Decimal128(10, 2) + auto schema = VariantShreddingSchema::Primitive(decimal128(10, 2)); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Both values should match (scale=2 matches schema) + ASSERT_EQ(value_col->null_count(), 2); // all in typed + ASSERT_EQ(typed_col->null_count(), 0); // all present + + // Verify native decimal values + auto& dec_arr = static_cast(*typed_col); + Decimal128 val0(dec_arr.GetValue(0)); + Decimal128 val1(dec_arr.GetValue(1)); + ASSERT_EQ(val0, Decimal128(12345)); + ASSERT_EQ(val1, Decimal128(-99999)); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); +} + +TEST_F(VariantShredTypedRoundTripTest, Decimal128ScaleMismatch) { + // Build variant column with decimal scale=3, but schema wants scale=2 + // Should NOT be shredded (scale mismatch → goes to residual) + BinaryBuilder value_builder; + { + VariantBuilder vb; + uint8_t scale = 3; + int32_t unscaled = 12345; + uint8_t bytes[4]; + std::memcpy(bytes, &unscaled, 4); + vb.Decimal4(scale, bytes).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + auto schema = VariantShreddingSchema::Primitive(decimal128(10, 2)); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Scale mismatch → value has content, typed is null + ASSERT_TRUE(value_col->IsValid(0)); + ASSERT_TRUE(typed_col->IsNull(0)); +} + +TEST_F(VariantShredTypedRoundTripTest, UUIDRoundTrip) { + // Build variant column with UUID values + BinaryBuilder value_builder; + uint8_t uuid1[16] = {0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, + 0xFE, 0xDC, 0xBA, 0x98, 0x76, 0x54, 0x32, 0x10}; + uint8_t uuid2[16] = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}; + { + VariantBuilder vb; + vb.UUID(uuid1).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + VariantBuilder vb; + vb.UUID(uuid2).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(2); + + auto schema = VariantShreddingSchema::Primitive(fixed_size_binary(16)); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Both UUIDs should be in typed_value + ASSERT_EQ(value_col->null_count(), 2); + ASSERT_EQ(typed_col->null_count(), 0); + + // Verify bytes + auto& fsb_arr = static_cast(*typed_col); + ASSERT_EQ(std::memcmp(fsb_arr.GetValue(0), uuid1, 16), 0); + ASSERT_EQ(std::memcmp(fsb_arr.GetValue(1), uuid2, 16), 0); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); +} + +TEST_F(VariantShredTypedRoundTripTest, TimestampMicrosRoundTrip) { + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.TimestampMicros(1654041600000000LL).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + auto schema = VariantShreddingSchema::Primitive(timestamp(TimeUnit::MICRO, "UTC")); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + ASSERT_EQ(value_col->null_count(), 1); + ASSERT_EQ(typed_col->null_count(), 0); + + // Verify int64 value stored + auto& int_arr = static_cast(*typed_col); + ASSERT_EQ(int_arr.Value(0), 1654041600000000LL); + + // Round-trip: should reconstruct as TimestampMicros (not NTZ or Nanos) + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify the variant bytes encode as TimestampMicros (header byte check) + auto recon_bytes = GetBinaryView(*reconstructed, 0); + ASSERT_GE(recon_bytes.size(), 1); + uint8_t header = static_cast(recon_bytes[0]); + // PrimitiveType::kTimestampMicros = 12, encoded as (12 << 2) | 0 = 0x30 + ASSERT_EQ(header, 0x30); +} + +TEST_F(VariantShredTypedRoundTripTest, TimestampNanosRoundTrip) { + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.TimestampNanos(1654041600000000000LL).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + // Schema says NANO resolution with timezone + auto schema = VariantShreddingSchema::Primitive(timestamp(TimeUnit::NANO, "UTC")); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + ASSERT_EQ(value_col->null_count(), 1); + ASSERT_EQ(typed_col->null_count(), 0); + + // Round-trip: should reconstruct as TimestampNanos (not Micros) + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify the variant bytes encode as TimestampNanos + auto recon_bytes = GetBinaryView(*reconstructed, 0); + ASSERT_GE(recon_bytes.size(), 1); + uint8_t header = static_cast(recon_bytes[0]); + // PrimitiveType::kTimestampNanos = 18, encoded as (18 << 2) | 0 = 0x48 + ASSERT_EQ(header, 0x48); +} + +TEST_F(VariantShredTypedRoundTripTest, TimestampMicrosNTZRoundTrip) { + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.TimestampMicrosNTZ(1654041600000000LL).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + // No timezone → NTZ + auto schema = VariantShreddingSchema::Primitive(timestamp(TimeUnit::MICRO)); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + ASSERT_EQ(value_col->null_count(), 1); + ASSERT_EQ(typed_col->null_count(), 0); + + // Round-trip: should reconstruct as TimestampMicrosNTZ + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify: PrimitiveType::kTimestampMicrosNTZ = 13, encoded as (13 << 2) | 0 = 0x34 + auto recon_bytes = GetBinaryView(*reconstructed, 0); + ASSERT_GE(recon_bytes.size(), 1); + uint8_t header = static_cast(recon_bytes[0]); + ASSERT_EQ(header, 0x34); +} + +TEST_F(VariantShredTypedRoundTripTest, TimestampNanosNTZRoundTrip) { + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.TimestampNanosNTZ(1654041600000000000LL).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + // Nano resolution, no timezone → NanosNTZ + auto schema = VariantShreddingSchema::Primitive(timestamp(TimeUnit::NANO)); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + ASSERT_EQ(value_col->null_count(), 1); + ASSERT_EQ(typed_col->null_count(), 0); + + // Round-trip: should reconstruct as TimestampNanosNTZ + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify: PrimitiveType::kTimestampNanosNTZ = 19, encoded as (19 << 2) | 0 = 0x4C + auto recon_bytes = GetBinaryView(*reconstructed, 0); + ASSERT_GE(recon_bytes.size(), 1); + uint8_t header = static_cast(recon_bytes[0]); + ASSERT_EQ(header, 0x4C); +} + +TEST_F(VariantShredTypedRoundTripTest, FloatWidenedToDoubleRoundTrip) { + // Float variant shredded into a Double column — exercises Float→Double widening. + // The value precision is preserved (float→double is lossless for the numeric value), + // but the variant type tag changes: Float→Double on reconstruction. + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Float(3.14f).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + // Target is Double — Float should be compatible (widening) + auto schema = VariantShreddingSchema::Primitive(float64()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Float should be shredded into the Double typed_value column + ASSERT_EQ(value_col->null_count(), 1); // value is null (shredded) + ASSERT_EQ(typed_col->null_count(), 0); // typed_value is present + + // Verify the stored double value matches the original float value + auto& dbl_arr = static_cast(*typed_col); + ASSERT_DOUBLE_EQ(dbl_arr.Value(0), static_cast(3.14f)); + + // Round-trip: reconstructed variant will be Double, not Float + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify: PrimitiveType::kDouble = 7, encoded as (7 << 2) | 0 = 0x1C + auto recon_bytes = GetBinaryView(*reconstructed, 0); + ASSERT_GE(recon_bytes.size(), 1); + uint8_t header = static_cast(recon_bytes[0]); + ASSERT_EQ(header, 0x1C); // Double, not Float (0x38) +} + +TEST_F(VariantShredTypedRoundTripTest, Int8ShredTargetRoundTrip) { + // Int8 variant shredded into an Int8 column + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Int8(42).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + // Int16 variant should NOT match Int8 target (no narrowing) + VariantBuilder vb; + vb.Int16(300).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(2); + + auto schema = VariantShreddingSchema::Primitive(int8()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Row 0 (Int8(42)) → typed_value; Row 1 (Int16(300)) → value (no narrowing) + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(value_col->IsValid(1)); + ASSERT_TRUE(typed_col->IsNull(1)); + + // Verify native value + auto& int_arr = static_cast(*typed_col); + ASSERT_EQ(int_arr.Value(0), 42); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); +} + +TEST_F(VariantShredTypedRoundTripTest, Int16ShredTargetRoundTrip) { + // Int8 and Int16 variants shredded into an Int16 column + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Int8(42).ok(); // Int8 → compatible with Int16 + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + VariantBuilder vb; + vb.Int16(300).ok(); // Int16 → compatible with Int16 + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + VariantBuilder vb; + vb.Int32(100000).ok(); // Int32 → NOT compatible with Int16 (no narrowing) + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(3); + + auto schema = VariantShreddingSchema::Primitive(int16()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Rows 0,1 → typed; Row 2 → value + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(value_col->IsNull(1)); + ASSERT_TRUE(typed_col->IsValid(1)); + ASSERT_TRUE(value_col->IsValid(2)); + ASSERT_TRUE(typed_col->IsNull(2)); + + auto& int_arr = static_cast(*typed_col); + ASSERT_EQ(int_arr.Value(0), 42); + ASSERT_EQ(int_arr.Value(1), 300); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 3); +} + +TEST_F(VariantShredTypedRoundTripTest, LargeStringShredRoundTrip) { + // String variant shredded into a LARGE_STRING column + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.String("hello world").ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + // Non-string should NOT match LARGE_STRING target + VariantBuilder vb; + vb.Int(42).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(2); + + auto schema = VariantShreddingSchema::Primitive(large_utf8()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Row 0 (String) → typed_value; Row 1 (Int) → value + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(value_col->IsValid(1)); + ASSERT_TRUE(typed_col->IsNull(1)); + + // Verify native value + auto& str_arr = static_cast(*typed_col); + ASSERT_EQ(str_arr.GetView(0), "hello world"); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); + + // Verify reconstructed string has correct variant header (short string: length in + // header) + auto recon_bytes = GetBinaryView(*reconstructed, 0); + ASSERT_GE(recon_bytes.size(), 1); + uint8_t header = static_cast(recon_bytes[0]); + // "hello world" = 11 bytes, short string header = (11 << 2) | 1 = 0x2D + ASSERT_EQ(header, 0x2D); +} + +TEST_F(VariantShredTypedRoundTripTest, LargeBinaryShredRoundTrip) { + // Binary variant shredded into a LARGE_BINARY column + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Binary(std::string_view("\x00\x01\x02\x03", 4)).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + // Non-binary should NOT match LARGE_BINARY target + VariantBuilder vb; + vb.String("not binary").ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(2); + + auto schema = VariantShreddingSchema::Primitive(large_binary()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Row 0 (Binary) → typed_value; Row 1 (String) → value + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(value_col->IsValid(1)); + ASSERT_TRUE(typed_col->IsNull(1)); + + // Verify native value + auto& bin_arr = static_cast(*typed_col); + auto bin_view = bin_arr.GetView(0); + ASSERT_EQ(bin_view.size(), 4); + ASSERT_EQ(std::memcmp(bin_view.data(), "\x00\x01\x02\x03", 4), 0); + + // Round-trip + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); +} + +// =========================================================================== +// Error handling / invalid input tests +// =========================================================================== + +class VariantShredErrorTest : public ::testing::Test { + protected: + std::shared_ptr BuildMetadataColumn(int64_t num_rows) { + VariantBuilder vb; + vb.Null().ok(); + auto encoded = vb.Finish().ValueOrDie(); + BinaryBuilder builder; + for (int64_t i = 0; i < num_rows; ++i) { + builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + } + std::shared_ptr result; + builder.Finish(&result).ok(); + return result; + } +}; + +TEST_F(VariantShredErrorTest, ReconstructBothNonNullPrimitiveSchema) { + // Reconstruction should error if both value and typed_value are non-null + // for a primitive schema (this is an invalid shredded state). + auto metadata = BuildMetadataColumn(1); + auto schema = VariantShreddingSchema::Primitive(int64()); + + // Build a value column with a valid value + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Int(42).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // Build a typed_value column that is also non-null + Int64Builder typed_builder; + typed_builder.Append(99).ok(); + std::shared_ptr typed_col; + typed_builder.Finish(&typed_col).ok(); + + // Should return Status::Invalid (both non-null is invalid for primitives) + auto result = ReconstructVariantColumn(metadata, value_col, typed_col, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsInvalid()); +} + +TEST_F(VariantShredErrorTest, ShredUnsupportedTargetType) { + // Shredding with an unsupported target type should return NotImplemented + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Int(42).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + auto metadata = BuildMetadataColumn(1); + + // duration() is not a valid shredding target + auto schema = VariantShreddingSchema::Primitive(duration(TimeUnit::MICRO)); + auto result = ShredVariantColumn(metadata, values, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsNotImplemented()); +} + +TEST_F(VariantShredErrorTest, ShredInvalidMetadataArrayType) { + // metadata_array must be BINARY, LARGE_BINARY, or BINARY_VIEW + Int64Builder int_builder; + int_builder.Append(42).ok(); + std::shared_ptr bad_metadata; + int_builder.Finish(&bad_metadata).ok(); + + BinaryBuilder value_builder; + { + VariantBuilder vb; + vb.Int(1).ok(); + auto encoded = vb.Finish().ValueOrDie(); + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr values; + value_builder.Finish(&values).ok(); + + auto schema = VariantShreddingSchema::Primitive(int64()); + auto result = ShredVariantColumn(bad_metadata, values, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsInvalid()); +} + +TEST_F(VariantShredErrorTest, ReconstructInvalidValueArrayType) { + // value_array must be BINARY, LARGE_BINARY, or BINARY_VIEW for reconstruction + auto metadata = BuildMetadataColumn(1); + + Int64Builder int_builder; + int_builder.Append(99).ok(); + std::shared_ptr bad_value; + int_builder.Finish(&bad_value).ok(); + + Int64Builder typed_builder; + typed_builder.AppendNull().ok(); + std::shared_ptr typed_col; + typed_builder.Finish(&typed_col).ok(); + + auto schema = VariantShreddingSchema::Primitive(int64()); + auto result = ReconstructVariantColumn(metadata, bad_value, typed_col, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsInvalid()); +} + +TEST_F(VariantShredErrorTest, ReconstructArrayTypedValueNotList) { + // For array schemas, typed_value must be a LIST or LARGE_LIST + auto metadata = BuildMetadataColumn(1); + + BinaryBuilder value_builder; + value_builder.AppendNull().ok(); + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // Pass an Int64Array instead of a ListArray + Int64Builder typed_builder; + typed_builder.Append(42).ok(); + std::shared_ptr typed_col; + typed_builder.Finish(&typed_col).ok(); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + auto result = ReconstructVariantColumn(metadata, value_col, typed_col, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsInvalid()); +} + +TEST_F(VariantShredErrorTest, ReconstructObjectTypedValueNotStruct) { + // For object schemas, typed_value must be a StructArray + auto metadata = BuildMetadataColumn(1); + + BinaryBuilder value_builder; + value_builder.AppendNull().ok(); + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // Pass an Int64Array instead of a StructArray + Int64Builder typed_builder; + typed_builder.Append(42).ok(); + std::shared_ptr typed_col; + typed_builder.Finish(&typed_col).ok(); + + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + }); + auto result = ReconstructVariantColumn(metadata, value_col, typed_col, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsInvalid()); +} + +TEST_F(VariantShredErrorTest, ReconstructObjectFieldCountMismatch) { + // typed_value struct has fewer fields than the schema expects + auto metadata = BuildMetadataColumn(1); + + BinaryBuilder value_builder; + value_builder.AppendNull().ok(); + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // Build a struct with only 1 field, but schema expects 2 + BinaryBuilder inner_value_builder; + inner_value_builder.AppendNull().ok(); + std::shared_ptr inner_value; + inner_value_builder.Finish(&inner_value).ok(); + + Int64Builder inner_typed_builder; + inner_typed_builder.AppendNull().ok(); + std::shared_ptr inner_typed; + inner_typed_builder.Finish(&inner_typed).ok(); + + auto inner_fields = std::vector>{ + field("value", binary(), true), + field("typed_value", int64(), true), + }; + ASSERT_OK_AND_ASSIGN(auto single_field_struct, + StructArray::Make({inner_value, inner_typed}, inner_fields)); + + // Wrap into outer struct with only 1 field ("name") + auto outer_fields = std::vector>{ + field("name", single_field_struct->type(), false), + }; + ASSERT_OK_AND_ASSIGN(auto typed_col, + StructArray::Make({single_field_struct}, outer_fields)); + + // Schema expects 2 fields: name + age + auto schema = VariantShreddingSchema::Object({ + {"name", VariantShreddingSchema::Primitive(utf8())}, + {"age", VariantShreddingSchema::Primitive(int64())}, + }); + auto result = ReconstructVariantColumn(metadata, value_col, typed_col, schema); + ASSERT_FALSE(result.ok()); + ASSERT_TRUE(result.status().IsInvalid()); +} + +// =========================================================================== +// StringView / BinaryView shredding tests +// =========================================================================== + +TEST_F(VariantShredRoundTripTest, StringViewShredRoundTrip) { + // String variant values shredded into StringView typed column + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.String("hello"); }, + [](VariantBuilder& b) { return b.Int(42); }, // doesn't match + [](VariantBuilder& b) { return b.String("world"); }, + }); + auto metadata = BuildMetadataColumn(3); + auto schema = VariantShreddingSchema::Primitive(utf8_view()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Rows 0,2 match (strings); Row 1 doesn't (int) + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(value_col->IsValid(1)); + ASSERT_TRUE(value_col->IsNull(2)); + + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(typed_col->IsNull(1)); + ASSERT_TRUE(typed_col->IsValid(2)); + + // Verify typed column is StringViewArray + ASSERT_EQ(typed_col->type_id(), Type::STRING_VIEW); + auto& sv_arr = static_cast(*typed_col); + ASSERT_EQ(sv_arr.GetView(0), "hello"); + ASSERT_EQ(sv_arr.GetView(2), "world"); + + // Round-trip via reconstruction + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 3); + // Verify reconstructed strings have correct short-string header byte + auto recon_bytes0 = GetBinaryView(*reconstructed, 0); + // Short string "hello" (5 chars): header = (5 << 2) | 0x01 = 0x15 + ASSERT_EQ(recon_bytes0.size(), 6); + ASSERT_EQ(static_cast(recon_bytes0[0]), 0x15); +} + +TEST_F(VariantShredRoundTripTest, BinaryViewShredRoundTrip) { + // Binary variant values shredded into BinaryView typed column + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.Binary(std::string_view("\x01\x02\x03", 3)); }, + [](VariantBuilder& b) { return b.String("text"); }, // doesn't match binary + [](VariantBuilder& b) { return b.Binary(std::string_view("\xAA\xBB", 2)); }, + }); + auto metadata = BuildMetadataColumn(3); + auto schema = VariantShreddingSchema::Primitive(binary_view()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Rows 0,2 match (binary); Row 1 doesn't (string) + ASSERT_TRUE(value_col->IsNull(0)); + ASSERT_TRUE(value_col->IsValid(1)); + ASSERT_TRUE(value_col->IsNull(2)); + + ASSERT_TRUE(typed_col->IsValid(0)); + ASSERT_TRUE(typed_col->IsNull(1)); + ASSERT_TRUE(typed_col->IsValid(2)); + + // Verify typed column is BinaryViewArray + ASSERT_EQ(typed_col->type_id(), Type::BINARY_VIEW); + auto& bv_arr = static_cast(*typed_col); + auto v0 = bv_arr.GetView(0); + ASSERT_EQ(v0.size(), 3); + ASSERT_EQ(v0[0], '\x01'); + ASSERT_EQ(v0[1], '\x02'); + ASSERT_EQ(v0[2], '\x03'); + + // Round-trip via reconstruction + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 3); +} + +TEST_F(VariantShredRoundTripTest, ShortStringToStringView) { + // Short strings (≤63 bytes) should be compatible with StringView target + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { return b.String("hi"); }, // short string + }); + auto metadata = BuildMetadataColumn(1); + auto schema = VariantShreddingSchema::Primitive(utf8_view()); + + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + + auto typed_col = shredded->field(2); + ASSERT_EQ(typed_col->type_id(), Type::STRING_VIEW); + ASSERT_TRUE(typed_col->IsValid(0)); + auto& sv_arr = static_cast(*typed_col); + ASSERT_EQ(sv_arr.GetView(0), "hi"); +} + +// =========================================================================== +// LargeList reconstruction tests +// =========================================================================== + +TEST_F(VariantShredRoundTripTest, LargeListReconstructRoundTrip) { + // Shred an array, then construct a LargeListArray with equivalent data + // and verify reconstruction works with 64-bit offsets. + auto values = BuildVariantColumn({ + [](VariantBuilder& b) { + auto s = b.Offset(); + std::vector offsets; + offsets.push_back(b.NextElement(s)); + b.Int(10).ok(); + offsets.push_back(b.NextElement(s)); + b.Int(20).ok(); + return b.FinishArray(s, offsets); + }, + }); + auto metadata = BuildMetadataColumn(1); + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + // First shred normally (produces ListArray with struct elements) + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + auto value_col = shredded->field(1); // residual (null for arrays) + auto typed_col = shredded->field(2); // ListArray + + ASSERT_EQ(typed_col->type_id(), Type::LIST); + + // Reconstruct from the normal ListArray (sanity check) + ASSERT_OK_AND_ASSIGN(auto recon1, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(recon1->length(), 1); + + // Now build a LargeListArray with the same element struct data. + // The elements are struct{value: binary, typed_value: int64}. + const auto* list_arr = static_cast(typed_col.get()); + auto elem_struct = list_arr->values(); // StructArray + + // Construct a LargeList with same struct elements using 64-bit offsets. + auto large_offsets = Buffer::FromVector(std::vector{0, 2}); + auto large_list_arr = std::make_shared( + large_list(field("element", elem_struct->type(), false)), 1, large_offsets, + elem_struct); + ASSERT_EQ(large_list_arr->type_id(), Type::LARGE_LIST); + + // Reconstruct from LargeListArray + ASSERT_OK_AND_ASSIGN( + auto recon2, ReconstructVariantColumn(metadata, value_col, large_list_arr, schema)); + ASSERT_EQ(recon2->length(), 1); + + // Verify both reconstructions produce identical output + auto bytes1 = GetBinaryView(*recon1, 0); + auto bytes2 = GetBinaryView(*recon2, 0); + ASSERT_EQ(bytes1, bytes2); +} + +TEST_F(VariantShredRoundTripTest, ReconstructArrayTypedValueLargeListAccepted) { + // Verify that LargeList is accepted alongside List for array reconstruction. + // This tests the validation path. + auto metadata = BuildMetadataColumn(0); + BinaryBuilder vb; + std::shared_ptr value_col; + vb.Finish(&value_col).ok(); + + // Empty LargeList of binary + auto large_list_builder = std::make_shared( + default_memory_pool(), std::make_shared()); + std::shared_ptr typed_col; + large_list_builder->Finish(&typed_col).ok(); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + // Should succeed (empty arrays, no actual reconstruction needed) + auto result = ReconstructVariantColumn(metadata, value_col, typed_col, schema); + ASSERT_TRUE(result.ok()); +} + +TEST_F(VariantShredRoundTripTest, ReconstructArrayFixedSizeListAccepted) { + // Verify that FixedSizeList is accepted for array reconstruction. + // Build a FixedSizeList(2) of binary variant bytes and reconstruct. + auto metadata = BuildMetadataColumn(1); + + // value_col is null (array went to typed_value) + BinaryBuilder value_builder; + value_builder.AppendNull().ok(); + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // Build element binary values: Int(10), Int(20) + BinaryBuilder elem_builder; + { + VariantBuilder vb; + vb.Int(10).ok(); + auto encoded = vb.Finish().ValueOrDie(); + elem_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + { + VariantBuilder vb; + vb.Int(20).ok(); + auto encoded = vb.Finish().ValueOrDie(); + elem_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr elem_arr; + elem_builder.Finish(&elem_arr).ok(); + + // Build FixedSizeList(2) containing the 2 elements + auto typed_col = std::make_shared( + fixed_size_list(field("item", binary()), 2), 1, elem_arr); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + // Should succeed + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify reconstructed is a variant array with 2 elements + auto recon_bytes = GetBinaryView(*reconstructed, 0); + auto* data = reinterpret_cast(recon_bytes.data()); + auto len = static_cast(recon_bytes.size()); + ASSERT_GE(len, 1); + ASSERT_EQ(GetBasicType(data[0]), BasicType::kArray); + ASSERT_OK_AND_ASSIGN(auto elem_count, GetArrayElementCount(data, len)); + ASSERT_EQ(elem_count, 2); +} + +TEST_F(VariantShredRoundTripTest, ReconstructArrayListViewAccepted) { + // Verify that ListView is accepted for array reconstruction. + auto metadata = BuildMetadataColumn(1); + + // value_col is null + BinaryBuilder value_builder; + value_builder.AppendNull().ok(); + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // Build element binary values: Int(5), Int(6), Int(7) + BinaryBuilder elem_builder; + for (int val : {5, 6, 7}) { + VariantBuilder vb; + vb.Int(val).ok(); + auto encoded = vb.Finish().ValueOrDie(); + elem_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + } + std::shared_ptr elem_arr; + elem_builder.Finish(&elem_arr).ok(); + + // Build a ListView array: 1 row pointing to elements [0, 3) (all 3 elements) + // ListView needs offsets buffer + sizes buffer + auto offsets_buf = Buffer::FromVector({0}); + auto sizes_buf = Buffer::FromVector({3}); + auto list_view_type = list_view(field("item", binary())); + auto typed_col = std::make_shared(list_view_type, 1, offsets_buf, + sizes_buf, elem_arr); + + auto schema = VariantShreddingSchema::Array(VariantShreddingSchema::Primitive(int64())); + + // Should succeed + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); + + // Verify reconstructed is a variant array with 3 elements + auto recon_bytes = GetBinaryView(*reconstructed, 0); + auto* data = reinterpret_cast(recon_bytes.data()); + auto len = static_cast(recon_bytes.size()); + ASSERT_GE(len, 1); + ASSERT_EQ(GetBasicType(data[0]), BasicType::kArray); + ASSERT_OK_AND_ASSIGN(auto elem_count, GetArrayElementCount(data, len)); + ASSERT_EQ(elem_count, 3); +} + +TEST_F(VariantShredRoundTripTest, StringViewMetadataArrayInput) { + // Verify that STRING_VIEW metadata arrays are accepted by ShredVariantColumn. + // Arrow's BinaryViewArray/StringViewArray are valid metadata containers. + VariantBuilder vb; + vb.Int(42).ok(); + auto encoded = vb.Finish().ValueOrDie(); + + // Build a BinaryView metadata array (STRING_VIEW has the same binary layout + // and is accepted by GetBinaryValue) + BinaryViewBuilder meta_builder; + meta_builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + std::shared_ptr metadata; + meta_builder.Finish(&metadata).ok(); + ASSERT_EQ(metadata->type_id(), Type::BINARY_VIEW); + + // Build value array as regular binary + BinaryBuilder value_builder; + value_builder.Append(encoded.value.data(), static_cast(encoded.value.size())) + .ok(); + std::shared_ptr values; + value_builder.Finish(&values).ok(); + + auto schema = VariantShreddingSchema::Primitive(int64()); + + // Shredding should work with BinaryView metadata + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + auto typed_col = shredded->field(2); + ASSERT_EQ(typed_col->null_count(), 0); + auto& int_arr = static_cast(*typed_col); + ASSERT_EQ(int_arr.Value(0), 42); +} + +TEST_F(VariantShredRoundTripTest, BinaryViewMetadataReconstructionRoundTrip) { + // Verify that BINARY_VIEW metadata arrays work in reconstruction path. + VariantBuilder vb; + vb.Int(99).ok(); + auto encoded = vb.Finish().ValueOrDie(); + + // Build metadata as BinaryView + BinaryViewBuilder meta_builder; + meta_builder + .Append(encoded.metadata.data(), static_cast(encoded.metadata.size())) + .ok(); + std::shared_ptr metadata; + meta_builder.Finish(&metadata).ok(); + ASSERT_EQ(metadata->type_id(), Type::BINARY_VIEW); + + // value_col is null (value goes to typed) + BinaryBuilder value_builder; + value_builder.AppendNull().ok(); + std::shared_ptr value_col; + value_builder.Finish(&value_col).ok(); + + // typed_value has the int64 + Int64Builder typed_builder; + typed_builder.Append(99).ok(); + std::shared_ptr typed_col; + typed_builder.Finish(&typed_col).ok(); + + auto schema = VariantShreddingSchema::Primitive(int64()); + + // Reconstruction should succeed with BinaryView metadata + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 1); + ASSERT_TRUE(reconstructed->IsValid(0)); +} + +TEST_F(VariantShredRoundTripTest, ObjectShredDifferentMetadataDictionaries) { + // Test object shredding with rows that have different metadata dictionaries. + // This exercises the cached_meta_bytes comparison in ReconstructVariantColumnObject. + auto schema = VariantShreddingSchema::Object({ + {"x", VariantShreddingSchema::Primitive(int64())}, + }); + + // Row 0: object {"x": 10} with metadata containing "x" + VariantBuilder vb1; + auto start1 = vb1.Offset(); + std::vector fields1; + fields1.push_back(vb1.NextField(start1, "x")); + vb1.Int(10).ok(); + vb1.FinishObject(start1, fields1).ok(); + auto encoded1 = vb1.Finish().ValueOrDie(); + + // Row 1: object {"x": 20, "y": 30} with metadata containing "x" AND "y" + VariantBuilder vb2; + auto start2 = vb2.Offset(); + std::vector fields2; + fields2.push_back(vb2.NextField(start2, "x")); + vb2.Int(20).ok(); + fields2.push_back(vb2.NextField(start2, "y")); + vb2.Int(30).ok(); + vb2.FinishObject(start2, fields2).ok(); + auto encoded2 = vb2.Finish().ValueOrDie(); + + // Build metadata array (different metadata per row) + BinaryBuilder meta_builder; + meta_builder + .Append(encoded1.metadata.data(), static_cast(encoded1.metadata.size())) + .ok(); + meta_builder + .Append(encoded2.metadata.data(), static_cast(encoded2.metadata.size())) + .ok(); + std::shared_ptr metadata; + meta_builder.Finish(&metadata).ok(); + + // Build value array + BinaryBuilder value_builder; + value_builder.Append(encoded1.value.data(), static_cast(encoded1.value.size())) + .ok(); + value_builder.Append(encoded2.value.data(), static_cast(encoded2.value.size())) + .ok(); + std::shared_ptr values; + value_builder.Finish(&values).ok(); + + // Shred + ASSERT_OK_AND_ASSIGN(auto shredded, ShredVariantColumn(metadata, values, schema)); + auto value_col = shredded->field(1); + auto typed_col = shredded->field(2); + + // Row 0: "x" is shredded, no residual (single field) + // Row 1: "x" is shredded, "y" goes to residual + ASSERT_TRUE(value_col->IsNull(0)); // no residual for row 0 + ASSERT_TRUE(value_col->IsValid(1)); // residual with "y" for row 1 + + // Reconstruct — must handle different metadata dictionaries correctly + ASSERT_OK_AND_ASSIGN(auto reconstructed, + ReconstructVariantColumn(metadata, value_col, typed_col, schema)); + ASSERT_EQ(reconstructed->length(), 2); + ASSERT_TRUE(reconstructed->IsValid(0)); + ASSERT_TRUE(reconstructed->IsValid(1)); + + // Verify row 1 reconstructed has both fields by checking it's a valid object + auto recon_bytes = GetBinaryView(*reconstructed, 1); + auto* data = reinterpret_cast(recon_bytes.data()); + auto len = static_cast(recon_bytes.size()); + ASSERT_GE(len, 1); + ASSERT_EQ(GetBasicType(data[0]), BasicType::kObject); + ASSERT_OK_AND_ASSIGN(auto field_count, GetObjectFieldCount(data, len)); + ASSERT_EQ(field_count, 2); // Both "x" and "y" reconstructed +} + +} // namespace arrow::extension::variant_internal diff --git a/cpp/src/arrow/meson.build b/cpp/src/arrow/meson.build index d8c81b868fe4..2760896848ed 100644 --- a/cpp/src/arrow/meson.build +++ b/cpp/src/arrow/meson.build @@ -144,6 +144,7 @@ arrow_components = { 'extension/parquet_variant.cc', 'extension/variant_builder.cc', 'extension/variant_internal.cc', + 'extension/variant_shredding.cc', 'extension/uuid.cc', 'pretty_print.cc', 'record_batch.cc',