diff --git a/set.go b/set.go index 803c8ea..d6b042c 100644 --- a/set.go +++ b/set.go @@ -62,6 +62,10 @@ type Set[T comparable] interface { // are all in the set. Contains(val ...T) bool + // ContainsAny returns whether at least one of the + // given items are in the set. + ContainsAny(val ...T) bool + // Difference returns the difference between this set // and other. The returned set will contain // all elements of this set that are not also diff --git a/set_test.go b/set_test.go index b7d4539..23e8f09 100644 --- a/set_test.go +++ b/set_test.go @@ -318,6 +318,38 @@ func Test_ContainsMultipleUnsafeSet(t *testing.T) { } } +func Test_ContainsAnySet(t *testing.T) { + a := NewSet[int]() + + a.Add(71) + + if !a.ContainsAny(71) { + t.Error("ContainsSet should contain 71") + } + + if !a.ContainsAny(71, 10) { + t.Error("ContainsSet should contain 71 or 10") + } + + a.Remove(71) + + if a.ContainsAny(71) { + t.Error("ContainsSet should not contain 71") + } + + if a.ContainsAny(71, 10) { + t.Error("ContainsSet should not contain 71 or 10") + } + + a.Add(13) + a.Add(7) + a.Add(1) + + if !(a.ContainsAny(13, 17, 10)) { + t.Error("ContainsSet should contain 13, 17, or 10") + } +} + func Test_ClearSet(t *testing.T) { a := makeSetInt([]int{2, 5, 9, 10}) diff --git a/threadsafe.go b/threadsafe.go index 9e3a0ca..067d09a 100644 --- a/threadsafe.go +++ b/threadsafe.go @@ -66,6 +66,14 @@ func (t *threadSafeSet[T]) Contains(v ...T) bool { return ret } +func (t *threadSafeSet[T]) ContainsAny(v ...T) bool { + t.RLock() + ret := t.uss.ContainsAny(v...) + t.RUnlock() + + return ret +} + func (t *threadSafeSet[T]) IsSubset(other Set[T]) bool { o := other.(*threadSafeSet[T]) diff --git a/threadsafe_test.go b/threadsafe_test.go index a74df68..368b603 100644 --- a/threadsafe_test.go +++ b/threadsafe_test.go @@ -172,6 +172,30 @@ func Test_ContainsConcurrent(t *testing.T) { wg.Wait() } +func Test_ContainsAnyConcurrent(t *testing.T) { + runtime.GOMAXPROCS(2) + + s := NewSet[int]() + ints := rand.Perm(N) + integers := make([]int, 0) + for _, v := range ints { + if v%N == 0 { + s.Add(v) + } + integers = append(integers, v) + } + + var wg sync.WaitGroup + for range ints { + wg.Add(1) + go func() { + s.ContainsAny(integers...) + wg.Done() + }() + } + wg.Wait() +} + func Test_DifferenceConcurrent(t *testing.T) { runtime.GOMAXPROCS(2) diff --git a/threadunsafe.go b/threadunsafe.go index e5f4629..da76a41 100644 --- a/threadunsafe.go +++ b/threadunsafe.go @@ -94,6 +94,15 @@ func (s threadUnsafeSet[T]) Contains(v ...T) bool { return true } +func (s threadUnsafeSet[T]) ContainsAny(v ...T) bool { + for _, val := range v { + if _, ok := s[val]; ok { + return true + } + } + return false +} + // private version of Contains for a single element v func (s threadUnsafeSet[T]) contains(v T) (ok bool) { _, ok = s[v]