Skip to content

Commit

Permalink
ProviderQueryManager: support "max" param on FindProvidersAsync
Browse files Browse the repository at this point in the history
This aligns the ProviderQueryManager with the routing.ContentDiscovery interface.
  • Loading branch information
hsanjuan committed Nov 21, 2024
1 parent cc82f9b commit f7da578
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 33 deletions.
6 changes: 4 additions & 2 deletions bitswap/client/internal/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ type SessionPeerManager interface {
// ProviderFinder is used to find providers for a given key
type ProviderFinder interface {
// FindProvidersAsync searches for peers that provide the given CID
FindProvidersAsync(ctx context.Context, k cid.Cid) <-chan peer.AddrInfo
FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo
}

// opType is the kind of operation that is being processed by the event loop
Expand Down Expand Up @@ -410,7 +410,9 @@ func (s *Session) findMorePeers(ctx context.Context, c cid.Cid) {
go func(k cid.Cid) {
ctx, span := internal.StartSpan(ctx, "Session.FindMorePeers")
defer span.End()
for p := range s.providerFinder.FindProvidersAsync(ctx, k) {
// Max is set to -1. This means "use the default limit" in the
// provider query manager.
for p := range s.providerFinder.FindProvidersAsync(ctx, k, -1) {
// When a provider indicates that it has a cid, it's equivalent to
// the providing peer sending a HAVE
span.AddEvent("FoundPeer")
Expand Down
2 changes: 1 addition & 1 deletion bitswap/client/internal/session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func newFakeProviderFinder() *fakeProviderFinder {
}
}

func (fpf *fakeProviderFinder) FindProvidersAsync(ctx context.Context, k cid.Cid) <-chan peer.AddrInfo {
func (fpf *fakeProviderFinder) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo {
go func() {
select {
case fpf.findMorePeersRequested <- k:
Expand Down
47 changes: 42 additions & 5 deletions routing/providerquerymanager/providerquerymanager.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,16 @@ func (pqm *ProviderQueryManager) setFindProviderTimeout(findProviderTimeout time
pqm.timeoutMutex.Unlock()
}

// FindProvidersAsync finds providers for the given block.
func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid) <-chan peer.AddrInfo {
// FindProvidersAsync finds providers for the given block. The max parameter
// controls how many will be returned at most. For a provider to be returned,
// we must have successfully connected to it. Setting max to -1 will use the
// configured MaxProviders. Setting max to 0 will return an unbounded number
// of providers.
func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo {
if max < 0 {
max = pqm.maxProviders
}

inProgressRequestChan := make(chan inProgressRequest)

var span trace.Span
Expand Down Expand Up @@ -203,10 +211,10 @@ func (pqm *ProviderQueryManager) FindProvidersAsync(sessionCtx context.Context,
case receivedInProgressRequest = <-inProgressRequestChan:
}

return pqm.receiveProviders(sessionCtx, k, receivedInProgressRequest, func() { span.End() })
return pqm.receiveProviders(sessionCtx, k, max, receivedInProgressRequest, func() { span.End() })
}

func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, receivedInProgressRequest inProgressRequest, onCloseFn func()) <-chan peer.AddrInfo {
func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k cid.Cid, max int, receivedInProgressRequest inProgressRequest, onCloseFn func()) <-chan peer.AddrInfo {
// maintains an unbuffered queue for incoming providers for given request for a given session
// essentially, as a provider comes in, for a given CID, we want to immediately broadcast to all
// sessions that queried that CID, without worrying about whether the client code is actually
Expand All @@ -216,6 +224,9 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
receivedProviders := append([]peer.AddrInfo(nil), receivedInProgressRequest.providersSoFar[0:]...)
incomingProviders := receivedInProgressRequest.incoming

// count how many providers we received from our workers etc.
// these providers should be peers we managed to connect to.
total := len(receivedProviders)
go func() {
defer close(returnedProviders)
defer onCloseFn()
Expand All @@ -231,6 +242,21 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
}
return receivedProviders[0]
}

stopWhenMaxReached := func() {
if max > 0 && total >= max {
if incomingProviders != nil {
// drains incomingProviders.
pqm.cancelProviderRequest(sessionCtx, k, incomingProviders)
incomingProviders = nil
}
}
}

// Handle the case when providersSoFar already is more than we
// need.
stopWhenMaxReached()

for len(receivedProviders) > 0 || incomingProviders != nil {
select {
case <-pqm.ctx.Done():
Expand All @@ -245,6 +271,13 @@ func (pqm *ProviderQueryManager) receiveProviders(sessionCtx context.Context, k
incomingProviders = nil
} else {
receivedProviders = append(receivedProviders, provider)
total++
stopWhenMaxReached()
// we do not return, we will loop on
// the case below until
// len(receivedProviders) == 0, which
// means they have all been sent out
// via returnedProviders
}
case outgoingProviders() <- nextProvider():
receivedProviders = receivedProviders[1:]
Expand Down Expand Up @@ -293,7 +326,11 @@ func (pqm *ProviderQueryManager) findProviderWorker() {
pqm.timeoutMutex.RUnlock()
span := trace.SpanFromContext(findProviderCtx)
span.AddEvent("StartFindProvidersAsync")
providers := pqm.router.FindProvidersAsync(findProviderCtx, k, pqm.maxProviders)
// We set count == 0. We will cancel the query
// manually once we have enough. This assumes the
// ContentDiscovery implementation does that, which a
// requirement per the libp2p/core/routing interface.
providers := pqm.router.FindProvidersAsync(findProviderCtx, k, 0)
wg := &sync.WaitGroup{}
for p := range providers {
wg.Add(1)
Expand Down
74 changes: 49 additions & 25 deletions routing/providerquerymanager/providerquerymanager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ type fakeProviderDialer struct {
connectDelay time.Duration
}

type fakeProviderNetwork struct {
type fakeProviderDiscovery struct {
peersFound []peer.ID
delay time.Duration
queriesMadeMutex sync.RWMutex
Expand All @@ -31,7 +31,7 @@ func (fpd *fakeProviderDialer) Connect(context.Context, peer.AddrInfo) error {
return fpd.connectError
}

func (fpn *fakeProviderNetwork) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo {
func (fpn *fakeProviderDiscovery) FindProvidersAsync(ctx context.Context, k cid.Cid, max int) <-chan peer.AddrInfo {
fpn.queriesMadeMutex.Lock()
fpn.queriesMade++
fpn.liveQueries++
Expand Down Expand Up @@ -70,7 +70,7 @@ func mustNotErr[T any](out T, err error) T {
func TestNormalSimultaneousFetch(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -81,8 +81,8 @@ func TestNormalSimultaneousFetch(t *testing.T) {

sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1])
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[1], 0)

var firstPeersReceived []peer.AddrInfo
for p := range firstRequestChan {
Expand All @@ -108,7 +108,7 @@ func TestNormalSimultaneousFetch(t *testing.T) {
func TestDedupingProviderRequests(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -119,8 +119,8 @@ func TestDedupingProviderRequests(t *testing.T) {

sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)

var firstPeersReceived []peer.AddrInfo
for p := range firstRequestChan {
Expand Down Expand Up @@ -149,7 +149,7 @@ func TestDedupingProviderRequests(t *testing.T) {
func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -162,10 +162,10 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
// first session will cancel before done
firstSessionCtx, firstCancel := context.WithTimeout(ctx, 3*time.Millisecond)
defer firstCancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key)
firstRequestChan := providerQueryManager.FindProvidersAsync(firstSessionCtx, key, 0)
secondSessionCtx, secondCancel := context.WithTimeout(ctx, 5*time.Second)
defer secondCancel()
secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(secondSessionCtx, key, 0)

var firstPeersReceived []peer.AddrInfo
for p := range firstRequestChan {
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestCancelOneRequestDoesNotTerminateAnother(t *testing.T) {
func TestCancelManagerExitsGracefully(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -208,8 +208,8 @@ func TestCancelManagerExitsGracefully(t *testing.T) {

sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)

var firstPeersReceived []peer.AddrInfo
for p := range firstRequestChan {
Expand All @@ -232,7 +232,7 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) {
fpd := &fakeProviderDialer{
connectError: errors.New("not able to connect"),
}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -244,8 +244,8 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) {

sessionCtx, cancel := context.WithTimeout(ctx, 20*time.Millisecond)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)
secondRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, key, 0)

var firstPeersReceived []peer.AddrInfo
for p := range firstRequestChan {
Expand All @@ -265,7 +265,7 @@ func TestPeersWithConnectionErrorsNotAddedToPeerList(t *testing.T) {
func TestRateLimitingRequests(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 5 * time.Millisecond,
}
Expand All @@ -280,7 +280,7 @@ func TestRateLimitingRequests(t *testing.T) {
defer cancel()
var requestChannels []<-chan peer.AddrInfo
for i := 0; i < providerQueryManager.maxInProcessRequests+1; i++ {
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i]))
requestChannels = append(requestChannels, providerQueryManager.FindProvidersAsync(sessionCtx, keys[i], 0))
}
time.Sleep(20 * time.Millisecond)
fpn.queriesMadeMutex.Lock()
Expand All @@ -305,7 +305,7 @@ func TestRateLimitingRequests(t *testing.T) {
func TestFindProviderTimeout(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 10 * time.Millisecond,
}
Expand All @@ -317,7 +317,7 @@ func TestFindProviderTimeout(t *testing.T) {

sessionCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
var firstPeersReceived []peer.AddrInfo
for p := range firstRequestChan {
firstPeersReceived = append(firstPeersReceived, p)
Expand All @@ -330,7 +330,7 @@ func TestFindProviderTimeout(t *testing.T) {
func TestFindProviderPreCanceled(t *testing.T) {
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -342,7 +342,7 @@ func TestFindProviderPreCanceled(t *testing.T) {

sessionCtx, cancel := context.WithCancel(ctx)
cancel()
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
if firstRequestChan == nil {
t.Fatal("expected non-nil channel")
}
Expand All @@ -356,7 +356,7 @@ func TestFindProviderPreCanceled(t *testing.T) {
func TestCancelFindProvidersAfterCompletion(t *testing.T) {
peers := random.Peers(2)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderNetwork{
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
Expand All @@ -367,7 +367,7 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) {
keys := random.Cids(1)

sessionCtx, cancel := context.WithCancel(ctx)
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0])
firstRequestChan := providerQueryManager.FindProvidersAsync(sessionCtx, keys[0], 0)
<-firstRequestChan // wait for everything to start.
time.Sleep(10 * time.Millisecond) // wait for the incoming providres to stop.
cancel() // cancel the context.
Expand All @@ -385,3 +385,27 @@ func TestCancelFindProvidersAfterCompletion(t *testing.T) {
}
}
}

func TestLimitedProviders(t *testing.T) {
max := 5
peers := random.Peers(10)
fpd := &fakeProviderDialer{}
fpn := &fakeProviderDiscovery{
peersFound: peers,
delay: 1 * time.Millisecond,
}
ctx := context.Background()
providerQueryManager := mustNotErr(New(ctx, fpd, fpn, WithMaxProviders(max)))
providerQueryManager.Startup()
providerQueryManager.setFindProviderTimeout(100 * time.Millisecond)
keys := random.Cids(1)

providersChan := providerQueryManager.FindProvidersAsync(ctx, keys[0], -1)
total := 0
for range providersChan {
total++
}
if total != max {
t.Fatal("returned more providers than requested")
}
}

0 comments on commit f7da578

Please sign in to comment.