diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index 9d607001d95a39..da7f619ab955c1 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -12,7 +12,10 @@ use { // Each internal tree node has FANOUT many child nodes with indices: // (index << BIT_SHIFT) + 1 ..= (index << BIT_SHIFT) + FANOUT // Conversely, for each node, the parent node is obtained by: -// (index - 1) >> BIT_SHIFT +// parent: (index - 1) >> BIT_SHIFT +// and the subtree weight is stored at +// offset: (index - 1) & BIT_MASK +// of its parent node. const BIT_SHIFT: usize = 4; const FANOUT: usize = 1 << BIT_SHIFT; const BIT_MASK: usize = FANOUT - 1; @@ -32,7 +35,7 @@ pub struct WeightedShuffle { // Nodes without children are never accessed and don't need to be // allocated, so tree.len() < num_nodes. // tree[i][j] is the sum of all weights in the j'th sub-tree of node i. - tree: Vec<[T; FANOUT - 1]>, + tree: Vec<[T; FANOUT]>, // Current sum of all weights, excluding already sampled ones. weight: T, // Indices of zero weighted entries. @@ -49,7 +52,7 @@ where let zero = ::default(); let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len()); debug_assert!(size <= num_nodes); - let mut tree = vec![[zero; FANOUT - 1]; size]; + let mut tree = vec![[zero; FANOUT]; size]; let mut sum = zero; let mut zeros = Vec::default(); let mut num_negative: usize = 0; @@ -78,11 +81,9 @@ where // updating the sub-tree sums along the way. let mut index = num_nodes + k; // leaf node while index != 0 { - let offset = index & BIT_MASK; + let offset = (index - 1) & BIT_MASK; index = (index - 1) >> BIT_SHIFT; // parent node - if offset > 0 { - tree[index][offset - 1] += weight; - } + tree[index][offset] += weight; } } if num_negative > 0 { @@ -112,12 +113,10 @@ where // updating the sub-tree sums along the way. let mut index = self.num_nodes + k; // leaf node while index != 0 { - let offset = index & BIT_MASK; + let offset = (index - 1) & BIT_MASK; index = (index - 1) >> BIT_SHIFT; // parent node - if offset > 0 { - debug_assert!(self.tree[index][offset - 1] >= weight); - self.tree[index][offset - 1] -= weight; - } + debug_assert!(self.tree[index][offset] >= weight); + self.tree[index][offset] -= weight; } } @@ -131,52 +130,32 @@ where // weight of the subtree which contains the target leaf node. let mut index = 0; // root let mut weight = self.weight; - 'outer: while index < self.tree.len() { - for (j, &node) in self.tree[index].iter().enumerate() { + while let Some(tree) = self.tree.get(index) { + for (j, &node) in tree.iter().enumerate() { if val < node { // Traverse to the j'th subtree of self.tree[index]. weight = node; index = (index << BIT_SHIFT) + j + 1; - continue 'outer; + break; } else { debug_assert!(weight >= node); weight -= node; val -= node; } } - // Traverse to the right-most subtree of self.tree[index]. - index = (index << BIT_SHIFT) + FANOUT; } (index - self.num_nodes, weight) } pub fn remove_index(&mut self, k: usize) { - // Traverse the tree from the leaf node upwards to the root, while - // maintaining the sum of weights of subtrees *not* containing the leaf - // node. - let mut index = self.num_nodes + k; // leaf node - let mut weight = ::default(); // zero - while index != 0 { - let offset = index & BIT_MASK; - index = (index - 1) >> BIT_SHIFT; // parent node - if offset > 0 { - if self.tree[index][offset - 1] != weight { - self.remove(k, self.tree[index][offset - 1] - weight); - } else { - self.remove_zero(k); - } - return; - } - // The leaf node is in the right-most subtree of self.tree[index]. - for &node in &self.tree[index] { - weight += node; - } - } - // The leaf node is the right-most node of the whole tree. - if self.weight != weight { - self.remove(k, self.weight - weight); - } else { + let zero = ::default(); + let index = self.num_nodes + k; // leaf node + let offset = (index - 1) & BIT_MASK; + let index = (index - 1) >> BIT_SHIFT; // parent node + if self.tree[index][offset] == zero { self.remove_zero(k); + } else { + self.remove(k, self.tree[index][offset]); } }