Skip to content

Commit

Permalink
removes offset > 0 branches from WeightedShuffle implementation
Browse files Browse the repository at this point in the history
Changing the offset definition to:

    let offset = (index - 1) & BIT_MASK

will avoid all the

    if offset > 0 { /* ... */ }

branches and massively simplifies the implementation.
It also improves performance in the benchmarks:

On master:

    test bench_weighted_shuffle_collect ... bench:     150,789.57 ns/iter (+/- 2,123.07)
    test bench_weighted_shuffle_new     ... bench:      33,039.23 ns/iter (+/- 1,696.02)
    test bench_weighted_shuffle_shuffle ... bench:     140,193.88 ns/iter (+/- 2,739.95)

On the branch:

    test bench_weighted_shuffle_collect ... bench:     145,451.45 ns/iter (+/- 1,307.71)
    test bench_weighted_shuffle_new     ... bench:      30,309.99 ns/iter (+/- 3,577.51)
    test bench_weighted_shuffle_shuffle ... bench:     136,634.04 ns/iter (+/- 3,124.60)
  • Loading branch information
behzadnouri committed Dec 16, 2024
1 parent 47ad9a1 commit b25948f
Showing 1 changed file with 21 additions and 42 deletions.
63 changes: 21 additions & 42 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ use {
// Each internal tree node has FANOUT many child nodes with indices:
// (index << BIT_SHIFT) + 1 ..= (index << BIT_SHIFT) + FANOUT
// Conversely, for each node, the parent node is obtained by:
// (index - 1) >> BIT_SHIFT
// parent: (index - 1) >> BIT_SHIFT
// and the subtree weight is stored at
// offset: (index - 1) & BIT_MASK
// of its parent node.
const BIT_SHIFT: usize = 4;
const FANOUT: usize = 1 << BIT_SHIFT;
const BIT_MASK: usize = FANOUT - 1;
Expand All @@ -32,7 +35,7 @@ pub struct WeightedShuffle<T> {
// Nodes without children are never accessed and don't need to be
// allocated, so tree.len() < num_nodes.
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
tree: Vec<[T; FANOUT - 1]>,
tree: Vec<[T; FANOUT]>,
// Current sum of all weights, excluding already sampled ones.
weight: T,
// Indices of zero weighted entries.
Expand All @@ -49,7 +52,7 @@ where
let zero = <T as Default>::default();
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[zero; FANOUT - 1]; size];
let mut tree = vec![[zero; FANOUT]; size];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative: usize = 0;
Expand Down Expand Up @@ -78,11 +81,9 @@ where
// updating the sub-tree sums along the way.
let mut index = num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
let offset = (index - 1) & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
tree[index][offset - 1] += weight;
}
tree[index][offset] += weight;
}
}
if num_negative > 0 {
Expand Down Expand Up @@ -112,12 +113,10 @@ where
// updating the sub-tree sums along the way.
let mut index = self.num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
let offset = (index - 1) & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
debug_assert!(self.tree[index][offset - 1] >= weight);
self.tree[index][offset - 1] -= weight;
}
debug_assert!(self.tree[index][offset] >= weight);
self.tree[index][offset] -= weight;
}
}

Expand All @@ -131,52 +130,32 @@ where
// weight of the subtree which contains the target leaf node.
let mut index = 0; // root
let mut weight = self.weight;
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
while let Some(tree) = self.tree.get(index) {
for (j, &node) in tree.iter().enumerate() {
if val < node {
// Traverse to the j'th subtree of self.tree[index].
weight = node;
index = (index << BIT_SHIFT) + j + 1;
continue 'outer;
break;
} else {
debug_assert!(weight >= node);
weight -= node;
val -= node;
}
}
// Traverse to the right-most subtree of self.tree[index].
index = (index << BIT_SHIFT) + FANOUT;
}
(index - self.num_nodes, weight)
}

pub fn remove_index(&mut self, k: usize) {
// Traverse the tree from the leaf node upwards to the root, while
// maintaining the sum of weights of subtrees *not* containing the leaf
// node.
let mut index = self.num_nodes + k; // leaf node
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
if offset > 0 {
if self.tree[index][offset - 1] != weight {
self.remove(k, self.tree[index][offset - 1] - weight);
} else {
self.remove_zero(k);
}
return;
}
// The leaf node is in the right-most subtree of self.tree[index].
for &node in &self.tree[index] {
weight += node;
}
}
// The leaf node is the right-most node of the whole tree.
if self.weight != weight {
self.remove(k, self.weight - weight);
} else {
let zero = <T as Default>::default();
let index = self.num_nodes + k; // leaf node
let offset = (index - 1) & BIT_MASK;
let index = (index - 1) >> BIT_SHIFT; // parent node
if self.tree[index][offset] == zero {
self.remove_zero(k);
} else {
self.remove(k, self.tree[index][offset]);
}
}

Expand Down

0 comments on commit b25948f

Please sign in to comment.