diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index ba82ca9040c7..4bf4a3a24e0b 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -16,6 +16,7 @@ // under the License. use crate::cast::*; +use crate::parse::*; /// A utility trait that provides checked conversions between /// decimal types inspired by [`NumCast`] @@ -230,106 +231,6 @@ where )?)) } -/// Parses given string to specified decimal native (i128/i256) based on given -/// scale. Returns an `Err` if it cannot parse given string. -pub(crate) fn parse_string_to_decimal_native( - value_str: &str, - scale: usize, -) -> Result -where - T::Native: DecimalCast + ArrowNativeTypeOp, -{ - let value_str = value_str.trim(); - let parts: Vec<&str> = value_str.split('.').collect(); - if parts.len() > 2 { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid decimal format: {value_str:?}" - ))); - } - - let (negative, first_part) = if parts[0].is_empty() { - (false, parts[0]) - } else { - match parts[0].as_bytes()[0] { - b'-' => (true, &parts[0][1..]), - b'+' => (false, &parts[0][1..]), - _ => (false, parts[0]), - } - }; - - let integers = first_part; - let decimals = if parts.len() == 2 { parts[1] } else { "" }; - - if !integers.is_empty() && !integers.as_bytes()[0].is_ascii_digit() { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid decimal format: {value_str:?}" - ))); - } - - if !decimals.is_empty() && !decimals.as_bytes()[0].is_ascii_digit() { - return Err(ArrowError::InvalidArgumentError(format!( - "Invalid decimal format: {value_str:?}" - ))); - } - - // Adjust decimal based on scale - let mut number_decimals = if decimals.len() > scale { - let decimal_number = i256::from_string(decimals).ok_or_else(|| { - ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) - })?; - - let div = i256::from_i128(10_i128).pow_checked((decimals.len() - scale) as u32)?; - - let half = div.div_wrapping(i256::from_i128(2)); - let half_neg = half.neg_wrapping(); - - let d = decimal_number.div_wrapping(div); - let r = decimal_number.mod_wrapping(div); - - // Round result - let adjusted = match decimal_number >= i256::ZERO { - true if r >= half => d.add_wrapping(i256::ONE), - false if r <= half_neg => d.sub_wrapping(i256::ONE), - _ => d, - }; - - let integers = if !integers.is_empty() { - i256::from_string(integers) - .ok_or_else(|| { - ArrowError::InvalidArgumentError(format!( - "Cannot parse decimal format: {value_str}" - )) - }) - .map(|v| v.mul_wrapping(i256::from_i128(10_i128).pow_wrapping(scale as u32)))? - } else { - i256::ZERO - }; - - format!("{}", integers.add_wrapping(adjusted)) - } else { - let padding = if scale > decimals.len() { scale } else { 0 }; - - let decimals = format!("{decimals:0( from: &'a S, precision: u8, @@ -342,10 +243,9 @@ where &'a S: StringArrayType<'a>, { if cast_options.safe { - let iter = from.iter().map(|v| { - v.and_then(|v| parse_string_to_decimal_native::(v, scale as usize).ok()) - .and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v)) - }); + let iter = from + .iter() + .map(|v| v.and_then(|v| parse_decimal::(v, precision, scale).ok())); // Benefit: // 20% performance improvement // Soundness: @@ -359,15 +259,12 @@ where .iter() .map(|v| { v.map(|v| { - parse_string_to_decimal_native::(v, scale as usize) - .map_err(|_| { - ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, - T::DATA_TYPE, - )) - }) - .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) + parse_decimal::(v, precision, scale).map_err(|_| { + ArrowError::CastError(format!( + "Cannot cast string '{}' to decimal type of precision {} and scale {}", + v, precision, scale + )) + }) }) .transpose() }) @@ -622,48 +519,3 @@ where let array = array.unary::<_, T>(op); Ok(Arc::new(array)) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> { - assert_eq!( - parse_string_to_decimal_native::("0", 0)?, - 0_i128 - ); - assert_eq!( - parse_string_to_decimal_native::("0", 5)?, - 0_i128 - ); - - assert_eq!( - parse_string_to_decimal_native::("123", 0)?, - 123_i128 - ); - assert_eq!( - parse_string_to_decimal_native::("123", 5)?, - 12300000_i128 - ); - - assert_eq!( - parse_string_to_decimal_native::("123.45", 0)?, - 123_i128 - ); - assert_eq!( - parse_string_to_decimal_native::("123.45", 5)?, - 12345000_i128 - ); - - assert_eq!( - parse_string_to_decimal_native::("123.4567891", 0)?, - 123_i128 - ); - assert_eq!( - parse_string_to_decimal_native::("123.4567891", 5)?, - 12345679_i128 - ); - Ok(()) - } -} diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index ba470635c6cd..92908824f651 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -2501,6 +2501,7 @@ where #[cfg(test)] mod tests { use super::*; + use crate::parse::parse_decimal; use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer}; use chrono::NaiveDate; use half::f16; @@ -3843,6 +3844,22 @@ mod tests { } } } + #[test] + fn test_cast_with_options_utf8_to_decimal() { + let array = StringArray::from(vec!["4e7"]); + let result = cast_with_options( + &array, + &DataType::Decimal128(10, 2), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap(); + let output_array = result.as_any().downcast_ref::(); + let result_128 = parse_decimal::("40000000", 10, 2); + assert_eq!(output_array.unwrap().value(0), result_128.unwrap()); + } #[test] fn test_cast_utf8_to_bool() { @@ -8481,99 +8498,6 @@ mod tests { ); } - #[test] - fn test_parse_string_to_decimal() { - assert_eq!( - Decimal128Type::format_decimal( - parse_string_to_decimal_native::("123.45", 2).unwrap(), - 38, - 2, - ), - "123.45" - ); - assert_eq!( - Decimal128Type::format_decimal( - parse_string_to_decimal_native::("12345", 2).unwrap(), - 38, - 2, - ), - "12345.00" - ); - assert_eq!( - Decimal128Type::format_decimal( - parse_string_to_decimal_native::("0.12345", 2).unwrap(), - 38, - 2, - ), - "0.12" - ); - assert_eq!( - Decimal128Type::format_decimal( - parse_string_to_decimal_native::(".12345", 2).unwrap(), - 38, - 2, - ), - "0.12" - ); - assert_eq!( - Decimal128Type::format_decimal( - parse_string_to_decimal_native::(".1265", 2).unwrap(), - 38, - 2, - ), - "0.13" - ); - assert_eq!( - Decimal128Type::format_decimal( - parse_string_to_decimal_native::(".1265", 2).unwrap(), - 38, - 2, - ), - "0.13" - ); - - assert_eq!( - Decimal256Type::format_decimal( - parse_string_to_decimal_native::("123.45", 3).unwrap(), - 38, - 3, - ), - "123.450" - ); - assert_eq!( - Decimal256Type::format_decimal( - parse_string_to_decimal_native::("12345", 3).unwrap(), - 38, - 3, - ), - "12345.000" - ); - assert_eq!( - Decimal256Type::format_decimal( - parse_string_to_decimal_native::("0.12345", 3).unwrap(), - 38, - 3, - ), - "0.123" - ); - assert_eq!( - Decimal256Type::format_decimal( - parse_string_to_decimal_native::(".12345", 3).unwrap(), - 38, - 3, - ), - "0.123" - ); - assert_eq!( - Decimal256Type::format_decimal( - parse_string_to_decimal_native::(".1265", 3).unwrap(), - 38, - 3, - ), - "0.127" - ); - } - fn test_cast_string_to_decimal(array: ArrayRef) { // Decimal128 let output_type = DataType::Decimal128(38, 2); @@ -8832,16 +8756,16 @@ mod tests { format_options: FormatOptions::default(), }; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err - .to_string() - .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type")); + assert!(casted_err.to_string().contains( + "Cast error: Cannot cast string '4.4.5' to decimal type of precision 38 and scale 2" + )); let str_array = StringArray::from(vec![". 0.123"]); let array = Arc::new(str_array) as ArrayRef; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err - .to_string() - .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type")); + assert!(casted_err.to_string().contains( + "Cast error: Cannot cast string '. 0.123' to decimal type of precision 38 and scale 2" + )); } fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { @@ -8885,7 +8809,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8", + err.unwrap_err().to_string() + ); } #[test] @@ -8968,7 +8895,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8", + err.unwrap_err().to_string() + ); } #[test] diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index f4c4639c1c08..750b0336b98f 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -824,7 +824,16 @@ fn parse_e_notation( } if exp < 0 { - result = result.div_wrapping(base.pow_wrapping(-exp as _)); + let result_with_scale = result.div_wrapping(base.pow_wrapping(-exp as _)); + let result_with_one_scale_above = + result.div_wrapping(base.pow_wrapping(-exp.add_wrapping(1) as _)); + let rounding_digit = + result_with_one_scale_above.sub_wrapping(result_with_scale.mul_wrapping(base)); + if rounding_digit >= T::Native::usize_as(5) { + result = result_with_scale.add_wrapping(T::Native::usize_as(1)); + } else { + result = result_with_scale; + } } else { result = result.mul_wrapping(base.pow_wrapping(exp as _)); } @@ -842,6 +851,7 @@ pub fn parse_decimal( let mut result = T::Native::usize_as(0); let mut fractionals: i8 = 0; let mut digits: u8 = 0; + let mut rounding_digit = -1; // to store digit after the scale for rounding let base = T::Native::usize_as(10); let bs = s.as_bytes(); @@ -871,6 +881,13 @@ pub fn parse_decimal( // Ignore leading zeros. continue; } + if fractionals == scale && scale != 0 { + // Capture the rounding digit once + if rounding_digit < 0 { + rounding_digit = (b - b'0') as i8; + } + continue; + } digits += 1; result = result.mul_wrapping(base); result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize)); @@ -902,11 +919,14 @@ pub fn parse_decimal( "can't parse the string value {s} to decimal" ))); } - if fractionals == scale && scale != 0 { - // We have processed all the digits that we need. All that - // is left is to validate that the rest of the string contains - // valid digits. - continue; + if fractionals == scale { + // Capture the rounding digit once + if rounding_digit < 0 { + rounding_digit = (b - b'0') as i8; + } + if scale != 0 { + continue; + } } fractionals += 1; digits += 1; @@ -966,6 +986,14 @@ pub fn parse_decimal( "parse decimal overflow ({s})" ))); } + //handle scale = 0 , scale down by fractional digits + if scale == 0 { + result = result.div_wrapping(base.pow_wrapping(fractionals as u32)) + } + //add one if >=5 + if rounding_digit >= 5 { + result = result.add_wrapping(T::Native::usize_as(1)); + } } Ok(if negative { @@ -2544,9 +2572,21 @@ mod tests { assert_eq!(i256::from_i128(i), result_256.unwrap()); } + let tests_with_varying_scale = [ + ("123.4567891", 12345679_i128, 5), + ("123.4567891", 123_i128, 0), + ("123.45", 12345000_i128, 5), + ]; + for (str, e, scale) in tests_with_varying_scale { + let result_128_a = parse_decimal::(str, 20, scale); + assert_eq!(result_128_a.unwrap(), e); + } + let e_notation_tests = [ ("1.23e3", "1230.0", 2), ("5.6714e+2", "567.14", 4), + ("4e+5", "400000", 4), + ("4e7", "40000000", 2), ("5.6714e-2", "0.056714", 4), ("5.6714e-2", "0.056714", 3), ("5.6741214125e2", "567.41214125", 4), @@ -2565,21 +2605,8 @@ mod tests { ("4749.3e-5", "0.047493", 10), ("4749.3e+5", "474930000", 10), ("4749.3e-5", "0.047493", 1), - ("4749.3e+5", "474930000", 1), - ("0E-8", "0", 10), - ("0E+6", "0", 10), - ("1E-8", "0.00000001", 10), - ("12E+6", "12000000", 10), - ("12E-6", "0.000012", 10), - ("0.1e-6", "0.0000001", 10), - ("0.1e+6", "100000", 10), - ("0.12e-6", "0.00000012", 10), - ("0.12e+6", "120000", 10), - ("000000000001e0", "000000000001", 3), - ("000001.1034567002e0", "000001.1034567002", 3), - ("1.234e16", "12340000000000000", 0), - ("123.4e16", "1234000000000000000", 0), ]; + for (e, d, scale) in e_notation_tests { let result_128_e = parse_decimal::(e, 20, scale); let result_128_d = parse_decimal::(d, 20, scale); @@ -2588,6 +2615,27 @@ mod tests { let result_256_d = parse_decimal::(d, 20, scale); assert_eq!(result_256_e.unwrap(), result_256_d.unwrap()); } + + let test_decimal_format_check = [ + ("123.45", "123.45", 2), + ("12345", "12345", 2), + ("0.12345", "0.12", 2), + (".12345", "0.12", 2), + ("123.45", "123.450", 3), + ("12345", "12345.000", 3), + ("0.12345", "0.123", 3), + (".1265", ".127", 3), + ]; + + for (e, d, scale) in test_decimal_format_check { + let result_128_e = parse_decimal::(e, 38, scale); + let result_128_d = parse_decimal::(d, 38, scale); + assert_eq!(result_128_e.unwrap(), result_128_d.unwrap()); + let result_256_e = parse_decimal::(e, 38, scale); + let result_256_d = parse_decimal::(d, 38, scale); + assert_eq!(result_256_e.unwrap(), result_256_d.unwrap()); + } + let can_not_parse_tests = [ "123,123", ".", diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index d3d518316397..76cc4841244b 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -1284,7 +1284,7 @@ mod tests { assert_eq!("53.002666", lat.value_as_string(1)); assert_eq!("52.412811", lat.value_as_string(2)); assert_eq!("51.481583", lat.value_as_string(3)); - assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("12.123457", lat.value_as_string(4)); assert_eq!("50.760000", lat.value_as_string(5)); assert_eq!("0.123000", lat.value_as_string(6)); assert_eq!("123.000000", lat.value_as_string(7)); diff --git a/arrow-json/src/reader/mod.rs b/arrow-json/src/reader/mod.rs index f857e8813c7e..4c9a7cf00776 100644 --- a/arrow-json/src/reader/mod.rs +++ b/arrow-json/src/reader/mod.rs @@ -1127,7 +1127,7 @@ mod tests { assert!(col1.is_null(5)); assert_eq!( col1.values(), - &[100, 200, 204, 1103420, 0, 0].map(T::Native::usize_as) + &[100, 200, 205, 1103420, 0, 0].map(T::Native::usize_as) ); let col2 = batches[0].column(1).as_primitive::(); @@ -1147,7 +1147,7 @@ mod tests { assert!(col3.is_null(5)); assert_eq!( col3.values(), - &[3830, 12345, 0, 0, 0, 0].map(T::Native::usize_as) + &[3830, 12346, 0, 0, 0, 0].map(T::Native::usize_as) ); }