diff --git a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs index 4ab63e532c..b65c507320 100644 --- a/native/spark-expr/src/bitwise_funcs/bitwise_count.rs +++ b/native/spark-expr/src/bitwise_funcs/bitwise_count.rs @@ -16,7 +16,7 @@ // under the License. use arrow::{array::*, datatypes::DataType}; -use datafusion::common::{exec_err, internal_datafusion_err, internal_err, Result}; +use datafusion::common::{exec_err, internal_datafusion_err, Result}; use datafusion::logical_expr::{ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; use std::any::Any; @@ -99,15 +99,38 @@ pub fn spark_bit_count(args: [ColumnarValue; 1]) -> Result { DataType::Int16 => compute_op!(array, Int16Array), DataType::Int32 => compute_op!(array, Int32Array), DataType::Int64 => compute_op!(array, Int64Array), - _ => exec_err!("bit_count can't be evaluated because the expression's type is {:?}, not signed int", array.data_type()), + _ => exec_err!("bit_count can't be evaluated because the array's type is {:?}, not signed int/boolean", array.data_type()), }; result.map(ColumnarValue::Array) } - [ColumnarValue::Scalar(_)] => internal_err!("shouldn't go to bitwise count scalar path"), + [ColumnarValue::Scalar(scalar)] => { + use datafusion::common::ScalarValue; + let result = match scalar { + ScalarValue::Int8(Some(v)) => bit_count(v as i64), + ScalarValue::Int16(Some(v)) => bit_count(v as i64), + ScalarValue::Int32(Some(v)) => bit_count(v as i64), + ScalarValue::Int64(Some(v)) => bit_count(v), + ScalarValue::Boolean(Some(v)) => bit_count(if v { 1 } else { 0 }), + ScalarValue::Int8(None) + | ScalarValue::Int16(None) + | ScalarValue::Int32(None) + | ScalarValue::Int64(None) + | ScalarValue::Boolean(None) => { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))) + } + _ => { + return exec_err!( + "bit_count can't be evaluated because the scalar's type is {:?}, not signed int/boolean", + scalar.data_type() + ) + } + }; + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(result)))) + } } } -// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType) +// Here’s the equivalent Rust implementation of the bitCount function (similar to Java's bitCount for LongType) fn bit_count(i: i64) -> i32 { let mut u = i as u64; u = u - ((u >> 1) & 0x5555555555555555); @@ -121,7 +144,7 @@ fn bit_count(i: i64) -> i32 { #[cfg(test)] mod tests { - use datafusion::common::{cast::as_int32_array, Result}; + use datafusion::common::{cast::as_int32_array, Result, ScalarValue}; use super::*; @@ -133,8 +156,18 @@ mod tests { Some(12345), Some(89), Some(-3456), + Some(i32::MIN), + Some(i32::MAX), ]))); - let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]); + let expected = &Int32Array::from(vec![ + Some(1), + None, + Some(6), + Some(4), + Some(54), + Some(33), + Some(31), + ]); let ColumnarValue::Array(result) = spark_bit_count([args])? else { unreachable!() @@ -145,4 +178,16 @@ mod tests { Ok(()) } + + #[test] + fn bitwise_count_scalar() { + let args = ColumnarValue::Scalar(ScalarValue::Int64(Some(i64::MAX))); + + match spark_bit_count([args]) { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(actual)))) => { + assert_eq!(actual, 63) + } + _ => unreachable!(), + } + } } diff --git a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql index 640aa1e990..74a971f368 100644 --- a/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql +++ b/spark/src/test/resources/sql-tests/expressions/bitwise/bitwise.sql @@ -73,7 +73,7 @@ SELECT bit_get(i, pos) FROM test_bit_get query SELECT 1111 & 2, 1111 | 2, 1111 ^ 2 -query ignore(https://github.com/apache/datafusion-comet/issues/3341) +query SELECT bit_count(0), bit_count(7), bit_count(-1) query spark_answer_only