Skip to content

Commit

Permalink
removes over allocation of tree nodes in WeightedShuffle
Browse files Browse the repository at this point in the history
Nodes without children are never accessed and don't need to be
allocated.
  • Loading branch information
behzadnouri committed Dec 15, 2024
1 parent 6505069 commit 92fc812
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ const BIT_MASK: usize = FANOUT - 1;
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
// Number of "internal" nodes of the tree.
num_nodes: usize,
// Underlying array implementing the tree.
// Nodes without children are never accessed and don't need to be
// allocated, so tree.len() < num_nodes.
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
tree: Vec<[T; FANOUT - 1]>,
// Current sum of all weights, excluding already sampled ones.
Expand All @@ -43,11 +47,13 @@ where
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())];
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[zero; FANOUT - 1]; size];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
let mut num_negative: usize = 0;
let mut num_overflow: usize = 0;
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
Expand All @@ -70,7 +76,7 @@ where
};
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = tree.len() + k; // leaf node
let mut index = num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
Expand All @@ -86,6 +92,7 @@ where
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
}
Self {
num_nodes,
tree,
weight: sum,
zeros,
Expand All @@ -103,7 +110,7 @@ where
self.weight -= weight;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = self.tree.len() + k; // leaf node
let mut index = self.num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
Expand All @@ -127,7 +134,7 @@ where
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
if val < node {
// Traverse to the j+1 subtree of self.tree[index].
// Traverse to the j'th subtree of self.tree[index].
weight = node;
index = (index << BIT_SHIFT) + j + 1;
continue 'outer;
Expand All @@ -140,14 +147,14 @@ where
// Traverse to the right-most subtree of self.tree[index].
index = (index << BIT_SHIFT) + FANOUT;
}
(index - self.tree.len(), weight)
(index - self.num_nodes, weight)
}

pub fn remove_index(&mut self, k: usize) {
// Traverse the tree from the leaf node upwards to the root, while
// maintaining the sum of weights of subtrees *not* containing the leaf
// node.
let mut index = self.tree.len() + k; // leaf node
let mut index = self.num_nodes + k; // leaf node
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & BIT_MASK;
Expand Down Expand Up @@ -223,16 +230,18 @@ where
}
}

// Maps number of items to the "internal" size of the tree
// Maps number of items to the number of "internal" nodes of the tree
// which "implicitly" holds those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let mut size = if count == 1 { 1 } else { 0 };
let mut nodes = 1;
while nodes < count {
// Nodes without children are never accessed and don't need to be
// allocated, so the tree size is the second smaller number.
fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_size:*/ usize) {
let mut size: usize = 0;
let mut nodes: usize = 1;
while nodes * FANOUT < count {
size += nodes;
nodes *= FANOUT;
}
size
(size + nodes, size + (count + FANOUT - 1) / FANOUT)
}

#[cfg(test)]
Expand Down Expand Up @@ -278,19 +287,25 @@ mod tests {
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
fn test_get_num_nodes_and_tree_size() {
assert_eq!(get_num_nodes_and_tree_size(0), (1, 0));
for count in 1..=16 {
assert_eq!(get_tree_size(count), 1);
assert_eq!(get_num_nodes_and_tree_size(count), (1, 1));
}
let num_nodes = 1 + 16;
for count in 17..=256 {
assert_eq!(get_tree_size(count), 1 + 16);
let tree_size = 1 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
let num_nodes = 1 + 16 + 16 * 16;
for count in 257..=4096 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16);
let tree_size = 1 + 16 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
let num_nodes = 1 + 16 + 16 * 16 + 16 * 16 * 16;
for count in 4097..=65536 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16);
let tree_size = 1 + 16 + 16 * 16 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
}

Expand Down

0 comments on commit 92fc812

Please sign in to comment.