From 255cf3606cb6a75123020f3758c50ce35382a5b5 Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 19 Dec 2024 00:53:15 -0800 Subject: [PATCH 1/5] Support interval multiplication and division by arbitrary numerics --- arrow-arith/src/numeric.rs | 600 +++++++++++++++++++++++++++++++++++++ arrow-array/src/cast.rs | 4 +- 2 files changed, 602 insertions(+), 2 deletions(-) diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index b6af40f7d7c2..0f6a3325b021 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -230,6 +230,24 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result interval_op::(op, l, l_scalar, r, r_scalar), (Interval(DayTime), Interval(DayTime)) => interval_op::(op, l, l_scalar, r, r_scalar), (Interval(MonthDayNano), Interval(MonthDayNano)) => interval_op::(op, l, l_scalar, r, r_scalar), + (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Mul | Op::MulWrapping) => + match unit { + YearMonth => interval_mul_op::(op, l, l_scalar, r, r_scalar), + DayTime => interval_mul_op::(op, l, l_scalar, r, r_scalar), + MonthDayNano => interval_mul_op::(op, l, l_scalar, r, r_scalar), + }, + (lhs, Interval(unit)) if lhs.is_integer() && matches!(op, Op::Mul | Op::MulWrapping) => + match unit { + YearMonth => interval_mul_op::(op, l, l_scalar, r, r_scalar), + DayTime => interval_mul_op::(op, l, l_scalar, r, r_scalar), + MonthDayNano => interval_mul_op::(op, l, l_scalar, r, r_scalar), + }, + (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Div) => + match unit { + YearMonth => interval_div_op::(op, l, l_scalar, r, r_scalar), + DayTime => interval_div_op::(op, l, l_scalar, r, r_scalar), + MonthDayNano => interval_div_op::(op, l, l_scalar, r, r_scalar), + }, (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), @@ -550,6 +568,10 @@ date!(Date64Type); trait IntervalOp: ArrowPrimitiveType { fn add(left: Self::Native, right: Self::Native) -> Result; fn sub(left: Self::Native, right: Self::Native) -> Result; + fn mul_int(left: Self::Native, right: i32) -> Result; + fn mul_float(left: Self::Native, right: f64) -> Result; + fn div_int(left: Self::Native, right: i32) -> Result; + fn div_float(left: Self::Native, right: f64) -> Result; } impl IntervalOp for IntervalYearMonthType { @@ -560,6 +582,29 @@ impl IntervalOp for IntervalYearMonthType { fn sub(left: Self::Native, right: Self::Native) -> Result { left.sub_checked(right) } + + fn mul_int(left: Self::Native, right: i32) -> Result { + left.mul_checked(right) + } + + fn mul_float(left: Self::Native, right: f64) -> Result { + let result = (left as f64 * right).round() as i32; + Ok(result) + } + + fn div_int(left: Self::Native, right: i32) -> Result { + if right == 0 { + return Err(ArrowError::DivideByZero); + } + Ok((left as f64 / right as f64).round() as i32) + } + + fn div_float(left: Self::Native, right: f64) -> Result { + if right == 0.0 { + return Err(ArrowError::DivideByZero); + } + Ok((left as f64 / right).round() as i32) + } } impl IntervalOp for IntervalDayTimeType { @@ -578,6 +623,68 @@ impl IntervalOp for IntervalDayTimeType { let ms = l_ms.sub_checked(r_ms)?; Ok(Self::make_value(days, ms)) } + + fn mul_int(left: Self::Native, right: i32) -> Result { + let (days, ms) = Self::to_parts(left); + Ok(IntervalDayTimeType::make_value( + days.mul_checked(right)?, + ms.mul_checked(right)?, + )) + } + + fn mul_float(left: Self::Native, right: f64) -> Result { + let (days, ms) = Self::to_parts(left); + + // Calculate total days including fractional part + let total_days = days as f64 * right; + // Split into whole and fractional days + let whole_days = total_days.trunc() as i32; + let frac_days = total_days.fract(); + + // Convert fractional days to milliseconds (24 * 60 * 60 * 1000 = 86_400_000 ms per day) + let frac_ms = (frac_days * 86_400_000.0).round() as i32; + + // Calculate total milliseconds including the fractional days + let total_ms = (ms as f64 * right).round() as i32 + frac_ms; + + Ok(Self::make_value(whole_days, total_ms)) + } + + fn div_int(left: Self::Native, right: i32) -> Result { + if right == 0 { + return Err(ArrowError::DivideByZero); + } + let (days, ms) = Self::to_parts(left); + + // Convert everything to milliseconds to handle remainders + let total_ms = ms as i64 + (days as i64 * 86_400_000); // 24 * 60 * 60 * 1000 + let result_ms = total_ms / right as i64; + + // Convert back to days and milliseconds + let result_days = result_ms / 86_400_000; + let result_ms = result_ms % 86_400_000; + + Ok(Self::make_value(result_days as i32, result_ms as i32)) + } + + fn div_float(left: Self::Native, right: f64) -> Result { + if right == 0.0 { + return Err(ArrowError::DivideByZero); + } + let (days, ms) = Self::to_parts(left); + + // Convert everything to milliseconds to handle remainders + let total_ms = (ms as f64 + (days as f64 * 86_400_000.0)) / right; + + // Convert back to days and milliseconds + let result_days = (total_ms / 86_400_000.0).floor(); + let result_ms = total_ms % 86_400_000.0; + + Ok(Self::make_value( + result_days as i32, + result_ms.round() as i32, + )) + } } impl IntervalOp for IntervalMonthDayNanoType { @@ -598,6 +705,33 @@ impl IntervalOp for IntervalMonthDayNanoType { let nanos = l_nanos.sub_checked(r_nanos)?; Ok(Self::make_value(months, days, nanos)) } + + fn mul_int(left: Self::Native, right: i32) -> Result { + let (months, days, nanos) = Self::to_parts(left); + Ok(Self::make_value( + months.mul_checked(right)?, + days.mul_checked(right)?, + nanos.mul_checked(right as i64)?, + )) + } + + fn mul_float(_left: Self::Native, _right: f64) -> Result { + Err(ArrowError::InvalidArgumentError( + "Floating point multiplication not supported for MonthDayNano intervals".to_string(), + )) + } + + fn div_int(_left: Self::Native, _right: i32) -> Result { + Err(ArrowError::InvalidArgumentError( + "Integer division not supported for MonthDayNano intervals".to_string(), + )) + } + + fn div_float(_left: Self::Native, _right: f64) -> Result { + Err(ArrowError::InvalidArgumentError( + "Floating point division not supported for MonthDayNano intervals".to_string(), + )) + } } /// Perform arithmetic operation on an interval array @@ -621,6 +755,375 @@ fn interval_op( } } +/// Perform multiplication between an interval array and a numeric array +fn interval_mul_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + // Try both orderings to handle either (interval * numeric) or (numeric * interval) + if let Some(l_interval) = l.as_primitive_opt::() { + // Handle numeric multiplication based on data type + match r.data_type() { + DataType::Int8 => multiply_interval(l_interval, l_s, r.as_primitive::(), r_s), + DataType::Int16 => { + multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::Int32 => { + multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::Int64 => { + let r_int = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_int, + r_s, + T::mul_int( + l_interval, + i32::try_from(r_int).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {} to i32", + r_int + )) + })? + ) + )) + } + DataType::UInt8 => { + multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::UInt16 => { + multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::UInt32 => { + multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::UInt64 => { + let r_int = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_int, + r_s, + T::mul_int( + l_interval, + i32::try_from(r_int).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {} to i32", + r_int + )) + })? + ) + )) + } + DataType::Float16 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::mul_float(l_interval, r_float.to_f64()) + )) + } + DataType::Float32 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::mul_float(l_interval, r_float as f64) + )) + } + DataType::Float64 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::mul_float(l_interval, r_float) + )) + } + DataType::Decimal128(_, scale) => { + let r_decimal = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_decimal, + r_s, + T::mul_float(l_interval, (r_decimal as f64) / 10f64.powi(*scale as i32)) + )) + } + DataType::Decimal256(_, scale) => { + let r_decimal = r.as_primitive::(); + // Convert i256 to f64, considering scale + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_decimal, + r_s, + T::mul_float( + l_interval, + r_decimal.to_string().parse::().map_err(|_| { + ArrowError::ComputeError("Cannot convert Decimal256 to f64".to_string()) + })? / 10f64.powi(*scale as i32) + ) + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid numeric type for interval multiplication: {}", + r.data_type() + ))), + } + } else if let Some(r_interval) = r.as_primitive_opt::() { + // Same logic for the reverse order + match l.data_type() { + DataType::Int32 => { + let l_int = l.as_primitive::(); + Ok(try_op_ref!( + T, + r_interval, + r_s, + l_int, + l_s, + T::mul_int(r_interval, l_int) + )) + } + DataType::Int64 => { + let l_int = l.as_primitive::(); + Ok(try_op_ref!( + T, + r_interval, + r_s, + l_int, + l_s, + T::mul_int( + r_interval, + i32::try_from(l_int).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {} to i32", + l_int + )) + })? + ) + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid integer type for interval multiplication: {}", + l.data_type() + ))), + } + } else { + Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval multiplication: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } +} + +fn multiply_interval( + interval: &PrimitiveArray, + interval_is_scalar: bool, + numeric: &PrimitiveArray, + numeric_is_scalar: bool, +) -> Result +where + N::Native: TryInto, + >::Error: std::error::Error, +{ + Ok(try_op_ref!( + T, + interval, + interval_is_scalar, + numeric, + numeric_is_scalar, + T::mul_int( + interval, + numeric.try_into().map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {:?} to i32: {}", + numeric, e + )) + })? + ) + )) +} + +fn interval_div_op( + op: Op, + l: &dyn Array, + l_s: bool, + r: &dyn Array, + r_s: bool, +) -> Result { + // Only allow interval / numeric (not numeric / interval) + if let Some(l_interval) = l.as_primitive_opt::() { + // Handle numeric division based on data type + match r.data_type() { + DataType::Int8 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), + DataType::Int16 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), + DataType::Int32 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), + DataType::Int64 => { + let r_int = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_int, + r_s, + T::div_int( + l_interval, + i32::try_from(r_int).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {} to i32", + r_int + )) + })? + ) + )) + } + DataType::UInt8 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), + DataType::UInt16 => { + divide_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::UInt32 => { + divide_interval(l_interval, l_s, r.as_primitive::(), r_s) + } + DataType::UInt64 => { + let r_int = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_int, + r_s, + T::div_int( + l_interval, + i32::try_from(r_int).map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {} to i32", + r_int + )) + })? + ) + )) + } + DataType::Float16 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::div_float(l_interval, r_float.to_f64()) + )) + } + DataType::Float32 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::div_float(l_interval, r_float as f64) + )) + } + DataType::Float64 => { + let r_float = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_float, + r_s, + T::div_float(l_interval, r_float) + )) + } + DataType::Decimal128(_, scale) => { + let r_decimal = r.as_primitive::(); + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_decimal, + r_s, + T::div_float(l_interval, (r_decimal as f64) / 10f64.powi(*scale as i32)) + )) + } + DataType::Decimal256(_, scale) => { + let r_decimal = r.as_primitive::(); + // Convert i256 to f64, considering scale + Ok(try_op_ref!( + T, + l_interval, + l_s, + r_decimal, + r_s, + T::div_float( + l_interval, + r_decimal.to_string().parse::().map_err(|_| { + ArrowError::ComputeError("Cannot convert Decimal256 to f64".to_string()) + })? / 10f64.powi(*scale as i32) + ) + )) + } + _ => Err(ArrowError::InvalidArgumentError(format!( + "Invalid numeric type for interval division: {}", + r.data_type() + ))), + } + } else { + Err(ArrowError::InvalidArgumentError(format!( + "Invalid interval division: {} {op} {}", + l.data_type(), + r.data_type() + ))) + } +} + +fn divide_interval( + interval: &PrimitiveArray, + interval_is_scalar: bool, + numeric: &PrimitiveArray, + numeric_is_scalar: bool, +) -> Result +where + N::Native: TryInto, + >::Error: std::error::Error, +{ + Ok(try_op_ref!( + T, + interval, + interval_is_scalar, + numeric, + numeric_is_scalar, + T::div_int( + interval, + numeric.try_into().map_err(|e| { + ArrowError::InvalidArgumentError(format!( + "Cannot safely convert {:?} to i32: {}", + numeric, e + )) + })? + ) + )) +} + fn duration_op( op: Op, l: &dyn Array, @@ -1356,6 +1859,103 @@ mod tests { err, "Arithmetic overflow: Overflow happened on: 2147483647 + 1" ); + + let a = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(2, 4)]); + let b = PrimitiveArray::::from(vec![5]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(10, 7200000), // 10 days, 2 hours + ]); + let b = PrimitiveArray::::from(vec![3]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(30, 21600000), // 30 days, 6 hours + ]) + ); + + let a = IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(12, 15, 5_000_000_000), // 12 months, 15 days, 5 seconds + ]); + let b = PrimitiveArray::::from(vec![2]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalMonthDayNanoArray::from(vec![ + IntervalMonthDayNanoType::make_value(24, 30, 10_000_000_000), // 24 months, 30 days, 10 seconds + ]) + ); + + let a = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(1, 6)]); // 1 year, 6 months + let b = PrimitiveArray::::from(vec![2.5]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(3, 9)]) // 3 years, 9 months = 45 months + ); + + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(5, 3600000), // 5 days, 1 hour + ]); + let b = PrimitiveArray::::from(vec![-2]); + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(-10, -7200000), // -10 days, -2 hours + ]) + ); + + // Test multiplication with Decimal128 + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(5, 3600000), // 5 days, 1 hour + ]); + let b = Decimal128Array::from(vec![25]) + .with_precision_and_scale(4, 1) + .unwrap(); // 2.5 + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(12, 52200000), // 12.5 days, 2.5 hours + ]) + ); + + // Test multiplication with Decimal256 + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour + ]); + let b = Decimal256Array::from(vec![i256::from_i128(15)]) + .with_precision_and_scale(3, 1) + .unwrap(); // 1.5 + let result = mul(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(22, 48600000), // 22.5 days, 1.5 hours + ]) + ); + + // Test division with Decimal256 + let a = IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour + ]); + let b = Decimal256Array::from(vec![i256::from_i128(20)]) + .with_precision_and_scale(3, 1) + .unwrap(); // 2.0 + let result = div(&a, &b).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalDayTimeArray::from(vec![ + IntervalDayTimeType::make_value(7, 45000000), // 7 days, 12.5 hours (half of 15 days, 1 hour) + ]) + ); } fn test_duration_impl>() { diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index fc657f94c6a6..c464cdf4b811 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -72,7 +72,7 @@ macro_rules! repeat_pat { /// [`DataType`]: arrow_schema::DataType #[macro_export] macro_rules! downcast_integer { - ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat => $fallback:expr $(,)*)*) => { + ($($data_type:expr),+ => ($m:path $(, $args:tt)*), $($p:pat $( if $guard:expr )? => $fallback:expr $(,)*)*) => { match ($($data_type),+) { $crate::repeat_pat!($crate::cast::__private::DataType::Int8, $($data_type),+) => { $m!($crate::types::Int8Type $(, $args)*) @@ -98,7 +98,7 @@ macro_rules! downcast_integer { $crate::repeat_pat!($crate::cast::__private::DataType::UInt64, $($data_type),+) => { $m!($crate::types::UInt64Type $(, $args)*) } - $($p => $fallback,)* + $($p $( if $guard )?=> $fallback,)* } }; } From be54b4a3c3519449b5068744a3c3c67f78144ada Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 19 Dec 2024 01:33:09 -0800 Subject: [PATCH 2/5] Make division safer; address PR feedback --- arrow-arith/src/numeric.rs | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index 0f6a3325b021..efcf2fd6e8bf 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -236,7 +236,7 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result interval_mul_op::(op, l, l_scalar, r, r_scalar), MonthDayNano => interval_mul_op::(op, l, l_scalar, r, r_scalar), }, - (lhs, Interval(unit)) if lhs.is_integer() && matches!(op, Op::Mul | Op::MulWrapping) => + (lhs, Interval(unit)) if lhs.is_numeric() && matches!(op, Op::Mul | Op::MulWrapping) => match unit { YearMonth => interval_mul_op::(op, l, l_scalar, r, r_scalar), DayTime => interval_mul_op::(op, l, l_scalar, r, r_scalar), @@ -574,6 +574,17 @@ trait IntervalOp: ArrowPrimitiveType { fn div_float(left: Self::Native, right: f64) -> Result; } +/// Helper function to safely convert f64 to i32, checking for overflow and invalid values +fn f64_to_i32(value: f64) -> Result { + if !value.is_finite() || value > i32::MAX as f64 || value < i32::MIN as f64 { + Err(ArrowError::ComputeError( + "Division result out of i32 range".to_string(), + )) + } else { + Ok(value as i32) + } +} + impl IntervalOp for IntervalYearMonthType { fn add(left: Self::Native, right: Self::Native) -> Result { left.add_checked(right) @@ -596,14 +607,18 @@ impl IntervalOp for IntervalYearMonthType { if right == 0 { return Err(ArrowError::DivideByZero); } - Ok((left as f64 / right as f64).round() as i32) + + let result = (left as f64 / right as f64).round(); + f64_to_i32(result) } fn div_float(left: Self::Native, right: f64) -> Result { if right == 0.0 { return Err(ArrowError::DivideByZero); } - Ok((left as f64 / right).round() as i32) + + let result = (left as f64 / right).round(); + f64_to_i32(result) } } @@ -664,7 +679,9 @@ impl IntervalOp for IntervalDayTimeType { let result_days = result_ms / 86_400_000; let result_ms = result_ms % 86_400_000; - Ok(Self::make_value(result_days as i32, result_ms as i32)) + let result_days_i32 = f64_to_i32(result_days as f64)?; + let result_ms_i32 = f64_to_i32(result_ms as f64)?; + Ok(Self::make_value(result_days_i32, result_ms_i32)) } fn div_float(left: Self::Native, right: f64) -> Result { @@ -680,10 +697,10 @@ impl IntervalOp for IntervalDayTimeType { let result_days = (total_ms / 86_400_000.0).floor(); let result_ms = total_ms % 86_400_000.0; - Ok(Self::make_value( - result_days as i32, - result_ms.round() as i32, - )) + let result_days_i32 = f64_to_i32(result_days)?; + let result_ms_i32 = f64_to_i32(result_ms)?; + + Ok(Self::make_value(result_days_i32, result_ms_i32)) } } From 4fc24a542cee408576f2f0655f8c6f595aee3f4d Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 19 Dec 2024 10:03:47 -0800 Subject: [PATCH 3/5] Simplifying to support only i32 and f64 --- arrow-arith/src/numeric.rs | 335 +++---------------------------------- 1 file changed, 21 insertions(+), 314 deletions(-) diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index efcf2fd6e8bf..1519afa6f957 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -238,9 +238,9 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result match unit { - YearMonth => interval_mul_op::(op, l, l_scalar, r, r_scalar), - DayTime => interval_mul_op::(op, l, l_scalar, r, r_scalar), - MonthDayNano => interval_mul_op::(op, l, l_scalar, r, r_scalar), + YearMonth => interval_mul_op::(op, r, r_scalar, l, l_scalar), + DayTime => interval_mul_op::(op, r, r_scalar, l, l_scalar), + MonthDayNano => interval_mul_op::(op, r, r_scalar, l, l_scalar), }, (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Div) => match unit { @@ -780,84 +780,18 @@ fn interval_mul_op( r: &dyn Array, r_s: bool, ) -> Result { - // Try both orderings to handle either (interval * numeric) or (numeric * interval) + // Assume the interval is the left argument if let Some(l_interval) = l.as_primitive_opt::() { - // Handle numeric multiplication based on data type match r.data_type() { - DataType::Int8 => multiply_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int16 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } DataType::Int32 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::Int64 => { - let r_int = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_int, - r_s, - T::mul_int( - l_interval, - i32::try_from(r_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int - )) - })? - ) - )) - } - DataType::UInt8 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt16 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt32 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt64 => { - let r_int = r.as_primitive::(); + let r_int = r.as_primitive::(); Ok(try_op_ref!( T, l_interval, l_s, r_int, r_s, - T::mul_int( - l_interval, - i32::try_from(r_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int - )) - })? - ) - )) - } - DataType::Float16 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::mul_float(l_interval, r_float.to_f64()) - )) - } - DataType::Float32 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::mul_float(l_interval, r_float as f64) + T::mul_int(l_interval, r_int) )) } DataType::Float64 => { @@ -871,77 +805,11 @@ fn interval_mul_op( T::mul_float(l_interval, r_float) )) } - DataType::Decimal128(_, scale) => { - let r_decimal = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::mul_float(l_interval, (r_decimal as f64) / 10f64.powi(*scale as i32)) - )) - } - DataType::Decimal256(_, scale) => { - let r_decimal = r.as_primitive::(); - // Convert i256 to f64, considering scale - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::mul_float( - l_interval, - r_decimal.to_string().parse::().map_err(|_| { - ArrowError::ComputeError("Cannot convert Decimal256 to f64".to_string()) - })? / 10f64.powi(*scale as i32) - ) - )) - } _ => Err(ArrowError::InvalidArgumentError(format!( "Invalid numeric type for interval multiplication: {}", r.data_type() ))), } - } else if let Some(r_interval) = r.as_primitive_opt::() { - // Same logic for the reverse order - match l.data_type() { - DataType::Int32 => { - let l_int = l.as_primitive::(); - Ok(try_op_ref!( - T, - r_interval, - r_s, - l_int, - l_s, - T::mul_int(r_interval, l_int) - )) - } - DataType::Int64 => { - let l_int = l.as_primitive::(); - Ok(try_op_ref!( - T, - r_interval, - r_s, - l_int, - l_s, - T::mul_int( - r_interval, - i32::try_from(l_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - l_int - )) - })? - ) - )) - } - _ => Err(ArrowError::InvalidArgumentError(format!( - "Invalid integer type for interval multiplication: {}", - l.data_type() - ))), - } } else { Err(ArrowError::InvalidArgumentError(format!( "Invalid interval multiplication: {} {op} {}", @@ -951,34 +819,6 @@ fn interval_mul_op( } } -fn multiply_interval( - interval: &PrimitiveArray, - interval_is_scalar: bool, - numeric: &PrimitiveArray, - numeric_is_scalar: bool, -) -> Result -where - N::Native: TryInto, - >::Error: std::error::Error, -{ - Ok(try_op_ref!( - T, - interval, - interval_is_scalar, - numeric, - numeric_is_scalar, - T::mul_int( - interval, - numeric.try_into().map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {:?} to i32: {}", - numeric, e - )) - })? - ) - )) -} - fn interval_div_op( op: Op, l: &dyn Array, @@ -986,41 +826,10 @@ fn interval_div_op( r: &dyn Array, r_s: bool, ) -> Result { - // Only allow interval / numeric (not numeric / interval) if let Some(l_interval) = l.as_primitive_opt::() { - // Handle numeric division based on data type match r.data_type() { - DataType::Int8 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int16 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int32 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int64 => { - let r_int = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_int, - r_s, - T::div_int( - l_interval, - i32::try_from(r_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int - )) - })? - ) - )) - } - DataType::UInt8 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::UInt16 => { - divide_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt32 => { - divide_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt64 => { - let r_int = r.as_primitive::(); + DataType::Int32 => { + let r_int = r.as_primitive::(); Ok(try_op_ref!( T, l_interval, @@ -1029,37 +838,15 @@ fn interval_div_op( r_s, T::div_int( l_interval, - i32::try_from(r_int).map_err(|_| { + r_int.try_into().map_err(|e| { ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int + "Cannot safely convert {:?} to i32: {}", + r_int, e )) })? ) )) } - DataType::Float16 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::div_float(l_interval, r_float.to_f64()) - )) - } - DataType::Float32 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::div_float(l_interval, r_float as f64) - )) - } DataType::Float64 => { let r_float = r.as_primitive::(); Ok(try_op_ref!( @@ -1071,34 +858,6 @@ fn interval_div_op( T::div_float(l_interval, r_float) )) } - DataType::Decimal128(_, scale) => { - let r_decimal = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::div_float(l_interval, (r_decimal as f64) / 10f64.powi(*scale as i32)) - )) - } - DataType::Decimal256(_, scale) => { - let r_decimal = r.as_primitive::(); - // Convert i256 to f64, considering scale - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::div_float( - l_interval, - r_decimal.to_string().parse::().map_err(|_| { - ArrowError::ComputeError("Cannot convert Decimal256 to f64".to_string()) - })? / 10f64.powi(*scale as i32) - ) - )) - } _ => Err(ArrowError::InvalidArgumentError(format!( "Invalid numeric type for interval division: {}", r.data_type() @@ -1113,34 +872,6 @@ fn interval_div_op( } } -fn divide_interval( - interval: &PrimitiveArray, - interval_is_scalar: bool, - numeric: &PrimitiveArray, - numeric_is_scalar: bool, -) -> Result -where - N::Native: TryInto, - >::Error: std::error::Error, -{ - Ok(try_op_ref!( - T, - interval, - interval_is_scalar, - numeric, - numeric_is_scalar, - T::div_int( - interval, - numeric.try_into().map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {:?} to i32: {}", - numeric, e - )) - })? - ) - )) -} - fn duration_op( op: Op, l: &dyn Array, @@ -1877,6 +1608,7 @@ mod tests { "Arithmetic overflow: Overflow happened on: 2147483647 + 1" ); + // Test interval multiplication let a = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(2, 4)]); let b = PrimitiveArray::::from(vec![5]); let result = mul(&a, &b).unwrap(); @@ -1885,6 +1617,13 @@ mod tests { &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) ); + // swap a and b + let result = mul(&b, &a).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) + ); + let a = IntervalDayTimeArray::from(vec![ IntervalDayTimeType::make_value(10, 7200000), // 10 days, 2 hours ]); @@ -1929,43 +1668,11 @@ mod tests { ]) ); - // Test multiplication with Decimal128 - let a = IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(5, 3600000), // 5 days, 1 hour - ]); - let b = Decimal128Array::from(vec![25]) - .with_precision_and_scale(4, 1) - .unwrap(); // 2.5 - let result = mul(&a, &b).unwrap(); - assert_eq!( - result.as_ref(), - &IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(12, 52200000), // 12.5 days, 2.5 hours - ]) - ); - - // Test multiplication with Decimal256 + // Test interval division let a = IntervalDayTimeArray::from(vec![ IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour ]); - let b = Decimal256Array::from(vec![i256::from_i128(15)]) - .with_precision_and_scale(3, 1) - .unwrap(); // 1.5 - let result = mul(&a, &b).unwrap(); - assert_eq!( - result.as_ref(), - &IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(22, 48600000), // 22.5 days, 1.5 hours - ]) - ); - - // Test division with Decimal256 - let a = IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour - ]); - let b = Decimal256Array::from(vec![i256::from_i128(20)]) - .with_precision_and_scale(3, 1) - .unwrap(); // 2.0 + let b = PrimitiveArray::::from(vec![2]); let result = div(&a, &b).unwrap(); assert_eq!( result.as_ref(), From 4f7628c6c5551d1a1d98f243ad43b290023c8243 Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 19 Dec 2024 10:24:58 -0800 Subject: [PATCH 4/5] Tidying up float - int switching --- arrow-arith/src/numeric.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index 1519afa6f957..0495d6320a34 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -599,7 +599,7 @@ impl IntervalOp for IntervalYearMonthType { } fn mul_float(left: Self::Native, right: f64) -> Result { - let result = (left as f64 * right).round() as i32; + let result = (left as f64 * right) as i32; Ok(result) } @@ -608,8 +608,8 @@ impl IntervalOp for IntervalYearMonthType { return Err(ArrowError::DivideByZero); } - let result = (left as f64 / right as f64).round(); - f64_to_i32(result) + let result = left / right; + Ok(result) } fn div_float(left: Self::Native, right: f64) -> Result { @@ -617,7 +617,7 @@ impl IntervalOp for IntervalYearMonthType { return Err(ArrowError::DivideByZero); } - let result = (left as f64 / right).round(); + let result = left as f64 / right; f64_to_i32(result) } } @@ -657,10 +657,10 @@ impl IntervalOp for IntervalDayTimeType { let frac_days = total_days.fract(); // Convert fractional days to milliseconds (24 * 60 * 60 * 1000 = 86_400_000 ms per day) - let frac_ms = (frac_days * 86_400_000.0).round() as i32; + let frac_ms = f64_to_i32(frac_days * 86_400_000.0)?; // Calculate total milliseconds including the fractional days - let total_ms = (ms as f64 * right).round() as i32 + frac_ms; + let total_ms = f64_to_i32(ms as f64 * right)? + frac_ms; Ok(Self::make_value(whole_days, total_ms)) } @@ -676,10 +676,10 @@ impl IntervalOp for IntervalDayTimeType { let result_ms = total_ms / right as i64; // Convert back to days and milliseconds - let result_days = result_ms / 86_400_000; + let result_days = result_ms as f64 / 86_400_000.0; let result_ms = result_ms % 86_400_000; - let result_days_i32 = f64_to_i32(result_days as f64)?; + let result_days_i32 = f64_to_i32(result_days)?; let result_ms_i32 = f64_to_i32(result_ms as f64)?; Ok(Self::make_value(result_days_i32, result_ms_i32)) } From 6030e82c4a0933324ba72359cb6b20fa4c34ee6e Mon Sep 17 00:00:00 2001 From: Michael Levin Date: Thu, 19 Dec 2024 12:14:52 -0800 Subject: [PATCH 5/5] Fixing clippy -- useless conversion --- arrow-arith/src/numeric.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index 0495d6320a34..848c2121343b 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -836,15 +836,7 @@ fn interval_div_op( l_s, r_int, r_s, - T::div_int( - l_interval, - r_int.try_into().map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {:?} to i32: {}", - r_int, e - )) - })? - ) + T::div_int(l_interval, r_int) )) } DataType::Float64 => {