diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index a043259694c1..aee5163389c1 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -540,7 +540,7 @@ pub fn min_string_view(array: &StringViewArray) -> Option<&str> { /// Returns the sum of values in the array. /// /// This doesn't detect overflow. Once overflowing, the result will wrap around. -/// For an overflow-checking variant, use `sum_array_checked` instead. +/// For an overflow-checking variant, use [`sum_array_checked`] instead. pub fn sum_array>(array: A) -> Option where T: ArrowNumericType, @@ -567,6 +567,12 @@ where Some(sum) } + DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { + DataType::Int16 => ree::sum_wrapping::(&array), + DataType::Int32 => ree::sum_wrapping::(&array), + DataType::Int64 => ree::sum_wrapping::(&array), + _ => None, + }, _ => sum::(as_primitive_array(&array)), } } @@ -574,7 +580,9 @@ where /// Returns the sum of values in the array. /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `sum_array` instead. +/// use [`sum_array`] instead. +/// Additionally returns an `Err` on run-end-encoded arrays with a provided +/// values type parameter that is incorrect. pub fn sum_array_checked>( array: A, ) -> Result, ArrowError> @@ -603,10 +611,111 @@ where Ok(Some(sum)) } + DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() { + DataType::Int16 => ree::sum_checked::(&array), + DataType::Int32 => ree::sum_checked::(&array), + DataType::Int64 => ree::sum_checked::(&array), + _ => Ok(None), + }, _ => sum_checked::(as_primitive_array(&array)), } } +// Logic for summing run-end-encoded arrays. +mod ree { + use std::convert::Infallible; + + use arrow_array::cast::AsArray; + use arrow_array::types::RunEndIndexType; + use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType, PrimitiveArray, TypedRunArray}; + use arrow_buffer::ArrowNativeType; + use arrow_schema::ArrowError; + + /// Downcasts an array to a TypedRunArray. + fn downcast<'a, I: RunEndIndexType, V: ArrowNumericType>( + array: &'a dyn Array, + ) -> Option>> { + let array = array.as_run_opt::()?; + // We only support RunArray wrapping primitive types. + array.downcast::>() + } + + /// Computes the sum (wrapping) of the array values. + pub(super) fn sum_wrapping( + array: &dyn Array, + ) -> Option { + let ree = downcast::(array)?; + let Ok(sum) = fold(ree, |acc, val, len| -> Result { + println!("Adding {:?}x{} to {:?}", val, len, acc); + Ok(acc.add_wrapping(val.mul_wrapping(V::Native::usize_as(len)))) + }); + sum + } + + /// Computes the sum (erroring on overflow) of the array values. + pub(super) fn sum_checked( + array: &dyn Array, + ) -> Result, ArrowError> { + let Some(ree) = downcast::(array) else { + return Err(ArrowError::InvalidArgumentError( + "Input array is not a TypedRunArray<'_, _, PrimitiveArray".to_string(), + )); + }; + fold(ree, |acc, val, len| -> Result { + let Some(len) = V::Native::from_usize(len) else { + return Err(ArrowError::ArithmeticOverflow(format!( + "Cannot convert a run-end index ({:?}) to the value type ({})", + len, + std::any::type_name::() + ))); + }; + acc.add_checked(val.mul_checked(len)?) + }) + } + + /// Folds over the values in a run-end-encoded array. + fn fold<'a, I: RunEndIndexType, V: ArrowNumericType, F, E>( + array: TypedRunArray<'a, I, PrimitiveArray>, + mut f: F, + ) -> Result, E> + where + F: FnMut(V::Native, V::Native, usize) -> Result, + { + let run_ends = array.run_ends(); + let logical_start = run_ends.offset(); + let logical_end = run_ends.offset() + run_ends.len(); + let run_ends = run_ends.sliced_values(); + + let values_slice = array.run_array().values_slice(); + let values = values_slice + .as_any() + .downcast_ref::>() + // Safety: we know the values array is PrimitiveArray. + .unwrap(); + + let mut prev_end = 0; + let mut acc = V::Native::ZERO; + let mut has_non_null_value = false; + + for (run_end, value) in run_ends.zip(values) { + let current_run_end = run_end.as_usize().clamp(logical_start, logical_end); + let run_length = current_run_end - prev_end; + + if let Some(value) = value { + has_non_null_value = true; + acc = f(acc, value, run_length)?; + } + + prev_end = current_run_end; + if current_run_end == logical_end { + break; + } + } + + Ok(if has_non_null_value { Some(acc) } else { None }) + } +} + /// Returns the min of values in the array of `ArrowNumericType` type, or dictionary /// array with value of `ArrowNumericType` type. pub fn min_array>(array: A) -> Option @@ -639,6 +748,20 @@ where { match array.data_type() { DataType::Dictionary(_, _) => min_max_helper::(array, cmp), + DataType::RunEndEncoded(run_ends, _) => { + // We can directly perform min/max on the values child array, as any + // run must have non-zero length. + let array: &dyn Array = &array; + let values = match run_ends.data_type() { + DataType::Int16 => array.as_run_opt::()?.values_slice(), + DataType::Int32 => array.as_run_opt::()?.values_slice(), + DataType::Int64 => array.as_run_opt::()?.values_slice(), + _ => return None, + }; + // We only support RunArray wrapping primitive types. + let values = values.as_any().downcast_ref::>()?; + m(values) + } _ => m(as_primitive_array(&array)), } } @@ -751,7 +874,7 @@ pub fn bool_or(array: &BooleanArray) -> Option { /// Returns `Ok(None)` if the array is empty or only contains null values. /// /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant, -/// use `sum` instead. +/// use [`sum`] instead. pub fn sum_checked(array: &PrimitiveArray) -> Result, ArrowError> where T: ArrowNumericType, @@ -799,7 +922,7 @@ where /// Returns `None` if the array is empty or only contains null values. /// /// This doesn't detect overflow in release mode by default. Once overflowing, the result will -/// wrap around. For an overflow-checking variant, use `sum_checked` instead. +/// wrap around. For an overflow-checking variant, use [`sum_checked`] instead. pub fn sum(array: &PrimitiveArray) -> Option where T::Native: ArrowNativeTypeOp, @@ -1750,4 +1873,170 @@ mod tests { sum_checked(&a).expect_err("overflow should be detected"); sum_array_checked::(&a).expect_err("overflow should be detected"); } + + /// Helper for building a RunArray. + fn make_run_array<'a, I: RunEndIndexType, V: ArrowNumericType, ItemType>( + values: impl IntoIterator, + ) -> RunArray + where + ItemType: Clone + Into> + 'static, + { + let mut builder = arrow_array::builder::PrimitiveRunBuilder::::new(); + for v in values.into_iter() { + builder.append_option((*v).clone().into()); + } + builder.finish() + } + + #[test] + fn test_ree_sum_array_basic() { + let run_array = make_run_array::(&[10, 10, 20, 30, 30, 30]); + let typed_array = run_array.downcast::().unwrap(); + + let result = sum_array::(typed_array); + assert_eq!(result, Some(130)); + + let result = sum_array_checked::(typed_array).unwrap(); + assert_eq!(result, Some(130)); + } + + #[test] + fn test_ree_sum_array_empty() { + let run_array = make_run_array::(&[]); + let typed_array = run_array.downcast::().unwrap(); + + let result = sum_array::(typed_array); + assert_eq!(result, None); + + let result = sum_array_checked::(typed_array).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_ree_sum_array_with_nulls() { + let run_array = + make_run_array::(&[Some(10), None, Some(20), None, Some(30)]); + let typed_array = run_array.downcast::().unwrap(); + + let result = sum_array::(typed_array); + assert_eq!(result, Some(60)); + + let result = sum_array_checked::(typed_array).unwrap(); + assert_eq!(result, Some(60)); + } + + #[test] + fn test_ree_sum_array_with_only_nulls() { + let run_array = make_run_array::(&[None, None, None, None, None]); + let typed_array = run_array.downcast::().unwrap(); + + let result = sum_array::(typed_array); + assert_eq!(result, None); + + let result = sum_array_checked::(typed_array).unwrap(); + assert_eq!(result, None); + } + + #[test] + fn test_ree_sum_array_overflow() { + let run_array = make_run_array::(&[126, 2]); + let typed_array = run_array.downcast::().unwrap(); + + // i8 range is -128..=127. 126+2 overflows to -128. + let result = sum_array::(typed_array); + assert_eq!(result, Some(-128)); + + let result = sum_array_checked::(typed_array); + assert!(result.is_err()); + } + + #[test] + fn test_ree_sum_array_sliced() { + let run_array = make_run_array::(&[0, 10, 10, 10, 20, 30, 30, 30]); + // Skip 2 values at the start and 1 at the end. + let sliced = run_array.slice(2, 5); + let typed_array = sliced.downcast::().unwrap(); + + let result = sum_array::(typed_array); + assert_eq!(result, Some(100)); + + let result = sum_array_checked::(typed_array).unwrap(); + assert_eq!(result, Some(100)); + } + + #[test] + fn test_ree_min_max_array_basic() { + let run_array = make_run_array::(&[30, 30, 10, 20, 20]); + let typed_array = run_array.downcast::().unwrap(); + + let result = min_array::(typed_array); + assert_eq!(result, Some(10)); + + let result = max_array::(typed_array); + assert_eq!(result, Some(30)); + } + + #[test] + fn test_ree_min_max_array_empty() { + let run_array = make_run_array::(&[]); + let typed_array = run_array.downcast::().unwrap(); + + let result = min_array::(typed_array); + assert_eq!(result, None); + + let result = max_array::(typed_array); + assert_eq!(result, None); + } + + #[test] + fn test_ree_min_max_array_float() { + let run_array = make_run_array::(&[5.5, 5.5, 2.1, 8.9, 8.9]); + let typed_array = run_array.downcast::().unwrap(); + + let result = min_array::(typed_array); + assert_eq!(result, Some(2.1)); + + let result = max_array::(typed_array); + assert_eq!(result, Some(8.9)); + } + + #[test] + fn test_ree_min_max_array_with_nulls() { + let run_array = make_run_array::(&[None, Some(10)]); + let typed_array = run_array.downcast::().unwrap(); + + let result = min_array::(typed_array); + assert_eq!(result, Some(10)); + + let result = max_array::(typed_array); + assert_eq!(result, Some(10)); + } + + #[test] + fn test_ree_min_max_array_sliced() { + let run_array = make_run_array::(&[0, 30, 30, 10, 20, 20, 100]); + // Skip 1 value at the start and 1 at the end. + let sliced = run_array.slice(1, 5); + let typed_array = sliced.downcast::().unwrap(); + + let result = min_array::(typed_array); + assert_eq!(result, Some(10)); + + let result = max_array::(typed_array); + assert_eq!(result, Some(30)); + } + + #[test] + fn test_ree_min_max_array_sliced_mid_run() { + let run_array = make_run_array::(&[0, 0, 30, 10, 20, 100, 100]); + // Skip 1 value at the start and 1 at the end. + let sliced = run_array.slice(1, 5); + let typed_array = sliced.downcast::().unwrap(); + + let result = min_array::(typed_array); + assert_eq!(result, Some(0)); + + let result = max_array::(typed_array); + assert_eq!(result, Some(100)); + } }