Skip to content

Commit

Permalink
endpointsharding: cast EndpointMap values to *balancerWrapper instead…
Browse files Browse the repository at this point in the history
… of Balancer (#8069)
  • Loading branch information
arjan-bal authored Feb 7, 2025
1 parent 267a09b commit 9afb49d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 7 deletions.
9 changes: 5 additions & 4 deletions balancer/endpointsharding/endpointsharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (es *endpointSharding) ResolverError(err error) {
}()
children := es.children.Load()
for _, child := range children.Values() {
child.(balancer.Balancer).ResolverError(err)
child.(*balancerWrapper).resolverErrorLocked(err)
}
}

Expand Down Expand Up @@ -349,9 +349,10 @@ func (bw *balancerWrapper) updateClientConnStateLocked(ccs balancer.ClientConnSt
// closeLocked closes the child balancer. Callers must hold the child mutext of
// the parent endpointsharding balancer.
func (bw *balancerWrapper) closeLocked() {
if bw.isClosed {
return
}
bw.child.Close()
bw.isClosed = true
}

func (bw *balancerWrapper) resolverErrorLocked(err error) {
bw.child.ResolverError(err)
}
48 changes: 45 additions & 3 deletions balancer/endpointsharding/endpointsharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,27 @@ package endpointsharding_test
import (
"context"
"encoding/json"
"errors"
"fmt"
"log"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/backoff"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/balancer/pickfirst/pickfirstleaf"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/testutils/roundrobin"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
Expand Down Expand Up @@ -125,7 +129,9 @@ func (fp *fakePetiole) UpdateState(state balancer.State) {
// special picker, so it should fallback to the default behavior, which is to
// round_robin amongst the endpoint children that are in the aggregated state.
// It also verifies the petiole has access to the raw child state in case it
// wants to implement a custom picker.
// wants to implement a custom picker. The test sends a resolver error to the
// endpointsharding balancer and verifies an error picker from the children
// is used while making an RPC.
func (s) TestEndpointShardingBasic(t *testing.T) {
backend1 := stubserver.StartTestService(t, nil)
defer backend1.Stop()
Expand All @@ -135,7 +141,7 @@ func (s) TestEndpointShardingBasic(t *testing.T) {
mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()

json := `{"loadBalancingConfig": [{"fake_petiole":{}}]}`
json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, fakePetioleName)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
Expand All @@ -145,7 +151,20 @@ func (s) TestEndpointShardingBasic(t *testing.T) {
ServiceConfig: sc,
})

cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
dOpts := []grpc.DialOption{
grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()),
// Use a large backoff delay to avoid the error picker being updated
// too quickly.
grpc.WithConnectParams(grpc.ConnectParams{
Backoff: backoff.Config{
BaseDelay: 2 * defaultTestTimeout,
Multiplier: float64(0),
Jitter: float64(0),
MaxDelay: 2 * defaultTestTimeout,
},
}),
}
cc, err := grpc.NewClient(mr.Scheme()+":///", dOpts...)
if err != nil {
log.Fatalf("Failed to create new client: %v", err)
}
Expand All @@ -159,6 +178,29 @@ func (s) TestEndpointShardingBasic(t *testing.T) {
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}

// Stopping both the backends should make the channel enter
// TransientFailure.
backend1.Stop()
backend2.Stop()
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)

// When the resolver reports an error, the picker should get updated to
// return the resolver error.
mr.ReportError(errors.New("test error"))
testutils.AwaitState(ctx, t, cc, connectivity.TransientFailure)
for ; ctx.Err() == nil; <-time.After(time.Millisecond) {
_, err := client.EmptyCall(ctx, &testpb.Empty{})
if err == nil {
t.Fatalf("EmptyCall succeeded when expected to fail with %q", "test error")
}
if strings.Contains(err.Error(), "test error") {
break
}
}
if ctx.Err() != nil {
t.Fatalf("Context timed out waiting for picker with resolver error.")
}
}

// Tests that endpointsharding doesn't automatically re-connect IDLE children.
Expand Down

0 comments on commit 9afb49d

Please sign in to comment.