Skip to content

Commit

Permalink
Improve decimal parsing performance
Browse files Browse the repository at this point in the history
  • Loading branch information
spebern committed Mar 13, 2023
1 parent 9ce0ebb commit f094bb5
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 80 deletions.
4 changes: 4 additions & 0 deletions arrow-cast/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,7 @@ half = { version = "2.1", default-features = false }
[[bench]]
name = "parse_timestamp"
harness = false

[[bench]]
name = "parse_decimal"
harness = false
56 changes: 56 additions & 0 deletions arrow-cast/benches/parse_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::types::Decimal256Type;
use arrow_cast::parse::parse_decimal;
use criterion::*;

fn criterion_benchmark(c: &mut Criterion) {
let decimals = [
"123.123",
"123.1234",
"123.1",
"123",
"-123.123",
"-123.1234",
"-123.1",
"-123",
"0.0000123",
"12.",
"-12.",
"00.1",
"-00.1",
"12345678912345678.1234",
"-12345678912345678.1234",
"99999999999999999.999",
"-99999999999999999.999",
".123",
"-.123",
"123.",
"-123.",
];

for decimal in decimals {
let d = black_box(decimal);
c.bench_function(d, |b| {
b.iter(|| parse_decimal::<Decimal256Type>(d, 20, 3).unwrap());
});
}
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
153 changes: 73 additions & 80 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -558,100 +558,93 @@ pub fn parse_decimal<T: DecimalType>(
precision: u8,
scale: i8,
) -> Result<T::Native, ArrowError> {
if !is_valid_decimal(s) {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
let mut offset = s.len();
let len = s.len();
let mut base = T::Native::usize_as(1);
let scale_usize = usize::from(scale as u8);

// handle the value after the '.' and meet the scale
let delimiter_position = s.find('.');
match delimiter_position {
None => {
// there is no '.'
base = T::Native::usize_as(10).pow_checked(scale as u32)?;
}
Some(mid) => {
// there is the '.'
if len - mid >= scale_usize + 1 {
// If the string value is "123.12345" and the scale is 2, we should just remain '.12' and drop the '345' value.
offset -= len - mid - 1 - scale_usize;
} else {
// If the string value is "123.12" and the scale is 4, we should append '00' to the tail.
base = T::Native::usize_as(10)
.pow_checked((scale_usize + 1 + mid - len) as u32)?;
}
}
};

// each byte is digit、'-' or '.'
let bytes = s.as_bytes();
let mut negative = false;
let mut result = T::Native::usize_as(0);

bytes[0..offset]
.iter()
.rev()
.try_for_each::<_, Result<(), ArrowError>>(|&byte| {
match byte {
b'-' => {
negative = true;
}
b'0'..=b'9' => {
let add =
T::Native::usize_as((byte - b'0') as usize).mul_checked(base)?;
result = result.add_checked(add)?;
base = base.mul_checked(T::Native::usize_as(10))?;
}
// because we have checked the string value
_ => (),
}
Ok(())
})?;

if negative {
result = result.neg_checked()?;
}

match T::validate_decimal_precision(result, precision) {
Ok(_) => Ok(result),
Err(e) => Err(ArrowError::ParseError(format!(
"parse decimal overflow: {e}"
))),
}
}

fn is_valid_decimal(s: &str) -> bool {
let mut seen_dot = false;
let mut seen_digit = false;
let mut seen_sign = false;
let mut negative = false;

for c in s.as_bytes() {
match c {
b'-' | b'+' => {
if seen_digit || seen_dot || seen_sign {
return false;
let mut result = T::Native::usize_as(0);
let mut fractional = 0;
let mut digits = 0;
let base = T::Native::usize_as(10);
let mut bs = s.as_bytes().iter();
while let Some(b) = bs.next() {
match b {
b'0'..=b'9' => {
if seen_dot {
if fractional == scale {
// We have processed and validated the whole part of our decimal (including sign and dot).
// All that is left is to validate the fractional part.
if bs.any(|b| !b.is_ascii_digit()) {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
break;
}
fractional += 1;
}
seen_sign = true;
digits += 1;
if digits > precision {
return Err(ArrowError::ParseError(
"parse decimal overflow".to_string(),
));
}
result = result.mul_checked(base)?;
result = result.add_checked(T::Native::usize_as((b - b'0') as usize))?;
}
b'.' => {
if seen_dot {
return false;
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
seen_dot = true;
}
b'0'..=b'9' => {
seen_digit = true;
b'-' => {
if seen_sign || digits > 0 || seen_dot {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
seen_sign = true;
negative = true;
}
b'+' => {
if seen_sign || digits > 0 || seen_dot {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
seen_sign = true;
}
_ => {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}
_ => return false,
}
}
// Fail on "."
if digits == 0 {
return Err(ArrowError::ParseError(format!(
"can't parse the string value {s} to decimal"
)));
}

seen_digit
if fractional < scale {
let exp = scale - fractional;
if exp as u8 + digits > precision {
return Err(ArrowError::ParseError("parse decimal overflow".to_string()));
}
let mul = base.pow_checked(exp as _)?;
result = result.mul_checked(mul)?;
}

Ok(if negative {
result.neg_checked()?
} else {
result
})
}

pub fn parse_interval_year_month(
Expand Down

0 comments on commit f094bb5

Please sign in to comment.