-
Notifications
You must be signed in to change notification settings - Fork 21
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e955f56
commit 9668783
Showing
4 changed files
with
323 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
package admission | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"sync" | ||
|
||
"github.com/google/uuid" | ||
orderedmap "github.com/wk8/go-ordered-map/v2" | ||
) | ||
|
||
type boundedQueue struct { | ||
maxLimitBytes int64 | ||
maxLimitWaiters int64 | ||
currentBytes int64 | ||
currentWaiters int64 | ||
mtx sync.Mutex | ||
// waiters waiters | ||
waiters *orderedmap.OrderedMap[uuid.UUID, waiter] | ||
} | ||
|
||
type waiter struct { | ||
readyCh chan struct{} | ||
pendingBytes int64 | ||
ID uuid.UUID | ||
} | ||
|
||
func NewBoundedQueue(maxLimitBytes, maxLimitWaiters int64) *boundedQueue { | ||
return &boundedQueue{ | ||
maxLimitBytes: maxLimitBytes, | ||
maxLimitWaiters: maxLimitWaiters, | ||
currentBytes: int64(0), | ||
currentWaiters: int64(0), | ||
waiters: orderedmap.New[uuid.UUID, waiter](), | ||
} | ||
} | ||
|
||
func (bq *boundedQueue) admit(pendingBytes int64) (bool, error) { | ||
bq.mtx.Lock() | ||
defer bq.mtx.Unlock() | ||
|
||
if pendingBytes > bq.maxLimitBytes { // will never succeed | ||
return false, fmt.Errorf("rejecting request, request size larger than configured limit") | ||
} | ||
|
||
if bq.currentBytes + pendingBytes <= bq.maxLimitBytes { // no need to wait to admit | ||
bq.currentBytes += pendingBytes | ||
return true, nil | ||
} | ||
|
||
// since we were unable to admit, check if we can wait. | ||
if bq.currentWaiters + 1 > bq.maxLimitWaiters { // too many waiters | ||
return false, fmt.Errorf("rejecting request, too many waiters") | ||
} | ||
|
||
// if we got to this point we need to wait to acquire bytes, so update currentWaiters before releasing mutex. | ||
bq.currentWaiters += 1 | ||
return false, nil | ||
} | ||
|
||
func (bq *boundedQueue) Acquire(ctx context.Context, pendingBytes int64) error { | ||
success, err := bq.admit(pendingBytes) | ||
if err != nil || success { | ||
return err | ||
} | ||
|
||
// otherwise we need to wait for bytes to be released | ||
curWaiter := waiter{ | ||
pendingBytes: pendingBytes, | ||
readyCh: make(chan struct{}), | ||
ID: uuid.New(), | ||
} | ||
|
||
bq.mtx.Lock() | ||
_, dupped := bq.waiters.Set(curWaiter.ID, curWaiter) | ||
if dupped { | ||
panic("duplicate keys found") | ||
} | ||
|
||
bq.mtx.Unlock() | ||
|
||
select { | ||
case <-curWaiter.readyCh: | ||
return nil | ||
case <-ctx.Done(): | ||
// canceled before acquired so remove waiter. | ||
bq.mtx.Lock() | ||
defer bq.mtx.Unlock() | ||
|
||
_, found := bq.waiters.Delete(curWaiter.ID) | ||
if !found { | ||
panic("deleting key that doesn't exist") | ||
} | ||
|
||
bq.currentWaiters -= 1 | ||
return fmt.Errorf("context canceled: %w ", ctx.Err()) | ||
} | ||
} | ||
|
||
func (bq *boundedQueue) Release(pendingBytes int64) { | ||
bq.mtx.Lock() | ||
defer bq.mtx.Unlock() | ||
|
||
bq.currentBytes -= pendingBytes | ||
|
||
for { | ||
if bq.waiters.Len() == 0 { | ||
return | ||
} | ||
next := bq.waiters.Oldest() | ||
nextWaiter := next.Value | ||
nextKey := next.Key | ||
if bq.currentBytes + nextWaiter.pendingBytes <= bq.maxLimitBytes { | ||
bq.currentBytes += nextWaiter.pendingBytes | ||
bq.currentWaiters -= 1 | ||
close(nextWaiter.readyCh) | ||
_, found := bq.waiters.Delete(nextKey) | ||
if !found { | ||
panic("deleting key that doesn't exist") | ||
} | ||
|
||
} else { | ||
break | ||
} | ||
} | ||
} | ||
|
||
func (bq *boundedQueue) TryAcquire(pendingBytes int64) bool { | ||
bq.mtx.Lock() | ||
defer bq.mtx.Unlock() | ||
if bq.currentBytes + pendingBytes <= bq.maxLimitBytes { | ||
bq.currentBytes += pendingBytes | ||
return true | ||
} | ||
return false | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
package admission | ||
|
||
|
||
import ( | ||
"context" | ||
"sync" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
"go.uber.org/multierr" | ||
) | ||
|
||
func min(x, y int64) int64 { | ||
if x <= y { | ||
return x | ||
} | ||
return y | ||
} | ||
|
||
func abs(x int64) int64 { | ||
if x < 0 { | ||
return -x | ||
} | ||
return x | ||
} | ||
func TestAcquireSimpleNoWaiters(t *testing.T) { | ||
maxLimitBytes := 1000 | ||
maxLimitWaiters := 10 | ||
numRequests := 40 | ||
requestSize := 21 | ||
|
||
|
||
bq := NewBoundedQueue(int64(maxLimitBytes), int64(maxLimitWaiters)) | ||
|
||
ctx, _ := context.WithTimeout(context.Background(), 10 * time.Second) | ||
for i := 0; i < numRequests; i++ { | ||
go func() { | ||
err := bq.Acquire(ctx, int64(requestSize)) | ||
assert.NoError(t, err) | ||
}() | ||
} | ||
|
||
require.Never(t, func() bool { | ||
return bq.waiters.Len() > 0 | ||
}, 2*time.Second, 10*time.Millisecond) | ||
|
||
for i := 0; i < int(numRequests); i++ { | ||
bq.Release(int64(requestSize)) | ||
assert.Equal(t, int64(0), bq.currentWaiters) | ||
} | ||
} | ||
|
||
func TestAcquireBoundedWithWaiters(t *testing.T) { | ||
tests := []struct{ | ||
name string | ||
maxLimitBytes int64 | ||
maxLimitWaiters int64 | ||
numRequests int64 | ||
requestSize int64 | ||
timeout time.Duration | ||
}{ | ||
{ | ||
name: "below max waiters above max bytes", | ||
maxLimitBytes: 1000, | ||
maxLimitWaiters: 100, | ||
numRequests: 100, | ||
requestSize: 21, | ||
timeout: 5 * time.Second, | ||
}, | ||
{ | ||
name: "above max waiters above max bytes", | ||
maxLimitBytes: 1000, | ||
maxLimitWaiters: 100, | ||
numRequests: 200, | ||
requestSize: 21, | ||
timeout: 5 * time.Second, | ||
}, | ||
} | ||
for _, tt := range tests { | ||
t.Run(tt.name, func(t *testing.T) { | ||
bq := NewBoundedQueue(tt.maxLimitBytes, tt.maxLimitWaiters) | ||
var blockedRequests int64 | ||
numReqsUntilBlocked := tt.maxLimitBytes / tt.requestSize | ||
requestsAboveLimit := abs(tt.numRequests - numReqsUntilBlocked) | ||
tooManyWaiters := requestsAboveLimit > tt.maxLimitWaiters | ||
|
||
// There should never be more blocked requests than maxLimitWaiters. | ||
blockedRequests = min(tt.maxLimitWaiters, requestsAboveLimit) | ||
|
||
ctx, _ := context.WithTimeout(context.Background(), tt.timeout) | ||
var errs error | ||
for i := 0; i < int(tt.numRequests); i++ { | ||
go func() { | ||
err := bq.Acquire(ctx, tt.requestSize) | ||
errs = multierr.Append(errs, err) | ||
}() | ||
} | ||
|
||
require.Eventually(t, func() bool { | ||
return bq.waiters.Len() == int(blockedRequests) | ||
}, 3*time.Second, 10*time.Millisecond) | ||
|
||
|
||
bq.Release(tt.requestSize) | ||
assert.Equal(t, bq.waiters.Len(), int(blockedRequests)-1) | ||
|
||
for i := 0; i < int(tt.numRequests)-1; i++ { | ||
bq.Release(tt.requestSize) | ||
} | ||
|
||
if tooManyWaiters { | ||
assert.ErrorContains(t, errs, "rejecting request, too many waiters") | ||
} else { | ||
assert.NoError(t, errs) | ||
} | ||
|
||
// confirm all bytes were released by acquiring maxLimitBytes. | ||
assert.True(t, bq.TryAcquire(tt.maxLimitBytes)) | ||
}) | ||
} | ||
} | ||
|
||
func TestAcquireContextCanceled(t *testing.T) { | ||
maxLimitBytes := 1000 | ||
maxLimitWaiters := 100 | ||
numRequests := 100 | ||
requestSize := 21 | ||
numReqsUntilBlocked := maxLimitBytes / requestSize | ||
requestsAboveLimit := abs(int64(numRequests) - int64(numReqsUntilBlocked)) | ||
|
||
blockedRequests := min(int64(maxLimitWaiters), int64(requestsAboveLimit)) | ||
|
||
bq := NewBoundedQueue(int64(maxLimitBytes), int64(maxLimitWaiters)) | ||
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10 * time.Second) | ||
var errs error | ||
var wg sync.WaitGroup | ||
for i := 0; i < numRequests; i++ { | ||
wg.Add(1) | ||
go func() { | ||
err := bq.Acquire(ctx, int64(requestSize)) | ||
errs = multierr.Append(errs, err) | ||
wg.Done() | ||
}() | ||
} | ||
|
||
// Wait until all calls to Acquire() happen and we have the expected number of waiters. | ||
require.Eventually(t, func() bool { | ||
return bq.waiters.Len() == int(blockedRequests) | ||
}, 3*time.Second, 10*time.Millisecond) | ||
|
||
cancel() | ||
// assert.Equal(t, len(bq.waiters.keys), int(blockedRequests)) | ||
// time.Sleep(10 * time.Second) | ||
wg.Wait() | ||
assert.ErrorContains(t, errs, "context canceled") | ||
|
||
// Now all waiters should have returned and been removed. | ||
// time.Sleep(3 * time.Second) | ||
assert.Equal(t, 0, bq.waiters.Len()) | ||
|
||
for i := 0; i < int(numRequests); i++ { | ||
bq.Release(int64(requestSize)) | ||
assert.Equal(t, int64(0), bq.currentWaiters) | ||
} | ||
assert.True(t, bq.TryAcquire(int64(maxLimitBytes))) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters