diff --git a/arrow-arith/src/numeric.rs b/arrow-arith/src/numeric.rs index efcf2fd6e8bf..1519afa6f957 100644 --- a/arrow-arith/src/numeric.rs +++ b/arrow-arith/src/numeric.rs @@ -238,9 +238,9 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result match unit { - YearMonth => interval_mul_op::(op, l, l_scalar, r, r_scalar), - DayTime => interval_mul_op::(op, l, l_scalar, r, r_scalar), - MonthDayNano => interval_mul_op::(op, l, l_scalar, r, r_scalar), + YearMonth => interval_mul_op::(op, r, r_scalar, l, l_scalar), + DayTime => interval_mul_op::(op, r, r_scalar, l, l_scalar), + MonthDayNano => interval_mul_op::(op, r, r_scalar, l, l_scalar), }, (Interval(unit), rhs) if rhs.is_numeric() && matches!(op, Op::Div) => match unit { @@ -780,84 +780,18 @@ fn interval_mul_op( r: &dyn Array, r_s: bool, ) -> Result { - // Try both orderings to handle either (interval * numeric) or (numeric * interval) + // Assume the interval is the left argument if let Some(l_interval) = l.as_primitive_opt::() { - // Handle numeric multiplication based on data type match r.data_type() { - DataType::Int8 => multiply_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int16 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } DataType::Int32 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::Int64 => { - let r_int = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_int, - r_s, - T::mul_int( - l_interval, - i32::try_from(r_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int - )) - })? - ) - )) - } - DataType::UInt8 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt16 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt32 => { - multiply_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt64 => { - let r_int = r.as_primitive::(); + let r_int = r.as_primitive::(); Ok(try_op_ref!( T, l_interval, l_s, r_int, r_s, - T::mul_int( - l_interval, - i32::try_from(r_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int - )) - })? - ) - )) - } - DataType::Float16 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::mul_float(l_interval, r_float.to_f64()) - )) - } - DataType::Float32 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::mul_float(l_interval, r_float as f64) + T::mul_int(l_interval, r_int) )) } DataType::Float64 => { @@ -871,77 +805,11 @@ fn interval_mul_op( T::mul_float(l_interval, r_float) )) } - DataType::Decimal128(_, scale) => { - let r_decimal = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::mul_float(l_interval, (r_decimal as f64) / 10f64.powi(*scale as i32)) - )) - } - DataType::Decimal256(_, scale) => { - let r_decimal = r.as_primitive::(); - // Convert i256 to f64, considering scale - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::mul_float( - l_interval, - r_decimal.to_string().parse::().map_err(|_| { - ArrowError::ComputeError("Cannot convert Decimal256 to f64".to_string()) - })? / 10f64.powi(*scale as i32) - ) - )) - } _ => Err(ArrowError::InvalidArgumentError(format!( "Invalid numeric type for interval multiplication: {}", r.data_type() ))), } - } else if let Some(r_interval) = r.as_primitive_opt::() { - // Same logic for the reverse order - match l.data_type() { - DataType::Int32 => { - let l_int = l.as_primitive::(); - Ok(try_op_ref!( - T, - r_interval, - r_s, - l_int, - l_s, - T::mul_int(r_interval, l_int) - )) - } - DataType::Int64 => { - let l_int = l.as_primitive::(); - Ok(try_op_ref!( - T, - r_interval, - r_s, - l_int, - l_s, - T::mul_int( - r_interval, - i32::try_from(l_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - l_int - )) - })? - ) - )) - } - _ => Err(ArrowError::InvalidArgumentError(format!( - "Invalid integer type for interval multiplication: {}", - l.data_type() - ))), - } } else { Err(ArrowError::InvalidArgumentError(format!( "Invalid interval multiplication: {} {op} {}", @@ -951,34 +819,6 @@ fn interval_mul_op( } } -fn multiply_interval( - interval: &PrimitiveArray, - interval_is_scalar: bool, - numeric: &PrimitiveArray, - numeric_is_scalar: bool, -) -> Result -where - N::Native: TryInto, - >::Error: std::error::Error, -{ - Ok(try_op_ref!( - T, - interval, - interval_is_scalar, - numeric, - numeric_is_scalar, - T::mul_int( - interval, - numeric.try_into().map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {:?} to i32: {}", - numeric, e - )) - })? - ) - )) -} - fn interval_div_op( op: Op, l: &dyn Array, @@ -986,41 +826,10 @@ fn interval_div_op( r: &dyn Array, r_s: bool, ) -> Result { - // Only allow interval / numeric (not numeric / interval) if let Some(l_interval) = l.as_primitive_opt::() { - // Handle numeric division based on data type match r.data_type() { - DataType::Int8 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int16 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int32 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::Int64 => { - let r_int = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_int, - r_s, - T::div_int( - l_interval, - i32::try_from(r_int).map_err(|_| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int - )) - })? - ) - )) - } - DataType::UInt8 => divide_interval(l_interval, l_s, r.as_primitive::(), r_s), - DataType::UInt16 => { - divide_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt32 => { - divide_interval(l_interval, l_s, r.as_primitive::(), r_s) - } - DataType::UInt64 => { - let r_int = r.as_primitive::(); + DataType::Int32 => { + let r_int = r.as_primitive::(); Ok(try_op_ref!( T, l_interval, @@ -1029,37 +838,15 @@ fn interval_div_op( r_s, T::div_int( l_interval, - i32::try_from(r_int).map_err(|_| { + r_int.try_into().map_err(|e| { ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {} to i32", - r_int + "Cannot safely convert {:?} to i32: {}", + r_int, e )) })? ) )) } - DataType::Float16 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::div_float(l_interval, r_float.to_f64()) - )) - } - DataType::Float32 => { - let r_float = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_float, - r_s, - T::div_float(l_interval, r_float as f64) - )) - } DataType::Float64 => { let r_float = r.as_primitive::(); Ok(try_op_ref!( @@ -1071,34 +858,6 @@ fn interval_div_op( T::div_float(l_interval, r_float) )) } - DataType::Decimal128(_, scale) => { - let r_decimal = r.as_primitive::(); - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::div_float(l_interval, (r_decimal as f64) / 10f64.powi(*scale as i32)) - )) - } - DataType::Decimal256(_, scale) => { - let r_decimal = r.as_primitive::(); - // Convert i256 to f64, considering scale - Ok(try_op_ref!( - T, - l_interval, - l_s, - r_decimal, - r_s, - T::div_float( - l_interval, - r_decimal.to_string().parse::().map_err(|_| { - ArrowError::ComputeError("Cannot convert Decimal256 to f64".to_string()) - })? / 10f64.powi(*scale as i32) - ) - )) - } _ => Err(ArrowError::InvalidArgumentError(format!( "Invalid numeric type for interval division: {}", r.data_type() @@ -1113,34 +872,6 @@ fn interval_div_op( } } -fn divide_interval( - interval: &PrimitiveArray, - interval_is_scalar: bool, - numeric: &PrimitiveArray, - numeric_is_scalar: bool, -) -> Result -where - N::Native: TryInto, - >::Error: std::error::Error, -{ - Ok(try_op_ref!( - T, - interval, - interval_is_scalar, - numeric, - numeric_is_scalar, - T::div_int( - interval, - numeric.try_into().map_err(|e| { - ArrowError::InvalidArgumentError(format!( - "Cannot safely convert {:?} to i32: {}", - numeric, e - )) - })? - ) - )) -} - fn duration_op( op: Op, l: &dyn Array, @@ -1877,6 +1608,7 @@ mod tests { "Arithmetic overflow: Overflow happened on: 2147483647 + 1" ); + // Test interval multiplication let a = IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(2, 4)]); let b = PrimitiveArray::::from(vec![5]); let result = mul(&a, &b).unwrap(); @@ -1885,6 +1617,13 @@ mod tests { &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) ); + // swap a and b + let result = mul(&b, &a).unwrap(); + assert_eq!( + result.as_ref(), + &IntervalYearMonthArray::from(vec![IntervalYearMonthType::make_value(11, 8),]) + ); + let a = IntervalDayTimeArray::from(vec![ IntervalDayTimeType::make_value(10, 7200000), // 10 days, 2 hours ]); @@ -1929,43 +1668,11 @@ mod tests { ]) ); - // Test multiplication with Decimal128 - let a = IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(5, 3600000), // 5 days, 1 hour - ]); - let b = Decimal128Array::from(vec![25]) - .with_precision_and_scale(4, 1) - .unwrap(); // 2.5 - let result = mul(&a, &b).unwrap(); - assert_eq!( - result.as_ref(), - &IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(12, 52200000), // 12.5 days, 2.5 hours - ]) - ); - - // Test multiplication with Decimal256 + // Test interval division let a = IntervalDayTimeArray::from(vec![ IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour ]); - let b = Decimal256Array::from(vec![i256::from_i128(15)]) - .with_precision_and_scale(3, 1) - .unwrap(); // 1.5 - let result = mul(&a, &b).unwrap(); - assert_eq!( - result.as_ref(), - &IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(22, 48600000), // 22.5 days, 1.5 hours - ]) - ); - - // Test division with Decimal256 - let a = IntervalDayTimeArray::from(vec![ - IntervalDayTimeType::make_value(15, 3600000), // 15 days, 1 hour - ]); - let b = Decimal256Array::from(vec![i256::from_i128(20)]) - .with_precision_and_scale(3, 1) - .unwrap(); // 2.0 + let b = PrimitiveArray::::from(vec![2]); let result = div(&a, &b).unwrap(); assert_eq!( result.as_ref(),