Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 86 additions & 29 deletions balancer/endpointsharding/endpointsharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,32 @@
type ChildState struct {
Endpoint resolver.Endpoint
State balancer.State

// Balancer exposes only the ExitIdler interface of the child LB policy.
// Other methods of the child policy are called only by endpointsharding.
Balancer balancer.ExitIdler
}

// NewBalancer returns a load balancing policy that manages homogeneous child
// policies each owning a single endpoint.
// policies each owning a single endpoint. The balancer will automatically call
// ExitIdle on its children if they report IDLE connectivity state.
func NewBalancer(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return newBlanacer(cc, opts, true)
}

// NewBalancerWithoutAutoReconnect returns a load balancing policy that manages
// homogeneous child policies each owning a single endpoint. The balancer will
// allow children to remain in IDLE state until triggered to exit idle state
// using the ChildState obtained using the endpointsharding picker.
func NewBalancerWithoutAutoReconnect(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer {
return newBlanacer(cc, opts, false)
}

func newBlanacer(cc balancer.ClientConn, opts balancer.BuildOptions, autoReconnect bool) balancer.Balancer {
es := &endpointSharding{
cc: cc,
bOpts: opts,
cc: cc,
bOpts: opts,
enableAutoReconnect: autoReconnect,
}
es.children.Store(resolver.NewEndpointMap())
return es
Expand All @@ -79,19 +97,26 @@
// balancer with child config for every unique Endpoint received. It updates the
// child states on any update from parent or child.
type endpointSharding struct {
cc balancer.ClientConn
bOpts balancer.BuildOptions
cc balancer.ClientConn
bOpts balancer.BuildOptions
enableAutoReconnect bool

childMu sync.Mutex // syncs balancer.Balancer calls into children
children atomic.Pointer[resolver.EndpointMap]
closed bool
// childMu synchronizes calls to any single child. It must be held for all
// calls into a child. To avoid deadlocks, do not acquire childMu while
// holding mu.
childMu sync.Mutex
children atomic.Pointer[resolver.EndpointMap] // endpoint -> *balancerWrapper

// inhibitChildUpdates is set during UpdateClientConnState/ResolverError
// calls (calls to children will each produce an update, only want one
// update).
inhibitChildUpdates atomic.Bool

mu sync.Mutex // Sync updateState callouts and childState recent state updates
// mu synchronizes access to the state stored in balancerWrappers in the
// children field. mu must not be held during calls into a child since
// synchronous calls back from the child may require taking mu, causing a
// deadlock. To avoid deadlocks, do not acquire childMu while holding mu.
mu sync.Mutex
}

// UpdateClientConnState creates a child for new endpoints and deletes children
Expand Down Expand Up @@ -121,19 +146,24 @@
// update.
continue
}
var bal *balancerWrapper
if child, ok := children.Get(endpoint); ok {
bal = child.(*balancerWrapper)
var childBalancer *balancerWrapper
if val, ok := children.Get(endpoint); ok {
childBalancer = val.(*balancerWrapper)
// Endpoint attributes may have changed, update the stored endpoint.
es.mu.Lock()
childBalancer.childState.Endpoint = endpoint
es.mu.Unlock()
} else {
bal = &balancerWrapper{
childBalancer = &balancerWrapper{
childState: ChildState{Endpoint: endpoint},
ClientConn: es.cc,
es: es,
}
bal.Balancer = gracefulswitch.NewBalancer(bal, es.bOpts)
childBalancer.childState.Balancer = childBalancer
childBalancer.child = gracefulswitch.NewBalancer(childBalancer, es.bOpts)
}
newChildren.Set(endpoint, bal)
if err := bal.UpdateClientConnState(balancer.ClientConnState{
newChildren.Set(endpoint, childBalancer)
if err := childBalancer.updateClientConnStateLocked(balancer.ClientConnState{
BalancerConfig: state.BalancerConfig,
ResolverState: resolver.State{
Endpoints: []resolver.Endpoint{endpoint},
Expand All @@ -150,9 +180,8 @@
// Delete old children that are no longer present.
for _, e := range children.Keys() {
child, _ := children.Get(e)
bal := child.(balancer.Balancer)
if _, ok := newChildren.Get(e); !ok {
bal.Close()
child.(*balancerWrapper).closeLocked()
}
}
es.children.Store(newChildren)
Expand All @@ -175,8 +204,7 @@
}()
children := es.children.Load()
for _, child := range children.Values() {
bal := child.(balancer.Balancer)
bal.ResolverError(err)
child.(balancer.Balancer).ResolverError(err)

Check warning on line 207 in balancer/endpointsharding/endpointsharding.go

View check run for this annotation

Codecov / codecov/patch

balancer/endpointsharding/endpointsharding.go#L207

Added line #L207 was not covered by tests
}
}

Expand All @@ -189,10 +217,8 @@
defer es.childMu.Unlock()
children := es.children.Load()
for _, child := range children.Values() {
bal := child.(balancer.Balancer)
bal.Close()
child.(*balancerWrapper).closeLocked()
}
es.closed = true
}

// updateState updates this component's state. It sends the aggregated state,
Expand Down Expand Up @@ -288,30 +314,61 @@
// balancerWrapper is a wrapper of a balancer. It ID's a child balancer by
// endpoint, and persists recent child balancer state.
type balancerWrapper struct {
balancer.Balancer // Simply forward balancer.Balancer operations.
// The following fields are initialized at build time and read-only after
// that and therefore do not need to be guarded by a mutex.

// child contains the wrapped balancer. Access its methods only through
// methods on balancerWrapper to ensure proper synchronization
child balancer.Balancer
balancer.ClientConn // embed to intercept UpdateState, doesn't deal with SubConns

es *endpointSharding

// Access to the following fields is guarded by es.mu.

childState ChildState
isClosed bool
}

func (bw *balancerWrapper) UpdateState(state balancer.State) {
bw.es.mu.Lock()
bw.childState.State = state
bw.es.mu.Unlock()
// When a child balancer says it's IDLE, ping it to exit idle and reconnect.
// TODO: In the future, perhaps make this a knob in configuration.
if ei, ok := bw.Balancer.(balancer.ExitIdler); state.ConnectivityState == connectivity.Idle && ok {
if state.ConnectivityState == connectivity.Idle && bw.es.enableAutoReconnect {
bw.ExitIdle()
}
bw.es.updateState()
}

// ExitIdle pings an IDLE child balancer to exit idle in a new goroutine to
// avoid deadlocks due to synchronous balancer state updates.
func (bw *balancerWrapper) ExitIdle() {
if ei, ok := bw.child.(balancer.ExitIdler); ok {
go func() {
bw.es.childMu.Lock()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we're calling into the child from a goroutine here:

  • All calls into the child need to grab bw.es.childMu to avoid concurrent calls in,
  • That means Close() needs it (and also needed it before you added the method explicitly here, so good that we caught it), and
  • That means we shouldn't embed the child Balancer, since directly calling it is wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That means Close() needs it (and also needed it before you added the method explicitly here, so good that we caught it)

Before this PR, calls to balancerWrapper.Close() and updates to endpointsharding.closedwere holding es.ChildMu:
In endpointshardng.Close():

func (es *endpointSharding) Close() {

In endpointshardng.UpdateClientConnState():

es.childMu.Lock()
defer es.childMu.Unlock()

// Delete old children that are no longer present.
for _, e := range children.Keys() {
child, _ := children.Get(e)
bal := child.(balancer.Balancer)
if _, ok := newChildren.Get(e); !ok {
bal.Close()
}
}

Maybe I'm not understanding this point correctly?

That means we shouldn't embed the child Balancer, since directly calling it is wrong.

I've removed the embedded balancer.Balancer from balancerWrapper and provided methods named updateClientConnStateLocked and closeLocked and that expect the caller to lock es.ChildMu.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this PR, calls to balancerWrapper.Close() and updates to endpointsharding.closed were holding es.ChildMu:

Oh, I see... I missed where it grabbed the lock in Close and thought it was just delegating directly to the wrappers.

This now seems good and makes it explicit what is required.

if !bw.es.closed {
if !bw.isClosed {
ei.ExitIdle()
}
bw.es.childMu.Unlock()
}()
}
bw.es.updateState()
}

// updateClientConnStateLocked delivers the ClientConnState to the child
// balancer. Callers must hold the child mutex of the parent endpointsharding
// balancer.
func (bw *balancerWrapper) updateClientConnStateLocked(ccs balancer.ClientConnState) error {
return bw.child.UpdateClientConnState(ccs)
}

// closeLocked closes the child balancer. Callers must hold the child mutext of
// the parent endpointsharding balancer.
func (bw *balancerWrapper) closeLocked() {
if bw.isClosed {
return
}

Check warning on line 369 in balancer/endpointsharding/endpointsharding.go

View check run for this annotation

Codecov / codecov/patch

balancer/endpointsharding/endpointsharding.go#L368-L369

Added lines #L368 - L369 were not covered by tests
bw.child.Close()
bw.isClosed = true
}

// ParseConfig parses a child config list and returns an LB config to use with
Expand Down
103 changes: 99 additions & 4 deletions balancer/endpointsharding/endpointsharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,34 @@ import (
"encoding/json"
"fmt"
"log"
"strings"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/balancer"
"google.golang.org/grpc/balancer/endpointsharding"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/stubserver"
"google.golang.org/grpc/internal/testutils/roundrobin"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/resolver"
"google.golang.org/grpc/resolver/manual"
"google.golang.org/grpc/serviceconfig"
"google.golang.org/grpc/status"

testgrpc "google.golang.org/grpc/interop/grpc_testing"
testpb "google.golang.org/grpc/interop/grpc_testing"
)

var (
defaultTestTimeout = time.Second * 10
defaultTestShortTimeout = time.Millisecond * 10
)

type s struct {
Expand All @@ -50,8 +61,6 @@ func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}

var gracefulSwitchPickFirst = endpointsharding.PickFirstConfig

var logger = grpclog.Component("endpoint-sharding-test")

func init() {
Expand Down Expand Up @@ -95,7 +104,7 @@ func (fp *fakePetiole) UpdateClientConnState(state balancer.ClientConnState) err
}

return fp.Balancer.UpdateClientConnState(balancer.ClientConnState{
BalancerConfig: gracefulSwitchPickFirst,
BalancerConfig: endpointsharding.PickFirstConfig,
ResolverState: state.ResolverState,
})
}
Expand Down Expand Up @@ -143,7 +152,7 @@ func (s) TestEndpointShardingBasic(t *testing.T) {
log.Fatalf("Failed to create new client: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
// Assert a round robin distribution between the two spun up backends. This
Expand All @@ -153,3 +162,89 @@ func (s) TestEndpointShardingBasic(t *testing.T) {
t.Fatalf("error in expected round robin: %v", err)
}
}

// Tests that endpointsharding doesn't automatically re-connect IDLE children.
// The test creates an endpoint with two servers and another with a single
// server. The active service in endpoint 1 is closed to make the child
// pickfirst enter IDLE state. The test verifies that the child pickfirst
// doesn't connect to the second address in the endpoint.
func (s) TestEndpointShardingReconnectDisabled(t *testing.T) {
backend1 := stubserver.StartTestService(t, nil)
defer backend1.Stop()
backend2 := stubserver.StartTestService(t, nil)
defer backend2.Stop()
backend3 := stubserver.StartTestService(t, nil)
defer backend3.Stop()

mr := manual.NewBuilderWithScheme("e2e-test")
defer mr.Close()

name := strings.ReplaceAll(strings.ToLower(t.Name()), "/", "")
bf := stub.BalancerFuncs{
Init: func(bd *stub.BalancerData) {
bd.Data = endpointsharding.NewBalancerWithoutAutoReconnect(bd.ClientConn, bd.BuildOptions)
},
UpdateClientConnState: func(bd *stub.BalancerData, ccs balancer.ClientConnState) error {
return bd.Data.(balancer.Balancer).UpdateClientConnState(balancer.ClientConnState{
BalancerConfig: endpointsharding.PickFirstConfig,
ResolverState: ccs.ResolverState,
})
},
Close: func(bd *stub.BalancerData) {
bd.Data.(balancer.Balancer).Close()
},
}
stub.Register(name, bf)

json := fmt.Sprintf(`{"loadBalancingConfig": [{"%s":{}}]}`, name)
sc := internal.ParseServiceConfig.(func(string) *serviceconfig.ParseResult)(json)
mr.InitialState(resolver.State{
Endpoints: []resolver.Endpoint{
{Addresses: []resolver.Address{{Addr: backend1.Address}, {Addr: backend2.Address}}},
{Addresses: []resolver.Address{{Addr: backend3.Address}}},
},
ServiceConfig: sc,
})

cc, err := grpc.NewClient(mr.Scheme()+":///", grpc.WithResolvers(mr), grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Fatalf("Failed to create new client: %v", err)
}
defer cc.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()
client := testgrpc.NewTestServiceClient(cc)
// Assert a round robin distribution between the two spun up backends. This
// requires a poll and eventual consistency as both endpoint children do not
// start in state READY.
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend1.Address}, {Addr: backend3.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}

// On closing the first server, the first child balancer should enter
// IDLE. Since endpointsharding is configured not to auto-reconnect, it will
// remain IDLE and will not try to connect to the second backend in the same
// endpoint.
backend1.Stop()
// CheckRoundRobinRPCs waits for all the backends to become reachable, we
// call it to ensure the picker no longer sends RPCs to closed backend.
if err = roundrobin.CheckRoundRobinRPCs(ctx, client, []resolver.Address{{Addr: backend3.Address}}); err != nil {
t.Fatalf("error in expected round robin: %v", err)
}

// Verify requests go only to backend3 for a short time.
shortCtx, cancel := context.WithTimeout(ctx, defaultTestShortTimeout)
defer cancel()
for ; shortCtx.Err() == nil; <-time.After(time.Millisecond) {
var peer peer.Peer
if _, err := client.EmptyCall(ctx, &testpb.Empty{}, grpc.Peer(&peer)); err != nil {
if status.Code(err) != codes.DeadlineExceeded {
t.Fatalf("EmptyCall() returned unexpected error %v", err)
}
break
}
if got, want := peer.Addr.String(), backend3.Address; got != want {
t.Fatalf("EmptyCall() went to unexpected backend: got %q, want %q", got, want)
}
}
}
Loading