Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve ScalarUDFImpl docs #14248

Merged
merged 2 commits into from
Jan 23, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ impl ScalarUDF {
self.inner.return_type_from_exprs(args, schema, arg_types)
}

/// Return the datatype this function returns given the input argument types.
///
/// See [`ScalarUDFImpl::return_type_from_args`] for more details.
pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
self.inner.return_type_from_args(args)
}
Expand Down Expand Up @@ -433,7 +436,6 @@ impl ReturnInfo {
/// # use datafusion_expr::{col, ColumnarValue, Documentation, ScalarFunctionArgs, Signature, Volatility};
/// # use datafusion_expr::{ScalarUDFImpl, ScalarUDF};
/// # use datafusion_expr::scalar_doc_sections::DOC_SECTION_MATH;
///
/// /// This struct for a simple UDF that adds one to an int32
/// #[derive(Debug)]
/// struct AddOne {
Expand Down Expand Up @@ -494,7 +496,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// Returns this function's name
fn name(&self) -> &str;

/// Returns the user-defined display name of the UDF given the arguments
/// Returns the user-defined display name of function, given the arguments
///
/// This can be used to customize the output column name generated by this
/// function.
///
/// Defaults to `name(args[0], args[1], ...)`
fn display_name(&self, args: &[Expr]) -> Result<String> {
let names: Vec<String> = args.iter().map(ToString::to_string).collect();
// TODO: join with ", " to standardize the formatting of Vec<Expr>, <https://github.com/apache/datafusion/issues/10364>
Expand Down Expand Up @@ -522,7 +529,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// # Notes
///
/// If you provide an implementation for [`Self::return_type_from_args`],
/// DataFusion will not call `return_type` (this function). In this case it
/// DataFusion will not call `return_type` (this function). In such cases
/// is recommended to return [`DataFusionError::Internal`].
///
/// [`DataFusionError::Internal`]: datafusion_common::DataFusionError::Internal
Expand All @@ -538,18 +545,24 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
self.return_type(arg_types)
}

/// What [`DataType`] will be returned by this function, given the
/// arguments?
///
/// Note most UDFs should implement [`Self::return_type`] and not this
/// function. The output type for most functions only depends on the types
/// of their inputs (e.g. `sqrt(f32)` is always `f32`).
/// What type will be returned by this function, given the arguments?
///
/// By default, this function calls [`Self::return_type`] with the
/// types of each argument.
///
/// This method can be overridden for functions that return different
/// *types* based on the *values* of their arguments.
/// # Notes
///
/// Most UDFs should implement [`Self::return_type`] and not this
/// function as the output type for most functions only depends on the types
/// of their inputs (e.g. `sqrt(f32)` is always `f32`).
///
/// This function can be used for more advanced cases such as:
///
/// 1. specifying nullability
/// 2. return types based on the **values** of the arguments (rather than
/// their **types**.
///
/// # Output Type based on Values
///
/// For example, the following two function calls get the same argument
/// types (something and a `Utf8` string) but return different types based
Expand All @@ -558,9 +571,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// * `arrow_cast(x, 'Int16')` --> `Int16`
/// * `arrow_cast(x, 'Float32')` --> `Float32`
///
/// # Notes:
/// # Requirements
///
/// This function must consistently return the same type for the same
/// This function **must** consistently return the same type for the same
/// logical input even if the input is simplified (e.g. it must return the same
/// value for `('foo' | 'bar')` as it does for ('foobar').
fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result<ReturnInfo> {
Expand Down
Loading