diff --git a/src/query/functions/src/scalars/map.rs b/src/query/functions/src/scalars/map.rs index 1c0c723d18ca..27ffaa153f23 100644 --- a/src/query/functions/src/scalars/map.rs +++ b/src/query/functions/src/scalars/map.rs @@ -268,21 +268,35 @@ pub fn register(registry: &mut FunctionRegistry) { return None; } - let inner_key_type = match args_type.get(0) { + let inner_key_type = match args_type.first() { Some(DataType::Map(m)) => m.as_tuple().map(|tuple| &tuple[0]), _ => None, }; - let key_match = args_type[1..].iter().all(|arg_type| match inner_key_type { - Some(key_type) => arg_type == key_type, - None => matches!( - arg_type, - DataType::String - | DataType::Number(_) - | DataType::Decimal(_) - | DataType::Date - | DataType::Timestamp - ), - }); + let key_match = match args_type.len() { + 2 => args_type.get(1).map_or(false, |t| match t { + DataType::Array(_) => inner_key_type.map_or(false, |key_type| { + t.as_array() + .map_or(false, |array| array.as_ref() == key_type) + }), + DataType::EmptyArray => false, + _ => false, + }), + _ => args_type.iter().skip(1).all(|arg_type| { + inner_key_type.map_or_else( + || { + matches!( + arg_type, + DataType::String + | DataType::Number(_) + | DataType::Decimal(_) + | DataType::Date + | DataType::Timestamp + ) + }, + |key_type| arg_type == key_type, + ) + }), + }; if !key_match { return None; } @@ -312,9 +326,18 @@ pub fn register(registry: &mut FunctionRegistry) { }; let source_map = match &args[0] { - ValueRef::Scalar(ScalarRef::Map(s)) => { - KvPair::, GenericType<1>>::try_downcast_column(s).unwrap() - } + ValueRef::Scalar(s) => match s { + ScalarRef::Map(cols) => { + KvPair::, GenericType<1>>::try_downcast_column(cols).unwrap() + } + ScalarRef::EmptyMap => { + KvPair::, GenericType<1>>::try_downcast_column( + &Column::EmptyMap { len: 0 }, + ) + .unwrap() + } + _ => unreachable!(), + }, ValueRef::Column(Column::Map(c)) => { KvPair::, GenericType<1>>::try_downcast_column(&c.values).unwrap() } @@ -326,11 +349,17 @@ pub fn register(registry: &mut FunctionRegistry) { args.len() - 1, source_data_type.as_map().unwrap().as_tuple().unwrap(), ); - for key_arg in args[1..].iter() { - if let Some((k, v)) = source_map + let select_keys = match &args[1] { + ValueRef::Scalar(ScalarRef::Array(arr)) if args.len() == 2 => { + arr.iter().collect::>() + } + _ => args[1..] .iter() - .find(|(k, _)| k == key_arg.as_scalar().unwrap()) - { + .map(|arg| arg.as_scalar().unwrap().clone()) + .collect::>(), + }; + for key_arg in select_keys { + if let Some((k, v)) = source_map.iter().find(|(k, _)| k == &key_arg) { builder.put_item((k.clone(), v.clone())); } } diff --git a/src/query/functions/tests/it/scalars/map.rs b/src/query/functions/tests/it/scalars/map.rs index 20408c9d3450..0d152a37b431 100644 --- a/src/query/functions/tests/it/scalars/map.rs +++ b/src/query/functions/tests/it/scalars/map.rs @@ -282,6 +282,7 @@ fn test_map_size(file: &mut impl Write) { fn test_map_pick(file: &mut impl Write) { run_ast(file, "map_pick({'a':1,'b':2,'c':3}, 'a', 'b')", &[]); + run_ast(file, "map_pick({'a':1,'b':2,'c':3}, ['a', 'b'])", &[]); let columns = [ ("a_col", StringType::from_data(vec!["a", "b", "c"])), diff --git a/src/query/functions/tests/it/scalars/testdata/map.txt b/src/query/functions/tests/it/scalars/testdata/map.txt index 3c782159a0c5..2d69d39a0905 100644 --- a/src/query/functions/tests/it/scalars/testdata/map.txt +++ b/src/query/functions/tests/it/scalars/testdata/map.txt @@ -626,6 +626,15 @@ output domain : {[{"a"..="b"}], [{1..=2}]} output : {'a':1, 'b':2} +ast : map_pick({'a':1,'b':2,'c':3}, ['a', 'b']) +raw expr : map_pick(map(array('a', 'b', 'c'), array(1, 2, 3)), array('a', 'b')) +checked expr : map_pick(map(array("a", "b", "c"), array(1_u8, 2_u8, 3_u8)), array("a", "b")) +optimized expr : {"a":1_u8, "b":2_u8} +output type : Map(String, UInt8) +output domain : {[{"a"..="b"}], [{1..=2}]} +output : {'a':1, 'b':2} + + ast : map_pick(map([a_col, b_col, c_col], [d_col, e_col, f_col]), 'a', 'b') raw expr : map_pick(map(array(a_col::String, b_col::String, c_col::String), array(d_col::String NULL, e_col::String NULL, f_col::String NULL)), 'a', 'b') checked expr : map_pick(map(array(a_col, b_col, c_col), array(d_col, e_col, f_col)), "a", "b")