diff --git a/grpclb/grpclb.go b/grpclb/grpclb.go index 996d27aeb7e9..d9a1a8b6fb53 100644 --- a/grpclb/grpclb.go +++ b/grpclb/grpclb.go @@ -40,6 +40,7 @@ import ( "errors" "fmt" "sync" + "time" "golang.org/x/net/context" "google.golang.org/grpc" @@ -93,16 +94,17 @@ type addrInfo struct { } type balancer struct { - r naming.Resolver - mu sync.Mutex - seq int // a sequence number to make sure addrCh does not get stale addresses. - w naming.Watcher - addrCh chan []grpc.Address - rbs []remoteBalancerInfo - addrs []*addrInfo - next int - waitCh chan struct{} - done bool + r naming.Resolver + mu sync.Mutex + seq int // a sequence number to make sure addrCh does not get stale addresses. + w naming.Watcher + addrCh chan []grpc.Address + rbs []remoteBalancerInfo + addrs []*addrInfo + next int + waitCh chan struct{} + done bool + expTimer *time.Timer } func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error { @@ -180,14 +182,39 @@ func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo return nil } +func (b *balancer) serverListExpire(seq int) { + b.mu.Lock() + defer b.mu.Unlock() + // TODO: gRPC interanls do not clear the connections when the server list is stale. + // This means RPCs will keep using the existing server list until b receives new + // server list even though the list is expired. Revisit this behavior later. + if b.done || seq < b.seq { + return + } + b.next = 0 + b.addrs = nil + // Ask grpc internals to close all the corresponding connections. + b.addrCh <- nil +} + +func convertDuration(d *lbpb.Duration) time.Duration { + if d == nil { + return 0 + } + return time.Duration(d.Seconds)*time.Second + time.Duration(d.Nanos)*time.Nanosecond +} + func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { + if l == nil { + return + } servers := l.GetServers() + expiration := convertDuration(l.GetExpirationInterval()) var ( sl []*addrInfo addrs []grpc.Address ) for _, s := range servers { - // TODO: Support ExpirationInterval md := metadata.Pairs("lb-token", s.LoadBalanceToken) addr := grpc.Address{ Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port), @@ -209,11 +236,20 @@ func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { b.next = 0 b.addrs = sl b.addrCh <- addrs + if b.expTimer != nil { + b.expTimer.Stop() + b.expTimer = nil + } + if expiration > 0 { + b.expTimer = time.AfterFunc(expiration, func() { + b.serverListExpire(seq) + }) + } } return } -func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) { +func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient, seq int) (retry bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false)) @@ -226,8 +262,6 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) b.mu.Unlock() return } - b.seq++ - seq := b.seq b.mu.Unlock() initReq := &lbpb.LoadBalanceRequest{ LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ @@ -260,6 +294,14 @@ func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) if err != nil { break } + b.mu.Lock() + if b.done || seq < b.seq { + b.mu.Unlock() + return + } + b.seq++ // tick when receiving a new list of servers. + seq = b.seq + b.mu.Unlock() if serverList := reply.GetServerList(); serverList != nil { b.processServerList(serverList, seq) } @@ -326,10 +368,15 @@ func (b *balancer) Start(target string, config grpc.BalancerConfig) error { grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) return } + b.mu.Lock() + b.seq++ // tick when getting a new balancer address + seq := b.seq + b.next = 0 + b.mu.Unlock() go func(cc *grpc.ClientConn) { lbc := lbpb.NewLoadBalancerClient(cc) for { - if retry := b.callRemoteBalancer(lbc); !retry { + if retry := b.callRemoteBalancer(lbc, seq); !retry { cc.Close() return } @@ -497,6 +544,9 @@ func (b *balancer) Close() error { b.mu.Lock() defer b.mu.Unlock() b.done = true + if b.expTimer != nil { + b.expTimer.Stop() + } if b.waitCh != nil { close(b.waitCh) } diff --git a/grpclb/grpclb_test.go b/grpclb/grpclb_test.go index 3215beafc33e..f034b6ba952a 100644 --- a/grpclb/grpclb_test.go +++ b/grpclb/grpclb_test.go @@ -162,14 +162,16 @@ func (c *serverNameCheckCreds) OverrideServerName(s string) error { } type remoteBalancer struct { - servers *lbpb.ServerList - done chan struct{} + sls []*lbpb.ServerList + intervals []time.Duration + done chan struct{} } -func newRemoteBalancer(servers *lbpb.ServerList) *remoteBalancer { +func newRemoteBalancer(sls []*lbpb.ServerList, intervals []time.Duration) *remoteBalancer { return &remoteBalancer{ - servers: servers, - done: make(chan struct{}), + sls: sls, + intervals: intervals, + done: make(chan struct{}), } } @@ -186,13 +188,16 @@ func (b *remoteBalancer) BalanceLoad(stream lbpb.LoadBalancer_BalanceLoadServer) if err := stream.Send(resp); err != nil { return err } - resp = &lbpb.LoadBalanceResponse{ - LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{ - ServerList: b.servers, - }, - } - if err := stream.Send(resp); err != nil { - return err + for k, v := range b.sls { + time.Sleep(b.intervals[k]) + resp = &lbpb.LoadBalanceResponse{ + LoadBalanceResponseType: &lbpb.LoadBalanceResponse_ServerList{ + ServerList: v, + }, + } + if err := stream.Send(resp); err != nil { + return err + } } <-b.done return nil @@ -268,7 +273,9 @@ func TestGRPCLB(t *testing.T) { sl := &lbpb.ServerList{ Servers: bes, } - ls := newRemoteBalancer(sl) + sls := []*lbpb.ServerList{sl} + intervals := []time.Duration{0} + ls := newRemoteBalancer(sls, intervals) lbpb.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -343,7 +350,9 @@ func TestDropRequest(t *testing.T) { sl := &lbpb.ServerList{ Servers: bes, } - ls := newRemoteBalancer(sl) + sls := []*lbpb.ServerList{sl} + intervals := []time.Duration{0} + ls := newRemoteBalancer(sls, intervals) lbpb.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -413,7 +422,9 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { sl := &lbpb.ServerList{ Servers: bes, } - ls := newRemoteBalancer(sl) + sls := []*lbpb.ServerList{sl} + intervals := []time.Duration{0} + ls := newRemoteBalancer(sls, intervals) lbpb.RegisterLoadBalancerServer(lb, ls) go func() { lb.Serve(lbLis) @@ -439,3 +450,86 @@ func TestDropRequestFailedNonFailFast(t *testing.T) { } cc.Close() } + +func TestServerExpiration(t *testing.T) { + // Start a backend. + beLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to listen %v", err) + } + beAddr := strings.Split(beLis.Addr().String(), ":") + bePort, err := strconv.Atoi(beAddr[1]) + backends := startBackends(t, besn, beLis) + defer stopBackends(backends) + + // Start a load balancer. + lbLis, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("Failed to create the listener for the load balancer %v", err) + } + lbCreds := &serverNameCheckCreds{ + sn: lbsn, + } + lb := grpc.NewServer(grpc.Creds(lbCreds)) + if err != nil { + t.Fatalf("Failed to generate the port number %v", err) + } + be := &lbpb.Server{ + IpAddress: []byte(beAddr[0]), + Port: int32(bePort), + LoadBalanceToken: lbToken, + } + var bes []*lbpb.Server + bes = append(bes, be) + exp := &lbpb.Duration{ + Seconds: 0, + Nanos: 100000000, // 100ms + } + var sls []*lbpb.ServerList + sl := &lbpb.ServerList{ + Servers: bes, + ExpirationInterval: exp, + } + sls = append(sls, sl) + sl = &lbpb.ServerList{ + Servers: bes, + } + sls = append(sls, sl) + var intervals []time.Duration + intervals = append(intervals, 0) + intervals = append(intervals, 500*time.Millisecond) + ls := newRemoteBalancer(sls, intervals) + lbpb.RegisterLoadBalancerServer(lb, ls) + go func() { + lb.Serve(lbLis) + }() + defer func() { + ls.stop() + lb.Stop() + }() + creds := serverNameCheckCreds{ + expected: besn, + } + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + cc, err := grpc.DialContext(ctx, besn, grpc.WithBalancer(Balancer(&testNameResolver{ + addr: lbLis.Addr().String(), + })), grpc.WithBlock(), grpc.WithTransportCredentials(&creds)) + if err != nil { + t.Fatalf("Failed to dial to the backend %v", err) + } + helloC := hwpb.NewGreeterClient(cc) + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } + // Sleep and wake up when the first server list gets expired. + time.Sleep(150 * time.Millisecond) + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}); grpc.Code(err) != codes.Unavailable { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, %s", helloC, err, codes.Unavailable) + } + // A non-failfast rpc should be succeeded after the second server list is received from + // the remote load balancer. + if _, err := helloC.SayHello(context.Background(), &hwpb.HelloRequest{Name: "grpc"}, grpc.FailFast(false)); err != nil { + t.Fatalf("%v.SayHello(_, _) = _, %v, want _, ", helloC, err) + } + cc.Close() +}