Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 230 additions & 11 deletions datafusion/functions/src/math/round.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ use arrow::datatypes::DataType::{
};
use arrow::datatypes::{
ArrowNativeTypeOp, DataType, Decimal32Type, Decimal64Type, Decimal128Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int32Type,
Decimal256Type, DecimalType, Float32Type, Float64Type, Int8Type, Int16Type,
Int32Type, Int64Type, UInt8Type, UInt16Type, UInt32Type, UInt64Type,
};
use arrow::datatypes::{Field, FieldRef};
use arrow::error::ArrowError;
Expand Down Expand Up @@ -180,6 +181,7 @@ impl Default for RoundFunc {
impl RoundFunc {
pub fn new() -> Self {
let decimal = Coercion::new_exact(TypeSignatureClass::Decimal);
let integer = Coercion::new_exact(TypeSignatureClass::Integer);
let decimal_places = Coercion::new_implicit(
TypeSignatureClass::Native(logical_int32()),
vec![TypeSignatureClass::Integer],
Expand All @@ -199,6 +201,11 @@ impl RoundFunc {
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![decimal]),
TypeSignature::Coercible(vec![
integer.clone(),
decimal_places.clone(),
]),
TypeSignature::Coercible(vec![integer]),
TypeSignature::Coercible(vec![
float32.clone(),
decimal_places.clone(),
Expand Down Expand Up @@ -266,6 +273,7 @@ impl ScalarUDFImpl for RoundFunc {
>(
*precision, *scale, decimal_places
)?,
dt if dt.is_integer() => dt.clone(),
_ => Float64,
};

Expand Down Expand Up @@ -298,14 +306,7 @@ impl ScalarUDFImpl for RoundFunc {
.cast_to(args.return_type(), None);
}

let dp = if let ScalarValue::Int32(Some(dp)) = dp_scalar {
*dp
} else {
return internal_err!(
"Unexpected datatype for decimal_places: {}",
dp_scalar.data_type()
);
};
let dp = decimal_places_from_scalar(dp_scalar)?;

match (value_scalar, args.return_type()) {
(ScalarValue::Float32(Some(v)), _) => {
Expand Down Expand Up @@ -416,6 +417,10 @@ impl ScalarUDFImpl for RoundFunc {
);
Ok(ColumnarValue::Scalar(scalar))
}
(value_scalar, return_type) if return_type.is_integer() => {
let rounded = round_integer_scalar(value_scalar, dp)?;
Ok(ColumnarValue::Scalar(rounded))
}
(ScalarValue::Null, _) => ColumnarValue::Scalar(ScalarValue::Null)
.cast_to(args.return_type(), None),
(value_scalar, return_type) => {
Expand Down Expand Up @@ -456,6 +461,17 @@ impl ScalarUDFImpl for RoundFunc {
}
}

macro_rules! round_integer_array {
($value_array:expr, $decimal_places:expr, $integer_type:ty, $round_fn:ident) => {{
let result = calculate_binary_math::<$integer_type, Int32Type, $integer_type, _>(
$value_array.as_ref(),
$decimal_places,
$round_fn,
)?;
result as _
}};
}

fn round_columnar(
value: &ColumnarValue,
decimal_places: &ColumnarValue,
Expand Down Expand Up @@ -620,6 +636,30 @@ fn round_columnar(
)?;
result as _
}
(DataType::Int8, DataType::Int8) => {
round_integer_array!(value_array, decimal_places, Int8Type, round_i8)
}
(DataType::Int16, DataType::Int16) => {
round_integer_array!(value_array, decimal_places, Int16Type, round_i16)
}
(DataType::Int32, DataType::Int32) => {
round_integer_array!(value_array, decimal_places, Int32Type, round_i32)
}
(DataType::Int64, DataType::Int64) => {
round_integer_array!(value_array, decimal_places, Int64Type, round_i64)
}
(DataType::UInt8, DataType::UInt8) => {
round_integer_array!(value_array, decimal_places, UInt8Type, round_u8)
}
(DataType::UInt16, DataType::UInt16) => {
round_integer_array!(value_array, decimal_places, UInt16Type, round_u16)
}
(DataType::UInt32, DataType::UInt32) => {
round_integer_array!(value_array, decimal_places, UInt32Type, round_u32)
}
(DataType::UInt64, DataType::UInt64) => {
round_integer_array!(value_array, decimal_places, UInt64Type, round_u64)
}
(other, _) => exec_err!("Unsupported data type {other:?} for function round")?,
};

Expand All @@ -630,6 +670,140 @@ fn round_columnar(
}
}

fn round_signed_integer(
value: i128,
decimal_places: i32,
min_value: i128,
max_value: i128,
) -> Result<i128, ArrowError> {
if decimal_places >= 0 || value == 0 {
return Ok(value);
}

let Some(factor) = 10_i128.checked_pow(decimal_places.unsigned_abs()) else {
return Ok(0);
};
let threshold = factor / 2;
let mut quotient = value / factor;
let remainder = value % factor;

if remainder >= threshold {
quotient += 1;
} else if remainder <= -threshold {
quotient -= 1;
}

let rounded = quotient.checked_mul(factor).ok_or_else(|| {
ArrowError::ComputeError("Overflow while rounding integer".into())
})?;
if rounded < min_value || rounded > max_value {
return Err(ArrowError::ComputeError(
"Overflow while rounding integer".into(),
));
}
Ok(rounded)
}

fn round_unsigned_integer(
value: u128,
decimal_places: i32,
max_value: u128,
) -> Result<u128, ArrowError> {
if decimal_places >= 0 || value == 0 {
return Ok(value);
}

let Some(factor) = 10_u128.checked_pow(decimal_places.unsigned_abs()) else {
return Ok(0);
};
let threshold = factor / 2;
let mut quotient = value / factor;
let remainder = value % factor;

if remainder >= threshold {
quotient += 1;
}

let rounded = quotient.checked_mul(factor).ok_or_else(|| {
ArrowError::ComputeError("Overflow while rounding integer".into())
})?;
if rounded > max_value {
return Err(ArrowError::ComputeError(
"Overflow while rounding integer".into(),
));
}
Ok(rounded)
}

macro_rules! round_signed_integer_fn {
($fn_name:ident, $type:ty) => {
fn $fn_name(value: $type, decimal_places: i32) -> Result<$type, ArrowError> {
let rounded = round_signed_integer(
i128::from(value),
decimal_places,
i128::from(<$type>::MIN),
i128::from(<$type>::MAX),
)?;
Ok(rounded as $type)
}
};
}

macro_rules! round_unsigned_integer_fn {
($fn_name:ident, $type:ty) => {
fn $fn_name(value: $type, decimal_places: i32) -> Result<$type, ArrowError> {
let rounded = round_unsigned_integer(
u128::from(value),
decimal_places,
u128::from(<$type>::MAX),
)?;
Ok(rounded as $type)
}
};
}

round_signed_integer_fn!(round_i8, i8);
round_signed_integer_fn!(round_i16, i16);
round_signed_integer_fn!(round_i32, i32);
round_signed_integer_fn!(round_i64, i64);
round_unsigned_integer_fn!(round_u8, u8);
round_unsigned_integer_fn!(round_u16, u16);
round_unsigned_integer_fn!(round_u32, u32);
round_unsigned_integer_fn!(round_u64, u64);

fn round_integer_scalar(value: &ScalarValue, decimal_places: i32) -> Result<ScalarValue> {
match value {
ScalarValue::Int8(Some(v)) => {
Ok(ScalarValue::Int8(Some(round_i8(*v, decimal_places)?)))
}
ScalarValue::Int16(Some(v)) => {
Ok(ScalarValue::Int16(Some(round_i16(*v, decimal_places)?)))
}
ScalarValue::Int32(Some(v)) => {
Ok(ScalarValue::Int32(Some(round_i32(*v, decimal_places)?)))
}
ScalarValue::Int64(Some(v)) => {
Ok(ScalarValue::Int64(Some(round_i64(*v, decimal_places)?)))
}
ScalarValue::UInt8(Some(v)) => {
Ok(ScalarValue::UInt8(Some(round_u8(*v, decimal_places)?)))
}
ScalarValue::UInt16(Some(v)) => {
Ok(ScalarValue::UInt16(Some(round_u16(*v, decimal_places)?)))
}
ScalarValue::UInt32(Some(v)) => {
Ok(ScalarValue::UInt32(Some(round_u32(*v, decimal_places)?)))
}
ScalarValue::UInt64(Some(v)) => {
Ok(ScalarValue::UInt64(Some(round_u64(*v, decimal_places)?)))
}
other => internal_err!(
"Unexpected datatype for integer round(value, decimal_places): {}",
other.data_type()
),
}
}

fn round_float<T>(value: T, decimal_places: i32) -> Result<T, ArrowError>
where
T: num_traits::Float,
Expand Down Expand Up @@ -727,10 +901,12 @@ fn round_decimal_or_zero<V: ArrowNativeTypeOp>(
mod test {
use std::sync::Arc;

use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array};
use arrow::array::{ArrayRef, Float32Array, Float64Array, Int64Array, UInt64Array};
use datafusion_common::DataFusionError;
use datafusion_common::ScalarValue;
use datafusion_common::cast::{as_float32_array, as_float64_array};
use datafusion_common::cast::{
as_float32_array, as_float64_array, as_int64_array, as_uint64_array,
};
use datafusion_expr::ColumnarValue;

fn round_arrays(
Expand Down Expand Up @@ -825,6 +1001,49 @@ mod test {
assert_eq!(floats, &expected);
}

#[test]
fn test_round_large_int64_nonnegative_scale_noop() {
let value = i64::pow(2, 53) + 1;
let args: Vec<ArrayRef> = vec![
Arc::new(Int64Array::from(vec![value, value])),
Arc::new(Int64Array::from(vec![0, 2])),
];

let result = round_arrays(Arc::clone(&args[0]), Some(Arc::clone(&args[1])))
.expect("failed to initialize function round");
let integers =
as_int64_array(&result).expect("failed to initialize function round");

let expected = Int64Array::from(vec![value, value]);
assert_eq!(integers, &expected);
}

#[test]
fn test_round_integer_column_nonnegative_scale_noop() {
let values = Arc::new(Int64Array::from(vec![i64::pow(2, 53) + 1])) as ArrayRef;
let decimal_places = Arc::new(Int64Array::from(vec![2])) as ArrayRef;
let result = round_arrays(values, Some(decimal_places))
.expect("failed to initialize function round");
let integers =
as_int64_array(&result).expect("failed to initialize function round");

let expected = Int64Array::from(vec![i64::pow(2, 53) + 1]);
assert_eq!(integers, &expected);
}

#[test]
fn test_round_uint64_max_noop() {
let args: Vec<ArrayRef> = vec![Arc::new(UInt64Array::from(vec![u64::MAX]))];

let result = round_arrays(Arc::clone(&args[0]), None)
.expect("failed to initialize function round");
let integers =
as_uint64_array(&result).expect("failed to initialize function round");

let expected = UInt64Array::from(vec![u64::MAX]);
assert_eq!(integers, &expected);
}

#[test]
fn test_round_f32_cast_fail() {
let args: Vec<ArrayRef> = vec![
Expand Down