Skip to content

Commit

Permalink
Merge branch 'main' into ir_fix_ffi
Browse files Browse the repository at this point in the history
  • Loading branch information
longbinlai authored Nov 14, 2023
2 parents 9c0b81a + d0b1cf1 commit 0d04305
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 1 deletion.
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
//
//! Copyright 2022 Alibaba Group Holding Limited.
//!
//! Licensed under the Apache License, Version 2.0 (the "License");
//! you may not use this file except in compliance with the License.
//! You may obtain a copy of the License at
//!
//! http://www.apache.org/licenses/LICENSE-2.0
//!
//! Unless required by applicable law or agreed to in writing, software
//! distributed under the License is distributed on an "AS IS" BASIS,
//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//! See the License for the specific language governing permissions and
//! limitations under the License.
use std::collections::HashMap;
use std::fmt::Debug;

use crate::api::function::FnResult;
use crate::api::Key;
Expand All @@ -14,4 +30,12 @@ pub trait FoldByKey<K: Data + Key, V: Data> {
I: Data,
F: FnMut(I, V) -> FnResult<I> + Send + 'static,
B: Fn() -> F + Send + 'static;

fn fold_partition_by_key<I, B, F>(
self, init: I, builder: B,
) -> Result<SingleItem<HashMap<K, I>>, BuildJobError>
where
I: Clone + Send + Sync + Debug + 'static,
F: FnMut(I, V) -> FnResult<I> + Send + 'static,
B: Fn() -> F + Send + 'static;
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,20 @@
//
//! Copyright 2022 Alibaba Group Holding Limited.
//!
//! Licensed under the Apache License, Version 2.0 (the "License");
//! you may not use this file except in compliance with the License.
//! You may obtain a copy of the License at
//!
//! http://www.apache.org/licenses/LICENSE-2.0
//!
//! Unless required by applicable law or agreed to in writing, software
//! distributed under the License is distributed on an "AS IS" BASIS,
//! WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
//! See the License for the specific language governing permissions and
//! limitations under the License.
use std::collections::HashMap;
use std::fmt::Debug;

use ahash::AHashMap;

Expand Down Expand Up @@ -58,4 +74,51 @@ impl<K: Data + Key, V: Data> FoldByKey<K, V> for Stream<Pair<K, V>> {
}
})
}

fn fold_partition_by_key<I, B, F>(
self, init: I, builder: B,
) -> Result<SingleItem<HashMap<K, I>>, BuildJobError>
where
I: Clone + Send + Sync + Debug + 'static,
F: FnMut(I, V) -> FnResult<I> + Send + 'static,
B: Fn() -> F + Send + 'static,
{
let s = self
.partition_by_key()
.unary("fold_by_key", |info| {
let mut ttm = TidyTagMap::new(info.scope_level);
move |input, output| {
let result = input.for_each_batch(|dataset| {
let group = ttm.get_mut_or_else(&dataset.tag, AHashMap::<K, (Option<I>, F)>::new);
for item in dataset.drain() {
let (k, v) = item.take();
let (seed, func) = group
.entry(k)
.or_insert_with(|| (Some(init.clone()), builder()));
let mut s = seed.take().expect("fold seed lost");
s = (*func)(s, v)?;
seed.replace(s);
}

if dataset.is_last() {
let group = std::mem::replace(group, Default::default());
let mut map = HashMap::new();
// todo: reuse group map;
for (k, v) in group {
map.insert(k, v.0.unwrap_or_else(|| init.clone()));
}
output
.new_session(&dataset.tag)?
.give(Single(map))?;
}

Ok(())
});

ttm.retain(|_, map| !map.is_empty());
result
}
})?;
Ok(SingleItem::new(s))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,50 @@ fn fold_by_key_test() {
assert_eq!(*cnt_3, (0..num * 2).filter(|x| x % 4 == 3).count() as u32);
}

#[test]
fn fold_partition_by_key_test() {
let mut conf = JobConf::new("fold_partition_by_key");
conf.set_workers(2);
let num = 1000u32;
let mut result = pegasus::run(conf, || {
let index = pegasus::get_current_worker().index;
let src = index * num..(index + 1) * num;
move |input, output| {
input
.input_from(src)?
.key_by(|x| Ok((x % 4, x)))?
.fold_partition_by_key(0u32, || |a, _| Ok(a + 1))?
.sink_into(output)
}
})
.expect("submit job failure:");
let groups_1 = result.next().unwrap().unwrap();
let groups_2 = result.next().unwrap().unwrap();
println!("groups 1: {:?}\n groups 2: {:?}", groups_1, groups_2);

assert_eq!(groups_1.len(), 2);
if let Some(cnt_0) = groups_1.get(&0) {
assert_eq!(*cnt_0, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
} else if let Some(cnt_0) = groups_2.get(&0) {
assert_eq!(*cnt_0, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
}
if let Some(cnt_1) = groups_1.get(&1) {
assert_eq!(*cnt_1, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
} else if let Some(cnt_1) = groups_2.get(&1) {
assert_eq!(*cnt_1, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
}
if let Some(cnt_2) = groups_1.get(&2) {
assert_eq!(*cnt_2, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
} else if let Some(cnt_2) = groups_2.get(&2) {
assert_eq!(*cnt_2, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
}
if let Some(cnt_3) = groups_1.get(&3) {
assert_eq!(*cnt_3, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
} else if let Some(cnt_3) = groups_2.get(&3) {
assert_eq!(*cnt_3, (0..num * 2).filter(|x| x % 4 == 0).count() as u32);
}
}

#[test]
fn fold_partition_test() {
let mut conf = JobConf::new("fold_partition_test");
Expand Down
2 changes: 1 addition & 1 deletion interactive_engine/executor/ir/runtime/src/assembly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ impl<P: PartitionInfo, C: ClusterInfo> IRJobAssembly<P, C> {
let group_map = group.gen_group_map()?;
stream = stream
.key_by(move |record| group_key.get_kv(record))?
.fold_by_key(group_accum, || {
.fold_partition_by_key(group_accum, || {
|mut accumulator, next| {
accumulator.accum(next)?;
Ok(accumulator)
Expand Down

0 comments on commit 0d04305

Please sign in to comment.