Skip to content

Commit

Permalink
Inproxy: broker matcher fixes
Browse files Browse the repository at this point in the history
- TestMatcherMultiQueue exited early without reaching most test cases.

- Since it took a value receiver instead of a pointer receiver, the guard
  against multiple calls to announcementQueueReference.dequeue had no effect,
  leading to potential underflows in the tracked counts.

- Empty compartment queues were left in place by dequeue, leading to
  potentially unbounded memory consumption in the case of high compartment ID
  churn.
  • Loading branch information
rod-hynes committed Sep 26, 2024
1 parent aeda8a5 commit 1b6d01c
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 33 deletions.
29 changes: 22 additions & 7 deletions psiphon/common/inproxy/matcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ func (m *Matcher) addAnnouncementEntry(announcementEntry *announcementEntry) err

func (m *Matcher) removeAnnouncementEntry(aborting bool, announcementEntry *announcementEntry) {

// In the aborting case, the queue isn't already locked. Otherise, assume
// In the aborting case, the queue isn't already locked. Othewise, assume
// it is locked.
if aborting {
m.announcementQueueMutex.Lock()
Expand Down Expand Up @@ -1119,6 +1119,8 @@ type announcementMultiQueue struct {
// announcements, which are used, when matching, to determine when better NAT
// matches may be possible.
type announcementCompartmentQueue struct {
isCommonCompartment bool
compartmentID ID
entries *list.List
unlimitedNATCount int
partiallyLimitedNATCount int
Expand Down Expand Up @@ -1177,20 +1179,24 @@ func (q *announcementMultiQueue) enqueue(announcementEntry *announcementEntry) e
return errors.TraceNew("announcement must specify exactly one compartment ID")
}

isCommonCompartment := true
var compartmentID ID
var compartmentQueues map[ID]*announcementCompartmentQueue
if len(commonCompartmentIDs) > 0 {
compartmentID = commonCompartmentIDs[0]
compartmentQueues = q.commonCompartmentQueues
} else {
isCommonCompartment = false
compartmentID = personalCompartmentIDs[0]
compartmentQueues = q.personalCompartmentQueues
}

compartmentQueue, ok := compartmentQueues[compartmentID]
if !ok {
compartmentQueue = &announcementCompartmentQueue{
entries: list.New(),
isCommonCompartment: isCommonCompartment,
compartmentID: compartmentID,
entries: list.New(),
}
compartmentQueues[compartmentID] = compartmentQueue
}
Expand Down Expand Up @@ -1221,7 +1227,7 @@ func (q *announcementMultiQueue) enqueue(announcementEntry *announcementEntry) e
}

// announcementQueueReference returns false if the item is already dequeued.
func (r announcementQueueReference) dequeue() bool {
func (r *announcementQueueReference) dequeue() bool {

if r.entry == nil {
// Already dequeued.
Expand All @@ -1242,6 +1248,15 @@ func (r announcementQueueReference) dequeue() bool {

r.compartmentQueue.entries.Remove(r.entry)

if r.compartmentQueue.entries.Len() == 0 {
// Remove empty compartment queue.
queues := r.multiQueue.commonCompartmentQueues
if !r.compartmentQueue.isCommonCompartment {
queues = r.multiQueue.personalCompartmentQueues
}
delete(queues, r.compartmentQueue.compartmentID)
}

r.multiQueue.totalEntries -= 1

// Mark as dequeued.
Expand Down Expand Up @@ -1319,14 +1334,14 @@ func (iter *announcementMatchIterator) getNext() *announcementEntry {

// Select the oldest item, by deadline, from all the candidate queue head
// items. This operation is linear in the number of matching compartment
// ID queues, which is currently bounded by This is a linear time
// operation, bounded by the length of matching compartment IDs (no more
// than maxCompartmentIDs, as enforced in
// ID queues, which is currently bounded by the length of matching
// compartment IDs (no more than maxCompartmentIDs, as enforced in
// ClientOfferRequest.ValidateAndGetLogFields).
//
// A potential future enhancement is to add more iterator state to track
// which queue has the next oldest time to select on the following
// getNext call.
// getNext call. Another potential enhancement is to remove fully
// consumed queues from compartmentQueues/compartmentIDs/nextEntries.

var selectedCandidate *announcementEntry
selectedIndex := -1
Expand Down
88 changes: 62 additions & 26 deletions psiphon/common/inproxy/matcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,6 @@ func TestMatcherMultiQueue(t *testing.T) {
if err != nil {
t.Errorf(errors.Trace(err).Error())
}

}

func runTestMatcherMultiQueue() error {
Expand Down Expand Up @@ -725,7 +724,7 @@ func runTestMatcherMultiQueue() error {
otherCommonCompartmentIDs[i%numOtherCompartmentIDs]},
NATType: NATTypeSymmetric,
}}})
if err == nil {
if err != nil {
return errors.Trace(err)
}
err = q.enqueue(&announcementEntry{
Expand All @@ -736,38 +735,41 @@ func runTestMatcherMultiQueue() error {
otherPersonalCompartmentIDs[i%numOtherCompartmentIDs]},
NATType: NATTypeSymmetric,
}}})
if err == nil {
if err != nil {
return errors.Trace(err)
}
}

var matchingCommonCompartmentIDs []ID
numMatchingCompartmentIDs := 2
numMatchingEntries := 2
var expectedMatches []*announcementEntry
for i := 0; i < numMatchingCompartmentIDs; i++ {
commonCompartmentID, _ := MakeID()
matchingCommonCompartmentIDs = append(
matchingCommonCompartmentIDs, commonCompartmentID)
ctx, cancel := context.WithDeadline(
context.Background(), time.Now().Add(time.Duration(i+1)*time.Minute))
defer cancel()
a := &announcementEntry{
ctx: ctx,
announcement: &MatchAnnouncement{
Properties: MatchProperties{
CommonCompartmentIDs: matchingCommonCompartmentIDs[i:i],
NATType: NATTypeNone,
}}}
expectedMatches = append(expectedMatches, a)
err := q.enqueue(a)
if err == nil {
return errors.Trace(err)
for j := 0; j < numMatchingEntries; j++ {
commonCompartmentID, _ := MakeID()
matchingCommonCompartmentIDs = append(
matchingCommonCompartmentIDs, commonCompartmentID)
ctx, cancel := context.WithDeadline(
context.Background(), time.Now().Add(time.Duration(i+1)*time.Minute))
defer cancel()
a := &announcementEntry{
ctx: ctx,
announcement: &MatchAnnouncement{
Properties: MatchProperties{
CommonCompartmentIDs: matchingCommonCompartmentIDs[i : i+1],
NATType: NATTypeNone,
}}}
expectedMatches = append(expectedMatches, a)
err := q.enqueue(a)
if err != nil {
return errors.Trace(err)
}
}
}

// Test: inspect queue state

if q.getLen() != numOtherEntries*2+numMatchingCompartmentIDs {
if q.getLen() != numOtherEntries*2+numMatchingCompartmentIDs*numMatchingEntries {
return errors.TraceNew("unexpected total entries count")
}

Expand All @@ -789,15 +791,17 @@ func runTestMatcherMultiQueue() error {
}

unlimited, partiallyLimited, strictlyLimited := iter.getNATCounts()
if unlimited != numMatchingCompartmentIDs || partiallyLimited != 0 || strictlyLimited != 0 {
if unlimited != numMatchingCompartmentIDs*numMatchingEntries ||
partiallyLimited != 0 ||
strictlyLimited != 0 {
return errors.TraceNew("unexpected NAT counts")
}

match := iter.getNext()
if match == nil {
return errors.TraceNew("unexpected missing match")
}
if match == expectedMatches[0] {
if match != expectedMatches[0] {
return errors.TraceNew("unexpected match")
}

Expand All @@ -811,20 +815,52 @@ func runTestMatcherMultiQueue() error {

iter = q.startMatching(true, matchingCommonCompartmentIDs)

if len(iter.compartmentQueues) != numMatchingCompartmentIDs-1 {
if len(iter.compartmentQueues) != numMatchingCompartmentIDs {
return errors.TraceNew("unexpected iterator state")
}

unlimited, partiallyLimited, strictlyLimited = iter.getNATCounts()
if unlimited != numMatchingCompartmentIDs-1 || partiallyLimited != 0 || strictlyLimited != 0 {
if unlimited != numMatchingEntries*numMatchingCompartmentIDs-1 ||
partiallyLimited != 0 ||
strictlyLimited != 0 {
return errors.TraceNew("unexpected NAT counts")
}

match = iter.getNext()
if match == nil {
return errors.TraceNew("unexpected missing match")
}
if match == expectedMatches[1] {
if match != expectedMatches[1] {
return errors.TraceNew("unexpected match")
}

if !match.queueReference.dequeue() {
return errors.TraceNew("unexpected already dequeued")
}

if len(iter.compartmentQueues) != numMatchingCompartmentIDs {
return errors.TraceNew("unexpected iterator state")
}

// Test: getNext after dequeue

match = iter.getNext()
if match == nil {
return errors.TraceNew("unexpected missing match")
}
if match != expectedMatches[2] {
return errors.TraceNew("unexpected match")
}

if !match.queueReference.dequeue() {
return errors.TraceNew("unexpected already dequeued")
}

match = iter.getNext()
if match == nil {
return errors.TraceNew("unexpected missing match")
}
if match != expectedMatches[3] {
return errors.TraceNew("unexpected match")
}

Expand Down

0 comments on commit 1b6d01c

Please sign in to comment.