-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
413 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
Additional synchronization primitives: | ||
* generic thread-safe map; | ||
* safer waitgroup; | ||
* singleflight (duplicate call suppression). | ||
|
||
**Thread-safe map** | ||
```go | ||
Len() int | ||
Get(key K) (V, bool) | ||
Set(key K, value V) | ||
SetIf(key K, cond func(value V, exists bool) bool, valfunc func(prev V) V) (value V, ok bool) | ||
Delete(key K) | ||
DeleteIf(key K, cond func(value V) bool) bool | ||
Clear() | ||
ForEach(fun func(key K, value V) bool) bool | ||
``` | ||
|
||
**Safer waitgroup** | ||
```go | ||
Go(fun func()) | ||
Wait() | ||
``` | ||
|
||
**Singleflight** | ||
```go | ||
Do(key K, fun func() V) V | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
module github.com/Zamony/go/par | ||
|
||
go 1.20 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package par | ||
|
||
import "sync" | ||
|
||
// Map is a thread-safe map. N is a size hint for the map. | ||
type Map[K comparable, V any] struct { | ||
data map[K]V | ||
mu sync.RWMutex // protects data | ||
N int | ||
} | ||
|
||
// Len returns number of elements in the map. | ||
func (m *Map[K, V]) Len() int { | ||
m.mu.RLock() | ||
n := len(m.data) | ||
m.mu.RUnlock() | ||
return n | ||
} | ||
|
||
func (m *Map[K, V]) initOnce() { | ||
switch { | ||
case m.data != nil: | ||
case m.N > 0: | ||
m.data = make(map[K]V, m.N) | ||
default: | ||
m.data = make(map[K]V) | ||
} | ||
} | ||
|
||
// Set sets the value by the given key. | ||
func (m *Map[K, V]) Set(key K, value V) { | ||
m.mu.Lock() | ||
m.initOnce() | ||
m.data[key] = value | ||
m.mu.Unlock() | ||
} | ||
|
||
func (m *Map[K, V]) canSet(key K, cond func(value V, exists bool) bool) (V, bool) { | ||
m.mu.RLock() | ||
defer m.mu.RUnlock() | ||
|
||
value, ok := m.data[key] | ||
condOk := cond(value, ok) | ||
return value, condOk | ||
} | ||
|
||
// SetIf conditionally sets the value by the given key. | ||
// Condition function must be pure. | ||
// Returns final value and condition result. | ||
func (m *Map[K, V]) SetIf(key K, cond func(value V, exists bool) bool, valfunc func(prev V) V) (value V, ok bool) { | ||
value, ok = m.canSet(key, cond) | ||
if !ok { | ||
return value, false | ||
} | ||
|
||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
value, ok = m.data[key] | ||
if !cond(value, ok) { | ||
return value, false | ||
} | ||
|
||
m.initOnce() | ||
value = valfunc(value) | ||
m.data[key] = value | ||
return value, true | ||
} | ||
|
||
// Get gets value y the given key. | ||
func (m *Map[K, V]) Get(key K) (V, bool) { | ||
m.mu.RLock() | ||
value, ok := m.data[key] | ||
m.mu.RUnlock() | ||
return value, ok | ||
} | ||
|
||
// Delete deletes the value by the given key. | ||
// If the key doesn't exist does nothing. | ||
func (m *Map[K, V]) Delete(key K) { | ||
m.mu.Lock() | ||
delete(m.data, key) | ||
m.mu.Unlock() | ||
} | ||
|
||
func (m *Map[K, V]) canDelete(key K, cond func(value V) bool) bool { | ||
m.mu.RLock() | ||
defer m.mu.RUnlock() | ||
|
||
value, ok := m.data[key] | ||
return ok && cond(value) | ||
} | ||
|
||
// DeleteIf conditionally deletes the value by the given key. | ||
// Condition function must be pure. | ||
// Returns true if the value was deleted. | ||
func (m *Map[K, V]) DeleteIf(key K, cond func(value V) bool) bool { | ||
if !m.canDelete(key, cond) { | ||
return false | ||
} | ||
|
||
m.mu.Lock() | ||
defer m.mu.Unlock() | ||
value, ok := m.data[key] | ||
if !ok || !cond(value) { | ||
return false | ||
} | ||
|
||
delete(m.data, key) | ||
return true | ||
} | ||
|
||
// Clear clears the map. | ||
func (m *Map[K, V]) Clear() { | ||
m.mu.Lock() | ||
m.data = nil | ||
m.mu.Unlock() | ||
} | ||
|
||
// ForEach iterates over map and calls provided function for each key and value. | ||
// Iteration is aborted after provided function returns false. | ||
// Returns false if iteration was aborted. | ||
func (m *Map[K, V]) ForEach(fun func(key K, value V) bool) bool { | ||
m.mu.RLock() | ||
defer m.mu.RUnlock() | ||
|
||
for k, v := range m.data { | ||
if !fun(k, v) { | ||
return false | ||
} | ||
} | ||
|
||
return true | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
package par_test | ||
|
||
import ( | ||
"reflect" | ||
"testing" | ||
|
||
"github.com/Zamony/go/par" | ||
) | ||
|
||
func TestMapSetGet(t *testing.T) { | ||
t.Parallel() | ||
|
||
var m par.Map[string, int] | ||
m.Set("a", 1) | ||
|
||
value, ok := m.SetIf("b", func(value int, exists bool) bool { | ||
return true | ||
}, func(int) int { | ||
return 2 | ||
}) | ||
equal(t, value, 2) | ||
equal(t, ok, true) | ||
|
||
value, ok = m.SetIf("b", func(value int, exists bool) bool { | ||
return !exists | ||
}, func(int) int { | ||
return 3 | ||
}) | ||
equal(t, value, 2) | ||
equal(t, ok, false) | ||
|
||
m.Set("c", 4) | ||
value, ok = m.SetIf("c", func(value int, exists bool) bool { | ||
return value == 4 | ||
}, func(int) int { | ||
return 5 | ||
}) | ||
equal(t, value, 5) | ||
equal(t, ok, true) | ||
equal(t, m.Len(), 3) | ||
|
||
value, ok = m.Get("a") | ||
equal(t, ok, true) | ||
equal(t, value, 1) | ||
|
||
value, ok = m.Get("b") | ||
equal(t, ok, true) | ||
equal(t, value, 2) | ||
|
||
value, ok = m.Get("c") | ||
equal(t, ok, true) | ||
equal(t, value, 5) | ||
|
||
_, ok = m.Get("d") | ||
equal(t, ok, false) | ||
} | ||
|
||
func TestMapSetDelete(t *testing.T) { | ||
t.Parallel() | ||
|
||
var m par.Map[string, int] | ||
m.Set("a", 1) | ||
m.Set("b", 2) | ||
m.Set("c", 3) | ||
m.Delete("a") | ||
equal(t, m.DeleteIf("b", func(value int) bool { | ||
return value == 2 | ||
}), true) | ||
equal(t, m.DeleteIf("c", func(value int) bool { | ||
return value == 100500 | ||
}), false) | ||
equal(t, m.Len(), 1) | ||
|
||
_, ok := m.Get("a") | ||
equal(t, ok, false) | ||
|
||
_, ok = m.Get("b") | ||
equal(t, ok, false) | ||
|
||
value, ok := m.Get("c") | ||
equal(t, ok, true) | ||
equal(t, value, 3) | ||
} | ||
|
||
func TestMapForEach(t *testing.T) { | ||
t.Parallel() | ||
|
||
var m par.Map[string, int] | ||
m.Set("a", 1) | ||
m.Set("b", 2) | ||
m.Set("c", 3) | ||
|
||
mit := map[string]int{} | ||
completed := m.ForEach(func(key string, value int) bool { | ||
mit[key] = value | ||
return true | ||
}) | ||
equal(t, completed, true) | ||
equal(t, mit, map[string]int{ | ||
"a": 1, | ||
"b": 2, | ||
"c": 3, | ||
}) | ||
|
||
m.Clear() | ||
equal(t, m.Len(), 0) | ||
} | ||
|
||
func TestMapConcurrent(t *testing.T) { | ||
t.Parallel() | ||
|
||
const key = "a" | ||
var m par.Map[string, int] | ||
var wg par.WaitGroup | ||
defer wg.Wait() | ||
|
||
for i := 1; i <= 10; i++ { | ||
i := i | ||
wg.Go(func() { | ||
m.Set(key, i) | ||
}) | ||
wg.Go(func() { | ||
m.SetIf(key, func(int, bool) bool { | ||
return true | ||
}, func(v int) int { | ||
return v + 1 | ||
}) | ||
}) | ||
wg.Go(func() { | ||
m.Get(key) | ||
}) | ||
wg.Go(func() { | ||
m.Len() | ||
}) | ||
wg.Go(func() { | ||
m.Delete(key) | ||
}) | ||
wg.Go(func() { | ||
m.DeleteIf(key, func(int) bool { | ||
return true | ||
}) | ||
}) | ||
wg.Go(func() { | ||
m.ForEach(func(string, int) bool { | ||
return true | ||
}) | ||
}) | ||
} | ||
} | ||
|
||
func equal(t *testing.T, got, want any) { | ||
t.Helper() | ||
|
||
if !reflect.DeepEqual(got, want) { | ||
t.Errorf("Not equal (-want, +got):\n- %+v\n+ %+v\n", want, got) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package par | ||
|
||
import ( | ||
"sync/atomic" | ||
) | ||
|
||
type flightResult[T any] struct { | ||
Result T | ||
Done chan struct{} | ||
Waiters int64 | ||
} | ||
|
||
// Singleflight suppresses duplicate function calls. | ||
type Singleflight[K comparable, V any] struct { | ||
flights Map[K, *flightResult[V]] | ||
} | ||
|
||
// Do executes and returns the results of the given function, making | ||
// sure that only one execution is in-flight for a given key at a | ||
// time. If a duplicate comes in, the duplicate caller waits for the | ||
// original to complete and receives the same results. | ||
func (f *Singleflight[K, V]) Do(key K, fun func() V) V { | ||
flight, isPrimary := f.flights.SetIf(key, func(_ *flightResult[V], exists bool) bool { | ||
return !exists | ||
}, func(*flightResult[V]) *flightResult[V] { | ||
return &flightResult[V]{Waiters: 1, Done: make(chan struct{})} | ||
}) | ||
if !isPrimary { | ||
atomic.AddInt64(&flight.Waiters, 1) | ||
} else { | ||
flight.Result = fun() | ||
close(flight.Done) | ||
} | ||
|
||
<-flight.Done | ||
atomic.AddInt64(&flight.Waiters, -1) | ||
f.flights.DeleteIf(key, func(value *flightResult[V]) bool { | ||
return atomic.LoadInt64(&value.Waiters) == 0 | ||
}) | ||
|
||
return flight.Result | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
package par_test | ||
|
||
import ( | ||
"sync/atomic" | ||
"testing" | ||
"time" | ||
|
||
"github.com/Zamony/go/par" | ||
) | ||
|
||
func TestSingleFlight(t *testing.T) { | ||
t.Parallel() | ||
|
||
var ncalls int64 | ||
var wg par.WaitGroup | ||
single := par.Singleflight[string, int64]{} | ||
for i := 0; i < 5; i++ { | ||
wg.Go(func() { | ||
got := single.Do("a", func() int64 { | ||
time.Sleep(10 * time.Millisecond) | ||
return atomic.AddInt64(&ncalls, 1) | ||
}) | ||
equal(t, got, int64(1)) | ||
}) | ||
} | ||
|
||
wg.Wait() | ||
} |
Oops, something went wrong.