Skip to content

Commit

Permalink
uses ConstZero instead of Default for zero of type T
Browse files Browse the repository at this point in the history
  • Loading branch information
behzadnouri committed Dec 18, 2024
1 parent 766901a commit 163a439
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions gossip/src/weighted_shuffle.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! The `weighted_shuffle` module provides an iterator over shuffled weights.
use {
num_traits::CheckedAdd,
num_traits::{CheckedAdd, ConstZero},
rand::{
distributions::uniform::{SampleUniform, UniformSampler},
Rng,
Expand Down Expand Up @@ -42,30 +42,33 @@ pub struct WeightedShuffle<T> {
zeros: Vec<usize>,
}

impl<T: ConstZero> WeightedShuffle<T> {
const ZERO: T = <T as ConstZero>::ZERO;
}

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + CheckedAdd,
T: Copy + ConstZero + PartialOrd + AddAssign + CheckedAdd,
{
/// If weights are negative or overflow the total sum
/// they are treated as zero.
pub fn new(name: &'static str, weights: &[T]) -> Self {
let zero = <T as Default>::default();
let (num_nodes, size) = get_num_nodes_and_tree_size(weights.len());
debug_assert!(size <= num_nodes);
let mut tree = vec![[zero; FANOUT]; size];
let mut sum = zero;
let mut tree = vec![[Self::ZERO; FANOUT]; size];
let mut sum = Self::ZERO;
let mut zeros = Vec::default();
let mut num_negative: usize = 0;
let mut num_overflow: usize = 0;
for (k, &weight) in weights.iter().enumerate() {
#[allow(clippy::neg_cmp_op_on_partial_ord)]
// weight < zero does not work for NaNs.
if !(weight >= zero) {
if !(weight >= Self::ZERO) {
zeros.push(k);
num_negative += 1;
continue;
}
if weight == zero {
if weight == Self::ZERO {
zeros.push(k);
continue;
}
Expand Down Expand Up @@ -103,7 +106,7 @@ where

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
T: Copy + ConstZero + PartialOrd + AddAssign + SubAssign + Sub<Output = T>,
{
// Removes given weight at index k.
fn remove(&mut self, k: usize, weight: T) {
Expand All @@ -123,8 +126,7 @@ where
// Returns smallest index such that sum of weights[..=k] > val,
// along with its respective weight.
fn search(&self, mut val: T) -> (/*index:*/ usize, /*weight:*/ T) {
let zero = <T as Default>::default();
debug_assert!(val >= zero);
debug_assert!(val >= Self::ZERO);
debug_assert!(val < self.weight);
// Traverse the tree downwards from the root while maintaining the
// weight of the subtree which contains the target leaf node.
Expand All @@ -148,11 +150,10 @@ where
}

pub fn remove_index(&mut self, k: usize) {
let zero = <T as Default>::default();
let index = self.num_nodes + k; // leaf node
let offset = (index - 1) & BIT_MASK;
let index = (index - 1) >> BIT_SHIFT; // parent node
if self.tree[index][offset] == zero {
if self.tree[index][offset] == Self::ZERO {
self.remove_zero(k);
} else {
self.remove(k, self.tree[index][offset]);
Expand All @@ -168,13 +169,12 @@ where

impl<T> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
T: Copy + ConstZero + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
{
// Equivalent to weighted_shuffle.shuffle(&mut rng).next()
pub fn first<R: Rng>(&self, rng: &mut R) -> Option<usize> {
let zero = <T as Default>::default();
if self.weight > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.weight, rng);
if self.weight > Self::ZERO {
let sample = <T as SampleUniform>::Sampler::sample_single(Self::ZERO, self.weight, rng);
let (index, _weight) = WeightedShuffle::search(self, sample);
return Some(index);
}
Expand All @@ -188,13 +188,13 @@ where

impl<'a, T: 'a> WeightedShuffle<T>
where
T: Copy + Default + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
T: Copy + ConstZero + PartialOrd + AddAssign + SampleUniform + SubAssign + Sub<Output = T>,
{
pub fn shuffle<R: Rng>(mut self, rng: &'a mut R) -> impl Iterator<Item = usize> + 'a {
std::iter::from_fn(move || {
let zero = <T as Default>::default();
if self.weight > zero {
let sample = <T as SampleUniform>::Sampler::sample_single(zero, self.weight, rng);
if self.weight > Self::ZERO {
let sample =
<T as SampleUniform>::Sampler::sample_single(Self::ZERO, self.weight, rng);
let (index, weight) = WeightedShuffle::search(&self, sample);
self.remove(index, weight);
return Some(index);
Expand Down

0 comments on commit 163a439

Please sign in to comment.