Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
297 changes: 293 additions & 4 deletions arrow-arith/src/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ pub fn min_string_view(array: &StringViewArray) -> Option<&str> {
/// Returns the sum of values in the array.
///
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
/// For an overflow-checking variant, use `sum_array_checked` instead.
/// For an overflow-checking variant, use [`sum_array_checked`] instead.
pub fn sum_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
where
T: ArrowNumericType,
Expand All @@ -567,14 +567,22 @@ where

Some(sum)
}
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
DataType::Int16 => ree::sum_wrapping::<types::Int16Type, T>(&array),
DataType::Int32 => ree::sum_wrapping::<types::Int32Type, T>(&array),
DataType::Int64 => ree::sum_wrapping::<types::Int64Type, T>(&array),
_ => None,
},
_ => sum::<T>(as_primitive_array(&array)),
}
}

/// Returns the sum of values in the array.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `sum_array` instead.
/// use [`sum_array`] instead.
/// Additionally returns an `Err` on run-end-encoded arrays with a provided
/// values type parameter that is incorrect.
pub fn sum_array_checked<T, A: ArrayAccessor<Item = T::Native>>(
array: A,
) -> Result<Option<T::Native>, ArrowError>
Expand Down Expand Up @@ -603,10 +611,111 @@ where

Ok(Some(sum))
}
DataType::RunEndEncoded(run_ends, _) => match run_ends.data_type() {
DataType::Int16 => ree::sum_checked::<types::Int16Type, T>(&array),
DataType::Int32 => ree::sum_checked::<types::Int32Type, T>(&array),
DataType::Int64 => ree::sum_checked::<types::Int64Type, T>(&array),
_ => Ok(None),
},
_ => sum_checked::<T>(as_primitive_array(&array)),
}
}

// Logic for summing run-end-encoded arrays.
mod ree {
use std::convert::Infallible;

use arrow_array::cast::AsArray;
use arrow_array::types::RunEndIndexType;
use arrow_array::{Array, ArrowNativeTypeOp, ArrowNumericType, PrimitiveArray, TypedRunArray};
use arrow_buffer::ArrowNativeType;
use arrow_schema::ArrowError;

/// Downcasts an array to a TypedRunArray.
fn downcast<'a, I: RunEndIndexType, V: ArrowNumericType>(
array: &'a dyn Array,
) -> Option<TypedRunArray<'a, I, PrimitiveArray<V>>> {
let array = array.as_run_opt::<I>()?;
// We only support RunArray wrapping primitive types.
array.downcast::<PrimitiveArray<V>>()
}

/// Computes the sum (wrapping) of the array values.
pub(super) fn sum_wrapping<I: RunEndIndexType, V: ArrowNumericType>(
array: &dyn Array,
) -> Option<V::Native> {
let ree = downcast::<I, V>(array)?;
let Ok(sum) = fold(ree, |acc, val, len| -> Result<V::Native, Infallible> {
println!("Adding {:?}x{} to {:?}", val, len, acc);
Ok(acc.add_wrapping(val.mul_wrapping(V::Native::usize_as(len))))
});
sum
}

/// Computes the sum (erroring on overflow) of the array values.
pub(super) fn sum_checked<I: RunEndIndexType, V: ArrowNumericType>(
array: &dyn Array,
) -> Result<Option<V::Native>, ArrowError> {
let Some(ree) = downcast::<I, V>(array) else {
return Err(ArrowError::InvalidArgumentError(
"Input array is not a TypedRunArray<'_, _, PrimitiveArray<T>".to_string(),
));
};
fold(ree, |acc, val, len| -> Result<V::Native, ArrowError> {
let Some(len) = V::Native::from_usize(len) else {
return Err(ArrowError::ArithmeticOverflow(format!(
"Cannot convert a run-end index ({:?}) to the value type ({})",
len,
std::any::type_name::<V::Native>()
)));
};
acc.add_checked(val.mul_checked(len)?)
})
}

/// Folds over the values in a run-end-encoded array.
fn fold<'a, I: RunEndIndexType, V: ArrowNumericType, F, E>(
array: TypedRunArray<'a, I, PrimitiveArray<V>>,
mut f: F,
) -> Result<Option<V::Native>, E>
where
F: FnMut(V::Native, V::Native, usize) -> Result<V::Native, E>,
{
let run_ends = array.run_ends();
let logical_start = run_ends.offset();
let logical_end = run_ends.offset() + run_ends.len();
let run_ends = run_ends.sliced_values();

let values_slice = array.run_array().values_slice();
let values = values_slice
.as_any()
.downcast_ref::<PrimitiveArray<V>>()
// Safety: we know the values array is PrimitiveArray<V>.
.unwrap();

let mut prev_end = 0;
let mut acc = V::Native::ZERO;
let mut has_non_null_value = false;

for (run_end, value) in run_ends.zip(values) {
let current_run_end = run_end.as_usize().clamp(logical_start, logical_end);
let run_length = current_run_end - prev_end;

if let Some(value) = value {
has_non_null_value = true;
acc = f(acc, value, run_length)?;
}

prev_end = current_run_end;
if current_run_end == logical_end {
break;
}
}

Ok(if has_non_null_value { Some(acc) } else { None })
}
}

/// Returns the min of values in the array of `ArrowNumericType` type, or dictionary
/// array with value of `ArrowNumericType` type.
pub fn min_array<T, A: ArrayAccessor<Item = T::Native>>(array: A) -> Option<T::Native>
Expand Down Expand Up @@ -639,6 +748,20 @@ where
{
match array.data_type() {
DataType::Dictionary(_, _) => min_max_helper::<T::Native, _, _>(array, cmp),
DataType::RunEndEncoded(run_ends, _) => {
// We can directly perform min/max on the values child array, as any
// run must have non-zero length.
let array: &dyn Array = &array;
let values = match run_ends.data_type() {
DataType::Int16 => array.as_run_opt::<types::Int16Type>()?.values_slice(),
DataType::Int32 => array.as_run_opt::<types::Int32Type>()?.values_slice(),
DataType::Int64 => array.as_run_opt::<types::Int64Type>()?.values_slice(),
_ => return None,
};
// We only support RunArray wrapping primitive types.
let values = values.as_any().downcast_ref::<PrimitiveArray<T>>()?;
m(values)
}
_ => m(as_primitive_array(&array)),
}
}
Expand Down Expand Up @@ -751,7 +874,7 @@ pub fn bool_or(array: &BooleanArray) -> Option<bool> {
/// Returns `Ok(None)` if the array is empty or only contains null values.
///
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
/// use `sum` instead.
/// use [`sum`] instead.
pub fn sum_checked<T>(array: &PrimitiveArray<T>) -> Result<Option<T::Native>, ArrowError>
where
T: ArrowNumericType,
Expand Down Expand Up @@ -799,7 +922,7 @@ where
/// Returns `None` if the array is empty or only contains null values.
///
/// This doesn't detect overflow in release mode by default. Once overflowing, the result will
/// wrap around. For an overflow-checking variant, use `sum_checked` instead.
/// wrap around. For an overflow-checking variant, use [`sum_checked`] instead.
pub fn sum<T: ArrowNumericType>(array: &PrimitiveArray<T>) -> Option<T::Native>
where
T::Native: ArrowNativeTypeOp,
Expand Down Expand Up @@ -1750,4 +1873,170 @@ mod tests {
sum_checked(&a).expect_err("overflow should be detected");
sum_array_checked::<Int32Type, _>(&a).expect_err("overflow should be detected");
}

/// Helper for building a RunArray.
fn make_run_array<'a, I: RunEndIndexType, V: ArrowNumericType, ItemType>(
values: impl IntoIterator<Item = &'a ItemType>,
) -> RunArray<I>
where
ItemType: Clone + Into<Option<V::Native>> + 'static,
{
let mut builder = arrow_array::builder::PrimitiveRunBuilder::<I, V>::new();
for v in values.into_iter() {
builder.append_option((*v).clone().into());
}
builder.finish()
}

#[test]
fn test_ree_sum_array_basic() {
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[10, 10, 20, 30, 30, 30]);
let typed_array = run_array.downcast::<Int32Array>().unwrap();

let result = sum_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(130));

let result = sum_array_checked::<Int32Type, _>(typed_array).unwrap();
assert_eq!(result, Some(130));
}

#[test]
fn test_ree_sum_array_empty() {
let run_array = make_run_array::<Int16Type, Int32Type, i32>(&[]);
let typed_array = run_array.downcast::<Int32Array>().unwrap();

let result = sum_array::<Int32Type, _>(typed_array);
assert_eq!(result, None);

let result = sum_array_checked::<Int32Type, _>(typed_array).unwrap();
assert_eq!(result, None);
}

#[test]
fn test_ree_sum_array_with_nulls() {
let run_array =
make_run_array::<Int16Type, Int32Type, _>(&[Some(10), None, Some(20), None, Some(30)]);
let typed_array = run_array.downcast::<Int32Array>().unwrap();

let result = sum_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(60));

let result = sum_array_checked::<Int32Type, _>(typed_array).unwrap();
assert_eq!(result, Some(60));
}

#[test]
fn test_ree_sum_array_with_only_nulls() {
let run_array = make_run_array::<Int16Type, Int16Type, _>(&[None, None, None, None, None]);
let typed_array = run_array.downcast::<Int16Array>().unwrap();

let result = sum_array::<Int16Type, _>(typed_array);
assert_eq!(result, None);

let result = sum_array_checked::<Int16Type, _>(typed_array).unwrap();
assert_eq!(result, None);
}

#[test]
fn test_ree_sum_array_overflow() {
let run_array = make_run_array::<Int16Type, Int8Type, _>(&[126, 2]);
let typed_array = run_array.downcast::<Int8Array>().unwrap();

// i8 range is -128..=127. 126+2 overflows to -128.
let result = sum_array::<Int8Type, _>(typed_array);
assert_eq!(result, Some(-128));

let result = sum_array_checked::<Int8Type, _>(typed_array);
assert!(result.is_err());
}

#[test]
fn test_ree_sum_array_sliced() {
let run_array = make_run_array::<Int16Type, UInt8Type, _>(&[0, 10, 10, 10, 20, 30, 30, 30]);
// Skip 2 values at the start and 1 at the end.
let sliced = run_array.slice(2, 5);
let typed_array = sliced.downcast::<UInt8Array>().unwrap();

let result = sum_array::<UInt8Type, _>(typed_array);
assert_eq!(result, Some(100));

let result = sum_array_checked::<UInt8Type, _>(typed_array).unwrap();
assert_eq!(result, Some(100));
}

#[test]
fn test_ree_min_max_array_basic() {
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[30, 30, 10, 20, 20]);
let typed_array = run_array.downcast::<Int32Array>().unwrap();

let result = min_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(10));

let result = max_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(30));
}

#[test]
fn test_ree_min_max_array_empty() {
let run_array = make_run_array::<Int16Type, Int32Type, i32>(&[]);
let typed_array = run_array.downcast::<Int32Array>().unwrap();

let result = min_array::<Int32Type, _>(typed_array);
assert_eq!(result, None);

let result = max_array::<Int32Type, _>(typed_array);
assert_eq!(result, None);
}

#[test]
fn test_ree_min_max_array_float() {
let run_array = make_run_array::<Int16Type, Float64Type, _>(&[5.5, 5.5, 2.1, 8.9, 8.9]);
let typed_array = run_array.downcast::<Float64Array>().unwrap();

let result = min_array::<Float64Type, _>(typed_array);
assert_eq!(result, Some(2.1));

let result = max_array::<Float64Type, _>(typed_array);
assert_eq!(result, Some(8.9));
}

#[test]
fn test_ree_min_max_array_with_nulls() {
let run_array = make_run_array::<Int16Type, UInt8Type, _>(&[None, Some(10)]);
let typed_array = run_array.downcast::<UInt8Array>().unwrap();

let result = min_array::<UInt8Type, _>(typed_array);
assert_eq!(result, Some(10));

let result = max_array::<UInt8Type, _>(typed_array);
assert_eq!(result, Some(10));
}

#[test]
fn test_ree_min_max_array_sliced() {
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[0, 30, 30, 10, 20, 20, 100]);
// Skip 1 value at the start and 1 at the end.
let sliced = run_array.slice(1, 5);
let typed_array = sliced.downcast::<Int32Array>().unwrap();

let result = min_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(10));

let result = max_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(30));
}

#[test]
fn test_ree_min_max_array_sliced_mid_run() {
let run_array = make_run_array::<Int16Type, Int32Type, _>(&[0, 0, 30, 10, 20, 100, 100]);
// Skip 1 value at the start and 1 at the end.
let sliced = run_array.slice(1, 5);
let typed_array = sliced.downcast::<Int32Array>().unwrap();

let result = min_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(0));

let result = max_array::<Int32Type, _>(typed_array);
assert_eq!(result, Some(100));
}
}
Loading