diff --git a/Cargo.lock b/Cargo.lock index e3301c4633db..a438e43023c3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3415,6 +3415,7 @@ dependencies = [ "strength_reduce", "stringslice", "twox-hash", + "unicase", ] [[package]] @@ -4059,6 +4060,7 @@ dependencies = [ "sha2", "simsearch", "tokio", + "unicase", "url", ] @@ -14944,6 +14946,12 @@ dependencies = [ "version_check", ] +[[package]] +name = "unicase" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" + [[package]] name = "unicode-bidi" version = "0.3.15" diff --git a/Cargo.toml b/Cargo.toml index 87be2fb47c36..b32d0f5b968d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -484,6 +484,7 @@ tower = { version = "0.5.1", features = ["util"] } tower-service = "0.3.3" twox-hash = "1.6.3" typetag = "0.2.3" +unicase = "2.8.0" unicode-segmentation = "1.10.1" unindent = "0.2" url = "2.3.1" diff --git a/src/query/functions/Cargo.toml b/src/query/functions/Cargo.toml index 60fcc19f1fb1..c02c7ad6a898 100644 --- a/src/query/functions/Cargo.toml +++ b/src/query/functions/Cargo.toml @@ -61,6 +61,7 @@ siphasher = { workspace = true } strength_reduce = { workspace = true } stringslice = { workspace = true } twox-hash = { workspace = true } +unicase = { workspace = true } [dev-dependencies] comfy-table = { workspace = true } diff --git a/src/query/functions/src/lib.rs b/src/query/functions/src/lib.rs index 0710fdcedd42..e32b6895b666 100644 --- a/src/query/functions/src/lib.rs +++ b/src/query/functions/src/lib.rs @@ -26,6 +26,7 @@ use aggregates::AggregateFunctionFactory; use ctor::ctor; use databend_common_expression::FunctionRegistry; +use unicase::Ascii; pub mod aggregates; mod cast_rules; @@ -33,8 +34,9 @@ pub mod scalars; pub mod srfs; pub fn is_builtin_function(name: &str) -> bool { - BUILTIN_FUNCTIONS.contains(name) - || AggregateFunctionFactory::instance().contains(name) + let name = Ascii::new(name); + BUILTIN_FUNCTIONS.contains(name.into_inner()) + || AggregateFunctionFactory::instance().contains(name.into_inner()) || GENERAL_WINDOW_FUNCTIONS.contains(&name) || GENERAL_LAMBDA_FUNCTIONS.contains(&name) || GENERAL_SEARCH_FUNCTIONS.contains(&name) @@ -44,8 +46,9 @@ pub fn is_builtin_function(name: &str) -> bool { // The plan of search function, async function and udf contains some arguments defined in meta, // which may be modified by user at any time. Those functions are not not suitable for caching. pub fn is_cacheable_function(name: &str) -> bool { - BUILTIN_FUNCTIONS.contains(name) - || AggregateFunctionFactory::instance().contains(name) + let name = Ascii::new(name); + BUILTIN_FUNCTIONS.contains(name.into_inner()) + || AggregateFunctionFactory::instance().contains(name.into_inner()) || GENERAL_WINDOW_FUNCTIONS.contains(&name) || GENERAL_LAMBDA_FUNCTIONS.contains(&name) } @@ -53,44 +56,48 @@ pub fn is_cacheable_function(name: &str) -> bool { #[ctor] pub static BUILTIN_FUNCTIONS: FunctionRegistry = builtin_functions(); -pub const ASYNC_FUNCTIONS: [&str; 2] = ["nextval", "dict_get"]; +pub const ASYNC_FUNCTIONS: [Ascii<&str>; 2] = [Ascii::new("nextval"), Ascii::new("dict_get")]; -pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [ - "row_number", - "rank", - "dense_rank", - "percent_rank", - "lag", - "lead", - "first_value", - "first", - "last_value", - "last", - "nth_value", - "ntile", - "cume_dist", +pub const GENERAL_WINDOW_FUNCTIONS: [Ascii<&str>; 13] = [ + Ascii::new("row_number"), + Ascii::new("rank"), + Ascii::new("dense_rank"), + Ascii::new("percent_rank"), + Ascii::new("lag"), + Ascii::new("lead"), + Ascii::new("first_value"), + Ascii::new("first"), + Ascii::new("last_value"), + Ascii::new("last"), + Ascii::new("nth_value"), + Ascii::new("ntile"), + Ascii::new("cume_dist"), ]; -pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 16] = [ - "array_transform", - "array_apply", - "array_map", - "array_filter", - "array_reduce", - "json_array_transform", - "json_array_apply", - "json_array_map", - "json_array_filter", - "json_array_reduce", - "map_filter", - "map_transform_keys", - "map_transform_values", - "json_map_filter", - "json_map_transform_keys", - "json_map_transform_values", +pub const GENERAL_LAMBDA_FUNCTIONS: [Ascii<&str>; 16] = [ + Ascii::new("array_transform"), + Ascii::new("array_apply"), + Ascii::new("array_map"), + Ascii::new("array_filter"), + Ascii::new("array_reduce"), + Ascii::new("json_array_transform"), + Ascii::new("json_array_apply"), + Ascii::new("json_array_map"), + Ascii::new("json_array_filter"), + Ascii::new("json_array_reduce"), + Ascii::new("map_filter"), + Ascii::new("map_transform_keys"), + Ascii::new("map_transform_values"), + Ascii::new("json_map_filter"), + Ascii::new("json_map_transform_keys"), + Ascii::new("json_map_transform_values"), ]; -pub const GENERAL_SEARCH_FUNCTIONS: [&str; 3] = ["match", "query", "score"]; +pub const GENERAL_SEARCH_FUNCTIONS: [Ascii<&str>; 3] = [ + Ascii::new("match"), + Ascii::new("query"), + Ascii::new("score"), +]; fn builtin_functions() -> FunctionRegistry { let mut registry = FunctionRegistry::empty(); diff --git a/src/query/sql/Cargo.toml b/src/query/sql/Cargo.toml index 1f017a82dac1..6bd8beef2b18 100644 --- a/src/query/sql/Cargo.toml +++ b/src/query/sql/Cargo.toml @@ -74,6 +74,7 @@ serde = { workspace = true } sha2 = { workspace = true } simsearch = { workspace = true } tokio = { workspace = true } +unicase = { workspace = true } url = { workspace = true } [lints] diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index 38350db16026..ad3bf18426de 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -106,6 +106,7 @@ use itertools::Itertools; use jsonb::keypath::KeyPath; use jsonb::keypath::KeyPaths; use simsearch::SimSearch; +use unicase::Ascii; use super::name_resolution::NameResolutionContext; use super::normalize_identifier; @@ -184,7 +185,7 @@ pub struct TypeChecker<'a> { // This is used to check if there is nested aggregate function. in_aggregate_function: bool, - // true if current expr is inside an window function. + // true if current expr is inside a window function. // This is used to allow aggregation function in window's aggregate function. in_window_function: bool, forbid_udf: bool, @@ -721,8 +722,9 @@ impl<'a> TypeChecker<'a> { } => { let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string(); let func_name = func_name.as_str(); + let uni_case_func_name = Ascii::new(func_name); if !is_builtin_function(func_name) - && !Self::all_sugar_functions().contains(&func_name) + && !Self::all_sugar_functions().contains(&uni_case_func_name) { if let Some(udf) = self.resolve_udf(*span, func_name, args)? { return Ok(udf); @@ -732,15 +734,35 @@ impl<'a> TypeChecker<'a> { .all_function_names() .into_iter() .chain(AggregateFunctionFactory::instance().registered_names()) - .chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string)) - .chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string)) + .chain( + GENERAL_WINDOW_FUNCTIONS + .iter() + .cloned() + .map(|ascii| ascii.into_inner().to_string()), + ) + .chain( + GENERAL_LAMBDA_FUNCTIONS + .iter() + .cloned() + .map(|ascii| ascii.into_inner().to_string()), + ) + .chain( + GENERAL_SEARCH_FUNCTIONS + .iter() + .cloned() + .map(|ascii| ascii.into_inner().to_string()), + ) + .chain( + ASYNC_FUNCTIONS + .iter() + .cloned() + .map(|ascii| ascii.into_inner().to_string()), + ) .chain( Self::all_sugar_functions() .iter() .cloned() - .map(str::to_string), + .map(|ascii| ascii.into_inner().to_string()), ); let mut engine: SimSearch = SimSearch::new(); for func_name in all_funcs { @@ -769,7 +791,7 @@ impl<'a> TypeChecker<'a> { // check window function legal if window.is_some() && !AggregateFunctionFactory::instance().contains(func_name) - && !GENERAL_WINDOW_FUNCTIONS.contains(&func_name) + && !GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) { return Err(ErrorCode::SemanticError( "only window and aggregate functions allowed in window syntax", @@ -777,7 +799,7 @@ impl<'a> TypeChecker<'a> { .set_span(*span)); } // check lambda function legal - if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) { + if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) { return Err(ErrorCode::SemanticError( "only lambda functions allowed in lambda syntax", ) @@ -786,7 +808,7 @@ impl<'a> TypeChecker<'a> { let args: Vec<&Expr> = args.iter().collect(); - if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) { + if GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) { // general window function if window.is_none() { return Err(ErrorCode::SemanticError(format!( @@ -852,7 +874,7 @@ impl<'a> TypeChecker<'a> { // aggregate function Box::new((new_agg_func.into(), data_type)) } - } else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) { + } else if GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) { if lambda.is_none() { return Err(ErrorCode::SemanticError(format!( "function {func_name} must have a lambda expression", @@ -861,8 +883,8 @@ impl<'a> TypeChecker<'a> { } let lambda = lambda.as_ref().unwrap(); self.resolve_lambda_function(*span, func_name, &args, lambda)? - } else if GENERAL_SEARCH_FUNCTIONS.contains(&func_name) { - match func_name { + } else if GENERAL_SEARCH_FUNCTIONS.contains(&uni_case_func_name) { + match func_name.to_lowercase().as_str() { "score" => self.resolve_score_search_function(*span, func_name, &args)?, "match" => self.resolve_match_search_function(*span, func_name, &args)?, "query" => self.resolve_query_search_function(*span, func_name, &args)?, @@ -874,7 +896,7 @@ impl<'a> TypeChecker<'a> { .set_span(*span)); } } - } else if ASYNC_FUNCTIONS.contains(&func_name) { + } else if ASYNC_FUNCTIONS.contains(&uni_case_func_name) { self.resolve_async_function(*span, func_name, &args)? } else if BUILTIN_FUNCTIONS .get_property(func_name) @@ -1436,7 +1458,7 @@ impl<'a> TypeChecker<'a> { self.in_window_function = false; // If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS - // (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression). + // (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression). let ignore_null = if let Some(ignore_null) = window_ignore_null { *ignore_null } else { @@ -2081,7 +2103,7 @@ impl<'a> TypeChecker<'a> { param_count: usize, span: Span, ) -> Result<()> { - // json lambda functions are casted to array or map, ignored here. + // json lambda functions are cast to array or map, ignored here. let expected_count = if func_name == "array_reduce" { 2 } else if func_name.starts_with("array") { @@ -3121,36 +3143,37 @@ impl<'a> TypeChecker<'a> { Ok(Box::new((subquery_expr.into(), data_type))) } - pub fn all_sugar_functions() -> &'static [&'static str] { - &[ - "database", - "currentdatabase", - "current_database", - "version", - "user", - "currentuser", - "current_user", - "current_role", - "connection_id", - "timezone", - "nullif", - "ifnull", - "nvl", - "nvl2", - "is_null", - "is_error", - "error_or", - "coalesce", - "last_query_id", - "array_sort", - "array_aggregate", - "to_variant", - "try_to_variant", - "greatest", - "least", - "stream_has_data", - "getvariable", - ] + pub fn all_sugar_functions() -> &'static [Ascii<&'static str>] { + static FUNCTIONS: &[Ascii<&'static str>] = &[ + Ascii::new("database"), + Ascii::new("currentdatabase"), + Ascii::new("current_database"), + Ascii::new("version"), + Ascii::new("user"), + Ascii::new("currentuser"), + Ascii::new("current_user"), + Ascii::new("current_role"), + Ascii::new("connection_id"), + Ascii::new("timezone"), + Ascii::new("nullif"), + Ascii::new("ifnull"), + Ascii::new("nvl"), + Ascii::new("nvl2"), + Ascii::new("is_null"), + Ascii::new("is_error"), + Ascii::new("error_or"), + Ascii::new("coalesce"), + Ascii::new("last_query_id"), + Ascii::new("array_sort"), + Ascii::new("array_aggregate"), + Ascii::new("to_variant"), + Ascii::new("try_to_variant"), + Ascii::new("greatest"), + Ascii::new("least"), + Ascii::new("stream_has_data"), + Ascii::new("getvariable"), + ]; + FUNCTIONS } fn try_rewrite_sugar_function( diff --git a/tests/sqllogictests/suites/query/case_sensitivity/name_hit.test b/tests/sqllogictests/suites/query/case_sensitivity/name_hit.test index 8a8426a8f442..1c94b63b9687 100644 --- a/tests/sqllogictests/suites/query/case_sensitivity/name_hit.test +++ b/tests/sqllogictests/suites/query/case_sensitivity/name_hit.test @@ -28,6 +28,9 @@ select * from Student statement ok set unquoted_ident_case_sensitive = 1 +statement ok +SELECT VERSION() + statement error (?s)1025,.*Unknown table `default`\.`default`\.student \. INSERT INTO student VALUES(1)