Skip to content

Commit

Permalink
race-audit (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
theganyo authored Jun 4, 2021
1 parent 1f5ba3f commit 65bc176
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
3 changes: 2 additions & 1 deletion quota/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ func (b *bucket) sync() error {
b.request.Weight -= r.Weight // same window, keep accumulated Weight
}
b.result = &quotaResult
log.Debugf("quota synced: %#v", quotaResult)
b.lock.Unlock()

prometheusBucketSynced.With(b.prometheusLabels).SetToCurrentTime()

log.Debugf("quota synced: %#v", quotaResult)
return nil

default:
Expand All @@ -210,6 +210,7 @@ func (b *bucket) needToSync() bool {
return b.request.Weight > 0 || b.now().After(b.synced.Add(b.refreshAfter))
}

// does not lock b.lock! lock before calling.
func (b *bucket) windowExpired() bool {
if b.result != nil {
return b.now().After(time.Unix(b.result.ExpiryTime, 0))
Expand Down
23 changes: 17 additions & 6 deletions quota/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestQuota(t *testing.T) {
}

serverResult := Result{}
ts := testServer(&serverResult, time.Now, nil)
ts, _ := testServer(&serverResult, time.Now, nil)

context := authtest.NewContext(ts.URL)
authContext := &auth.Context{
Expand Down Expand Up @@ -142,7 +142,7 @@ func TestSync(t *testing.T) {

fakeTime := newClock()
serverResult := Result{}
ts := testServer(&serverResult, fakeTime.now, nil)
ts, resultLock := testServer(&serverResult, fakeTime.now, nil)
defer ts.Close()

context := authtest.NewContext(ts.URL)
Expand Down Expand Up @@ -185,15 +185,19 @@ func TestSync(t *testing.T) {
b.refreshAfter = time.Hour
b.lock.Unlock()

resultLock.Lock()
serverResult.ExpiryTime /= 1000 // convert back to seconds for comparison
resultLock.Unlock()

b.lock.RLock()
if b.request.Weight != 0 {
t.Errorf("pending request weight got: %d, want: %d", b.request.Weight, 0)
}
resultLock.Lock()
if !reflect.DeepEqual(*b.result, serverResult) {
t.Errorf("result got: %#v, want: %#v", *b.result, serverResult)
}
resultLock.Unlock()
if b.synced != m.now() {
t.Errorf("synced got: %#v, want: %#v", b.synced, m.now())
}
Expand All @@ -220,15 +224,19 @@ func TestSync(t *testing.T) {
t.Errorf("should not have received error on sync: %v", err)
}

resultLock.Lock()
serverResult.ExpiryTime /= 1000 // convert back to seconds for comparison
resultLock.Unlock()

b.lock.Lock()
if b.request.Weight != 0 {
t.Errorf("pending request weight got: %d, want: %d", b.request.Weight, 0)
}
resultLock.Lock()
if !reflect.DeepEqual(*b.result, serverResult) {
t.Errorf("result got: %#v, want: %#v", *b.result, serverResult)
}
resultLock.Unlock()
if b.synced != m.now() {
t.Errorf("synced got: %#v, want: %#v", b.synced, m.now())
}
Expand All @@ -250,7 +258,7 @@ func TestDisconnected(t *testing.T) {
send: 404,
}
serverResult := Result{}
ts := testServer(&serverResult, fakeTime.now, errC)
ts, _ := testServer(&serverResult, fakeTime.now, errC)
ts.Close()

context := authtest.NewContext(ts.URL)
Expand Down Expand Up @@ -346,7 +354,7 @@ func TestWindowExpired(t *testing.T) {
send: 200,
}
serverResult := Result{}
ts := testServer(&serverResult, fakeTime.now, errC)
ts, _ := testServer(&serverResult, fakeTime.now, errC)
defer ts.Close()

context := authtest.NewContext(ts.URL)
Expand Down Expand Up @@ -452,15 +460,18 @@ type errControl struct {
send int
}

func testServer(serverResult *Result, now func() time.Time, errC *errControl) *httptest.Server {
func testServer(serverResult *Result, now func() time.Time, errC *errControl) (*httptest.Server, *sync.Mutex) {

resultLock := &sync.Mutex{}
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errC != nil && errC.send != 200 {
w.WriteHeader(errC.send)
_, _ = w.Write([]byte("error"))
return
}

resultLock.Lock()
defer resultLock.Unlock()
req := Request{}
_ = json.NewDecoder(r.Body).Decode(&req)
serverResult.Allowed = req.Allow
Expand All @@ -473,7 +484,7 @@ func testServer(serverResult *Result, now func() time.Time, errC *errControl) *h
serverResult.ExpiryTime = now().Unix() * 1000 // milliseconds needed
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(serverResult)
}))
})), resultLock
}

// ignores if no matching quota bucket
Expand Down

0 comments on commit 65bc176

Please sign in to comment.