Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: ignore case when matching function name #16912

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,7 @@ tower = { version = "0.5.1", features = ["util"] }
tower-service = "0.3.3"
twox-hash = "1.6.3"
typetag = "0.2.3"
unicase = "2.8.0"
unicode-segmentation = "1.10.1"
unindent = "0.2"
url = "2.3.1"
Expand Down
1 change: 1 addition & 0 deletions src/query/functions/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ siphasher = { workspace = true }
strength_reduce = { workspace = true }
stringslice = { workspace = true }
twox-hash = { workspace = true }
unicase = { workspace = true }

[dev-dependencies]
comfy-table = { workspace = true }
Expand Down
81 changes: 44 additions & 37 deletions src/query/functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,17 @@
use aggregates::AggregateFunctionFactory;
use ctor::ctor;
use databend_common_expression::FunctionRegistry;
use unicase::Ascii;

pub mod aggregates;
mod cast_rules;
pub mod scalars;
pub mod srfs;

pub fn is_builtin_function(name: &str) -> bool {
BUILTIN_FUNCTIONS.contains(name)
|| AggregateFunctionFactory::instance().contains(name)
let name = Ascii::new(name);
BUILTIN_FUNCTIONS.contains(name.into_inner())
|| AggregateFunctionFactory::instance().contains(name.into_inner())
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
|| GENERAL_SEARCH_FUNCTIONS.contains(&name)
Expand All @@ -44,53 +46,58 @@ pub fn is_builtin_function(name: &str) -> bool {
// The plan of search function, async function and udf contains some arguments defined in meta,
// which may be modified by user at any time. Those functions are not not suitable for caching.
pub fn is_cacheable_function(name: &str) -> bool {
BUILTIN_FUNCTIONS.contains(name)
|| AggregateFunctionFactory::instance().contains(name)
let name = Ascii::new(name);
BUILTIN_FUNCTIONS.contains(name.into_inner())
|| AggregateFunctionFactory::instance().contains(name.into_inner())
|| GENERAL_WINDOW_FUNCTIONS.contains(&name)
|| GENERAL_LAMBDA_FUNCTIONS.contains(&name)
}

#[ctor]
pub static BUILTIN_FUNCTIONS: FunctionRegistry = builtin_functions();

pub const ASYNC_FUNCTIONS: [&str; 2] = ["nextval", "dict_get"];
pub const ASYNC_FUNCTIONS: [Ascii<&str>; 2] = [Ascii::new("nextval"), Ascii::new("dict_get")];

pub const GENERAL_WINDOW_FUNCTIONS: [&str; 13] = [
"row_number",
"rank",
"dense_rank",
"percent_rank",
"lag",
"lead",
"first_value",
"first",
"last_value",
"last",
"nth_value",
"ntile",
"cume_dist",
pub const GENERAL_WINDOW_FUNCTIONS: [Ascii<&str>; 13] = [
Ascii::new("row_number"),
Ascii::new("rank"),
Ascii::new("dense_rank"),
Ascii::new("percent_rank"),
Ascii::new("lag"),
Ascii::new("lead"),
Ascii::new("first_value"),
Ascii::new("first"),
Ascii::new("last_value"),
Ascii::new("last"),
Ascii::new("nth_value"),
Ascii::new("ntile"),
Ascii::new("cume_dist"),
];

pub const GENERAL_LAMBDA_FUNCTIONS: [&str; 16] = [
"array_transform",
"array_apply",
"array_map",
"array_filter",
"array_reduce",
"json_array_transform",
"json_array_apply",
"json_array_map",
"json_array_filter",
"json_array_reduce",
"map_filter",
"map_transform_keys",
"map_transform_values",
"json_map_filter",
"json_map_transform_keys",
"json_map_transform_values",
pub const GENERAL_LAMBDA_FUNCTIONS: [Ascii<&str>; 16] = [
Ascii::new("array_transform"),
Ascii::new("array_apply"),
Ascii::new("array_map"),
Ascii::new("array_filter"),
Ascii::new("array_reduce"),
Ascii::new("json_array_transform"),
Ascii::new("json_array_apply"),
Ascii::new("json_array_map"),
Ascii::new("json_array_filter"),
Ascii::new("json_array_reduce"),
Ascii::new("map_filter"),
Ascii::new("map_transform_keys"),
Ascii::new("map_transform_values"),
Ascii::new("json_map_filter"),
Ascii::new("json_map_transform_keys"),
Ascii::new("json_map_transform_values"),
];

pub const GENERAL_SEARCH_FUNCTIONS: [&str; 3] = ["match", "query", "score"];
pub const GENERAL_SEARCH_FUNCTIONS: [Ascii<&str>; 3] = [
Ascii::new("match"),
Ascii::new("query"),
Ascii::new("score"),
];

fn builtin_functions() -> FunctionRegistry {
let mut registry = FunctionRegistry::empty();
Expand Down
1 change: 1 addition & 0 deletions src/query/sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ serde = { workspace = true }
sha2 = { workspace = true }
simsearch = { workspace = true }
tokio = { workspace = true }
unicase = { workspace = true }
url = { workspace = true }

[lints]
Expand Down
115 changes: 69 additions & 46 deletions src/query/sql/src/planner/semantic/type_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ use itertools::Itertools;
use jsonb::keypath::KeyPath;
use jsonb::keypath::KeyPaths;
use simsearch::SimSearch;
use unicase::Ascii;

use super::name_resolution::NameResolutionContext;
use super::normalize_identifier;
Expand Down Expand Up @@ -184,7 +185,7 @@ pub struct TypeChecker<'a> {
// This is used to check if there is nested aggregate function.
in_aggregate_function: bool,

// true if current expr is inside an window function.
// true if current expr is inside a window function.
// This is used to allow aggregation function in window's aggregate function.
in_window_function: bool,
forbid_udf: bool,
Expand Down Expand Up @@ -721,8 +722,9 @@ impl<'a> TypeChecker<'a> {
} => {
let func_name = normalize_identifier(name, self.name_resolution_ctx).to_string();
let func_name = func_name.as_str();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about handling case here?

Suggested change
let func_name = func_name.as_str();
let func_name = func_name.to_lowercase();

Copy link
Author

@notauserx notauserx Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion.

Since sugar and built-in functions are matched ignoring case, that's why I created a sep function and moved the logic inside there, and we could do the same in is_builtin_function, otherwise each caller has to convert to lower case before calling the functions

Copy link
Collaborator

@andylokandy andylokandy Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I see. Then I'll suggest replacing &str in all_sugar_functions and builtin_functions to https://crates.io/crates/unicase

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've used the Ascii type from the unicase crate, a case-insensitive wrapper for ASCII strings. Is this assumption correct? If not, I can switch to the UniCase type, which provides a case-insensitive wrapper for general strings.

let uni_case_func_name = Ascii::new(func_name);
if !is_builtin_function(func_name)
&& !Self::all_sugar_functions().contains(&func_name)
&& !Self::all_sugar_functions().contains(&uni_case_func_name)
{
if let Some(udf) = self.resolve_udf(*span, func_name, args)? {
return Ok(udf);
Expand All @@ -732,15 +734,35 @@ impl<'a> TypeChecker<'a> {
.all_function_names()
.into_iter()
.chain(AggregateFunctionFactory::instance().registered_names())
.chain(GENERAL_WINDOW_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_LAMBDA_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(GENERAL_SEARCH_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(ASYNC_FUNCTIONS.iter().cloned().map(str::to_string))
.chain(
GENERAL_WINDOW_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
GENERAL_LAMBDA_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
GENERAL_SEARCH_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
ASYNC_FUNCTIONS
.iter()
.cloned()
.map(|ascii| ascii.into_inner().to_string()),
)
.chain(
Self::all_sugar_functions()
.iter()
.cloned()
.map(str::to_string),
.map(|ascii| ascii.into_inner().to_string()),
);
let mut engine: SimSearch<String> = SimSearch::new();
for func_name in all_funcs {
Expand Down Expand Up @@ -769,15 +791,15 @@ impl<'a> TypeChecker<'a> {
// check window function legal
if window.is_some()
&& !AggregateFunctionFactory::instance().contains(func_name)
&& !GENERAL_WINDOW_FUNCTIONS.contains(&func_name)
&& !GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name)
{
return Err(ErrorCode::SemanticError(
"only window and aggregate functions allowed in window syntax",
)
.set_span(*span));
}
// check lambda function legal
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
if lambda.is_some() && !GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
return Err(ErrorCode::SemanticError(
"only lambda functions allowed in lambda syntax",
)
Expand All @@ -786,7 +808,7 @@ impl<'a> TypeChecker<'a> {

let args: Vec<&Expr> = args.iter().collect();

if GENERAL_WINDOW_FUNCTIONS.contains(&func_name) {
if GENERAL_WINDOW_FUNCTIONS.contains(&uni_case_func_name) {
// general window function
if window.is_none() {
return Err(ErrorCode::SemanticError(format!(
Expand Down Expand Up @@ -852,7 +874,7 @@ impl<'a> TypeChecker<'a> {
// aggregate function
Box::new((new_agg_func.into(), data_type))
}
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&func_name) {
} else if GENERAL_LAMBDA_FUNCTIONS.contains(&uni_case_func_name) {
if lambda.is_none() {
return Err(ErrorCode::SemanticError(format!(
"function {func_name} must have a lambda expression",
Expand All @@ -861,8 +883,8 @@ impl<'a> TypeChecker<'a> {
}
let lambda = lambda.as_ref().unwrap();
self.resolve_lambda_function(*span, func_name, &args, lambda)?
} else if GENERAL_SEARCH_FUNCTIONS.contains(&func_name) {
match func_name {
} else if GENERAL_SEARCH_FUNCTIONS.contains(&uni_case_func_name) {
match func_name.to_lowercase().as_str() {
"score" => self.resolve_score_search_function(*span, func_name, &args)?,
"match" => self.resolve_match_search_function(*span, func_name, &args)?,
"query" => self.resolve_query_search_function(*span, func_name, &args)?,
Expand All @@ -874,7 +896,7 @@ impl<'a> TypeChecker<'a> {
.set_span(*span));
}
}
} else if ASYNC_FUNCTIONS.contains(&func_name) {
} else if ASYNC_FUNCTIONS.contains(&uni_case_func_name) {
self.resolve_async_function(*span, func_name, &args)?
} else if BUILTIN_FUNCTIONS
.get_property(func_name)
Expand Down Expand Up @@ -1436,7 +1458,7 @@ impl<'a> TypeChecker<'a> {
self.in_window_function = false;

// If { IGNORE | RESPECT } NULLS is not specified, the default is RESPECT NULLS
// (i.e. a NULL value will be returned if the expression contains a NULL value and it is the first value in the expression).
// (i.e. a NULL value will be returned if the expression contains a NULL value, and it is the first value in the expression).
let ignore_null = if let Some(ignore_null) = window_ignore_null {
*ignore_null
} else {
Expand Down Expand Up @@ -2081,7 +2103,7 @@ impl<'a> TypeChecker<'a> {
param_count: usize,
span: Span,
) -> Result<()> {
// json lambda functions are casted to array or map, ignored here.
// json lambda functions are cast to array or map, ignored here.
let expected_count = if func_name == "array_reduce" {
2
} else if func_name.starts_with("array") {
Expand Down Expand Up @@ -3121,36 +3143,37 @@ impl<'a> TypeChecker<'a> {
Ok(Box::new((subquery_expr.into(), data_type)))
}

pub fn all_sugar_functions() -> &'static [&'static str] {
&[
"database",
"currentdatabase",
"current_database",
"version",
"user",
"currentuser",
"current_user",
"current_role",
"connection_id",
"timezone",
"nullif",
"ifnull",
"nvl",
"nvl2",
"is_null",
"is_error",
"error_or",
"coalesce",
"last_query_id",
"array_sort",
"array_aggregate",
"to_variant",
"try_to_variant",
"greatest",
"least",
"stream_has_data",
"getvariable",
]
pub fn all_sugar_functions() -> &'static [Ascii<&'static str>] {
static FUNCTIONS: &[Ascii<&'static str>] = &[
Ascii::new("database"),
Ascii::new("currentdatabase"),
Ascii::new("current_database"),
Ascii::new("version"),
Ascii::new("user"),
Ascii::new("currentuser"),
Ascii::new("current_user"),
Ascii::new("current_role"),
Ascii::new("connection_id"),
Ascii::new("timezone"),
Ascii::new("nullif"),
Ascii::new("ifnull"),
Ascii::new("nvl"),
Ascii::new("nvl2"),
Ascii::new("is_null"),
Ascii::new("is_error"),
Ascii::new("error_or"),
Ascii::new("coalesce"),
Ascii::new("last_query_id"),
Ascii::new("array_sort"),
Ascii::new("array_aggregate"),
Ascii::new("to_variant"),
Ascii::new("try_to_variant"),
Ascii::new("greatest"),
Ascii::new("least"),
Ascii::new("stream_has_data"),
Ascii::new("getvariable"),
];
FUNCTIONS
}

fn try_rewrite_sugar_function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ select * from Student
statement ok
set unquoted_ident_case_sensitive = 1

statement ok
SELECT VERSION()

statement error (?s)1025,.*Unknown table `default`\.`default`\.student \.
INSERT INTO student VALUES(1)

Expand Down