Skip to content

Commit

Permalink
adding safe support to to_date and to_timestamp functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Omega359 committed Sep 21, 2024
1 parent f2159e6 commit 07a0a90
Show file tree
Hide file tree
Showing 4 changed files with 337 additions and 84 deletions.
9 changes: 8 additions & 1 deletion datafusion/core/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,14 @@ impl ListingTable {
// Add the partition columns to the file schema
let mut builder = SchemaBuilder::from(file_schema.as_ref().to_owned());
for (part_col_name, part_col_type) in &options.table_partition_cols {
builder.push(Field::new(part_col_name, part_col_type.clone(), false));
// only add the partition if it is not already in the file_schema
if !file_schema
.fields
.iter()
.any(|f| f.name().eq_ignore_ascii_case(part_col_name))
{
builder.push(Field::new(part_col_name, part_col_type.clone(), false));
}
}

let table = Self {
Expand Down
59 changes: 47 additions & 12 deletions datafusion/functions/src/datetime/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ pub(crate) fn handle<'a, O, F, S>(
args: &'a [ColumnarValue],
op: F,
name: &str,
safe: bool,
) -> Result<ColumnarValue>
where
O: ArrowPrimitiveType,
Expand All @@ -191,14 +192,25 @@ where
match &args[0] {
ColumnarValue::Array(a) => match a.data_type() {
DataType::Utf8 | DataType::LargeUtf8 => Ok(ColumnarValue::Array(Arc::new(
unary_string_to_primitive_function::<i32, O, _>(&[a.as_ref()], op, name)?,
unary_string_to_primitive_function::<i32, O, _>(
&[a.as_ref()],
op,
name,
safe,
)?,
))),
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => {
let result = a.as_ref().map(|x| (op)(x)).transpose()?;
Ok(ColumnarValue::Scalar(S::scalar(result)))
let result = a.as_ref().map(|x| op(x)).transpose();
if let Ok(v) = result {
Ok(ColumnarValue::Scalar(S::scalar(v)))
} else if safe {
Ok(ColumnarValue::Scalar(S::scalar(None)))
} else {
Err(result.err().unwrap())
}
}
other => exec_err!("Unsupported data type {other:?} for function {name}"),
},
Expand All @@ -213,6 +225,7 @@ pub(crate) fn handle_multiple<'a, O, F, S, M>(
op: F,
op2: M,
name: &str,
safe: bool,
) -> Result<ColumnarValue>
where
O: ArrowPrimitiveType,
Expand Down Expand Up @@ -244,7 +257,9 @@ where
}

Ok(ColumnarValue::Array(Arc::new(
strings_to_primitive_function::<i32, O, _, _>(args, op, op2, name)?,
strings_to_primitive_function::<i32, O, _, _>(
args, op, op2, name, safe,
)?,
)))
}
other => {
Expand All @@ -254,12 +269,13 @@ where
// if the first argument is a scalar utf8 all arguments are expected to be scalar utf8
ColumnarValue::Scalar(scalar) => match scalar {
ScalarValue::Utf8(a) | ScalarValue::LargeUtf8(a) => {
let mut val: Option<Result<ColumnarValue>> = None;
let mut err: Option<DataFusionError> = None;

let a = a.as_ref();
// ASK: Why do we trust `a` to be non-null at this point?
let a = unwrap_or_internal_err!(a);

let mut ret = None;

for (pos, v) in args.iter().enumerate().skip(1) {
let ColumnarValue::Scalar(
ScalarValue::Utf8(x) | ScalarValue::LargeUtf8(x),
Expand All @@ -271,17 +287,26 @@ where
if let Some(s) = x {
match op(a.as_str(), s.as_str()) {
Ok(r) => {
ret = Some(Ok(ColumnarValue::Scalar(S::scalar(Some(
val = Some(Ok(ColumnarValue::Scalar(S::scalar(Some(
op2(r),
)))));
break;
}
Err(e) => ret = Some(Err(e)),
Err(e) => err = Some(e),
}
}
}

unwrap_or_internal_err!(ret)
if let Some(v) = val {
v
} else if safe {
Ok(ColumnarValue::Scalar(S::scalar(None)))
} else {
match err {
Some(e) => Err(e),
None => Ok(ColumnarValue::Scalar(S::scalar(None))),
}
}
}
other => {
exec_err!("Unsupported data type {other:?} for function {name}")
Expand All @@ -300,12 +325,13 @@ where
/// This function errors iff:
/// * the number of arguments is not > 1 or
/// * the array arguments are not castable to a `GenericStringArray` or
/// * the function `op` errors for all input
/// * the function `op` errors for all input and safe is false
pub(crate) fn strings_to_primitive_function<'a, T, O, F, F2>(
args: &'a [ColumnarValue],
op: F,
op2: F2,
name: &str,
safe: bool,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
Expand Down Expand Up @@ -375,6 +401,7 @@ where
};

val.transpose()
.or_else(|e| if safe { Ok(None) } else { Err(e) })
})
.collect()
}
Expand All @@ -386,11 +413,12 @@ where
/// This function errors iff:
/// * the number of arguments is not 1 or
/// * the first argument is not castable to a `GenericStringArray` or
/// * the function `op` errors
/// * the function `op` errors and safe is false
fn unary_string_to_primitive_function<'a, T, O, F>(
args: &[&'a dyn Array],
op: F,
name: &str,
safe: bool,
) -> Result<PrimitiveArray<O>>
where
O: ArrowPrimitiveType,
Expand All @@ -408,5 +436,12 @@ where
let array = as_generic_string_array::<T>(args[0])?;

// first map is the iterator, second is for the `Option<_>`
array.iter().map(|x| x.map(&op).transpose()).collect()
array
.iter()
.map(|x| {
x.map(&op)
.transpose()
.or_else(|e| if safe { Ok(None) } else { Err(e) })
})
.collect()
}
21 changes: 18 additions & 3 deletions datafusion/functions/src/datetime/to_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,21 @@

use std::any::Any;

use crate::datetime::common::*;
use arrow::compute::CastOptions;
use arrow::datatypes::DataType;
use arrow::datatypes::DataType::Date32;
use arrow::error::ArrowError::ParseError;
use arrow::{array::types::Date32Type, compute::kernels::cast_utils::Parser};

use crate::datetime::common::*;
use datafusion_common::error::DataFusionError;
use datafusion_common::{arrow_err, exec_err, internal_datafusion_err, Result};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};

#[derive(Debug)]
pub struct ToDateFunc {
signature: Signature,
/// how to handle cast or parsing failures, either return NULL (safe=true) or return ERR (safe=false)
safe: bool,
}

impl Default for ToDateFunc {
Expand All @@ -42,6 +44,14 @@ impl ToDateFunc {
pub fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
safe: false,
}
}

pub fn new_with_safe(safe: bool) -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
safe,
}
}

Expand All @@ -57,6 +67,7 @@ impl ToDateFunc {
)),
},
"to_date",
self.safe,
),
2.. => handle_multiple::<Date32Type, _, Date32Type, _>(
args,
Expand All @@ -71,6 +82,7 @@ impl ToDateFunc {
},
|n| n,
"to_date",
self.safe,
),
0 => exec_err!("Unsupported 0 argument count for function to_date"),
}
Expand Down Expand Up @@ -110,7 +122,10 @@ impl ScalarUDFImpl for ToDateFunc {
| DataType::Null
| DataType::Float64
| DataType::Date32
| DataType::Date64 => args[0].cast_to(&DataType::Date32, None),
| DataType::Date64 => match self.safe {
true => args[0].cast_to(&DataType::Date32, Some(&CastOptions::default())),
false => args[0].cast_to(&DataType::Date32, None),
},
DataType::Utf8 => self.to_date(args),
other => {
exec_err!("Unsupported data type {:?} for function to_date", other)
Expand Down
Loading

0 comments on commit 07a0a90

Please sign in to comment.