Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(functions): add map_insert function #15567

Merged
merged 11 commits into from
Nov 5, 2024
148 changes: 148 additions & 0 deletions src/query/functions/src/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
use std::collections::HashSet;
use std::hash::Hash;

use databend_common_expression::types::array::ArrayColumn;
use databend_common_expression::types::map::KvColumn;
use databend_common_expression::types::map::KvPair;
use databend_common_expression::types::nullable::NullableDomain;
use databend_common_expression::types::ArgType;
use databend_common_expression::types::ArrayType;
Expand All @@ -27,10 +30,15 @@ use databend_common_expression::types::NullType;
use databend_common_expression::types::NullableType;
use databend_common_expression::types::NumberType;
use databend_common_expression::types::SimpleDomain;
use databend_common_expression::types::ValueType;
use databend_common_expression::vectorize_1_arg;
use databend_common_expression::vectorize_with_builder_2_arg;
use databend_common_expression::vectorize_with_builder_3_arg;
use databend_common_expression::vectorize_with_builder_4_arg;
use databend_common_expression::EvalContext;
use databend_common_expression::FunctionDomain;
use databend_common_expression::FunctionRegistry;
use databend_common_expression::ScalarRef;
use databend_common_expression::Value;
use databend_common_hashtable::StackHashSet;
use siphasher::sip128::Hasher128;
Expand Down Expand Up @@ -244,4 +252,144 @@ pub fn register(registry: &mut FunctionRegistry) {
.any(|(k, _)| k == key)
},
);

registry.register_3_arg_core::<EmptyMapType, GenericType<0>, GenericType<1>, MapType<GenericType<0>, GenericType<1>>, _, _>(
b41sh marked this conversation as resolved.
Show resolved Hide resolved
"map_insert",
|_, _, insert_key_domain, insert_value_domain| {
FunctionDomain::Domain(Some((
insert_key_domain.clone(),
insert_value_domain.clone(),
)))
},
|_, key, value, ctx| {
b41sh marked this conversation as resolved.
Show resolved Hide resolved
let key_type = &ctx.generics[0];
if !key_type.is_boolean()
&& !key_type.is_string()
&& !key_type.is_numeric()
&& !key_type.is_decimal()
&& !key_type.is_date_or_date_time()
{
ctx.set_error(0, format!("map keys can not be {}", key_type));
b41sh marked this conversation as resolved.
Show resolved Hide resolved
}

let mut b = ArrayType::create_builder(1, ctx.generics);
hanxuanliang marked this conversation as resolved.
Show resolved Hide resolved
b.put_item((key.into_scalar().unwrap(), value.into_scalar().unwrap()));
b.commit_row();
return Value::Scalar(MapType::build_scalar(b));
},
);

registry.register_3_arg_core::<NullableType<MapType<GenericType<0>, GenericType<1>>>, GenericType<0>, GenericType<1>, MapType<GenericType<0>, GenericType<1>>, _, _>(
"map_insert",
|_, source_domain, insert_key_domain, insert_value_domain| {
FunctionDomain::Domain(match source_domain.has_null {
true => Some((
insert_key_domain.clone(),
insert_value_domain.clone(),
b41sh marked this conversation as resolved.
Show resolved Hide resolved
)),
false => source_domain.value.as_ref().map(|v| {
let a = v.clone().unwrap();
(a.0.clone(), a.1.clone())
}),
})
},
vectorize_with_builder_3_arg::<
NullableType<MapType<GenericType<0>, GenericType<1>>>,
GenericType<0>,
GenericType<1>,
MapType<GenericType<0>, GenericType<1>>,
>(|source, key, value, output, ctx| {
match source {
Some(source) => {
output.append_column(&build_new_map(&source, key, value, ctx));
},
None => {
let mut b = ArrayType::create_builder(1, ctx.generics);
b.put_item((key.clone(), value.clone()));
b.commit_row();
output.append_column(&b.build());
},
};
}),
);

registry.register_passthrough_nullable_3_arg(
hanxuanliang marked this conversation as resolved.
Show resolved Hide resolved
"map_insert",
|_, domain1, key_domain, value_domain| {
b41sh marked this conversation as resolved.
Show resolved Hide resolved
FunctionDomain::Domain(match (domain1, key_domain, value_domain) {
(Some((key_domain, val_domain)), insert_key_domain, insert_value_domain) => Some((
key_domain.merge(insert_key_domain),
val_domain.merge(insert_value_domain),
)),
(None, _, _) => None,
})
},
vectorize_with_builder_3_arg::<
MapType<GenericType<0>, GenericType<1>>,
GenericType<0>,
GenericType<1>,
MapType<GenericType<0>, GenericType<1>>,
>(|source, key, value, output, ctx| {
// default behavior is to insert new key-value pair, and if the key already exists, update the value.
output.append_column(&build_new_map(&source, key, value, ctx));
}),
);

// grammar: map_insert(map, insert_key, insert_value, allow_update)
registry.register_passthrough_nullable_4_arg(
b41sh marked this conversation as resolved.
Show resolved Hide resolved
"map_insert",
|_, domain1, key_domain, value_domain, _| {
FunctionDomain::Domain(match (domain1, key_domain, value_domain) {
(Some((key_domain, val_domain)), insert_key_domain, insert_value_domain) => Some((
key_domain.merge(insert_key_domain),
val_domain.merge(insert_value_domain),
)),
(None, _, _) => None,
})
},
vectorize_with_builder_4_arg::<
MapType<GenericType<0>, GenericType<1>>,
GenericType<0>,
GenericType<1>,
BooleanType,
MapType<GenericType<0>, GenericType<1>>,
>(|source, key, value, allow_update, output, ctx| {
let duplicate_key = source.iter().any(|(k, _)| k == key);
// if duplicate_key is true and allow_update is false, return the original map
if duplicate_key && !allow_update {
let mut new_builder = ArrayType::create_builder(source.len(), ctx.generics);
source
.iter()
.for_each(|(k, v)| new_builder.put_item((k.clone(), v.clone())));
new_builder.commit_row();
output.append_column(&new_builder.build());
return;
}

output.append_column(&build_new_map(&source, key, value, ctx));
}),
);

fn build_new_map(
source: &KvColumn<GenericType<0>, GenericType<1>>,
insert_key: ScalarRef,
insert_value: ScalarRef,
ctx: &EvalContext,
) -> ArrayColumn<KvPair<GenericType<0>, GenericType<1>>> {
let duplicate_key = source.iter().any(|(k, _)| k == insert_key);
let mut new_map = ArrayType::create_builder(source.len() + 1, ctx.generics);
for (k, v) in source.iter() {
if k == insert_key {
new_map.put_item((k.clone(), insert_value.clone()));
continue;
}
new_map.put_item((k.clone(), v.clone()));
}
if !duplicate_key {
new_map.put_item((insert_key.clone(), insert_value.clone()));
}
new_map.commit_row();

new_map.build()
}
}
49 changes: 49 additions & 0 deletions src/query/functions/tests/it/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ fn test_map() {
test_map_size(file);
test_map_cat(file);
test_map_contains_key(file);
test_map_insert(file)
}

fn test_map_cat(file: &mut impl Write) {
Expand Down Expand Up @@ -278,3 +279,51 @@ fn test_map_size(file: &mut impl Write) {
&columns,
);
}

fn test_map_insert(file: &mut impl Write) {
run_ast(file, "map_insert({}, 'k1', 'v1')", &[]);
run_ast(file, "map_insert({'k1': 'v1'}, 'k2', 'v2')", &[]);
run_ast(
file,
"map_insert({'k1': 'v1', 'k2': 'v2'}, 'k1', 'v10', false)",
&[],
);
run_ast(
file,
"map_insert({'k1': 'v1', 'k2': 'v2'}, 'k1', 'v10', true)",
&[],
);

let columns = [
("a_col", StringType::from_data(vec!["a", "b", "c"])),
("b_col", StringType::from_data(vec!["d", "e", "f"])),
("c_col", StringType::from_data(vec!["x", "y", "z"])),
(
"d_col",
StringType::from_data_with_validity(vec!["v1", "v2", "v3"], vec![true, true, true]),
),
(
"e_col",
StringType::from_data_with_validity(vec!["v4", "v5", ""], vec![true, true, false]),
),
(
"f_col",
StringType::from_data_with_validity(vec!["v6", "", "v7"], vec![true, false, true]),
),
];
run_ast(
file,
"map_insert(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'k1', 'v10')",
&columns,
);
run_ast(
file,
"map_insert(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'v10', true)",
&columns,
);
run_ast(
file,
"map_insert(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'v10', false)",
&columns,
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -2449,6 +2449,9 @@ Functions overloads:
0 map_size(Map(Nothing)) :: UInt8
1 map_size(Map(T0, T1)) :: UInt64
2 map_size(Map(T0, T1) NULL) :: UInt64 NULL
0 map_insert(Map(Nothing), T0, T1) :: Map(Nothing, T1)
1 map_insert(Map(T0, T1), T0, T2) :: Map(T0, T2)
2 map_insert(Map(T0, T1) NULL, T0 NULL, T2 NULL) :: Map(T0, T2) NULL
0 map_values(Map(Nothing)) :: Array(Nothing)
1 map_values(Map(T0, T1)) :: Array(T1)
2 map_values(Map(T0, T1) NULL) :: Array(T1) NULL
Expand Down
Loading
Loading