From b20a7ceea75a4e8437612e82004ca929b8ac5b1e Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 28 Nov 2023 07:55:10 +0900 Subject: [PATCH] Fix --- arrow-cast/src/cast.rs | 91 ++++++++++++++++++++++++++++++++---------- 1 file changed, 69 insertions(+), 22 deletions(-) diff --git a/arrow-cast/src/cast.rs b/arrow-cast/src/cast.rs index b531743b587e..1429fa4dffef 100644 --- a/arrow-cast/src/cast.rs +++ b/arrow-cast/src/cast.rs @@ -2641,22 +2641,27 @@ where ))); } - let integers = parts[0].trim_start_matches('0'); - - let (decimals, negative) = if parts.len() == 2 { - (parts[1], integers.starts_with('-')) + let (negative, first_part) = if parts[0].len() == 0 { + (false, parts[0]) } else { - ("", false) + match parts[0].as_bytes()[0] { + b'-' => (true, parts[0].strip_prefix('-').unwrap()), + b'+' => (false, parts[0].strip_prefix('+').unwrap()), + _ => (false, parts[0]), + } }; - let integers = if negative { - integers.trim_start_matches('-') - } else { - integers - }; + let integers = first_part.trim_start_matches('0'); + let decimals = if parts.len() == 2 { parts[1] } else { "" }; + + if decimals.len() != 0 && !decimals.as_bytes()[0].is_ascii_digit() { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid decimal format: {value_str:?}" + ))); + } // Adjust decimal based on scale - let number_decimals = if decimals.len() > scale { + let mut number_decimals = if decimals.len() > scale { let decimal_number = i256::from_string(decimals).ok_or_else(|| { ArrowError::InvalidArgumentError(format!("Cannot parse decimal format: {value_str}")) })?; @@ -2688,17 +2693,7 @@ where i256::ZERO }; - let integers = if negative { - integers.neg_wrapping() - } else { - integers - }; - - if negative { - format!("{}", integers.sub_wrapping(adjusted)) - } else { - format!("{}", integers.add_wrapping(adjusted)) - } + format!("{}", integers.add_wrapping(adjusted)) } else { let padding = if scale > decimals.len() { scale } else { 0 }; @@ -2706,6 +2701,10 @@ where format!("{integers}{decimals}") }; + if negative { + number_decimals.insert(0, '-'); + } + let value = i256::from_string(number_decimals.as_str()).ok_or_else(|| { ArrowError::InvalidArgumentError(format!( "Cannot convert {} to {}: Overflow", @@ -8266,6 +8265,18 @@ mod tests { assert!(decimal_arr.is_null(12)); assert_eq!("-1.23", decimal_arr.value_as_string(13)); assert_eq!("-1.24", decimal_arr.value_as_string(14)); + assert_eq!("0.00", decimal_arr.value_as_string(15)); + assert_eq!("-123.00", decimal_arr.value_as_string(16)); + assert_eq!("-123.23", decimal_arr.value_as_string(17)); + assert_eq!("-0.12", decimal_arr.value_as_string(18)); + assert_eq!("1.23", decimal_arr.value_as_string(19)); + assert_eq!("1.24", decimal_arr.value_as_string(20)); + assert_eq!("0.00", decimal_arr.value_as_string(21)); + assert_eq!("123.00", decimal_arr.value_as_string(22)); + assert_eq!("123.23", decimal_arr.value_as_string(23)); + assert_eq!("0.12", decimal_arr.value_as_string(24)); + assert!(decimal_arr.is_null(25)); + assert!(decimal_arr.is_null(26)); // Decimal256 let output_type = DataType::Decimal256(76, 3); @@ -8289,6 +8300,18 @@ mod tests { assert!(decimal_arr.is_null(12)); assert_eq!("-1.235", decimal_arr.value_as_string(13)); assert_eq!("-1.236", decimal_arr.value_as_string(14)); + assert_eq!("0.000", decimal_arr.value_as_string(15)); + assert_eq!("-123.000", decimal_arr.value_as_string(16)); + assert_eq!("-123.234", decimal_arr.value_as_string(17)); + assert_eq!("-0.123", decimal_arr.value_as_string(18)); + assert_eq!("1.235", decimal_arr.value_as_string(19)); + assert_eq!("1.236", decimal_arr.value_as_string(20)); + assert_eq!("0.000", decimal_arr.value_as_string(21)); + assert_eq!("123.000", decimal_arr.value_as_string(22)); + assert_eq!("123.234", decimal_arr.value_as_string(23)); + assert_eq!("0.123", decimal_arr.value_as_string(24)); + assert!(decimal_arr.is_null(25)); + assert!(decimal_arr.is_null(26)); } #[test] @@ -8309,6 +8332,18 @@ mod tests { None, Some("-1.23499999"), Some("-1.23599999"), + Some("-0.00001"), + Some("-123"), + Some("-123.234000"), + Some("-000.123"), + Some("+1.23499999"), + Some("+1.23599999"), + Some("+0.00001"), + Some("+123"), + Some("+123.234000"), + Some("+000.123"), + Some("1.-23499999"), + Some("-1.-23499999"), ]); let array = Arc::new(str_array) as ArrayRef; @@ -8333,6 +8368,18 @@ mod tests { None, Some("-1.23499999"), Some("-1.23599999"), + Some("-0.00001"), + Some("-123"), + Some("-123.234000"), + Some("-000.123"), + Some("+1.23499999"), + Some("+1.23599999"), + Some("+0.00001"), + Some("+123"), + Some("+123.234000"), + Some("+000.123"), + Some("1.-23499999"), + Some("-1.-23499999"), ]); let array = Arc::new(str_array) as ArrayRef;