Skip to content

Commit

Permalink
refactor(union_extract): new udf api, docs macro, use any signature
Browse files Browse the repository at this point in the history
  • Loading branch information
gstvg committed Jan 20, 2025
1 parent 30940f7 commit fad85ea
Showing 1 changed file with 74 additions and 63 deletions.
137 changes: 74 additions & 63 deletions datafusion/functions/src/core/union_extract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,30 @@ use datafusion_common::cast::as_union_array;
use datafusion_common::{
exec_datafusion_err, exec_err, internal_err, ExprSchema, Result, ScalarValue,
};
use datafusion_expr::{ColumnarValue, Expr};
use datafusion_doc::Documentation;
use datafusion_expr::{ColumnarValue, Expr, ScalarFunctionArgs};
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};

use datafusion_macros::user_doc;

#[user_doc(
doc_section(include = "true", label = "Union Functions"),
description = "Returns the value of the given field when selected, or NULL otherwise.",
syntax_example = "union_extract(union, field_name)",
sql_example = r#"```sql
❯ select union_column, union_extract(union_column, 'a'), union_extract(union_column, 'b') from table_with_union;
+--------------+----------------------------------+----------------------------------+
| union_column | union_extract(union_column, 'a') | union_extract(union_column, 'b') |
+--------------+----------------------------------+----------------------------------+
| {a=1} | 1 | |
| {b=3.0} | | 3.0 |
| {a=4} | 4 | |
| {b=} | | |
| {a=} | | |
+--------------+----------------------------------+----------------------------------+
```"#,
standard_argument(name = "union", prefix = "Union"),
standard_argument(name = "field_name", prefix = "String")
)]
#[derive(Debug)]
pub struct UnionExtractFun {
signature: Signature,
Expand All @@ -38,7 +59,7 @@ impl Default for UnionExtractFun {
impl UnionExtractFun {
pub fn new() -> Self {
Self {
signature: Signature::user_defined(Volatility::Immutable),
signature: Signature::any(2, Volatility::Immutable),
}
}
}
Expand Down Expand Up @@ -93,38 +114,37 @@ impl ScalarUDFImpl for UnionExtractFun {
Ok(field.data_type().clone())
}

fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
let args = args.args;

if args.len() != 2 {
return exec_err!(
"union_extract expects 2 arguments, got {} instead",
args.len()
);
}

let union = &args[0];

let target_name = match &args[1] {
ColumnarValue::Scalar(ScalarValue::Utf8(Some(target_name))) => Ok(target_name),
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => exec_err!("union_extract second argument must be a non-null string literal, got a null instead"),
_ => exec_err!("union_extract second argument must be a non-null string literal, got {} instead", &args[1].data_type()),
};

match union {
match &args[0] {
ColumnarValue::Array(array) => {
let _union_array = as_union_array(&array).map_err(|_| {
let union_array = as_union_array(&array).map_err(|_| {
exec_datafusion_err!(
"union_extract first argument must be a union, got {} instead",
array.data_type()
)
})?;

// Ok(arrow::compute::kernels::union_extract::union_extract(
// &union_array,
// target_name,
// )?)
Ok(ColumnarValue::Array(std::sync::Arc::new(
arrow::array::Int32Array::from(vec![1, 2]),
)))
Ok(ColumnarValue::Array(
arrow::compute::kernels::union_extract::union_extract(
union_array,
target_name?,
)?,
))
}
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => {
let target_name = target_name?;
Expand All @@ -146,29 +166,8 @@ impl ScalarUDFImpl for UnionExtractFun {
}
}

fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
if arg_types.len() != 2 {
return exec_err!(
"union_extract expects 2 arguments, got {} instead",
arg_types.len()
);
}

if !matches!(arg_types[0], DataType::Union(_, _)) {
return exec_err!(
"union_extract first argument must be a union, got {} instead",
arg_types[0]
);
}

if !matches!(arg_types[1], DataType::Utf8) {
return exec_err!(
"union_extract second argument must be a non-null string literal, got {} instead",
arg_types[1]
);
}

Ok(arg_types.to_vec())
fn documentation(&self) -> Option<&Documentation> {
self.doc()
}
}

Expand All @@ -184,7 +183,7 @@ mod tests {

use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};

use super::UnionExtractFun;

Expand All @@ -201,36 +200,48 @@ mod tests {
],
);

let result = fun.invoke(&[
ColumnarValue::Scalar(ScalarValue::Union(
None,
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
])?;
let result = fun.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Union(
None,
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
],
number_rows: 1,
return_type: &DataType::Utf8,
})?;

assert_scalar(result, ScalarValue::Utf8(None));

let result = fun.invoke(&[
ColumnarValue::Scalar(ScalarValue::Union(
Some((3, Box::new(ScalarValue::Int32(Some(42))))),
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
])?;
let result = fun.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Union(
Some((3, Box::new(ScalarValue::Int32(Some(42))))),
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
],
number_rows: 1,
return_type: &DataType::Utf8,
})?;

assert_scalar(result, ScalarValue::Utf8(None));

let result = fun.invoke(&[
ColumnarValue::Scalar(ScalarValue::Union(
Some((1, Box::new(ScalarValue::new_utf8("42")))),
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
])?;
let result = fun.invoke_with_args(ScalarFunctionArgs {
args: vec![
ColumnarValue::Scalar(ScalarValue::Union(
Some((1, Box::new(ScalarValue::new_utf8("42")))),
fields.clone(),
UnionMode::Dense,
)),
ColumnarValue::Scalar(ScalarValue::new_utf8("str")),
],
number_rows: 1,
return_type: &DataType::Utf8,
})?;

assert_scalar(result, ScalarValue::new_utf8("42"));

Expand Down

0 comments on commit fad85ea

Please sign in to comment.