Skip to content

Commit

Permalink
Improve interval parsing (#6211)
Browse files Browse the repository at this point in the history
* improve interval parsing

* rename

* cleanup

* fix formatting

* make IntervalParseConfig public

* add debug to IntervalParseConfig

* fmt
  • Loading branch information
samuelcolvin authored Aug 12, 2024
1 parent fe03d39 commit a693f0f
Showing 1 changed file with 179 additions and 62 deletions.
241 changes: 179 additions & 62 deletions arrow-cast/src/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -994,28 +994,47 @@ pub fn parse_interval_day_time(
Ok(IntervalDayTimeType::make_value(days, millis))
}

pub fn parse_interval_month_day_nano(
pub fn parse_interval_month_day_nano_config(
value: &str,
config: IntervalParseConfig,
) -> Result<<IntervalMonthDayNanoType as ArrowPrimitiveType>::Native, ArrowError> {
let config = IntervalParseConfig::new(IntervalUnit::Month);
let interval = Interval::parse(value, &config)?;

let (months, days, nanos) = interval.to_month_day_nanos();

Ok(IntervalMonthDayNanoType::make_value(months, days, nanos))
}

pub fn parse_interval_month_day_nano(
value: &str,
) -> Result<<IntervalMonthDayNanoType as ArrowPrimitiveType>::Native, ArrowError> {
parse_interval_month_day_nano_config(value, IntervalParseConfig::new(IntervalUnit::Month))
}

const NANOS_PER_MILLIS: i64 = 1_000_000;
const NANOS_PER_SECOND: i64 = 1_000 * NANOS_PER_MILLIS;
const NANOS_PER_MINUTE: i64 = 60 * NANOS_PER_SECOND;
const NANOS_PER_HOUR: i64 = 60 * NANOS_PER_MINUTE;
#[cfg(test)]
const NANOS_PER_DAY: i64 = 24 * NANOS_PER_HOUR;

#[derive(Debug, Clone)]
pub struct IntervalParseConfig {
/// The default unit to use if none is specified
/// e.g. `INTERVAL 1` represents `INTERVAL 1 SECOND` when default_unit = IntervalType::Second
default_unit: IntervalUnit,
}

impl IntervalParseConfig {
pub fn new(default_unit: IntervalUnit) -> Self {
Self { default_unit }
}
}

#[rustfmt::skip]
#[derive(Clone, Copy)]
#[derive(Debug, Clone, Copy)]
#[repr(u16)]
enum IntervalUnit {
pub enum IntervalUnit {
Century = 0b_0000_0000_0001,
Decade = 0b_0000_0000_0010,
Year = 0b_0000_0000_0100,
Expand All @@ -1030,30 +1049,50 @@ enum IntervalUnit {
Nanosecond = 0b_1000_0000_0000,
}

/// Logic for parsing interval unit strings
///
/// See <https://github.com/postgres/postgres/blob/2caa85f4aae689e6f6721d7363b4c66a2a6417d6/src/backend/utils/adt/datetime.c#L189>
/// for a list of unit names supported by PostgreSQL which we try to match here.
impl FromStr for IntervalUnit {
type Err = ArrowError;

fn from_str(s: &str) -> Result<Self, ArrowError> {
match s.to_lowercase().as_str() {
"century" | "centuries" => Ok(Self::Century),
"decade" | "decades" => Ok(Self::Decade),
"year" | "years" => Ok(Self::Year),
"month" | "months" => Ok(Self::Month),
"week" | "weeks" => Ok(Self::Week),
"day" | "days" => Ok(Self::Day),
"hour" | "hours" => Ok(Self::Hour),
"minute" | "minutes" => Ok(Self::Minute),
"second" | "seconds" => Ok(Self::Second),
"millisecond" | "milliseconds" => Ok(Self::Millisecond),
"microsecond" | "microseconds" => Ok(Self::Microsecond),
"c" | "cent" | "cents" | "century" | "centuries" => Ok(Self::Century),
"dec" | "decs" | "decade" | "decades" => Ok(Self::Decade),
"y" | "yr" | "yrs" | "year" | "years" => Ok(Self::Year),
"mon" | "mons" | "month" | "months" => Ok(Self::Month),
"w" | "week" | "weeks" => Ok(Self::Week),
"d" | "day" | "days" => Ok(Self::Day),
"h" | "hr" | "hrs" | "hour" | "hours" => Ok(Self::Hour),
"m" | "min" | "mins" | "minute" | "minutes" => Ok(Self::Minute),
"s" | "sec" | "secs" | "second" | "seconds" => Ok(Self::Second),
"ms" | "msec" | "msecs" | "msecond" | "mseconds" | "millisecond" | "milliseconds" => {
Ok(Self::Millisecond)
}
"us" | "usec" | "usecs" | "usecond" | "useconds" | "microsecond" | "microseconds" => {
Ok(Self::Microsecond)
}
"nanosecond" | "nanoseconds" => Ok(Self::Nanosecond),
_ => Err(ArrowError::NotYetImplemented(format!(
_ => Err(ArrowError::InvalidArgumentError(format!(
"Unknown interval type: {s}"
))),
}
}
}

impl IntervalUnit {
fn from_str_or_config(
s: Option<&str>,
config: &IntervalParseConfig,
) -> Result<Self, ArrowError> {
match s {
Some(s) => s.parse(),
None => Ok(config.default_unit),
}
}
}

pub type MonthDayNano = (i32, i32, i64);

/// Chosen based on the number of decimal digits in 1 week in nanoseconds
Expand Down Expand Up @@ -1352,68 +1391,35 @@ impl Interval {
}
}

struct IntervalParseConfig {
/// The default unit to use if none is specified
/// e.g. `INTERVAL 1` represents `INTERVAL 1 SECOND` when default_unit = IntervalType::Second
default_unit: IntervalUnit,
}

impl IntervalParseConfig {
fn new(default_unit: IntervalUnit) -> Self {
Self { default_unit }
}
}

/// parse the string into a vector of interval components i.e. (amount, unit) tuples
fn parse_interval_components(
value: &str,
config: &IntervalParseConfig,
) -> Result<Vec<(IntervalAmount, IntervalUnit)>, ArrowError> {
let parts = value.split_whitespace();

let raw_amounts = parts.clone().step_by(2);
let raw_units = parts.skip(1).step_by(2);

// parse amounts
let (amounts, invalid_amounts) = raw_amounts
.map(IntervalAmount::from_str)
.partition::<Vec<_>, _>(Result::is_ok);

// invalid amounts?
if !invalid_amounts.is_empty() {
return Err(ArrowError::ParseError(format!(
"Invalid input syntax for type interval: {value:?}"
)));
}
let raw_pairs = split_interval_components(value);

// parse units
let (units, invalid_units): (Vec<_>, Vec<_>) = raw_units
.clone()
.map(IntervalUnit::from_str)
.partition(Result::is_ok);

// invalid units?
if !invalid_units.is_empty() {
// parse amounts and units
let Ok(pairs): Result<Vec<(IntervalAmount, IntervalUnit)>, ArrowError> = raw_pairs
.iter()
.map(|(a, u)| Ok((a.parse()?, IntervalUnit::from_str_or_config(*u, config)?)))
.collect()
else {
return Err(ArrowError::ParseError(format!(
"Invalid input syntax for type interval: {value:?}"
)));
}
};

// collect parsed results
let amounts = amounts.into_iter().map(Result::unwrap).collect::<Vec<_>>();
let units = units.into_iter().map(Result::unwrap).collect::<Vec<_>>();

// if only an amount is specified, use the default unit
if amounts.len() == 1 && units.is_empty() {
return Ok(vec![(amounts[0], config.default_unit)]);
};
let (amounts, units): (Vec<_>, Vec<_>) = pairs.into_iter().unzip();

// duplicate units?
let mut observed_interval_types = 0;
for (unit, raw_unit) in units.iter().zip(raw_units) {
for (unit, (_, raw_unit)) in units.iter().zip(raw_pairs) {
if observed_interval_types & (*unit as u16) != 0 {
return Err(ArrowError::ParseError(format!(
"Invalid input syntax for type interval: {value:?}. Repeated type '{raw_unit}'",
"Invalid input syntax for type interval: {:?}. Repeated type '{}'",
value,
raw_unit.unwrap_or_default(),
)));
}

Expand All @@ -1425,6 +1431,33 @@ fn parse_interval_components(
Ok(result.collect::<Vec<_>>())
}

/// Split an interval into a vec of amounts and units.
///
/// Pairs are separated by spaces, but within a pair the amount and unit may or may not be separated by a space.
///
/// This should match the behavior of PostgreSQL's interval parser.
fn split_interval_components(value: &str) -> Vec<(&str, Option<&str>)> {
let mut result = vec![];
let mut words = value.split(char::is_whitespace);
while let Some(word) = words.next() {
if let Some(split_word_at) = word.find(not_interval_amount) {
let (amount, unit) = word.split_at(split_word_at);
result.push((amount, Some(unit)));
} else if let Some(unit) = words.next() {
result.push((word, Some(unit)));
} else {
result.push((word, None));
break;
}
}
result
}

/// test if a character is NOT part of an interval numeric amount
fn not_interval_amount(c: char) -> bool {
!c.is_ascii_digit() && c != '.' && c != '-'
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -2202,6 +2235,78 @@ mod tests {
)
.unwrap(),
);

// no units
assert_eq!(
Interval::new(1, 0, 0),
Interval::parse("1", &config).unwrap()
);
assert_eq!(
Interval::new(42, 0, 0),
Interval::parse("42", &config).unwrap()
);
assert_eq!(
Interval::new(0, 0, 42_000_000_000),
Interval::parse("42", &IntervalParseConfig::new(IntervalUnit::Second)).unwrap()
);

// shorter units
assert_eq!(
Interval::new(1, 0, 0),
Interval::parse("1 mon", &config).unwrap()
);
assert_eq!(
Interval::new(1, 0, 0),
Interval::parse("1 mons", &config).unwrap()
);
assert_eq!(
Interval::new(0, 0, 1_000_000),
Interval::parse("1 ms", &config).unwrap()
);
assert_eq!(
Interval::new(0, 0, 1_000),
Interval::parse("1 us", &config).unwrap()
);

// no space
assert_eq!(
Interval::new(0, 0, 1_000),
Interval::parse("1us", &config).unwrap()
);
assert_eq!(
Interval::new(0, 0, NANOS_PER_SECOND),
Interval::parse("1s", &config).unwrap()
);
assert_eq!(
Interval::new(1, 2, 10_864_000_000_000),
Interval::parse("1mon 2days 3hr 1min 4sec", &config).unwrap()
);

assert_eq!(
Interval::new(
-13i32,
-8i32,
-NANOS_PER_HOUR
- NANOS_PER_MINUTE
- NANOS_PER_SECOND
- (1.11_f64 * NANOS_PER_MILLIS as f64) as i64
),
Interval::parse(
"-1year -1month -1week -1day -1 hour -1 minute -1 second -1.11millisecond",
&config
)
.unwrap(),
);

assert_eq!(
Interval::parse("1h s", &config).unwrap_err().to_string(),
r#"Parser error: Invalid input syntax for type interval: "1h s""#
);

assert_eq!(
Interval::parse("1XX", &config).unwrap_err().to_string(),
r#"Parser error: Invalid input syntax for type interval: "1XX""#
);
}

#[test]
Expand Down Expand Up @@ -2625,4 +2730,16 @@ mod tests {
assert_eq!(TimestampNanosecondType::parse(""), None);
assert_eq!(Date32Type::parse(""), None);
}

#[test]
fn test_parse_interval_month_day_nano_config() {
let interval = parse_interval_month_day_nano_config(
"1",
IntervalParseConfig::new(IntervalUnit::Second),
)
.unwrap();
assert_eq!(interval.months, 0);
assert_eq!(interval.days, 0);
assert_eq!(interval.nanoseconds, NANOS_PER_SECOND);
}
}

0 comments on commit a693f0f

Please sign in to comment.