diff --git a/native/core/Cargo.toml b/native/core/Cargo.toml index 07d4c6cc8f..ba5509583c 100644 --- a/native/core/Cargo.toml +++ b/native/core/Cargo.toml @@ -132,3 +132,7 @@ harness = false [[bench]] name = "parquet_decode" harness = false + +[[bench]] +name = "array_element_append" +harness = false diff --git a/native/core/benches/array_element_append.rs b/native/core/benches/array_element_append.rs new file mode 100644 index 0000000000..75fc2bbf76 --- /dev/null +++ b/native/core/benches/array_element_append.rs @@ -0,0 +1,272 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Micro-benchmarks for SparkUnsafeArray element iteration. +//! +//! This tests the low-level `append_to_builder` function which converts +//! SparkUnsafeArray elements to Arrow array builders. This is the inner loop +//! used when processing List/Array columns in JVM shuffle. + +use arrow::array::builder::{ + Date32Builder, Float64Builder, Int32Builder, Int64Builder, TimestampMicrosecondBuilder, +}; +use arrow::datatypes::{DataType, TimeUnit}; +use comet::execution::shuffle::list::{append_to_builder, SparkUnsafeArray}; +use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; + +const NUM_ELEMENTS: usize = 10000; + +/// Create a SparkUnsafeArray in memory with i32 elements. +/// Layout: +/// - 8 bytes: num_elements (i64) +/// - null bitset: 8 bytes per 64 elements +/// - element data: 4 bytes per element (i32) +fn create_spark_unsafe_array_i32(num_elements: usize, with_nulls: bool) -> Vec { + // Header size: 8 (num_elements) + ceil(num_elements/64) * 8 (null bitset) + let null_bitset_words = num_elements.div_ceil(64); + let header_size = 8 + null_bitset_words * 8; + let data_size = num_elements * 4; // i32 = 4 bytes + let total_size = header_size + data_size; + + let mut buffer = vec![0u8; total_size]; + + // Write num_elements + buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes()); + + // Write null bitset (set every 10th element as null if with_nulls) + if with_nulls { + for i in (0..num_elements).step_by(10) { + let word_idx = i / 64; + let bit_idx = i % 64; + let word_offset = 8 + word_idx * 8; + let current_word = + i64::from_le_bytes(buffer[word_offset..word_offset + 8].try_into().unwrap()); + let new_word = current_word | (1i64 << bit_idx); + buffer[word_offset..word_offset + 8].copy_from_slice(&new_word.to_le_bytes()); + } + } + + // Write element data + for i in 0..num_elements { + let offset = header_size + i * 4; + buffer[offset..offset + 4].copy_from_slice(&(i as i32).to_le_bytes()); + } + + buffer +} + +/// Create a SparkUnsafeArray in memory with i64 elements. +fn create_spark_unsafe_array_i64(num_elements: usize, with_nulls: bool) -> Vec { + let null_bitset_words = num_elements.div_ceil(64); + let header_size = 8 + null_bitset_words * 8; + let data_size = num_elements * 8; // i64 = 8 bytes + let total_size = header_size + data_size; + + let mut buffer = vec![0u8; total_size]; + + // Write num_elements + buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes()); + + // Write null bitset + if with_nulls { + for i in (0..num_elements).step_by(10) { + let word_idx = i / 64; + let bit_idx = i % 64; + let word_offset = 8 + word_idx * 8; + let current_word = + i64::from_le_bytes(buffer[word_offset..word_offset + 8].try_into().unwrap()); + let new_word = current_word | (1i64 << bit_idx); + buffer[word_offset..word_offset + 8].copy_from_slice(&new_word.to_le_bytes()); + } + } + + // Write element data + for i in 0..num_elements { + let offset = header_size + i * 8; + buffer[offset..offset + 8].copy_from_slice(&(i as i64).to_le_bytes()); + } + + buffer +} + +/// Create a SparkUnsafeArray in memory with f64 elements. +fn create_spark_unsafe_array_f64(num_elements: usize, with_nulls: bool) -> Vec { + let null_bitset_words = num_elements.div_ceil(64); + let header_size = 8 + null_bitset_words * 8; + let data_size = num_elements * 8; // f64 = 8 bytes + let total_size = header_size + data_size; + + let mut buffer = vec![0u8; total_size]; + + // Write num_elements + buffer[0..8].copy_from_slice(&(num_elements as i64).to_le_bytes()); + + // Write null bitset + if with_nulls { + for i in (0..num_elements).step_by(10) { + let word_idx = i / 64; + let bit_idx = i % 64; + let word_offset = 8 + word_idx * 8; + let current_word = + i64::from_le_bytes(buffer[word_offset..word_offset + 8].try_into().unwrap()); + let new_word = current_word | (1i64 << bit_idx); + buffer[word_offset..word_offset + 8].copy_from_slice(&new_word.to_le_bytes()); + } + } + + // Write element data + for i in 0..num_elements { + let offset = header_size + i * 8; + buffer[offset..offset + 8].copy_from_slice(&(i as f64).to_le_bytes()); + } + + buffer +} + +fn benchmark_array_conversion(c: &mut Criterion) { + let mut group = c.benchmark_group("spark_unsafe_array_to_arrow"); + + // Benchmark i32 array conversion + for with_nulls in [false, true] { + let buffer = create_spark_unsafe_array_i32(NUM_ELEMENTS, with_nulls); + let array = SparkUnsafeArray::new(buffer.as_ptr() as i64); + let null_str = if with_nulls { "with_nulls" } else { "no_nulls" }; + + group.bench_with_input( + BenchmarkId::new("i32", null_str), + &(&array, &buffer), + |b, (array, _buffer)| { + b.iter(|| { + let mut builder = Int32Builder::with_capacity(NUM_ELEMENTS); + if with_nulls { + append_to_builder::(&DataType::Int32, &mut builder, array).unwrap(); + } else { + append_to_builder::(&DataType::Int32, &mut builder, array).unwrap(); + } + builder.finish() + }); + }, + ); + } + + // Benchmark i64 array conversion + for with_nulls in [false, true] { + let buffer = create_spark_unsafe_array_i64(NUM_ELEMENTS, with_nulls); + let array = SparkUnsafeArray::new(buffer.as_ptr() as i64); + let null_str = if with_nulls { "with_nulls" } else { "no_nulls" }; + + group.bench_with_input( + BenchmarkId::new("i64", null_str), + &(&array, &buffer), + |b, (array, _buffer)| { + b.iter(|| { + let mut builder = Int64Builder::with_capacity(NUM_ELEMENTS); + if with_nulls { + append_to_builder::(&DataType::Int64, &mut builder, array).unwrap(); + } else { + append_to_builder::(&DataType::Int64, &mut builder, array).unwrap(); + } + builder.finish() + }); + }, + ); + } + + // Benchmark f64 array conversion + for with_nulls in [false, true] { + let buffer = create_spark_unsafe_array_f64(NUM_ELEMENTS, with_nulls); + let array = SparkUnsafeArray::new(buffer.as_ptr() as i64); + let null_str = if with_nulls { "with_nulls" } else { "no_nulls" }; + + group.bench_with_input( + BenchmarkId::new("f64", null_str), + &(&array, &buffer), + |b, (array, _buffer)| { + b.iter(|| { + let mut builder = Float64Builder::with_capacity(NUM_ELEMENTS); + if with_nulls { + append_to_builder::(&DataType::Float64, &mut builder, array).unwrap(); + } else { + append_to_builder::(&DataType::Float64, &mut builder, array) + .unwrap(); + } + builder.finish() + }); + }, + ); + } + + // Benchmark date32 array conversion (same memory layout as i32) + for with_nulls in [false, true] { + let buffer = create_spark_unsafe_array_i32(NUM_ELEMENTS, with_nulls); + let array = SparkUnsafeArray::new(buffer.as_ptr() as i64); + let null_str = if with_nulls { "with_nulls" } else { "no_nulls" }; + + group.bench_with_input( + BenchmarkId::new("date32", null_str), + &(&array, &buffer), + |b, (array, _buffer)| { + b.iter(|| { + let mut builder = Date32Builder::with_capacity(NUM_ELEMENTS); + if with_nulls { + append_to_builder::(&DataType::Date32, &mut builder, array).unwrap(); + } else { + append_to_builder::(&DataType::Date32, &mut builder, array).unwrap(); + } + builder.finish() + }); + }, + ); + } + + // Benchmark timestamp array conversion (same memory layout as i64) + for with_nulls in [false, true] { + let buffer = create_spark_unsafe_array_i64(NUM_ELEMENTS, with_nulls); + let array = SparkUnsafeArray::new(buffer.as_ptr() as i64); + let null_str = if with_nulls { "with_nulls" } else { "no_nulls" }; + + group.bench_with_input( + BenchmarkId::new("timestamp", null_str), + &(&array, &buffer), + |b, (array, _buffer)| { + b.iter(|| { + let mut builder = TimestampMicrosecondBuilder::with_capacity(NUM_ELEMENTS); + let dt = DataType::Timestamp(TimeUnit::Microsecond, None); + if with_nulls { + append_to_builder::(&dt, &mut builder, array).unwrap(); + } else { + append_to_builder::(&dt, &mut builder, array).unwrap(); + } + builder.finish() + }); + }, + ); + } + + group.finish(); +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = benchmark_array_conversion +} +criterion_main!(benches); diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index e9f2d6523d..aacf36ec72 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -491,7 +491,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let physical_plan_time = start.elapsed(); exec_context.plan_creation_time += physical_plan_time; - exec_context.root_op = Some(Arc::clone(&root_op)); exec_context.scans = scans; if exec_context.explain_native { @@ -505,6 +504,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( // so we should always execute partition 0. let stream = root_op.native_plan.execute(0, task_ctx)?; exec_context.stream = Some(stream); + exec_context.root_op = Some(root_op); } else { // Pull input batches pull_input_batches(exec_context)?; diff --git a/native/core/src/execution/shuffle/list.rs b/native/core/src/execution/shuffle/list.rs index c31244b87d..cb21cc3497 100644 --- a/native/core/src/execution/shuffle/list.rs +++ b/native/core/src/execution/shuffle/list.rs @@ -32,6 +32,52 @@ use arrow::array::{ }; use arrow::datatypes::{DataType, TimeUnit}; +/// Generates bulk append methods for primitive types in SparkUnsafeArray. +/// +/// # Safety invariants for all generated methods: +/// - `element_offset` points to contiguous element data of length `num_elements` +/// - `null_bitset_ptr()` returns a pointer to `ceil(num_elements/64)` i64 words +/// - These invariants are guaranteed by the SparkUnsafeArray layout from the JVM +macro_rules! impl_append_to_builder { + ($method_name:ident, $builder_type:ty, $element_type:ty) => { + pub(crate) fn $method_name(&self, builder: &mut $builder_type) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + if NULLABLE { + let mut ptr = self.element_offset as *const $element_type; + let null_words = self.null_bitset_ptr(); + for idx in 0..num_elements { + let word_idx = idx >> 6; + let bit_idx = idx & 0x3f; + // SAFETY: word_idx < ceil(num_elements/64) since idx < num_elements + let is_null = unsafe { (*null_words.add(word_idx) & (1i64 << bit_idx)) != 0 }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + // SAFETY: element_offset points to contiguous data of length num_elements + let slice = unsafe { + std::slice::from_raw_parts( + self.element_offset as *const $element_type, + num_elements, + ) + }; + builder.append_slice(slice); + } + } + }; +} + pub struct SparkUnsafeArray { row_addr: i64, num_elements: usize, @@ -39,10 +85,12 @@ pub struct SparkUnsafeArray { } impl SparkUnsafeObject for SparkUnsafeArray { + #[inline] fn get_row_addr(&self) -> i64 { self.row_addr } + #[inline] fn get_element_offset(&self, index: usize, element_size: usize) -> *const u8 { (self.element_offset + (index * element_size) as i64) as *const u8 } @@ -51,7 +99,8 @@ impl SparkUnsafeObject for SparkUnsafeArray { impl SparkUnsafeArray { /// Creates a `SparkUnsafeArray` which points to the given address and size in bytes. pub fn new(addr: i64) -> Self { - // Read the number of elements from the first 8 bytes. + // SAFETY: addr points to valid Spark UnsafeArray data from the JVM. + // The first 8 bytes contain the element count as a little-endian i64. let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr as *const u8, 8) }; let num_elements = i64::from_le_bytes(slice.try_into().unwrap()); @@ -83,6 +132,9 @@ impl SparkUnsafeArray { /// Returns true if the null bit at the given index of the array is set. #[inline] pub(crate) fn is_null_at(&self, index: usize) -> bool { + // SAFETY: row_addr points to valid Spark UnsafeArray data. The null bitset starts + // at offset 8 and contains ceil(num_elements/64) * 8 bytes. The caller ensures + // index < num_elements, so word_offset is within the bitset region. unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + 8 + (((index >> 6) as i64) << 3)) as *const i64; @@ -90,6 +142,132 @@ impl SparkUnsafeArray { (word & mask) != 0 } } + + /// Returns the null bitset pointer (starts at row_addr + 8). + #[inline] + fn null_bitset_ptr(&self) -> *const i64 { + (self.row_addr + 8) as *const i64 + } + + impl_append_to_builder!(append_ints_to_builder, Int32Builder, i32); + impl_append_to_builder!(append_longs_to_builder, Int64Builder, i64); + impl_append_to_builder!(append_shorts_to_builder, Int16Builder, i16); + impl_append_to_builder!(append_bytes_to_builder, Int8Builder, i8); + impl_append_to_builder!(append_floats_to_builder, Float32Builder, f32); + impl_append_to_builder!(append_doubles_to_builder, Float64Builder, f64); + + /// Bulk append boolean values to builder. + /// Booleans are stored as 1 byte each in SparkUnsafeArray, requiring special handling. + pub(crate) fn append_booleans_to_builder( + &self, + builder: &mut BooleanBuilder, + ) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + let mut ptr = self.element_offset as *const u8; + + if NULLABLE { + let null_words = self.null_bitset_ptr(); + for idx in 0..num_elements { + let word_idx = idx >> 6; + let bit_idx = idx & 0x3f; + // SAFETY: word_idx < ceil(num_elements/64) since idx < num_elements + let is_null = unsafe { (*null_words.add(word_idx) & (1i64 << bit_idx)) != 0 }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr != 0 }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + for _ in 0..num_elements { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr != 0 }); + ptr = unsafe { ptr.add(1) }; + } + } + } + + /// Bulk append timestamp values to builder (stored as i64 microseconds). + pub(crate) fn append_timestamps_to_builder( + &self, + builder: &mut TimestampMicrosecondBuilder, + ) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + if NULLABLE { + let mut ptr = self.element_offset as *const i64; + let null_words = self.null_bitset_ptr(); + for idx in 0..num_elements { + let word_idx = idx >> 6; + let bit_idx = idx & 0x3f; + // SAFETY: word_idx < ceil(num_elements/64) since idx < num_elements + let is_null = unsafe { (*null_words.add(word_idx) & (1i64 << bit_idx)) != 0 }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + // SAFETY: element_offset points to contiguous i64 data of length num_elements + let slice = unsafe { + std::slice::from_raw_parts(self.element_offset as *const i64, num_elements) + }; + builder.append_slice(slice); + } + } + + /// Bulk append date values to builder (stored as i32 days since epoch). + pub(crate) fn append_dates_to_builder( + &self, + builder: &mut Date32Builder, + ) { + let num_elements = self.num_elements; + if num_elements == 0 { + return; + } + + if NULLABLE { + let mut ptr = self.element_offset as *const i32; + let null_words = self.null_bitset_ptr(); + for idx in 0..num_elements { + let word_idx = idx >> 6; + let bit_idx = idx & 0x3f; + // SAFETY: word_idx < ceil(num_elements/64) since idx < num_elements + let is_null = unsafe { (*null_words.add(word_idx) & (1i64 << bit_idx)) != 0 }; + + if is_null { + builder.append_null(); + } else { + // SAFETY: ptr is within element data bounds + builder.append_value(unsafe { *ptr }); + } + // SAFETY: ptr stays within bounds, iterating num_elements times + ptr = unsafe { ptr.add(1) }; + } + } else { + // SAFETY: element_offset points to contiguous i32 data of length num_elements + let slice = unsafe { + std::slice::from_raw_parts(self.element_offset as *const i32, num_elements) + }; + builder.append_slice(slice); + } + } } pub fn append_to_builder( @@ -112,77 +290,40 @@ pub fn append_to_builder( match data_type { DataType::Boolean => { - add_values!( - BooleanBuilder, - |builder: &mut BooleanBuilder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_boolean(idx)), - |builder: &mut BooleanBuilder| builder.append_null() - ); + let builder = downcast_builder_ref!(BooleanBuilder, builder); + array.append_booleans_to_builder::(builder); } DataType::Int8 => { - add_values!( - Int8Builder, - |builder: &mut Int8Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_byte(idx)), - |builder: &mut Int8Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Int8Builder, builder); + array.append_bytes_to_builder::(builder); } DataType::Int16 => { - add_values!( - Int16Builder, - |builder: &mut Int16Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_short(idx)), - |builder: &mut Int16Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Int16Builder, builder); + array.append_shorts_to_builder::(builder); } DataType::Int32 => { - add_values!( - Int32Builder, - |builder: &mut Int32Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_int(idx)), - |builder: &mut Int32Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Int32Builder, builder); + array.append_ints_to_builder::(builder); } DataType::Int64 => { - add_values!( - Int64Builder, - |builder: &mut Int64Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_long(idx)), - |builder: &mut Int64Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Int64Builder, builder); + array.append_longs_to_builder::(builder); } DataType::Float32 => { - add_values!( - Float32Builder, - |builder: &mut Float32Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_float(idx)), - |builder: &mut Float32Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Float32Builder, builder); + array.append_floats_to_builder::(builder); } DataType::Float64 => { - add_values!( - Float64Builder, - |builder: &mut Float64Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_double(idx)), - |builder: &mut Float64Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Float64Builder, builder); + array.append_doubles_to_builder::(builder); } DataType::Timestamp(TimeUnit::Microsecond, _) => { - add_values!( - TimestampMicrosecondBuilder, - |builder: &mut TimestampMicrosecondBuilder, - values: &SparkUnsafeArray, - idx: usize| builder.append_value(values.get_timestamp(idx)), - |builder: &mut TimestampMicrosecondBuilder| builder.append_null() - ); + let builder = downcast_builder_ref!(TimestampMicrosecondBuilder, builder); + array.append_timestamps_to_builder::(builder); } DataType::Date32 => { - add_values!( - Date32Builder, - |builder: &mut Date32Builder, values: &SparkUnsafeArray, idx: usize| builder - .append_value(values.get_date(idx)), - |builder: &mut Date32Builder| builder.append_null() - ); + let builder = downcast_builder_ref!(Date32Builder, builder); + array.append_dates_to_builder::(builder); } DataType::Binary => { add_values!( diff --git a/native/core/src/execution/shuffle/mod.rs b/native/core/src/execution/shuffle/mod.rs index e2798df63e..172dc5f942 100644 --- a/native/core/src/execution/shuffle/mod.rs +++ b/native/core/src/execution/shuffle/mod.rs @@ -17,7 +17,7 @@ pub(crate) mod codec; mod comet_partitioning; -mod list; +pub mod list; mod map; pub mod row; mod shuffle_writer; diff --git a/native/core/src/execution/shuffle/row.rs b/native/core/src/execution/shuffle/row.rs index 821607ddb9..7a23254256 100644 --- a/native/core/src/execution/shuffle/row.rs +++ b/native/core/src/execution/shuffle/row.rs @@ -56,6 +56,19 @@ const NESTED_TYPE_BUILDER_CAPACITY: usize = 100; /// A common trait for Spark Unsafe classes that can be used to access the underlying data, /// e.g., `UnsafeRow` and `UnsafeArray`. This defines a set of methods that can be used to /// access the underlying data with index. +/// +/// # Safety +/// +/// Implementations must ensure that: +/// - `get_row_addr()` returns a valid pointer to JVM-allocated memory +/// - `get_element_offset()` returns a valid pointer within the row/array data region +/// - The memory layout follows Spark's UnsafeRow/UnsafeArray format +/// - The memory remains valid for the lifetime of the object (guaranteed by JVM ownership) +/// +/// All accessor methods (get_boolean, get_int, etc.) use unsafe pointer operations but are +/// safe to call as long as: +/// - The index is within bounds (caller's responsibility) +/// - The object was constructed from valid Spark UnsafeRow/UnsafeArray data pub trait SparkUnsafeObject { /// Returns the address of the row. fn get_row_addr(&self) -> i64; @@ -73,12 +86,14 @@ pub trait SparkUnsafeObject { } /// Returns boolean value at the given index of the object. + #[inline] fn get_boolean(&self, index: usize) -> bool { let addr = self.get_element_offset(index, 1); unsafe { *addr != 0 } } /// Returns byte value at the given index of the object. + #[inline] fn get_byte(&self, index: usize) -> i8 { let addr = self.get_element_offset(index, 1); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 1) }; @@ -86,6 +101,7 @@ pub trait SparkUnsafeObject { } /// Returns short value at the given index of the object. + #[inline] fn get_short(&self, index: usize) -> i16 { let addr = self.get_element_offset(index, 2); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 2) }; @@ -93,6 +109,7 @@ pub trait SparkUnsafeObject { } /// Returns integer value at the given index of the object. + #[inline] fn get_int(&self, index: usize) -> i32 { let addr = self.get_element_offset(index, 4); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; @@ -100,6 +117,7 @@ pub trait SparkUnsafeObject { } /// Returns long value at the given index of the object. + #[inline] fn get_long(&self, index: usize) -> i64 { let addr = self.get_element_offset(index, 8); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; @@ -107,6 +125,7 @@ pub trait SparkUnsafeObject { } /// Returns float value at the given index of the object. + #[inline] fn get_float(&self, index: usize) -> f32 { let addr = self.get_element_offset(index, 4); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; @@ -114,6 +133,7 @@ pub trait SparkUnsafeObject { } /// Returns double value at the given index of the object. + #[inline] fn get_double(&self, index: usize) -> f64 { let addr = self.get_element_offset(index, 8); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; @@ -137,6 +157,7 @@ pub trait SparkUnsafeObject { } /// Returns date value at the given index of the object. + #[inline] fn get_date(&self, index: usize) -> i32 { let addr = self.get_element_offset(index, 4); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 4) }; @@ -144,6 +165,7 @@ pub trait SparkUnsafeObject { } /// Returns timestamp value at the given index of the object. + #[inline] fn get_timestamp(&self, index: usize) -> i64 { let addr = self.get_element_offset(index, 8); let slice: &[u8] = unsafe { std::slice::from_raw_parts(addr, 8) }; @@ -255,6 +277,9 @@ impl SparkUnsafeRow { /// Returns true if the null bit at the given index of the row is set. #[inline] pub(crate) fn is_null_at(&self, index: usize) -> bool { + // SAFETY: row_addr points to valid Spark UnsafeRow data with at least + // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. + // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *const i64; @@ -265,6 +290,10 @@ impl SparkUnsafeRow { /// Unsets the null bit at the given index of the row, i.e., set the bit to 0 (not null). pub fn set_not_null_at(&mut self, index: usize) { + // SAFETY: row_addr points to valid Spark UnsafeRow data with at least + // ceil(num_fields/64) * 8 bytes of null bitset. The caller ensures index < num_fields. + // word_offset is within the bitset region since (index >> 6) << 3 < bitset size. + // Writing is safe because we have mutable access and the memory is owned by the JVM. unsafe { let mask: i64 = 1i64 << (index & 0x3f); let word_offset = (self.row_addr + (((index >> 6) as i64) << 3)) as *mut i64; @@ -275,11 +304,32 @@ impl SparkUnsafeRow { } macro_rules! downcast_builder_ref { - ($builder_type:ty, $builder:expr) => { + ($builder_type:ty, $builder:expr) => {{ + let actual_type_id = $builder.as_any().type_id(); $builder .as_any_mut() .downcast_mut::<$builder_type>() - .expect(stringify!($builder_type)) + .ok_or_else(|| { + CometError::Internal(format!( + "Failed to downcast builder: expected {}, got {:?}", + stringify!($builder_type), + actual_type_id + )) + })? + }}; +} + +macro_rules! get_field_builder { + ($struct_builder:expr, $builder_type:ty, $idx:expr) => { + $struct_builder + .field_builder::<$builder_type>($idx) + .ok_or_else(|| { + CometError::Internal(format!( + "Failed to get field builder at index {}: expected {}", + $idx, + stringify!($builder_type) + )) + })? }; } @@ -302,7 +352,7 @@ pub(crate) fn append_field( /// A macro for generating code of appending value into field builder of Arrow struct builder. macro_rules! append_field_to_builder { ($builder_type:ty, $accessor:expr) => {{ - let field_builder = struct_builder.field_builder::<$builder_type>(idx).unwrap(); + let field_builder = get_field_builder!(struct_builder, $builder_type, idx); if row.is_null_row() { // The row is null. @@ -375,7 +425,7 @@ pub(crate) fn append_field( } DataType::Struct(fields) => { // Appending value into struct field builder of Arrow struct builder. - let field_builder = struct_builder.field_builder::(idx).unwrap(); + let field_builder = get_field_builder!(struct_builder, StructBuilder, idx); let nested_row = if row.is_null_row() || row.is_null_at(idx) { // The row is null, or the field in the row is null, i.e., a null nested row. @@ -392,9 +442,11 @@ pub(crate) fn append_field( } } DataType::Map(field, _) => { - let field_builder = struct_builder - .field_builder::, Box>>(idx) - .unwrap(); + let field_builder = get_field_builder!( + struct_builder, + MapBuilder, Box>, + idx + ); if row.is_null_row() { // The row is null. @@ -412,9 +464,8 @@ pub(crate) fn append_field( } } DataType::List(field) => { - let field_builder = struct_builder - .field_builder::>>(idx) - .unwrap(); + let field_builder = + get_field_builder!(struct_builder, ListBuilder>, idx); if row.is_null_row() { // The row is null. @@ -439,7 +490,662 @@ pub(crate) fn append_field( Ok(()) } +/// Appends nested struct fields to the struct builder using field-major order. +/// This is a helper function for processing nested struct fields recursively. +/// +/// Unlike `append_struct_fields_field_major`, this function takes slices of row addresses, +/// sizes, and null flags directly, without needing to navigate from a parent row. +#[allow(clippy::redundant_closure_call)] +fn append_nested_struct_fields_field_major( + row_addresses: &[jlong], + row_sizes: &[jint], + struct_is_null: &[bool], + struct_builder: &mut StructBuilder, + fields: &arrow::datatypes::Fields, +) -> Result<(), CometError> { + let num_rows = row_addresses.len(); + let mut row = SparkUnsafeRow::new_with_num_fields(fields.len()); + + // Helper macro for processing primitive fields + macro_rules! process_field { + ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{ + let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + // Struct is null, field is also null + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at($field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value($get_value(&row, $field_idx)); + } + } + } + }}; + } + + // Process each field across all rows + for (field_idx, field) in fields.iter().enumerate() { + match field.data_type() { + DataType::Boolean => { + process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_boolean(idx)); + } + DataType::Int8 => { + process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_byte(idx)); + } + DataType::Int16 => { + process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_short(idx)); + } + DataType::Int32 => { + process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_int(idx)); + } + DataType::Int64 => { + process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_long(idx)); + } + DataType::Float32 => { + process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_float(idx)); + } + DataType::Float64 => { + process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_double(idx)); + } + DataType::Date32 => { + process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_date(idx)); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + process_field!( + TimestampMicrosecondBuilder, + field_idx, + |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) + ); + } + DataType::Binary => { + let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(row.get_binary(field_idx)); + } + } + } + } + DataType::Utf8 => { + let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(row.get_string(field_idx)); + } + } + } + } + DataType::Decimal128(p, _) => { + let p = *p; + let field_builder = + get_field_builder!(struct_builder, Decimal128Builder, field_idx); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(row.get_decimal(field_idx, p)); + } + } + } + } + DataType::Struct(nested_fields) => { + let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx); + + // Collect nested struct addresses and sizes in one pass, building validity + let mut nested_addresses: Vec = Vec::with_capacity(num_rows); + let mut nested_sizes: Vec = Vec::with_capacity(num_rows); + let mut nested_is_null: Vec = Vec::with_capacity(num_rows); + + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + // Parent struct is null, nested struct is also null + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + + if row.is_null_at(field_idx) { + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + nested_builder.append(true); + nested_is_null.push(false); + // Get nested struct address and size + let nested_row = row.get_struct(field_idx, nested_fields.len()); + nested_addresses.push(nested_row.get_row_addr()); + nested_sizes.push(nested_row.get_row_size()); + } + } + } + + // Recursively process nested struct fields in field-major order + append_nested_struct_fields_field_major( + &nested_addresses, + &nested_sizes, + &nested_is_null, + nested_builder, + nested_fields, + )?; + } + // For list and map, fall back to append_field since they have variable-length elements + dt @ (DataType::List(_) | DataType::Map(_, _)) => { + for row_idx in 0..num_rows { + if struct_is_null[row_idx] { + let null_row = SparkUnsafeRow::default(); + append_field(dt, struct_builder, &null_row, field_idx)?; + } else { + let row_addr = row_addresses[row_idx]; + let row_size = row_sizes[row_idx]; + row.point_to(row_addr, row_size); + append_field(dt, struct_builder, &row, field_idx)?; + } + } + } + _ => { + unreachable!( + "Unsupported data type of struct field: {:?}", + field.data_type() + ) + } + } + } + + Ok(()) +} + +/// Reads row address and size from JVM-provided pointer arrays and points the row to that data. +/// +/// # Safety +/// Caller must ensure row_addresses_ptr and row_sizes_ptr are valid for index i. +/// This is guaranteed when called from append_columns with indices in [row_start, row_end). +macro_rules! read_row_at { + ($row:expr, $row_addresses_ptr:expr, $row_sizes_ptr:expr, $i:expr) => {{ + // SAFETY: Caller guarantees pointers are valid for this index (see macro doc) + let row_addr = unsafe { *$row_addresses_ptr.add($i) }; + let row_size = unsafe { *$row_sizes_ptr.add($i) }; + $row.point_to(row_addr, row_size); + }}; +} + +/// Appends a batch of list values to the list builder with a single type dispatch. +/// This moves type dispatch from O(rows) to O(1), significantly improving performance +/// for large batches. +#[allow(clippy::too_many_arguments)] +fn append_list_column_batch( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + schema: &[DataType], + column_idx: usize, + element_type: &DataType, + list_builder: &mut ListBuilder>, +) -> Result<(), CometError> { + let mut row = SparkUnsafeRow::new(schema); + + // Helper macro for primitive element types - gets builder fresh each iteration + // to avoid borrow conflicts with list_builder.append() + macro_rules! process_primitive_lists { + ($builder_type:ty, $append_fn:ident) => {{ + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + list_builder.append_null(); + } else { + let array = row.get_array(column_idx); + // Get values builder fresh each iteration to avoid borrow conflict + let values_builder = list_builder + .values() + .as_any_mut() + .downcast_mut::<$builder_type>() + .expect(stringify!($builder_type)); + array.$append_fn::(values_builder); + list_builder.append(true); + } + } + }}; + } + + match element_type { + DataType::Boolean => { + process_primitive_lists!(BooleanBuilder, append_booleans_to_builder); + } + DataType::Int8 => { + process_primitive_lists!(Int8Builder, append_bytes_to_builder); + } + DataType::Int16 => { + process_primitive_lists!(Int16Builder, append_shorts_to_builder); + } + DataType::Int32 => { + process_primitive_lists!(Int32Builder, append_ints_to_builder); + } + DataType::Int64 => { + process_primitive_lists!(Int64Builder, append_longs_to_builder); + } + DataType::Float32 => { + process_primitive_lists!(Float32Builder, append_floats_to_builder); + } + DataType::Float64 => { + process_primitive_lists!(Float64Builder, append_doubles_to_builder); + } + DataType::Date32 => { + process_primitive_lists!(Date32Builder, append_dates_to_builder); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + process_primitive_lists!(TimestampMicrosecondBuilder, append_timestamps_to_builder); + } + // For complex element types, fall back to per-row dispatch + _ => { + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + list_builder.append_null(); + } else { + append_list_element(element_type, list_builder, &row.get_array(column_idx))?; + } + } + } + } + + Ok(()) +} + +/// Appends a batch of map values to the map builder with a single type dispatch. +/// This moves type dispatch from O(rows × 2) to O(2), improving performance for maps. +#[allow(clippy::too_many_arguments)] +fn append_map_column_batch( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + schema: &[DataType], + column_idx: usize, + field: &arrow::datatypes::FieldRef, + map_builder: &mut MapBuilder, Box>, +) -> Result<(), CometError> { + let mut row = SparkUnsafeRow::new(schema); + let (key_field, value_field, _) = get_map_key_value_fields(field)?; + let key_type = key_field.data_type(); + let value_type = value_field.data_type(); + + // Helper macro for processing maps with primitive key/value types + // Uses scoped borrows to avoid borrow checker conflicts + macro_rules! process_primitive_maps { + ($key_builder:ty, $key_append:ident, $val_builder:ty, $val_append:ident) => {{ + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + map_builder.append(false)?; + } else { + let map = row.get_map(column_idx); + // Process keys in a scope so borrow ends + { + let keys_builder = map_builder + .keys() + .as_any_mut() + .downcast_mut::<$key_builder>() + .expect(stringify!($key_builder)); + map.keys.$key_append::(keys_builder); + } + // Process values in a scope so borrow ends + { + let values_builder = map_builder + .values() + .as_any_mut() + .downcast_mut::<$val_builder>() + .expect(stringify!($val_builder)); + map.values.$val_append::(values_builder); + } + map_builder.append(true)?; + } + } + }}; + } + + // Optimize common map type combinations + match (key_type, value_type) { + // Map + (DataType::Int64, DataType::Int64) => { + process_primitive_maps!( + Int64Builder, + append_longs_to_builder, + Int64Builder, + append_longs_to_builder + ); + } + // Map + (DataType::Int64, DataType::Float64) => { + process_primitive_maps!( + Int64Builder, + append_longs_to_builder, + Float64Builder, + append_doubles_to_builder + ); + } + // Map + (DataType::Int32, DataType::Int32) => { + process_primitive_maps!( + Int32Builder, + append_ints_to_builder, + Int32Builder, + append_ints_to_builder + ); + } + // Map + (DataType::Int32, DataType::Int64) => { + process_primitive_maps!( + Int32Builder, + append_ints_to_builder, + Int64Builder, + append_longs_to_builder + ); + } + // For other types, fall back to per-row dispatch + _ => { + for i in row_start..row_end { + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); + + if row.is_null_at(column_idx) { + map_builder.append(false)?; + } else { + append_map_elements(field, map_builder, &row.get_map(column_idx))?; + } + } + } + } + + Ok(()) +} + +/// Appends struct fields to the struct builder using field-major order. +/// This processes one field at a time across all rows, which moves type dispatch +/// outside the row loop (O(fields) dispatches instead of O(rows × fields)). +#[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] +fn append_struct_fields_field_major( + row_addresses_ptr: *mut jlong, + row_sizes_ptr: *mut jint, + row_start: usize, + row_end: usize, + parent_row: &mut SparkUnsafeRow, + column_idx: usize, + struct_builder: &mut StructBuilder, + fields: &arrow::datatypes::Fields, +) -> Result<(), CometError> { + let num_rows = row_end - row_start; + let num_fields = fields.len(); + + // First pass: Build struct validity and collect which structs are null + // We use a Vec for simplicity; could use a bitset for better memory + let mut struct_is_null = Vec::with_capacity(num_rows); + + for i in row_start..row_end { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + + let is_null = parent_row.is_null_at(column_idx); + struct_is_null.push(is_null); + + if is_null { + struct_builder.append_null(); + } else { + struct_builder.append(true); + } + } + + // Helper macro for processing primitive fields + macro_rules! process_field { + ($builder_type:ty, $field_idx:expr, $get_value:expr) => {{ + let field_builder = get_field_builder!(struct_builder, $builder_type, $field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + // Struct is null, field is also null + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at($field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value($get_value(&nested_row, $field_idx)); + } + } + } + }}; + } + + // Second pass: Process each field across all rows + for (field_idx, field) in fields.iter().enumerate() { + match field.data_type() { + DataType::Boolean => { + process_field!(BooleanBuilder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_boolean(idx)); + } + DataType::Int8 => { + process_field!(Int8Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_byte(idx)); + } + DataType::Int16 => { + process_field!(Int16Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_short(idx)); + } + DataType::Int32 => { + process_field!(Int32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_int(idx)); + } + DataType::Int64 => { + process_field!(Int64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_long(idx)); + } + DataType::Float32 => { + process_field!(Float32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_float(idx)); + } + DataType::Float64 => { + process_field!(Float64Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_double(idx)); + } + DataType::Date32 => { + process_field!(Date32Builder, field_idx, |row: &SparkUnsafeRow, idx| row + .get_date(idx)); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + process_field!( + TimestampMicrosecondBuilder, + field_idx, + |row: &SparkUnsafeRow, idx| row.get_timestamp(idx) + ); + } + DataType::Binary => { + let field_builder = get_field_builder!(struct_builder, BinaryBuilder, field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(nested_row.get_binary(field_idx)); + } + } + } + } + DataType::Utf8 => { + let field_builder = get_field_builder!(struct_builder, StringBuilder, field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(nested_row.get_string(field_idx)); + } + } + } + } + DataType::Decimal128(p, _) => { + let p = *p; + let field_builder = + get_field_builder!(struct_builder, Decimal128Builder, field_idx); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + field_builder.append_null(); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + + if nested_row.is_null_at(field_idx) { + field_builder.append_null(); + } else { + field_builder.append_value(nested_row.get_decimal(field_idx, p)); + } + } + } + } + // For nested structs, apply field-major processing recursively + DataType::Struct(nested_fields) => { + let nested_builder = get_field_builder!(struct_builder, StructBuilder, field_idx); + + // Collect nested struct addresses and sizes in one pass, building validity + let mut nested_addresses: Vec = Vec::with_capacity(num_rows); + let mut nested_sizes: Vec = Vec::with_capacity(num_rows); + let mut nested_is_null: Vec = Vec::with_capacity(num_rows); + + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + // Parent struct is null, nested struct is also null + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let parent_struct = parent_row.get_struct(column_idx, num_fields); + + if parent_struct.is_null_at(field_idx) { + nested_builder.append_null(); + nested_is_null.push(true); + nested_addresses.push(0); + nested_sizes.push(0); + } else { + nested_builder.append(true); + nested_is_null.push(false); + // Get nested struct address and size + let nested_row = + parent_struct.get_struct(field_idx, nested_fields.len()); + nested_addresses.push(nested_row.get_row_addr()); + nested_sizes.push(nested_row.get_row_size()); + } + } + } + + // Recursively process nested struct fields in field-major order + append_nested_struct_fields_field_major( + &nested_addresses, + &nested_sizes, + &nested_is_null, + nested_builder, + nested_fields, + )?; + } + // For list and map, fall back to append_field since they have variable-length elements + dt @ (DataType::List(_) | DataType::Map(_, _)) => { + for (row_idx, i) in (row_start..row_end).enumerate() { + if struct_is_null[row_idx] { + let null_row = SparkUnsafeRow::default(); + append_field(dt, struct_builder, &null_row, field_idx)?; + } else { + read_row_at!(parent_row, row_addresses_ptr, row_sizes_ptr, i); + let nested_row = parent_row.get_struct(column_idx, num_fields); + append_field(dt, struct_builder, &nested_row, field_idx)?; + } + } + } + _ => { + unreachable!( + "Unsupported data type of struct field: {:?}", + field.data_type() + ) + } + } + } + + Ok(()) +} + /// Appends column of top rows to the given array builder. +/// +/// # Safety +/// +/// The caller must ensure: +/// - `row_addresses_ptr` points to an array of at least `row_end` jlong values +/// - `row_sizes_ptr` points to an array of at least `row_end` jint values +/// - Each address in `row_addresses_ptr[row_start..row_end]` points to valid Spark UnsafeRow data +/// - The memory remains valid for the duration of this function call +/// +/// These invariants are guaranteed when called from JNI with arrays provided by the JVM. #[allow(clippy::redundant_closure_call, clippy::too_many_arguments)] pub(crate) fn append_columns( row_addresses_ptr: *mut jlong, @@ -461,9 +1167,7 @@ pub(crate) fn append_columns( let mut row = SparkUnsafeRow::new(schema); for i in row_start..row_end { - let row_addr = unsafe { *row_addresses_ptr.add(i) }; - let row_size = unsafe { *row_sizes_ptr.add(i) }; - row.point_to(row_addr, row_size); + read_row_at!(row, row_addresses_ptr, row_sizes_ptr, i); let is_null = row.is_null_at(column_idx); @@ -588,47 +1292,31 @@ pub(crate) fn append_columns( MapBuilder, Box>, builder ); - let mut row = SparkUnsafeRow::new(schema); - - for i in row_start..row_end { - let row_addr = unsafe { *row_addresses_ptr.add(i) }; - let row_size = unsafe { *row_sizes_ptr.add(i) }; - row.point_to(row_addr, row_size); - - let is_null = row.is_null_at(column_idx); - - if is_null { - // The map is null. - // Append a null value to the map builder. - map_builder.append(false)?; - } else { - append_map_elements(field, map_builder, &row.get_map(column_idx))? - } - } + // Use batched processing for better performance + append_map_column_batch( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + schema, + column_idx, + field, + map_builder, + )?; } DataType::List(field) => { let list_builder = downcast_builder_ref!(ListBuilder>, builder); - let mut row = SparkUnsafeRow::new(schema); - - for i in row_start..row_end { - let row_addr = unsafe { *row_addresses_ptr.add(i) }; - let row_size = unsafe { *row_sizes_ptr.add(i) }; - row.point_to(row_addr, row_size); - - let is_null = row.is_null_at(column_idx); - - if is_null { - // The list is null. - // Append a null value to the list builder. - list_builder.append_null(); - } else { - append_list_element( - field.data_type(), - list_builder, - &row.get_array(column_idx), - )? - } - } + // Use batched processing for better performance + append_list_column_batch( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + schema, + column_idx, + field.data_type(), + list_builder, + )?; } DataType::Struct(fields) => { let struct_builder = builder @@ -637,27 +1325,17 @@ pub(crate) fn append_columns( .expect("StructBuilder"); let mut row = SparkUnsafeRow::new(schema); - for i in row_start..row_end { - let row_addr = unsafe { *row_addresses_ptr.add(i) }; - let row_size = unsafe { *row_sizes_ptr.add(i) }; - row.point_to(row_addr, row_size); - - let is_null = row.is_null_at(column_idx); - - let nested_row = if is_null { - // The struct is null. - // Append a null value to the struct builder and field builders. - struct_builder.append_null(); - SparkUnsafeRow::default() - } else { - struct_builder.append(true); - row.get_struct(column_idx, fields.len()) - }; - - for (idx, field) in fields.into_iter().enumerate() { - append_field(field.data_type(), struct_builder, &nested_row, idx)?; - } - } + // Use field-major processing to avoid per-row type dispatch + append_struct_fields_field_major( + row_addresses_ptr, + row_sizes_ptr, + row_start, + row_end, + &mut row, + column_idx, + struct_builder, + fields, + )?; } _ => { unreachable!("Unsupported data type of column: {:?}", dt)