Skip to content

Commit

Permalink
feat(functions): add more args_type check
Browse files Browse the repository at this point in the history
  • Loading branch information
hanxuanliang committed May 24, 2024
1 parent 9b16d06 commit 979fcd6
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions src/query/functions/src/scalars/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,29 @@ pub fn register(registry: &mut FunctionRegistry) {
return None;
}

if !matches!(args_type[0], DataType::Map(_) | DataType::EmptyMap) {
return None;
}

let inner_key_type = match args_type.get(0) {
Some(DataType::Map(m)) => m.as_tuple().map(|tuple| &tuple[0]),
_ => None,
};
let key_match = args_type[1..].iter().all(|arg_type| match inner_key_type {
Some(key_type) => arg_type == key_type,
None => matches!(
arg_type,
DataType::String
| DataType::Number(_)
| DataType::Decimal(_)
| DataType::Date
| DataType::Timestamp
),
});
if !key_match {
return None;
}

Some(Arc::new(Function {
signature: FunctionSignature {
name: "map_pick".to_string(),
Expand Down Expand Up @@ -301,7 +324,7 @@ pub fn register(registry: &mut FunctionRegistry) {
let mut builder: ArrayColumnBuilder<KvPair<GenericType<0>, GenericType<1>>> =
ArrayType::create_builder(
args.len() - 1,
&source_data_type.as_map().unwrap().as_tuple().unwrap(),
source_data_type.as_map().unwrap().as_tuple().unwrap(),
);
for key_arg in args[1..].iter() {
if let Some((k, v)) = source_map
Expand Down Expand Up @@ -336,29 +359,4 @@ pub fn register(registry: &mut FunctionRegistry) {
|_, _, _| FunctionDomain::Full,
|_, _, _| Value::Scalar(()),
);

registry.register_passthrough_nullable_2_arg(
"map_pick",
|_, domain1, domain2| {
FunctionDomain::Domain(match (domain1, domain2) {
(Some(domain1), _) => Some(domain1).cloned(),
(None, _) => None,
})
},
vectorize_with_builder_2_arg::<
MapType<GenericType<0>, GenericType<1>>,
ArrayType<GenericType<0>>,
MapType<GenericType<0>, GenericType<1>>,
>(|map, keys, output_map, ctx| {
let mut picked_map_builder = ArrayType::create_builder(keys.len(), ctx.generics);
for key in keys.iter() {
if let Some((k, v)) = map.iter().find(|(k, _)| k == &key) {
picked_map_builder.put_item((k.clone(), v.clone()));
}
}

picked_map_builder.commit_row();
output_map.append_column(&picked_map_builder.build());
}),
);
}

0 comments on commit 979fcd6

Please sign in to comment.