From ec8ba1368d420fff8e68da44c5b1eb1883ae353a Mon Sep 17 00:00:00 2001 From: Andrew Fitzgerald Date: Tue, 3 Dec 2024 10:18:42 -0600 Subject: [PATCH] use trailing_zeros for threadset iteration (#3871) --- .../thread_aware_account_locks.rs | 53 ++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs index b279102756eed4..f3e3d3f8d683af 100644 --- a/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs +++ b/core/src/banking_stage/transaction_scheduler/thread_aware_account_locks.rs @@ -421,7 +421,7 @@ impl ThreadSet { #[inline(always)] pub(crate) fn contained_threads_iter(self) -> impl Iterator { - (0..MAX_THREADS).filter(move |thread_id| self.contains(*thread_id)) + ThreadSetIterator(self.0) } #[inline(always)] @@ -430,6 +430,22 @@ impl ThreadSet { } } +struct ThreadSetIterator(u64); + +impl Iterator for ThreadSetIterator { + type Item = ThreadId; + + fn next(&mut self) -> Option { + if self.0 == 0 { + None + } else { + let thread_id = self.0.trailing_zeros() as ThreadId; + self.0 &= self.0 - 1; + Some(thread_id) + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -739,4 +755,39 @@ mod tests { let any_threads = ThreadSet::any(MAX_THREADS); assert_eq!(any_threads.num_threads(), MAX_THREADS as u32); } + + #[test] + fn test_thread_set_iter() { + let mut thread_set = ThreadSet::none(); + assert!(thread_set.contained_threads_iter().next().is_none()); + + thread_set.insert(4); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4] + ); + + thread_set.insert(5); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4, 5] + ); + thread_set.insert(63); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4, 5, 63] + ); + + thread_set.remove(5); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + vec![4, 63] + ); + + let thread_set = ThreadSet::any(64); + assert_eq!( + thread_set.contained_threads_iter().collect::>(), + (0..64).collect::>() + ); + } }