Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Support for e notation using existing parse_decimal in string to decimal conversion #6905

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
33 changes: 11 additions & 22 deletions arrow-cast/src/cast/decimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use crate::cast::*;
use crate::parse::*;

/// A utility trait that provides checked conversions between
/// decimal types inspired by [`NumCast`]
Expand Down Expand Up @@ -230,6 +231,7 @@ where
)?))
}

#[allow(dead_code)]
Copy link
Contributor Author

@himadripal himadripal Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails in clippy, hence added #[allow(dead_code)], there is no use, if required we can remove it and cover existing tests with parse_decimal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove this and port the tests, to ensure we aren't losing test coverage / accidentally changing behaviour

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

/// Parses given string to specified decimal native (i128/i256) based on given
/// scale. Returns an `Err` if it cannot parse given string.
pub(crate) fn parse_string_to_decimal_native<T: DecimalType>(
Expand Down Expand Up @@ -342,10 +344,9 @@ where
&'a S: StringArrayType<'a>,
{
if cast_options.safe {
let iter = from.iter().map(|v| {
v.and_then(|v| parse_string_to_decimal_native::<T>(v, scale as usize).ok())
.and_then(|v| T::is_valid_decimal_precision(v, precision).then_some(v))
});
let iter = from
.iter()
.map(|v| v.and_then(|v| parse_decimal::<T>(v, precision, scale).ok()));
// Benefit:
// 20% performance improvement
// Soundness:
Expand All @@ -359,15 +360,12 @@ where
.iter()
.map(|v| {
v.map(|v| {
parse_string_to_decimal_native::<T>(v, scale as usize)
.map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast string '{}' to value of {:?} type",
v,
T::DATA_TYPE,
))
})
.and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v))
parse_decimal::<T>(v, precision, scale).map_err(|_| {
ArrowError::CastError(format!(
"Cannot cast string '{}' to decimal type of precision {} and scale {}",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

T:DATA_TYPE shows default Decimal(38,10) or Decimal256(76,..) in the error message, hiding the precision and scale provided for cast.

v, precision, scale
))
})
})
.transpose()
})
Expand Down Expand Up @@ -629,15 +627,6 @@ mod tests {

#[test]
fn test_parse_string_to_decimal_native() -> Result<(), ArrowError> {
assert_eq!(
parse_string_to_decimal_native::<Decimal128Type>("0", 0)?,
0_i128
);
assert_eq!(
parse_string_to_decimal_native::<Decimal128Type>("0", 5)?,
0_i128
);

assert_eq!(
parse_string_to_decimal_native::<Decimal128Type>("123", 0)?,
123_i128
Expand Down
39 changes: 31 additions & 8 deletions arrow-cast/src/cast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,7 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::parse::parse_decimal;
use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer};
use chrono::NaiveDate;
use half::f16;
Expand Down Expand Up @@ -3843,6 +3844,22 @@ mod tests {
}
}
}
#[test]
fn test_cast_with_options_utf8_to_decimal() {
let array = StringArray::from(vec!["4e7"]);
let result = cast_with_options(
&array,
&DataType::Decimal128(10, 2),
&CastOptions {
safe: false,
format_options: FormatOptions::default(),
},
)
.unwrap();
let output_array = result.as_any().downcast_ref::<Decimal128Array>();
let result_128 = parse_decimal::<Decimal128Type>("40000000", 10, 2);
assert_eq!(output_array.unwrap().value(0), result_128.unwrap());
}

#[test]
fn test_cast_utf8_to_bool() {
Expand Down Expand Up @@ -8832,16 +8849,16 @@ mod tests {
format_options: FormatOptions::default(),
};
let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err();
assert!(casted_err
.to_string()
.contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type"));
assert!(casted_err.to_string().contains(
"Cast error: Cannot cast string '4.4.5' to decimal type of precision 38 and scale 2"
));

let str_array = StringArray::from(vec![". 0.123"]);
let array = Arc::new(str_array) as ArrayRef;
let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err();
assert!(casted_err
.to_string()
.contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type"));
assert!(casted_err.to_string().contains(
"Cast error: Cannot cast string '. 0.123' to decimal type of precision 38 and scale 2"
));
}

fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) {
Expand Down Expand Up @@ -8885,7 +8902,10 @@ mod tests {
format_options: FormatOptions::default(),
},
);
assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
assert_eq!(
"Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8",
err.unwrap_err().to_string()
);
}

#[test]
Expand Down Expand Up @@ -8968,7 +8988,10 @@ mod tests {
format_options: FormatOptions::default(),
},
);
assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string());
assert_eq!(
"Cast error: Cannot cast string '1000' to decimal type of precision 10 and scale 8",
err.unwrap_err().to_string()
);
}

#[test]
Expand Down
66 changes: 46 additions & 20 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,16 @@ fn parse_e_notation<T: DecimalType>(
}

if exp < 0 {
result = result.div_wrapping(base.pow_wrapping(-exp as _));
let result_with_scale = result.div_wrapping(base.pow_wrapping(-exp as _));
let result_with_one_scale_above =
result.div_wrapping(base.pow_wrapping(-exp.add_wrapping(1) as _));
let rounding_digit =
result_with_one_scale_above.sub_wrapping(result_with_scale.mul_wrapping(base));
if rounding_digit >= T::Native::usize_as(5) {
result = result_with_scale.add_wrapping(T::Native::usize_as(1));
} else {
result = result_with_scale;
}
} else {
result = result.mul_wrapping(base.pow_wrapping(exp as _));
}
Expand All @@ -842,6 +851,7 @@ pub fn parse_decimal<T: DecimalType>(
let mut result = T::Native::usize_as(0);
let mut fractionals: i8 = 0;
let mut digits: u8 = 0;
let mut rounding_digit = -1; // to store digit after the scale for rounding
let base = T::Native::usize_as(10);

let bs = s.as_bytes();
Expand Down Expand Up @@ -871,6 +881,13 @@ pub fn parse_decimal<T: DecimalType>(
// Ignore leading zeros.
continue;
}
if fractionals == scale && scale != 0 {
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
continue;
}
digits += 1;
result = result.mul_wrapping(base);
result = result.add_wrapping(T::Native::usize_as((b - b'0') as usize));
Expand Down Expand Up @@ -902,11 +919,14 @@ pub fn parse_decimal<T: DecimalType>(
"can't parse the string value {s} to decimal"
)));
}
if fractionals == scale && scale != 0 {
// We have processed all the digits that we need. All that
// is left is to validate that the rest of the string contains
// valid digits.
continue;
if fractionals == scale {
// Capture the rounding digit once
if rounding_digit < 0 {
rounding_digit = (b - b'0') as i8;
}
if scale != 0 {
continue;
}
}
fractionals += 1;
digits += 1;
Expand Down Expand Up @@ -966,6 +986,14 @@ pub fn parse_decimal<T: DecimalType>(
"parse decimal overflow ({s})"
)));
}
//handle scale = 0 , scale down by fractional digits
if scale == 0 {
result = result.div_wrapping(base.pow_wrapping(fractionals as u32))
}
//add one if >=5
if rounding_digit >= 5 {
result = result.add_wrapping(T::Native::usize_as(1));
}
}

Ok(if negative {
Expand Down Expand Up @@ -2544,9 +2572,20 @@ mod tests {
assert_eq!(i256::from_i128(i), result_256.unwrap());
}

let tests_with_varying_scale = [
// ("123.4567891", 12345679_i128, 5),
("123.4567891", 123_i128, 0),
];
for (str, e, scale) in tests_with_varying_scale {
let result_128_a = parse_decimal::<Decimal128Type>(str, 20, scale);
assert_eq!(result_128_a.unwrap(), e);
}

let e_notation_tests = [
("1.23e3", "1230.0", 2),
("5.6714e+2", "567.14", 4),
("4e+5", "400000", 4),
("4e7", "40000000", 2),
("5.6714e-2", "0.056714", 4),
("5.6714e-2", "0.056714", 3),
("5.6741214125e2", "567.41214125", 4),
Expand All @@ -2565,21 +2604,8 @@ mod tests {
("4749.3e-5", "0.047493", 10),
("4749.3e+5", "474930000", 10),
("4749.3e-5", "0.047493", 1),
("4749.3e+5", "474930000", 1),
("0E-8", "0", 10),
("0E+6", "0", 10),
("1E-8", "0.00000001", 10),
("12E+6", "12000000", 10),
("12E-6", "0.000012", 10),
("0.1e-6", "0.0000001", 10),
("0.1e+6", "100000", 10),
("0.12e-6", "0.00000012", 10),
("0.12e+6", "120000", 10),
("000000000001e0", "000000000001", 3),
("000001.1034567002e0", "000001.1034567002", 3),
("1.234e16", "12340000000000000", 0),
("123.4e16", "1234000000000000000", 0),
];

for (e, d, scale) in e_notation_tests {
let result_128_e = parse_decimal::<Decimal128Type>(e, 20, scale);
let result_128_d = parse_decimal::<Decimal128Type>(d, 20, scale);
Expand Down
2 changes: 1 addition & 1 deletion arrow-csv/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1284,7 +1284,7 @@ mod tests {
assert_eq!("53.002666", lat.value_as_string(1));
assert_eq!("52.412811", lat.value_as_string(2));
assert_eq!("51.481583", lat.value_as_string(3));
assert_eq!("12.123456", lat.value_as_string(4));
assert_eq!("12.123457", lat.value_as_string(4));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we can see this is a breaking change to the rounding behaviour

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also to note, previous behavior was not correct.

12.12345678 cast to `Decimal128(38, 6)` =  12.123457

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It truncated rather than rounding, they're both valid behaviours, changing this is a breaking change

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an argument for accepting the breaking change to use rounding since it would be consistent with how we cast floating point to decimal. However, do we want to consider adding a parameter to choose between truncation and rounding?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally wouldn't characterize this a breaking change, though I can see how others might.

In my opinion, adding a parameter to choose between the behaviors would be the safest thing (aka a field to CastOptions that defaults to the old, rounding, behavior) for https://docs.rs/arrow/latest/arrow/compute/kernels/cast/fn.cast_with_options.html

Maybe @liukun4515 who added much of the initial decimal support in arrow-rs has time to offer historical perspective on rounding vs truncation during casting?

assert_eq!("50.760000", lat.value_as_string(5));
assert_eq!("0.123000", lat.value_as_string(6));
assert_eq!("123.000000", lat.value_as_string(7));
Expand Down
4 changes: 2 additions & 2 deletions arrow-json/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1127,7 +1127,7 @@ mod tests {
assert!(col1.is_null(5));
assert_eq!(
col1.values(),
&[100, 200, 204, 1103420, 0, 0].map(T::Native::usize_as)
&[100, 200, 205, 1103420, 0, 0].map(T::Native::usize_as)
);

let col2 = batches[0].column(1).as_primitive::<T>();
Expand All @@ -1147,7 +1147,7 @@ mod tests {
assert!(col3.is_null(5));
assert_eq!(
col3.values(),
&[3830, 12345, 0, 0, 0, 0].map(T::Native::usize_as)
&[3830, 12346, 0, 0, 0, 0].map(T::Native::usize_as)
);
}

Expand Down
Loading