diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 69456446f52b..9debc7c00115 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -25,9 +25,9 @@ checksum = "aae1277d39aeec15cb388266ecc24b11c80469deae6067e17a1a7aa9e5c1f234" [[package]] name = "ahash" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42cd52102d3df161c77a887b608d7a4897d7cc112886a9537b738a887a03aaff" +checksum = "d713b3834d76b85304d4d525563c1276e2e30dc97cc67bfb4585a4a29fc2c89f" dependencies = [ "cfg-if", "const-random", @@ -345,9 +345,9 @@ dependencies = [ [[package]] name = "assert_cmd" -version = "2.0.13" +version = "2.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00ad3f3a942eee60335ab4342358c161ee296829e0d16ff42fc1d6cb07815467" +checksum = "ed72493ac66d5804837f480ab3766c72bdfab91a65e565fc54fa9e42db0073a8" dependencies = [ "anstyle", "bstr", @@ -384,7 +384,7 @@ checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -802,9 +802,9 @@ dependencies = [ [[package]] name = "bumpalo" -version = "3.15.0" +version = "3.15.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a994c2b3ca201d9b263612a374263f05e7adde37c4707f693dcd375076d1f" +checksum = "c764d619ca78fccbf3069b37bd7af92577f044bb15236036662d79b6559f25b7" [[package]] name = "byteorder" @@ -851,11 +851,10 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.83" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +checksum = "7f9fa1897e4325be0d68d48df6aa1a71ac2ed4d27723887e7754192705350730" dependencies = [ - "jobserver", "libc", ] @@ -1074,7 +1073,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1231,6 +1230,7 @@ name = "datafusion-functions" version = "36.0.0" dependencies = [ "arrow", + "arrow-schema", "base64", "datafusion-common", "datafusion-execution", @@ -1607,7 +1607,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -1952,15 +1952,6 @@ version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" -[[package]] -name = "jobserver" -version = "0.1.28" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab46a6e9526ddef3ae7f787c06f0f2600639ba80ea3eade3d8e670a2230f51d6" -dependencies = [ - "libc", -] - [[package]] name = "js-sys" version = "0.3.68" @@ -2336,7 +2327,7 @@ dependencies = [ "quick-xml", "rand", "reqwest", - "ring 0.17.7", + "ring 0.17.8", "rustls-pemfile 2.1.0", "serde", "serde_json", @@ -2524,7 +2515,7 @@ checksum = "266c042b60c9c76b8d53061e52b2e0d1116abc57cefc8c5cd671619a56ac3690" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -2800,16 +2791,17 @@ dependencies = [ [[package]] name = "ring" -version = "0.17.7" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" dependencies = [ "cc", + "cfg-if", "getrandom", "libc", "spin 0.9.8", "untrusted 0.9.0", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -2891,7 +2883,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", - "ring 0.17.7", + "ring 0.17.8", "rustls-webpki", "sct", ] @@ -2939,7 +2931,7 @@ version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -2974,9 +2966,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.16" +version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" [[package]] name = "same-file" @@ -3008,7 +3000,7 @@ version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -3037,9 +3029,9 @@ dependencies = [ [[package]] name = "semver" -version = "1.0.21" +version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" [[package]] name = "seq-macro" @@ -3049,29 +3041,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "870026e60fa08c69f064aa766c10f10b1d62db9ccd4d0abb206472bee0ce3b32" +checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.196" +version = "1.0.197" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33c85360c95e7d137454dc81d9a4ed2b8efd8fbe19cee57357b32b9771fccb67" +checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] name = "serde_json" -version = "1.0.113" +version = "1.0.114" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69801b70b1c3dac963ecb03a364ba0ceda9cf60c71cfe475e99864759c8b8a79" +checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0" dependencies = [ "itoa", "ryu", @@ -3199,7 +3191,7 @@ checksum = "01b2e185515564f15375f593fb966b5718bc624ba77fe49fa4616ad619690554" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3245,7 +3237,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3258,7 +3250,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3280,9 +3272,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.49" +version = "2.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915aea9e586f80826ee59f8453c1101f9d1c4b3964cd2460185ee8e299ada496" +checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb" dependencies = [ "proc-macro2", "quote", @@ -3345,9 +3337,9 @@ checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" [[package]] name = "textwrap" -version = "0.16.0" +version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d" +checksum = "23d434d3f8967a09480fb04132ebe0a3e088c173e6d0ee7897abbdf4eab0f8b9" [[package]] name = "thiserror" @@ -3366,7 +3358,7 @@ checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3461,7 +3453,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3558,7 +3550,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3603,7 +3595,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] @@ -3626,9 +3618,9 @@ checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-normalization" -version = "0.1.22" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c5713f0fc4b5db668a2ac63cdb7bb4469d8c9fed047b1d0292cc7b0ce2ba921" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" dependencies = [ "tinyvec", ] @@ -3757,7 +3749,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", "wasm-bindgen-shared", ] @@ -3791,7 +3783,7 @@ checksum = "642f325be6301eb8107a83d12a8ac6c1e1c54345a7ef1a9261962dfefda09e66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3831,7 +3823,7 @@ version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" dependencies = [ - "ring 0.17.7", + "ring 0.17.8", "untrusted 0.9.0", ] @@ -4049,7 +4041,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.49", + "syn 2.0.50", ] [[package]] diff --git a/datafusion-examples/examples/to_char.rs b/datafusion-examples/examples/to_char.rs index e99f69fbcd55..ef616d72cc1c 100644 --- a/datafusion-examples/examples/to_char.rs +++ b/datafusion-examples/examples/to_char.rs @@ -125,14 +125,14 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+------------+", - "| t.values |", - "+------------+", - "| 2020-09-01 |", - "| 2020-09-02 |", - "| 2020-09-03 |", - "| 2020-09-04 |", - "+------------+", + "+-----------------------------------+", + "| arrow_cast(t.values,Utf8(\"Utf8\")) |", + "+-----------------------------------+", + "| 2020-09-01 |", + "| 2020-09-02 |", + "| 2020-09-03 |", + "| 2020-09-04 |", + "+-----------------------------------+", ], &result ); @@ -146,11 +146,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+-----------------------------------------------------------------+", - "| to_char(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", - "+-----------------------------------------------------------------+", - "| 03-08-2023 14:38:50 |", - "+-----------------------------------------------------------------+", + "+-------------------------------------------------------------------------------------------------------------+", + "| to_char(arrow_cast(Utf8(\"2023-08-03 14:38:50Z\"),Utf8(\"Timestamp(Second, None)\")),Utf8(\"%d-%m-%Y %H:%M:%S\")) |", + "+-------------------------------------------------------------------------------------------------------------+", + "| 03-08-2023 14:38:50 |", + "+-------------------------------------------------------------------------------------------------------------+", ], &result ); @@ -165,11 +165,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+---------------------------------------+", - "| to_char(Int64(123456),Utf8(\"pretty\")) |", - "+---------------------------------------+", - "| 1 days 10 hours 17 mins 36 secs |", - "+---------------------------------------+", + "+----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"pretty\")) |", + "+----------------------------------------------------------------------------+", + "| 1 days 10 hours 17 mins 36 secs |", + "+----------------------------------------------------------------------------+", ], &result ); @@ -184,11 +184,11 @@ async fn main() -> Result<()> { assert_batches_eq!( &[ - "+----------------------------------------+", - "| to_char(Int64(123456),Utf8(\"iso8601\")) |", - "+----------------------------------------+", - "| PT123456S |", - "+----------------------------------------+", + "+-----------------------------------------------------------------------------+", + "| to_char(arrow_cast(Int64(123456),Utf8(\"Duration(Second)\")),Utf8(\"iso8601\")) |", + "+-----------------------------------------------------------------------------+", + "| PT123456S |", + "+-----------------------------------------------------------------------------+", ], &result ); diff --git a/datafusion/core/tests/user_defined/user_defined_aggregates.rs b/datafusion/core/tests/user_defined/user_defined_aggregates.rs index 0b29ad10d670..76e743589ff7 100644 --- a/datafusion/core/tests/user_defined/user_defined_aggregates.rs +++ b/datafusion/core/tests/user_defined/user_defined_aggregates.rs @@ -185,11 +185,11 @@ async fn test_udaf_shadows_builtin_fn() { // compute with builtin `sum` aggregator let expected = [ - "+-------------+", - "| SUM(t.time) |", - "+-------------+", - "| 19000 |", - "+-------------+", + "+---------------------------------------+", + "| SUM(arrow_cast(t.time,Utf8(\"Int64\"))) |", + "+---------------------------------------+", + "| 19000 |", + "+---------------------------------------+", ]; assert_batches_eq!(expected, &execute(&ctx, sql).await.unwrap()); diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml index f63f18f955de..2c0565d66490 100644 --- a/datafusion/functions/Cargo.toml +++ b/datafusion/functions/Cargo.toml @@ -46,6 +46,7 @@ path = "src/lib.rs" [dependencies] arrow = { workspace = true } +arrow-schema = { workspace = true } base64 = { version = "0.21", optional = true } datafusion-common = { workspace = true } datafusion-execution = { workspace = true } diff --git a/datafusion/sql/src/expr/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs similarity index 88% rename from datafusion/sql/src/expr/arrow_cast.rs rename to datafusion/functions/src/core/arrow_cast.rs index 9a0d61f41c01..adecc96a3ba6 100644 --- a/datafusion/sql/src/expr/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -19,16 +19,65 @@ //! casting to arbitrary arrow types (rather than SQL types) use std::{fmt::Display, iter::Peekable, str::Chars, sync::Arc}; +use std::any::Any; +use arrow::compute::cast; use arrow_schema::{DataType, Field, IntervalUnit, TimeUnit}; -use datafusion_common::{ - plan_datafusion_err, DFSchema, DataFusionError, Result, ScalarValue, -}; +use datafusion_common::{plan_datafusion_err, DataFusionError, Result, ScalarValue, ExprSchema, plan_err}; -use datafusion_common::plan_err; -use datafusion_expr::{Expr, ExprSchemable}; +use datafusion_expr::{ColumnarValue, Expr, ScalarUDFImpl, Signature, Volatility}; -pub const ARROW_CAST_NAME: &str = "arrow_cast"; +#[derive(Debug)] +pub(super) struct ArrowCastFunc { + signature: Signature, +} + +impl ArrowCastFunc { + pub fn new() -> Self{ + Self { + signature: + Signature::any(2, Volatility::Immutable) + } + } + +} + +impl ScalarUDFImpl for ArrowCastFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "arrow_cast" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + parse_data_type(&arg_types[1].to_string()) + } + + fn return_type_from_exprs(&self, args: &[Expr], _schema: &dyn ExprSchema) -> Result { + if args.len() != 2 { + return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); + } + let arg1 = args.get(1); + + match arg1 { + Some(Expr::Literal(ScalarValue::Utf8(Some(arg1)))) => { + let data_type = parse_data_type(arg1)?; + Ok(data_type) + } + _ => plan_err!("arrow_cast requires its second argument to be a constant string, got {:?}",arg1.unwrap().to_string()) + } + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + create_arrow_cast(args) + } +} /// Create an [`Expr`] that evaluates the `arrow_cast` function /// @@ -52,26 +101,31 @@ pub const ARROW_CAST_NAME: &str = "arrow_cast"; /// select arrow_cast(column_x, 'Float64') /// ``` /// [`BuiltinScalarFunction`]: datafusion_expr::BuiltinScalarFunction -pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result { +fn create_arrow_cast(args: &[ColumnarValue]) -> Result { if args.len() != 2 { return plan_err!("arrow_cast needs 2 arguments, {} provided", args.len()); } - let arg1 = args.pop().unwrap(); - let arg0 = args.pop().unwrap(); - - // arg1 must be a string - let data_type_string = if let Expr::Literal(ScalarValue::Utf8(Some(v))) = arg1 { - v - } else { - return plan_err!( - "arrow_cast requires its second argument to be a constant string, got {arg1}" - ); - }; + let arg1 = &args[1]; + let arg0 = &args[0]; - // do the actual lookup to the appropriate data type - let data_type = parse_data_type(&data_type_string)?; - - arg0.cast_to(&data_type, schema) + match (arg0, arg1) { + (ColumnarValue::Scalar(arg0),ColumnarValue::Scalar(ScalarValue::Utf8(Some(arg1)))) =>{ + let data_type = parse_data_type(arg1)?; + match arg0.cast_to(&data_type) { + Ok(val0) => Ok(ColumnarValue::Scalar(val0)), + Err(error) => plan_err!("{:?}",error.to_string()), + } + } + (ColumnarValue::Array(arg0),ColumnarValue::Scalar(ScalarValue::Utf8(Some(arg1)))) =>{ + let data_type = parse_data_type( arg1)?; + match cast(&arg0,&data_type) { + Ok(val0) => Ok(ColumnarValue::Array(val0)), + Err(error) => plan_err!("{:?}",error.to_string()), + } + } + (ColumnarValue::Scalar(_arg0), ColumnarValue::Scalar(arg1)) => plan_err!("arrow_cast requires its second argument to be a constant string, got {arg1}"), + _ => plan_err!("arrow_cast requires two scalar value as input. got {:?} and {:?}",arg0,arg1) + } } /// Parses `str` into a `DataType`. @@ -80,22 +134,8 @@ pub fn create_arrow_cast(mut args: Vec, schema: &DFSchema) -> Result /// impl, and maintains the invariant that /// `parse_data_type(data_type.to_string()) == data_type` /// -/// Example: -/// ``` -/// # use datafusion_sql::parse_data_type; -/// # use arrow_schema::DataType; -/// let display_value = "Int32"; -/// -/// // "Int32" is the Display value of `DataType` -/// assert_eq!(display_value, &format!("{}", DataType::Int32)); -/// -/// // parse_data_type coverts "Int32" back to `DataType`: -/// let data_type = parse_data_type(display_value).unwrap(); -/// assert_eq!(data_type, DataType::Int32); -/// ``` -/// /// Remove if added to arrow: -pub fn parse_data_type(val: &str) -> Result { +fn parse_data_type(val: &str) -> Result { Parser::new(val).parse() } @@ -647,6 +687,7 @@ impl Display for Token { #[cfg(test)] mod test { + use arrow::array::{ArrayRef, Int64Array, Int8Array}; use arrow_schema::{IntervalUnit, TimeUnit}; use super::*; @@ -847,4 +888,29 @@ mod test { println!(" Ok"); } } + + #[test] + fn test_arrow_cast_scalar() -> Result<()>{ + let input_arg0 = ColumnarValue::Scalar(ScalarValue::Int8(Some(100i8))); + let input_arg1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some("Int64".to_string()))); + let result = create_arrow_cast(&[input_arg0,input_arg1])?; + let result = result.into_array(1).expect("Failed to cast values"); + let expected = Arc::new(Int64Array::from(vec![100])) as ArrayRef; + assert_eq!(expected.as_ref(),result.as_ref()); + + Ok(()) + } + + #[test] + fn test_arrow_cast_array() -> Result<()>{ + let input_arg0 = ColumnarValue::Array(Arc::new(Int8Array::from(vec![100,101,102])) as ArrayRef); + let input_arg1 = ColumnarValue::Scalar(ScalarValue::Utf8(Some("Int64".to_string()))); + let result = create_arrow_cast(&[input_arg0,input_arg1])?; + let result = result.into_array(1).expect("Failed to cast values"); + let expected = Arc::new(Int64Array::from(vec![100,101,102])) as ArrayRef; + + assert_eq!(expected.as_ref(),result.as_ref()); + + Ok(()) + } } diff --git a/datafusion/functions/src/core/mod.rs b/datafusion/functions/src/core/mod.rs index 9aab4bd450d1..690daf966c1e 100644 --- a/datafusion/functions/src/core/mod.rs +++ b/datafusion/functions/src/core/mod.rs @@ -18,12 +18,15 @@ //! "core" DataFusion functions mod nullif; +mod arrow_cast; // create UDFs make_udf_function!(nullif::NullIfFunc, NULLIF, nullif); +make_udf_function!(arrow_cast::ArrowCastFunc, ARROW_CAST, arrow_cast); // Export the functions out of this package, both as expr_fn as well as a list of functions export_functions!( - (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression.") + (nullif, arg_1 arg_2, "returns NULL if value1 equals value2; otherwise it returns value1. This can be used to perform the inverse operation of the COALESCE expression."), + (arrow_cast, arg_1 arg_2, "returns arg_1 parsed to the `arrow_type` given the second argument. This can be used to cast to a specific `arrow_type`.") ); diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 64b8d6957d2b..25530bd39f6c 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -32,8 +32,6 @@ use sqlparser::ast::{ }; use std::str::FromStr; -use super::arrow_cast::ARROW_CAST_NAME; - impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, @@ -234,12 +232,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun, args, distinct, filter, order_by, ))); }; - - // Special case arrow_cast (as its type is dependent on its argument value) - if name == ARROW_CAST_NAME { - let args = self.function_args_to_expr(args, schema, planner_context)?; - return super::arrow_cast::create_arrow_cast(args, schema); - } } // Could not find the relevant function, so return an error diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index ecf510da7bce..c62f2a1b61c8 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -pub(crate) mod arrow_cast; mod binary_op; mod function; mod grouping_set; diff --git a/datafusion/sql/src/lib.rs b/datafusion/sql/src/lib.rs index d805f61397e9..23f9a64d7c81 100644 --- a/datafusion/sql/src/lib.rs +++ b/datafusion/sql/src/lib.rs @@ -40,5 +40,4 @@ pub mod utils; mod values; pub use datafusion_common::{ResolvedTableReference, TableReference}; -pub use expr::arrow_cast::parse_data_type; pub use sqlparser; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 55551d1d25a3..0a67e962376c 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2579,7 +2579,7 @@ fn approx_median_window() { fn select_arrow_cast() { let sql = "SELECT arrow_cast(1234, 'Float64'), arrow_cast('foo', 'LargeUtf8')"; let expected = "\ - Projection: CAST(Int64(1234) AS Float64), CAST(Utf8(\"foo\") AS LargeUtf8)\ + Projection: arrow_cast(Int64(1234), Utf8(\"Float64\")), arrow_cast(Utf8(\"foo\"), Utf8(\"LargeUtf8\"))\ \n EmptyRelation"; quick_test(sql, expected); } @@ -2673,11 +2673,17 @@ fn logical_plan_with_dialect_and_options( dialect: &dyn Dialect, options: ParserOptions, ) -> Result { - let context = MockContextProvider::default().with_udf(make_udf( - "nullif", - vec![DataType::Int32, DataType::Int32], - DataType::Int32, - )); + let context = MockContextProvider::default() + .with_udf(make_udf( + "nullif", + vec![DataType::Int32, DataType::Int32], + DataType::Int32, + )) + .with_udf(make_udf( + "arrow_cast", + vec![DataType::Int64, DataType::Utf8], + DataType::Float64, + )); let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); diff --git a/datafusion/sqllogictest/test_files/arrow_typeof.slt b/datafusion/sqllogictest/test_files/arrow_typeof.slt index 8e2a091423da..cdf218c35951 100644 --- a/datafusion/sqllogictest/test_files/arrow_typeof.slt +++ b/datafusion/sqllogictest/test_files/arrow_typeof.slt @@ -95,7 +95,7 @@ SELECT arrow_cast('1', 'Int16') query error Error during planning: arrow_cast needs 2 arguments, 1 provided SELECT arrow_cast('1') -query error Error during planning: arrow_cast requires its second argument to be a constant string, got Int64\(43\) +query error DataFusion error: Error during planning: arrow_cast requires its second argument to be a constant string, got "Int64\(43\)" SELECT arrow_cast('1', 43) query error Error unrecognized word: unknown @@ -315,7 +315,7 @@ select arrow_cast(interval '30 minutes', 'Duration(Second)'); ---- 0 days 0 hours 30 mins 0 secs -query error DataFusion error: Error during planning: Cannot automatically convert Utf8 to Duration\(Second\) +query error DataFusion error: Error during planning: "Arrow error: Cast error: Casting from Utf8 to Duration\(Second\) not supported" select arrow_cast('30 minutes', 'Duration(Second)'); @@ -336,7 +336,7 @@ select arrow_cast(timestamp '2000-01-01T00:00:00Z', 'Timestamp(Nanosecond, Some( ---- 2000-01-01T00:00:00+08:00 -statement error Arrow error: Parser error: Invalid timezone "\+25:00": '\+25:00' is not a valid timezone +statement error DataFusion error: Error during planning: "Arrow error: Parser error: Invalid timezone \\"\+25:00\\": '\+25:00' is not a valid timezone" select arrow_cast(timestamp '2000-01-01T00:00:00', 'Timestamp(Nanosecond, Some( "+25:00" ))'); @@ -405,7 +405,7 @@ select arrow_cast([1], 'FixedSizeList(1, Int64)'); ---- [1] -query error DataFusion error: Arrow error: Cast error: Cannot cast to FixedSizeList\(4\): value at index 0 has length 3 +query error DataFusion error: Error during planning: "Arrow error: Cast error: Cannot cast to FixedSizeList\(4\): value at index 0 has length 3" select arrow_cast(make_array(1, 2, 3), 'FixedSizeList(4, Int64)'); query ? @@ -421,4 +421,4 @@ FixedSizeList(Field { name: "item", data_type: Int64, nullable: true, dict_id: 0 query ? select arrow_cast([1, 2, 3], 'FixedSizeList(3, Int64)'); ---- -[1, 2, 3] \ No newline at end of file +[1, 2, 3] diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index b61bee670811..22fc744b5098 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -274,5 +274,5 @@ query PI SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; ---- -query +statement ok drop table hits; \ No newline at end of file diff --git a/datafusion/sqllogictest/test_files/expr.slt b/datafusion/sqllogictest/test_files/expr.slt index 9e4e3aa8185d..9697c32c5f88 100644 --- a/datafusion/sqllogictest/test_files/expr.slt +++ b/datafusion/sqllogictest/test_files/expr.slt @@ -63,7 +63,7 @@ SELECT NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL NULL # test_array_cast_invalid_timezone_will_panic -statement error Parser error: Invalid timezone "Foo": 'Foo' is not a valid timezone +statement error DataFusion error: Error during planning: "Arrow error: Parser error: Invalid timezone \\"Foo\\": 'Foo' is not a valid timezone" SELECT arrow_cast('2021-01-02T03:04:00', 'Timestamp(Nanosecond, Some("Foo"))') # test_array_index