Skip to content

Commit

Permalink
add ScalarFunction and InList
Browse files Browse the repository at this point in the history
  • Loading branch information
yyy1000 committed Mar 23, 2024
1 parent 01ff537 commit 05bf92d
Showing 1 changed file with 105 additions and 6 deletions.
111 changes: 105 additions & 6 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,48 @@ impl Unparser<'_> {
match expr {
Expr::InList(InList {
expr,
list: _,
negated: _,
list,
negated,
}) => {
not_impl_err!("Unsupported expression: {expr:?}")
let list_expr = list
.iter()
.map(|e| self.expr_to_sql(e))
.collect::<Result<Vec<_>>>()?;
Ok(ast::Expr::InList {
expr: Box::new(self.expr_to_sql(expr)?),
list: list_expr,
negated: *negated,
})
}
Expr::ScalarFunction(ScalarFunction { .. }) => {
not_impl_err!("Unsupported expression: {expr:?}")
Expr::ScalarFunction(ScalarFunction { func_def, args }) => {
let func_name = func_def.name();

let args = args
.iter()
.map(|e| {
if matches!(e, Expr::Wildcard { qualifier: None }) {
Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
} else {
self.expr_to_sql(e).map(|e| {
FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
})
}
})
.collect::<Result<Vec<_>>>()?;

Ok(ast::Expr::Function(Function {
name: ast::ObjectName(vec![Ident {
value: func_name.to_string(),
quote_style: None,
}]),
args,
filter: None,
null_treatment: None,
over: None,
distinct: false,
special: false,
order_by: vec![],
}))
}
Expr::Between(Between {
expr,
Expand Down Expand Up @@ -526,13 +561,52 @@ impl Unparser<'_> {

#[cfg(test)]
mod tests {
use std::{any::Any, sync::Arc};

use datafusion_common::TableReference;
use datafusion_expr::{col, expr::AggregateFunction, lit};
use datafusion_expr::{
col, expr::AggregateFunction, lit, ColumnarValue, ScalarFunctionDefinition, ScalarUDF, ScalarUDFImpl, Signature, Volatility
};

use crate::unparser::dialect::CustomDialect;

use super::*;

/// Mocked UDF
#[derive(Debug)]
struct DummyUDF {
signature: Signature,
}

impl DummyUDF {
fn new() -> Self {
Self {
signature: Signature::variadic_any(Volatility::Immutable),
}
}
}

impl ScalarUDFImpl for DummyUDF {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"dummy_udf"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
Ok(DataType::Int32)
}

fn invoke(&self, _args: &[ColumnarValue]) -> Result<ColumnarValue> {
unimplemented!("DummyUDF::invoke")
}
}
// See sql::tests for E2E tests.

#[test]
Expand Down Expand Up @@ -561,6 +635,31 @@ mod tests {
}),
r#"CAST("a" AS INTEGER UNSIGNED)"#,
),
(
Expr::InList(InList {
expr: Box::new(col("a")),
list: vec![lit(1), lit(2), lit(3)],
negated: false,
}),
r#""a" IN (1, 2, 3)"#,
),
(
Expr::InList(InList {
expr: Box::new(col("a")),
list: vec![lit(1), lit(2), lit(3)],
negated: true,
}),
r#""a" NOT IN (1, 2, 3)"#,
),
(
Expr::ScalarFunction(ScalarFunction {
func_def: ScalarFunctionDefinition::UDF(Arc::new(
ScalarUDF::new_from_impl(DummyUDF::new()),
)),
args: vec![col("a"), col("b")],
}),
r#"dummy_udf("a", "b")"#,
),
(
Expr::Literal(ScalarValue::Date64(Some(0))),
r#"CAST('1970-01-01 00:00:00' AS DATETIME)"#,
Expand Down

0 comments on commit 05bf92d

Please sign in to comment.