diff --git a/gossip/src/weighted_shuffle.rs b/gossip/src/weighted_shuffle.rs index da7f619ab955c1..4064dc63481bf3 100644 --- a/gossip/src/weighted_shuffle.rs +++ b/gossip/src/weighted_shuffle.rs @@ -1,7 +1,7 @@ //! The `weighted_shuffle` module provides an iterator over shuffled weights. use { - num_traits::CheckedAdd, + num_traits::{CheckedAdd, ConstZero}, rand::{ distributions::uniform::{SampleUniform, UniformSampler}, Rng, @@ -42,30 +42,33 @@ pub struct WeightedShuffle { zeros: Vec, } +impl WeightedShuffle { + const ZERO: T = ::ZERO; +} + impl WeightedShuffle where - T: Copy + Default + PartialOrd + AddAssign + CheckedAdd, + T: Copy + ConstZero + PartialOrd + AddAssign + CheckedAdd, { /// If weights are negative or overflow the total sum /// they are treated as zero. pub fn new(name: &'static str, weights: &[T]) -> Self { - let zero = ::default(); let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len()); debug_assert!(size <= num_nodes); - let mut tree = vec![[zero; FANOUT]; size]; - let mut sum = zero; + let mut tree = vec![[Self::ZERO; FANOUT]; size]; + let mut sum = Self::ZERO; let mut zeros = Vec::default(); let mut num_negative: usize = 0; let mut num_overflow: usize = 0; for (k, &weight) in weights.iter().enumerate() { #[allow(clippy::neg_cmp_op_on_partial_ord)] // weight < zero does not work for NaNs. - if !(weight >= zero) { + if !(weight >= Self::ZERO) { zeros.push(k); num_negative += 1; continue; } - if weight == zero { + if weight == Self::ZERO { zeros.push(k); continue; } @@ -103,7 +106,7 @@ where impl WeightedShuffle where - T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub, + T: Copy + ConstZero + PartialOrd + AddAssign + SubAssign + Sub, { // Removes given weight at index k. fn remove(&mut self, k: usize, weight: T) { @@ -123,8 +126,7 @@ where // Returns smallest index such that sum of weights[..=k] > val, // along with its respective weight. fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) { - let zero = ::default(); - debug_assert!(val >= zero); + debug_assert!(val >= Self::ZERO); debug_assert!(val < self.weight); // Traverse the tree downwards from the root while maintaining the // weight of the subtree which contains the target leaf node. @@ -148,11 +150,10 @@ where } pub fn remove_index(&mut self, k: usize) { - let zero = ::default(); let index = self.num_nodes + k; // leaf node let offset = (index - 1) & BIT_MASK; let index = (index - 1) >> BIT_SHIFT; // parent node - if self.tree[index][offset] == zero { + if self.tree[index][offset] == Self::ZERO { self.remove_zero(k); } else { self.remove(k, self.tree[index][offset]); @@ -168,13 +169,12 @@ where impl WeightedShuffle where - T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, + T: Copy + ConstZero + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { // Equivalent to weighted_shuffle.shuffle(&mut rng).next() pub fn first(&self, rng: &mut R) -> Option { - let zero = ::default(); - if self.weight > zero { - let sample = ::Sampler::sample_single(zero, self.weight, rng); + if self.weight > Self::ZERO { + let sample = ::Sampler::sample_single(Self::ZERO, self.weight, rng); let (index, _weight) = WeightedShuffle::search(self, sample); return Some(index); } @@ -188,13 +188,13 @@ where impl<'a, T: 'a> WeightedShuffle where - T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, + T: Copy + ConstZero + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub, { pub fn shuffle(mut self, rng: &'a mut R) -> impl Iterator + 'a { std::iter::from_fn(move || { - let zero = ::default(); - if self.weight > zero { - let sample = ::Sampler::sample_single(zero, self.weight, rng); + if self.weight > Self::ZERO { + let sample = + ::Sampler::sample_single(Self::ZERO, self.weight, rng); let (index, weight) = WeightedShuffle::search(&self, sample); self.remove(index, weight); return Some(index);