Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(functions): add new function: map_pick #15573

Merged
merged 10 commits into from
Nov 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 189 additions & 50 deletions src/query/functions/src/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ pub fn register(registry: &mut FunctionRegistry) {
vectorize_with_builder_2_arg::<ArrayType<GenericType<0>>, ArrayType<GenericType<1>>, MapType<GenericType<0>, GenericType<1>>>(
|keys, vals, output, ctx| {
let key_type = &ctx.generics[0];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time() {
if !check_valid_map_key_type(key_type) {
ctx.set_error(output.len(), format!("map keys can not be {}", key_type));
} else if keys.len() != vals.len() {
ctx.set_error(output.len(), format!(
Expand Down Expand Up @@ -241,43 +237,7 @@ pub fn register(registry: &mut FunctionRegistry) {
);

registry.register_function_factory("map_delete", |_, args_type| {
if args_type.len() < 2 {
return None;
}

let map_key_type = match args_type[0].remove_nullable() {
DataType::Map(box DataType::Tuple(type_tuple)) if type_tuple.len() == 2 => {
Some(type_tuple[0].clone())
}
DataType::EmptyMap => None,
_ => return None,
};

if let Some(map_key_type) = map_key_type {
for arg_type in args_type.iter().skip(1) {
if arg_type != &map_key_type {
return None;
}
}
} else {
let key_type = &args_type[1];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time()
{
return None;
}
for arg_type in args_type.iter().skip(2) {
if arg_type != key_type {
return None;
}
}
}

let return_type = args_type[0].clone();

let return_type = check_map_arg_types(args_type)?;
Some(Arc::new(Function {
signature: FunctionSignature {
name: "map_delete".to_string(),
Expand All @@ -297,31 +257,47 @@ pub fn register(registry: &mut FunctionRegistry) {
let mut output_map_builder =
ColumnBuilder::with_capacity(&return_type, input_length.unwrap_or(1));

let mut delete_key_list = HashSet::new();
for idx in 0..(input_length.unwrap_or(1)) {
let input_map_sref = match &args[0] {
let input_map = match &args[0] {
ValueRef::Scalar(map) => map.clone(),
ValueRef::Column(map) => unsafe { map.index_unchecked(idx) },
};

match &input_map_sref {
match &input_map {
ScalarRef::Null | ScalarRef::EmptyMap => {
output_map_builder.push_default();
}
ScalarRef::Map(col) => {
let mut delete_key_list = HashSet::new();

delete_key_list.clear();
for input_key_item in args.iter().skip(1) {
let input_key = match &input_key_item {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe {
col.index_unchecked(idx)
},
};

delete_key_list.insert(input_key.to_owned());
match input_key {
ScalarRef::EmptyArray | ScalarRef::Null => {}
ScalarRef::Array(arr_col) => {
for arr_key in arr_col.iter() {
if arr_key == ScalarRef::Null {
continue;
}
delete_key_list.insert(arr_key.to_owned());
}
}
_ => {
delete_key_list.insert(input_key.to_owned());
}
}
}
if delete_key_list.is_empty() {
output_map_builder.push(input_map);
continue;
}

let inner_builder_type = match input_map_sref.infer_data_type() {
let inner_builder_type = match input_map.infer_data_type() {
DataType::Map(box typ) => typ,
_ => unreachable!(),
};
Expand All @@ -330,7 +306,7 @@ pub fn register(registry: &mut FunctionRegistry) {
ColumnBuilder::with_capacity(&inner_builder_type, col.len());

let input_map: KvColumn<AnyType, AnyType> =
MapType::try_downcast_scalar(&input_map_sref).unwrap();
MapType::try_downcast_scalar(&input_map).unwrap();

input_map.iter().for_each(|(map_key, map_value)| {
if !delete_key_list.contains(&map_key.to_owned()) {
Expand Down Expand Up @@ -371,4 +347,167 @@ pub fn register(registry: &mut FunctionRegistry) {
.any(|(k, _)| k == key)
},
);

registry.register_function_factory("map_pick", |_, args_type: &[DataType]| {
let return_type = check_map_arg_types(args_type)?;
Some(Arc::new(Function {
signature: FunctionSignature {
name: "map_pick".to_string(),
args_type: args_type.to_vec(),
return_type: args_type[0].clone(),
},
eval: FunctionEval::Scalar {
calc_domain: Box::new(|_, args_domain| {
FunctionDomain::Domain(args_domain[0].clone())
}),
eval: Box::new(move |args, _ctx| {
let input_length = args.iter().find_map(|arg| match arg {
ValueRef::Column(col) => Some(col.len()),
_ => None,
});

let mut output_map_builder =
ColumnBuilder::with_capacity(&return_type, input_length.unwrap_or(1));

let mut pick_key_list = HashSet::new();
for idx in 0..(input_length.unwrap_or(1)) {
let input_map = match &args[0] {
ValueRef::Scalar(map) => map.clone(),
ValueRef::Column(map) => unsafe { map.index_unchecked(idx) },
};

match &input_map {
ScalarRef::Null | ScalarRef::EmptyMap => {
output_map_builder.push_default();
}
ScalarRef::Map(col) => {
pick_key_list.clear();
for input_key_item in args.iter().skip(1) {
let input_key = match &input_key_item {
ValueRef::Scalar(scalar) => scalar.clone(),
ValueRef::Column(col) => unsafe {
col.index_unchecked(idx)
},
};
match input_key {
ScalarRef::EmptyArray | ScalarRef::Null => {}
ScalarRef::Array(arr_col) => {
for arr_key in arr_col.iter() {
if arr_key == ScalarRef::Null {
continue;
}
pick_key_list.insert(arr_key.to_owned());
}
}
_ => {
pick_key_list.insert(input_key.to_owned());
}
}
}
if pick_key_list.is_empty() {
output_map_builder.push_default();
continue;
}

let inner_builder_type = match input_map.infer_data_type() {
DataType::Map(box typ) => typ,
_ => unreachable!(),
};

let mut filtered_kv_builder =
ColumnBuilder::with_capacity(&inner_builder_type, col.len());

let input_map: KvColumn<AnyType, AnyType> =
MapType::try_downcast_scalar(&input_map).unwrap();

input_map.iter().for_each(|(map_key, map_value)| {
if pick_key_list.contains(&map_key.to_owned()) {
filtered_kv_builder.push(ScalarRef::Tuple(vec![
map_key.clone(),
map_value.clone(),
]));
}
});
output_map_builder
.push(ScalarRef::Map(filtered_kv_builder.build()));
}
_ => unreachable!(),
}
}

match input_length {
Some(_) => Value::Column(output_map_builder.build()),
None => Value::Scalar(output_map_builder.build_scalar()),
}
}),
},
}))
});
}

// Check map function arg types
// 1. The first arg must be a Map or EmptyMap.
// 2. The second arg can be an Array or EmptyArray.
// 3. Multiple args with same key type is also valid.
fn check_map_arg_types(args_type: &[DataType]) -> Option<DataType> {
if args_type.len() < 2 {
return None;
}

let map_key_type = match args_type[0].remove_nullable() {
DataType::Map(box DataType::Tuple(type_tuple)) if type_tuple.len() == 2 => {
Some(type_tuple[0].clone())
}
DataType::EmptyMap => None,
_ => return None,
};

// the second argument can be an array of keys.
let (is_array, array_key_type) = match args_type[1].remove_nullable() {
DataType::Array(box key_type) => (true, Some(key_type.remove_nullable())),
DataType::EmptyArray => (true, None),
_ => (false, None),
};
if is_array && args_type.len() != 2 {
return None;
}
if let Some(map_key_type) = map_key_type {
if is_array {
if let Some(array_key_type) = array_key_type {
if array_key_type != DataType::Null && array_key_type != map_key_type {
return None;
}
}
} else {
for arg_type in args_type.iter().skip(1) {
let arg_type = arg_type.remove_nullable();
if arg_type != DataType::Null && arg_type != map_key_type {
return None;
}
}
}
} else if is_array {
if let Some(array_key_type) = array_key_type {
if array_key_type != DataType::Null && !check_valid_map_key_type(&array_key_type) {
return None;
}
}
} else {
for arg_type in args_type.iter().skip(1) {
let arg_type = arg_type.remove_nullable();
if arg_type != DataType::Null && !check_valid_map_key_type(&arg_type) {
return None;
}
}
}
let return_type = args_type[0].clone();
Some(return_type)
}

fn check_valid_map_key_type(key_type: &DataType) -> bool {
key_type.is_boolean()
|| key_type.is_string()
|| key_type.is_numeric()
|| key_type.is_decimal()
|| key_type.is_date_or_date_time()
}
38 changes: 38 additions & 0 deletions src/query/functions/tests/it/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ fn test_map() {
test_map_cat(file);
test_map_delete(file);
test_map_contains_key(file);
test_map_pick(file);
}

fn test_map_cat(file: &mut impl Write) {
Expand Down Expand Up @@ -296,6 +297,11 @@ fn test_map_delete(file: &mut impl Write) {
"map_delete({'k1': 'v1', 'k2': 'v2', 'k3': 'v3', 'k4': 'v4'}, 'k3', 'k2')",
&[],
);
run_ast(
file,
"map_delete({'k1': 'v1', 'k2': 'v2', 'k3': 'v3', 'k4': 'v4'}, ['k3', 'k2'])",
&[],
);

// Deleting keys from a nested map
let columns = [
Expand Down Expand Up @@ -381,3 +387,35 @@ fn test_map_delete(file: &mut impl Write) {
&columns,
);
}

fn test_map_pick(file: &mut impl Write) {
run_ast(file, "map_pick({'a':1,'b':2,'c':3}, 'a', 'b')", &[]);
run_ast(file, "map_pick({'a':1,'b':2,'c':3}, ['a', 'b'])", &[]);
run_ast(file, "map_pick({'a':1,'b':2,'c':3}, [])", &[]);
run_ast(file, "map_pick({1:'a',2:'b',3:'c'}, 1, 3)", &[]);
run_ast(file, "map_pick({}, 'a', 'b')", &[]);
run_ast(file, "map_pick({}, [])", &[]);

let columns = [
("a_col", StringType::from_data(vec!["a", "b", "c"])),
("b_col", StringType::from_data(vec!["d", "e", "f"])),
("c_col", StringType::from_data(vec!["x", "y", "z"])),
(
"d_col",
StringType::from_data_with_validity(vec!["v1", "v2", "v3"], vec![true, true, true]),
),
(
"e_col",
StringType::from_data_with_validity(vec!["v4", "v5", ""], vec![true, true, false]),
),
(
"f_col",
StringType::from_data_with_validity(vec!["v6", "", "v7"], vec![true, false, true]),
),
];
run_ast(
file,
"map_pick(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'b')",
&columns,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2547,6 +2547,7 @@ Functions overloads:
0 map_keys(Map(Nothing)) :: Array(Nothing)
1 map_keys(Map(T0, T1)) :: Array(T0)
2 map_keys(Map(T0, T1) NULL) :: Array(T0) NULL
0 map_pick FACTORY
0 map_size(Map(Nothing)) :: UInt8
1 map_size(Map(T0, T1)) :: UInt64
2 map_size(Map(T0, T1) NULL) :: UInt64 NULL
Expand Down
Loading
Loading