Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

removes over-allocation of tree nodes in WeightedShuffle #4126

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions gossip/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ name = "crds_gossip_pull"
[[bench]]
name = "crds_shards"

[[bench]]
name = "weighted_shuffle"

[[bin]]
name = "solana-gossip"
path = "src/main.rs"
Expand Down
18 changes: 18 additions & 0 deletions gossip/benches/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,24 @@ fn bench_weighted_shuffle_new(bencher: &mut Bencher) {

#[bench]
fn bench_weighted_shuffle_shuffle(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
let weighted_shuffle = WeightedShuffle::new("", &weights);
bencher.iter(|| {
rng.fill(&mut seed[..]);
let mut rng = ChaChaRng::from_seed(seed);
weighted_shuffle
.clone()
.shuffle(&mut rng)
.for_each(|index| {
std::hint::black_box(index);
});
});
}

#[bench]
fn bench_weighted_shuffle_collect(bencher: &mut Bencher) {
let mut seed = [0u8; 32];
let mut rng = rand::thread_rng();
let weights = make_weights(&mut rng);
Expand Down
55 changes: 35 additions & 20 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ const BIT_MASK: usize = FANOUT - 1;
/// non-zero weighted indices.
#[derive(Clone)]
pub struct WeightedShuffle<T> {
// Number of "internal" nodes of the tree.
num_nodes: usize,
// Underlying array implementing the tree.
// Nodes without children are never accessed and don't need to be
// allocated, so tree.len() < num_nodes.
// tree[i][j] is the sum of all weights in the j'th sub-tree of node i.
tree: Vec<[T; FANOUT - 1]>,
// Current sum of all weights, excluding already sampled ones.
Expand All @@ -43,11 +47,13 @@ where
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let mut tree = vec![[zero; FANOUT - 1]; get_tree_size(weights.len())];
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[zero; FANOUT - 1]; size];
let mut sum = zero;
let mut zeros = Vec::default();
let mut num_negative = 0;
let mut num_overflow = 0;
let mut num_negative: usize = 0;
let mut num_overflow: usize = 0;
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
Expand All @@ -70,7 +76,7 @@ where
};
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = tree.len() + k; // leaf node
let mut index = num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
Expand All @@ -86,6 +92,7 @@ where
datapoint_error!("weighted-shuffle-overflow", (name, num_overflow, i64));
}
Self {
num_nodes,
tree,
weight: sum,
zeros,
Expand All @@ -103,7 +110,7 @@ where
self.weight -= weight;
// Traverse the tree from the leaf node upwards to the root,
// updating the sub-tree sums along the way.
let mut index = self.tree.len() + k; // leaf node
let mut index = self.num_nodes + k; // leaf node
while index != 0 {
let offset = index & BIT_MASK;
index = (index - 1) >> BIT_SHIFT; // parent node
Expand All @@ -127,7 +134,7 @@ where
'outer: while index < self.tree.len() {
for (j, &node) in self.tree[index].iter().enumerate() {
if val < node {
// Traverse to the j+1 subtree of self.tree[index].
// Traverse to the j'th subtree of self.tree[index].
weight = node;
index = (index << BIT_SHIFT) + j + 1;
continue 'outer;
Expand All @@ -140,14 +147,14 @@ where
// Traverse to the right-most subtree of self.tree[index].
index = (index << BIT_SHIFT) + FANOUT;
}
(index - self.tree.len(), weight)
(index - self.num_nodes, weight)
}

pub fn remove_index(&mut self, k: usize) {
// Traverse the tree from the leaf node upwards to the root, while
// maintaining the sum of weights of subtrees *not* containing the leaf
// node.
let mut index = self.tree.len() + k; // leaf node
let mut index = self.num_nodes + k; // leaf node
let mut weight = <T as Default>::default(); // zero
while index != 0 {
let offset = index & BIT_MASK;
Expand Down Expand Up @@ -223,16 +230,18 @@ where
}
}

// Maps number of items to the "internal" size of the tree
// Maps number of items to the number of "internal" nodes of the tree
// which "implicitly" holds those items on the leaves.
fn get_tree_size(count: usize) -> usize {
let mut size = if count == 1 { 1 } else { 0 };
let mut nodes = 1;
while nodes < count {
// Nodes without children are never accessed and don't need to be
// allocated, so the tree size is the second smaller number.
fn get_num_nodes_and_tree_size(count: usize) -> (/*num_nodes:*/ usize, /*tree_size:*/ usize) {
let mut size: usize = 0;
let mut nodes: usize = 1;
while nodes * FANOUT < count {
size += nodes;
nodes *= FANOUT;
}
size
(size + nodes, size + (count + FANOUT - 1) / FANOUT)
}

#[cfg(test)]
Expand Down Expand Up @@ -278,19 +287,25 @@ mod tests {
}

#[test]
fn test_get_tree_size() {
assert_eq!(get_tree_size(0), 0);
fn test_get_num_nodes_and_tree_size() {
assert_eq!(get_num_nodes_and_tree_size(0), (1, 0));
for count in 1..=16 {
assert_eq!(get_tree_size(count), 1);
assert_eq!(get_num_nodes_and_tree_size(count), (1, 1));
}
let num_nodes = 1 + 16;
for count in 17..=256 {
assert_eq!(get_tree_size(count), 1 + 16);
let tree_size = 1 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
let num_nodes = 1 + 16 + 16 * 16;
for count in 257..=4096 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16);
let tree_size = 1 + 16 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
let num_nodes = 1 + 16 + 16 * 16 + 16 * 16 * 16;
for count in 4097..=65536 {
assert_eq!(get_tree_size(count), 1 + 16 + 16 * 16 + 16 * 16 * 16);
let tree_size = 1 + 16 + 16 * 16 + (count + 15) / 16;
assert_eq!(get_num_nodes_and_tree_size(count), (num_nodes, tree_size));
}
}

Expand Down
Loading