From e377d8ba11618e07877b9ba0b9b3a8386ea7ab28 Mon Sep 17 00:00:00 2001 From: Bernhard Specht Date: Mon, 13 Mar 2023 21:10:13 +0100 Subject: [PATCH] Improve decimal parsing performance --- arrow-cast/Cargo.toml | 4 + arrow-cast/benches/parse_decimal.rs | 56 ++++++++++ arrow-cast/src/parse.rs | 153 +++++++++++++--------------- 3 files changed, 133 insertions(+), 80 deletions(-) create mode 100644 arrow-cast/benches/parse_decimal.rs diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 235dca135e5a..15386ed5c8ac 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -63,3 +63,7 @@ half = { version = "2.1", default-features = false } [[bench]] name = "parse_timestamp" harness = false + +[[bench]] +name = "parse_decimal" +harness = false diff --git a/arrow-cast/benches/parse_decimal.rs b/arrow-cast/benches/parse_decimal.rs new file mode 100644 index 000000000000..5682859dd25a --- /dev/null +++ b/arrow-cast/benches/parse_decimal.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::types::Decimal256Type; +use arrow_cast::parse::parse_decimal; +use criterion::*; + +fn criterion_benchmark(c: &mut Criterion) { + let decimals = [ + "123.123", + "123.1234", + "123.1", + "123", + "-123.123", + "-123.1234", + "-123.1", + "-123", + "0.0000123", + "12.", + "-12.", + "00.1", + "-00.1", + "12345678912345678.1234", + "-12345678912345678.1234", + "99999999999999999.999", + "-99999999999999999.999", + ".123", + "-.123", + "123.", + "-123.", + ]; + + for decimal in decimals { + let d = black_box(decimal); + c.bench_function(d, |b| { + b.iter(|| parse_decimal::(d, 20, 3).unwrap()); + }); + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index 38fb4fc29934..e30891da9036 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -558,100 +558,93 @@ pub fn parse_decimal( precision: u8, scale: i8, ) -> Result { - if !is_valid_decimal(s) { - return Err(ArrowError::ParseError(format!( - "can't parse the string value {s} to decimal" - ))); - } - let mut offset = s.len(); - let len = s.len(); - let mut base = T::Native::usize_as(1); - let scale_usize = usize::from(scale as u8); - - // handle the value after the '.' and meet the scale - let delimiter_position = s.find('.'); - match delimiter_position { - None => { - // there is no '.' - base = T::Native::usize_as(10).pow_checked(scale as u32)?; - } - Some(mid) => { - // there is the '.' - if len - mid >= scale_usize + 1 { - // If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value. - offset -= len - mid - 1 - scale_usize; - } else { - // If the string value is "123.12" and the scale is 4, we should append '00' to the tail. - base = T::Native::usize_as(10) - .pow_checked((scale_usize + 1 + mid - len) as u32)?; - } - } - }; - - // each byte is digit、'-' or '.' - let bytes = s.as_bytes(); - let mut negative = false; - let mut result = T::Native::usize_as(0); - - bytes[0..offset] - .iter() - .rev() - .try_for_each::<_, Result<(), ArrowError>>(|&byte| { - match byte { - b'-' => { - negative = true; - } - b'0'..=b'9' => { - let add = - T::Native::usize_as((byte - b'0') as usize).mul_checked(base)?; - result = result.add_checked(add)?; - base = base.mul_checked(T::Native::usize_as(10))?; - } - // because we have checked the string value - _ => (), - } - Ok(()) - })?; - - if negative { - result = result.neg_checked()?; - } - - match T::validate_decimal_precision(result, precision) { - Ok(_) => Ok(result), - Err(e) => Err(ArrowError::ParseError(format!( - "parse decimal overflow: {e}" - ))), - } -} - -fn is_valid_decimal(s: &str) -> bool { let mut seen_dot = false; - let mut seen_digit = false; let mut seen_sign = false; + let mut negative = false; - for c in s.as_bytes() { - match c { - b'-' | b'+' => { - if seen_digit || seen_dot || seen_sign { - return false; + let mut result = T::Native::usize_as(0); + let mut fractionals = 0; + let mut digits = 0; + let base = T::Native::usize_as(10); + let mut bs = s.as_bytes().iter(); + while let Some(b) = bs.next() { + match b { + b'0'..=b'9' => { + if seen_dot { + if fractionals == scale { + // We have processed and validated the whole part of our decimal (including sign and dot). + // All that is left is to validate the fractional part. + if bs.any(|b| !b.is_ascii_digit()) { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + break; + } + fractionals += 1; } - seen_sign = true; + digits += 1; + if digits > precision { + return Err(ArrowError::ParseError( + "parse decimal overflow".to_string(), + )); + } + result = result.mul_checked(base)?; + result = result.add_checked(T::Native::usize_as((b - b'0') as usize))?; } b'.' => { if seen_dot { - return false; + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); } seen_dot = true; } - b'0'..=b'9' => { - seen_digit = true; + b'-' => { + if seen_sign || digits > 0 || seen_dot { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + seen_sign = true; + negative = true; + } + b'+' => { + if seen_sign || digits > 0 || seen_dot { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } + seen_sign = true; + } + _ => { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); } - _ => return false, } } + // Fail on "." + if digits == 0 { + return Err(ArrowError::ParseError(format!( + "can't parse the string value {s} to decimal" + ))); + } - seen_digit + if fractionals < scale { + let exp = scale - fractionals; + if exp as u8 + digits > precision { + return Err(ArrowError::ParseError("parse decimal overflow".to_string())); + } + let mul = base.pow_checked(exp as _)?; + result = result.mul_checked(mul)?; + } + + Ok(if negative { + result.neg_checked()? + } else { + result + }) } pub fn parse_interval_year_month(