diff --git a/MAINTAINERS.md b/MAINTAINERS.md index a9ac5c8c046c..093c82b3afe8 100644 --- a/MAINTAINERS.md +++ b/MAINTAINERS.md @@ -8,20 +8,20 @@ See [CONTRIBUTING.md](https://github.com/grpc/grpc-community/blob/master/CONTRIB for general contribution guidelines. ## Maintainers (in alphabetical order) -- [canguler](https://github.com/canguler), Google Inc. -- [cesarghali](https://github.com/cesarghali), Google Inc. -- [dfawley](https://github.com/dfawley), Google Inc. -- [easwars](https://github.com/easwars), Google Inc. -- [jadekler](https://github.com/jadekler), Google Inc. -- [menghanl](https://github.com/menghanl), Google Inc. -- [srini100](https://github.com/srini100), Google Inc. +- [canguler](https://github.com/canguler), Google LLC +- [cesarghali](https://github.com/cesarghali), Google LLC +- [dfawley](https://github.com/dfawley), Google LLC +- [easwars](https://github.com/easwars), Google LLC +- [jadekler](https://github.com/jadekler), Google LLC +- [menghanl](https://github.com/menghanl), Google LLC +- [srini100](https://github.com/srini100), Google LLC ## Emeritus Maintainers (in alphabetical order) -- [adelez](https://github.com/adelez), Google Inc. -- [iamqizhao](https://github.com/iamqizhao), Google Inc. -- [jtattermusch](https://github.com/jtattermusch), Google Inc. -- [lyuxuan](https://github.com/lyuxuan), Google Inc. -- [makmukhi](https://github.com/makmukhi), Google Inc. -- [matt-kwong](https://github.com/matt-kwong), Google Inc. -- [nicolasnoble](https://github.com/nicolasnoble), Google Inc. -- [yongni](https://github.com/yongni), Google Inc. +- [adelez](https://github.com/adelez), Google LLC +- [iamqizhao](https://github.com/iamqizhao), Google LLC +- [jtattermusch](https://github.com/jtattermusch), Google LLC +- [lyuxuan](https://github.com/lyuxuan), Google LLC +- [makmukhi](https://github.com/makmukhi), Google LLC +- [matt-kwong](https://github.com/matt-kwong), Google LLC +- [nicolasnoble](https://github.com/nicolasnoble), Google LLC +- [yongni](https://github.com/yongni), Google LLC diff --git a/benchmark/benchresult/main.go b/benchmark/benchresult/main.go index ec27a830640d..2dab58cadc53 100644 --- a/benchmark/benchresult/main.go +++ b/benchmark/benchresult/main.go @@ -76,8 +76,8 @@ func compareTwoMap(m1, m2 map[string]stats.BenchResults) { changes += intChange("TotalOps", v1.Data.TotalOps, v2.Data.TotalOps) changes += intChange("SendOps", v1.Data.SendOps, v2.Data.SendOps) changes += intChange("RecvOps", v1.Data.RecvOps, v2.Data.RecvOps) - changes += intChange("Bytes/op", v1.Data.AllocedBytes, v2.Data.AllocedBytes) - changes += intChange("Allocs/op", v1.Data.Allocs, v2.Data.Allocs) + changes += floatChange("Bytes/op", v1.Data.AllocedBytes, v2.Data.AllocedBytes) + changes += floatChange("Allocs/op", v1.Data.Allocs, v2.Data.Allocs) changes += floatChange("ReqT/op", v1.Data.ReqT, v2.Data.ReqT) changes += floatChange("RespT/op", v1.Data.RespT, v2.Data.RespT) changes += timeChange("50th-Lat", v1.Data.Fiftieth, v2.Data.Fiftieth) @@ -93,9 +93,16 @@ func compareBenchmark(file1, file2 string) { compareTwoMap(createMap(file1), createMap(file2)) } -func printline(benchName, total, send, recv, allocB, allocN, reqT, respT, ltc50, ltc90, l99, lAvg interface{}) { - fmt.Printf("%-80v%12v%12v%12v%12v%12v%18v%18v%12v%12v%12v%12v\n", - benchName, total, send, recv, allocB, allocN, reqT, respT, ltc50, ltc90, l99, lAvg) +func printHeader() { + fmt.Printf("%-80s%12s%12s%12s%18s%18s%18s%18s%12s%12s%12s%12s\n", + "Name", "TotalOps", "SendOps", "RecvOps", "Bytes/op (B)", "Allocs/op (#)", + "RequestT", "ResponseT", "L-50", "L-90", "L-99", "L-Avg") +} + +func printline(benchName string, d stats.RunData) { + fmt.Printf("%-80s%12d%12d%12d%18.2f%18.2f%18.2f%18.2f%12v%12v%12v%12v\n", + benchName, d.TotalOps, d.SendOps, d.RecvOps, d.AllocedBytes, d.Allocs, + d.ReqT, d.RespT, d.Fiftieth, d.Ninetieth, d.NinetyNinth, d.Average) } func formatBenchmark(fileName string) { @@ -122,12 +129,9 @@ func formatBenchmark(fileName string) { wantFeatures[i] = !wantFeatures[i] } - printline("Name", "TotalOps", "SendOps", "RecvOps", "Alloc (B)", "Alloc (#)", - "RequestT", "ResponseT", "L-50", "L-90", "L-99", "L-Avg") + printHeader() for _, r := range results { - d := r.Data - printline(r.RunMode+r.Features.PrintableName(wantFeatures), d.TotalOps, d.SendOps, d.RecvOps, - d.AllocedBytes, d.Allocs, d.ReqT, d.RespT, d.Fiftieth, d.Ninetieth, d.NinetyNinth, d.Average) + printline(r.RunMode+r.Features.PrintableName(wantFeatures), r.Data) } } diff --git a/benchmark/stats/stats.go b/benchmark/stats/stats.go index 70972cb845ca..6829cd211401 100644 --- a/benchmark/stats/stats.go +++ b/benchmark/stats/stats.go @@ -200,9 +200,9 @@ type RunData struct { // run. Only makes sense for unconstrained workloads. RecvOps uint64 // AllocedBytes is the average memory allocation in bytes per operation. - AllocedBytes uint64 + AllocedBytes float64 // Allocs is the average number of memory allocations per operation. - Allocs uint64 + Allocs float64 // ReqT is the average request throughput associated with this run. ReqT float64 // RespT is the average response throughput associated with this run. @@ -275,8 +275,8 @@ func (s *Stats) EndRun(count uint64) { r := &s.results[len(s.results)-1] r.Data = RunData{ TotalOps: count, - AllocedBytes: s.stopMS.TotalAlloc - s.startMS.TotalAlloc, - Allocs: s.stopMS.Mallocs - s.startMS.Mallocs, + AllocedBytes: float64(s.stopMS.TotalAlloc-s.startMS.TotalAlloc) / float64(count), + Allocs: float64(s.stopMS.Mallocs-s.startMS.Mallocs) / float64(count), ReqT: float64(count) * float64(r.Features.ReqSizeBytes) * 8 / r.Features.BenchTime.Seconds(), RespT: float64(count) * float64(r.Features.RespSizeBytes) * 8 / r.Features.BenchTime.Seconds(), } @@ -296,8 +296,8 @@ func (s *Stats) EndUnconstrainedRun(req uint64, resp uint64) { r.Data = RunData{ SendOps: req, RecvOps: resp, - AllocedBytes: (s.stopMS.TotalAlloc - s.startMS.TotalAlloc) / ((req + resp) / 2), - Allocs: (s.stopMS.Mallocs - s.startMS.Mallocs) / ((req + resp) / 2), + AllocedBytes: float64(s.stopMS.TotalAlloc-s.startMS.TotalAlloc) / float64((req+resp)/2), + Allocs: float64(s.stopMS.Mallocs-s.startMS.Mallocs) / float64((req+resp)/2), ReqT: float64(req) * float64(r.Features.ReqSizeBytes) * 8 / r.Features.BenchTime.Seconds(), RespT: float64(resp) * float64(r.Features.RespSizeBytes) * 8 / r.Features.BenchTime.Seconds(), } diff --git a/credentials/alts/utils.go b/credentials/alts/utils.go index 4ed27c605b6b..f13aeef1c471 100644 --- a/credentials/alts/utils.go +++ b/credentials/alts/utils.go @@ -83,6 +83,9 @@ var ( // running on GCP. func isRunningOnGCP() bool { manufacturer, err := readManufacturer() + if os.IsNotExist(err) { + return false + } if err != nil { log.Fatalf("failure to read manufacturer information: %v", err) } diff --git a/credentials/alts/utils_test.go b/credentials/alts/utils_test.go index 3c7e43db14a0..8935c5fbec84 100644 --- a/credentials/alts/utils_test.go +++ b/credentials/alts/utils_test.go @@ -21,6 +21,7 @@ package alts import ( "context" "io" + "os" "strings" "testing" @@ -28,6 +29,34 @@ import ( "google.golang.org/grpc/peer" ) +func setupManufacturerReader(testOS string, reader func() (io.Reader, error)) func() { + tmpOS := runningOS + tmpReader := manufacturerReader + + // Set test OS and reader function. + runningOS = testOS + manufacturerReader = reader + return func() { + runningOS = tmpOS + manufacturerReader = tmpReader + } + +} + +func setup(testOS string, testReader io.Reader) func() { + reader := func() (io.Reader, error) { + return testReader, nil + } + return setupManufacturerReader(testOS, reader) +} + +func setupError(testOS string, err error) func() { + reader := func() (io.Reader, error) { + return nil, err + } + return setupManufacturerReader(testOS, reader) +} + func TestIsRunningOnGCP(t *testing.T) { for _, tc := range []struct { description string @@ -53,20 +82,12 @@ func TestIsRunningOnGCP(t *testing.T) { } } -func setup(testOS string, testReader io.Reader) func() { - tmpOS := runningOS - tmpReader := manufacturerReader - - // Set test OS and reader function. - runningOS = testOS - manufacturerReader = func() (io.Reader, error) { - return testReader, nil - } - - return func() { - runningOS = tmpOS - manufacturerReader = tmpReader +func TestIsRunningOnGCPNoProductNameFile(t *testing.T) { + reverseFunc := setupError("linux", os.ErrNotExist) + if isRunningOnGCP() { + t.Errorf("ErrNotExist: isRunningOnGCP()=true, want false") } + reverseFunc() } func TestAuthInfoFromContext(t *testing.T) { diff --git a/examples/helloworld/greeter_client/main.go b/examples/helloworld/greeter_client/main.go index 4330b9e51fc9..f908170c78e5 100644 --- a/examples/helloworld/greeter_client/main.go +++ b/examples/helloworld/greeter_client/main.go @@ -54,5 +54,5 @@ func main() { if err != nil { log.Fatalf("could not greet: %v", err) } - log.Printf("Greeting: %s", r.Message) + log.Printf("Greeting: %s", r.GetMessage()) } diff --git a/examples/helloworld/greeter_server/main.go b/examples/helloworld/greeter_server/main.go index e99fb26a3146..eac864548346 100644 --- a/examples/helloworld/greeter_server/main.go +++ b/examples/helloworld/greeter_server/main.go @@ -39,8 +39,8 @@ type server struct{} // SayHello implements helloworld.GreeterServer func (s *server) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) { - log.Printf("Received: %v", in.Name) - return &pb.HelloReply{Message: "Hello " + in.Name}, nil + log.Printf("Received: %v", in.GetName()) + return &pb.HelloReply{Message: "Hello " + in.GetName()}, nil } func main() { diff --git a/internal/transport/controlbuf.go b/internal/transport/controlbuf.go index b8e0aa4db275..ddee20b6bef2 100644 --- a/internal/transport/controlbuf.go +++ b/internal/transport/controlbuf.go @@ -107,8 +107,8 @@ func (*registerStream) isTransportResponseFrame() bool { return false } type headerFrame struct { streamID uint32 hf []hpack.HeaderField - endStream bool // Valid on server side. - initStream func(uint32) (bool, error) // Used only on the client side. + endStream bool // Valid on server side. + initStream func(uint32) error // Used only on the client side. onWrite func() wq *writeQuota // write quota for the stream created. cleanup *cleanupStream // Valid on the server side. @@ -637,21 +637,17 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error { func (l *loopyWriter) originateStream(str *outStream) error { hdr := str.itl.dequeue().(*headerFrame) - sendPing, err := hdr.initStream(str.id) - if err != nil { + if err := hdr.initStream(str.id); err != nil { if err == ErrConnClosing { return err } // Other errors(errStreamDrain) need not close transport. return nil } - if err = l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { + if err := l.writeHeader(str.id, hdr.endStream, hdr.hf, hdr.onWrite); err != nil { return err } l.estdStreams[str.id] = str - if sendPing { - return l.pingHandler(&ping{data: [8]byte{}}) - } return nil } diff --git a/internal/transport/http2_client.go b/internal/transport/http2_client.go index 39ab5c075a61..8b2be08733df 100644 --- a/internal/transport/http2_client.go +++ b/internal/transport/http2_client.go @@ -62,8 +62,6 @@ type http2Client struct { // goAway is closed to notify the upper layer (i.e., addrConn.transportMonitor) // that the server sent GoAway on this transport. goAway chan struct{} - // awakenKeepalive is used to wake up keepalive when after it has gone dormant. - awakenKeepalive chan struct{} framer *framer // controlBuf delivers all the control related tasks (e.g., window @@ -77,11 +75,9 @@ type http2Client struct { perRPCCreds []credentials.PerRPCCredentials - // Boolean to keep track of reading activity on transport. - // 1 is true and 0 is false. - activity uint32 // Accessed atomically. kp keepalive.ClientParameters keepaliveEnabled bool + lr lastRead statsHandler stats.Handler @@ -110,6 +106,16 @@ type http2Client struct { // goAwayReason records the http2.ErrCode and debug data received with the // GoAway frame. goAwayReason GoAwayReason + // A condition variable used to signal when the keepalive goroutine should + // go dormant. The condition for dormancy is based on the number of active + // streams and the `PermitWithoutStream` keepalive client parameter. And + // since the number of active streams is guarded by the above mutex, we use + // the same for this condition variable as well. + kpDormancyCond *sync.Cond + // A boolean to track whether the keepalive goroutine is dormant or not. + // This is checked before attempting to signal the above condition + // variable. + kpDormant bool // Fields below are for channelz metric collection. channelzID int64 // channelz unique identification number @@ -121,6 +127,16 @@ type http2Client struct { bufferPool *bufferPool } +type lastRead struct { + // Stores the Unix time in nanoseconds. This time cannot be directly embedded + // in the http2Client struct because this field is accessed using functions + // from the atomic package. And on 32-bit machines, it is the caller's + // responsibility to arrange for 64-bit alignment of this field. + timeNano int64 + // Channel to keep track of read activity on the transport. + ch chan struct{} +} + func dial(ctx context.Context, fn func(context.Context, string) (net.Conn, error), addr string) (net.Conn, error) { if fn != nil { return fn(ctx, addr) @@ -232,7 +248,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne readerDone: make(chan struct{}), writerDone: make(chan struct{}), goAway: make(chan struct{}), - awakenKeepalive: make(chan struct{}, 1), framer: newFramer(conn, writeBufSize, readBufSize, maxHeaderListSize), fc: &trInFlow{limit: uint32(icwz)}, scheme: scheme, @@ -252,6 +267,7 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne onClose: onClose, keepaliveEnabled: keepaliveEnabled, bufferPool: newBufferPool(), + lr: lastRead{ch: make(chan struct{}, 1)}, } t.controlBuf = newControlBuffer(t.ctxDone) if opts.InitialWindowSize >= defaultWindowSize { @@ -264,9 +280,6 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne updateFlowControl: t.updateFlowControl, } } - // Make sure awakenKeepalive can't be written upon. - // keepalive routine will make it writable, if need be. - t.awakenKeepalive <- struct{}{} if t.statsHandler != nil { t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, @@ -281,6 +294,8 @@ func newHTTP2Client(connectCtx, ctx context.Context, addr TargetInfo, opts Conne t.channelzID = channelz.RegisterNormalSocket(t, opts.ChannelzParentID, fmt.Sprintf("%s -> %s", t.localAddr, t.remoteAddr)) } if t.keepaliveEnabled { + t.kpDormancyCond = sync.NewCond(&t.mu) + go t.activityMonitor() go t.keepalive() } // Start the reader goroutine for incoming message. Each transport has @@ -564,7 +579,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea hdr := &headerFrame{ hf: headerFields, endStream: false, - initStream: func(id uint32) (bool, error) { + initStream: func(id uint32) error { t.mu.Lock() if state := t.state; state != reachable { t.mu.Unlock() @@ -574,29 +589,19 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (_ *Strea err = ErrConnClosing } cleanup(err) - return false, err + return err } t.activeStreams[id] = s if channelz.IsOn() { atomic.AddInt64(&t.czData.streamsStarted, 1) atomic.StoreInt64(&t.czData.lastStreamCreatedTime, time.Now().UnixNano()) } - var sendPing bool - // If the number of active streams change from 0 to 1, then check if keepalive - // has gone dormant. If so, wake it up. - if len(t.activeStreams) == 1 && t.keepaliveEnabled { - select { - case t.awakenKeepalive <- struct{}{}: - sendPing = true - // Fill the awakenKeepalive channel again as this channel must be - // kept non-writable except at the point that the keepalive() - // goroutine is waiting either to be awaken or shutdown. - t.awakenKeepalive <- struct{}{} - default: - } + // If the keepalive goroutine has gone dormant, wake it up. + if t.kpDormant { + t.kpDormancyCond.Signal() } t.mu.Unlock() - return sendPing, nil + return nil }, onOrphaned: cleanup, wq: s.wq, @@ -778,6 +783,11 @@ func (t *http2Client) Close() error { t.state = closing streams := t.activeStreams t.activeStreams = nil + if t.kpDormant { + // If the keepalive goroutine is blocked on this condition variable, we + // should unblock it so that the goroutine eventually exits. + t.kpDormancyCond.Signal() + } t.mu.Unlock() t.controlBuf.finish() t.cancel() @@ -1233,7 +1243,7 @@ func (t *http2Client) reader() { } t.conn.SetReadDeadline(time.Time{}) // reset deadline once we get the settings frame (we didn't time out, yay!) if t.keepaliveEnabled { - atomic.CompareAndSwapUint32(&t.activity, 0, 1) + t.lr.ch <- struct{}{} } sf, ok := frame.(*http2.SettingsFrame) if !ok { @@ -1248,7 +1258,10 @@ func (t *http2Client) reader() { t.controlBuf.throttle() frame, err := t.framer.fr.ReadFrame() if t.keepaliveEnabled { - atomic.CompareAndSwapUint32(&t.activity, 0, 1) + select { + case t.lr.ch <- struct{}{}: + default: + } } if err != nil { // Abort an active stream if the http2.Framer returns a @@ -1292,56 +1305,97 @@ func (t *http2Client) reader() { } } -// keepalive running in a separate goroutune makes sure the connection is alive by sending pings. +// activityMonitory reads from the activity channel (which is written to, when +// there is a read), and updates the lastRead.timeNano atomic. +func (t *http2Client) activityMonitor() { + for { + select { + case <-t.lr.ch: + atomic.StoreInt64(&t.lr.timeNano, time.Now().UnixNano()) + case <-t.ctx.Done(): + return + } + } +} + +func minTime(a, b time.Duration) time.Duration { + if a < b { + return a + } + return b +} + +// keepalive running in a separate goroutune makes sure the connection is alive +// by sending pings. func (t *http2Client) keepalive() { p := &ping{data: [8]byte{}} + // True iff a ping has been sent, and no data has been received since then. + outstandingPing := false + // Amount of time remaining before which we should receive an ACK for the + // last sent ping. + timeoutLeft := time.Duration(0) + // UnixNanos recorded before we go block on the timer. This is required to + // check for read activity since then. + prevNano := time.Now().UTC().UnixNano() timer := time.NewTimer(t.kp.Time) for { select { case <-timer.C: - if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { - timer.Reset(t.kp.Time) + if lastRead := atomic.LoadInt64(&t.lr.timeNano); lastRead > prevNano { + // Read activity since the last time we were here. + outstandingPing = false + prevNano = time.Now().UTC().UnixNano() + // Timer should fire at kp.Time seconds from lastRead time. + timer.Reset(time.Duration(lastRead) + t.kp.Time - time.Duration(prevNano)) continue } - // Check if keepalive should go dormant. + if outstandingPing && timeoutLeft <= 0 { + t.Close() + return + } t.mu.Lock() - if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream { - // Make awakenKeepalive writable. - <-t.awakenKeepalive - t.mu.Unlock() - select { - case <-t.awakenKeepalive: - // If the control gets here a ping has been sent - // need to reset the timer with keepalive.Timeout. - case <-t.ctx.Done(): - return - } - } else { + if t.state == closing { + // If the transport is closing, we should exit from the + // keepalive goroutine here. If not, we could have a race + // between the call to Signal() from Close() and the call to + // Wait() here, whereby the keepalive goroutine ends up + // blocking on the condition variable which will never be + // signalled again. t.mu.Unlock() + return + } + if len(t.activeStreams) < 1 && !t.kp.PermitWithoutStream { + // If a ping was sent out previously (because there were active + // streams at that point) which wasn't acked and it's timeout + // hadn't fired, but we got here and are about to go dormant, + // we should make sure that we unconditionally send a ping once + // we awaken. + outstandingPing = false + t.kpDormant = true + t.kpDormancyCond.Wait() + } + t.kpDormant = false + t.mu.Unlock() + + // We get here either because we were dormant and a new stream was + // created which unblocked the Wait() call, or because the + // keepalive timer expired. In both cases, we need to send a ping. + if !outstandingPing { if channelz.IsOn() { atomic.AddInt64(&t.czData.kpCount, 1) } - // Send ping. t.controlBuf.put(p) + timeoutLeft = t.kp.Timeout + outstandingPing = true } - - // By the time control gets here a ping has been sent one way or the other. - timer.Reset(t.kp.Timeout) - select { - case <-timer.C: - if atomic.CompareAndSwapUint32(&t.activity, 1, 0) { - timer.Reset(t.kp.Time) - continue - } - infof("transport: closing client transport due to idleness.") - t.Close() - return - case <-t.ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return - } + // The amount of time to sleep here is the minimum of kp.Time and + // timeoutLeft. This will ensure that we wait only for kp.Time + // before sending out the next ping (for cases where the ping is + // acked). + sleepDuration := minTime(t.kp.Time, timeoutLeft) + timeoutLeft -= sleepDuration + prevNano = time.Now().UTC().UnixNano() + timer.Reset(sleepDuration) case <-t.ctx.Done(): if !timer.Stop() { <-timer.C diff --git a/internal/transport/transport_test.go b/internal/transport/transport_test.go index a3a34319300f..341e56009ce8 100644 --- a/internal/transport/transport_test.go +++ b/internal/transport/transport_test.go @@ -462,6 +462,8 @@ func setUpWithNoPingServer(t *testing.T, copts ConnectOptions, done chan net.Con // TestInflightStreamClosing ensures that closing in-flight stream // sends status error to concurrent stream reader. func TestInflightStreamClosing(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{} server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, ConnectOptions{}) defer cancel() @@ -501,6 +503,8 @@ func TestInflightStreamClosing(t *testing.T) { // An idle client is one who doesn't make any RPC calls for a duration of // MaxConnectionIdle time. func TestMaxConnectionIdle(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepaliveParams: keepalive.ServerParameters{ MaxConnectionIdle: 2 * time.Second, @@ -529,6 +533,8 @@ func TestMaxConnectionIdle(t *testing.T) { // TestMaxConenctionIdleNegative tests that a server will not send GoAway to a non-idle(busy) client. func TestMaxConnectionIdleNegative(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepaliveParams: keepalive.ServerParameters{ MaxConnectionIdle: 2 * time.Second, @@ -556,6 +562,8 @@ func TestMaxConnectionIdleNegative(t *testing.T) { // TestMaxConnectionAge tests that a server will send GoAway after a duration of MaxConnectionAge. func TestMaxConnectionAge(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepaliveParams: keepalive.ServerParameters{ MaxConnectionAge: 2 * time.Second, @@ -588,6 +596,8 @@ const ( // TestKeepaliveServer tests that a server closes connection with a client that doesn't respond to keepalive pings. func TestKeepaliveServer(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepaliveParams: keepalive.ServerParameters{ Time: 2 * time.Second, @@ -632,6 +642,8 @@ func TestKeepaliveServer(t *testing.T) { // TestKeepaliveServerNegative tests that a server doesn't close connection with a client that responds to keepalive pings. func TestKeepaliveServerNegative(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepaliveParams: keepalive.ServerParameters{ Time: 2 * time.Second, @@ -653,6 +665,8 @@ func TestKeepaliveServerNegative(t *testing.T) { } func TestKeepaliveClientClosesIdleTransport(t *testing.T) { + t.Parallel() + done := make(chan net.Conn, 1) tr, cancel := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ Time: 2 * time.Second, // Keepalive time = 2 sec. @@ -677,6 +691,8 @@ func TestKeepaliveClientClosesIdleTransport(t *testing.T) { } func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { + t.Parallel() + done := make(chan net.Conn, 1) tr, cancel := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ Time: 2 * time.Second, // Keepalive time = 2 sec. @@ -700,6 +716,8 @@ func TestKeepaliveClientStaysHealthyOnIdleTransport(t *testing.T) { } func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { + t.Parallel() + done := make(chan net.Conn, 1) tr, cancel := setUpWithNoPingServer(t, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ Time: 2 * time.Second, // Keepalive time = 2 sec. @@ -728,6 +746,8 @@ func TestKeepaliveClientClosesWithActiveStreams(t *testing.T) { } func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { + t.Parallel() + s, tr, cancel := setUpWithOptions(t, 0, &ServerConfig{MaxStreams: math.MaxUint32}, normal, ConnectOptions{KeepaliveParams: keepalive.ClientParameters{ Time: 2 * time.Second, // Keepalive time = 2 sec. Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. @@ -747,16 +767,18 @@ func TestKeepaliveClientStaysHealthyWithResponsiveServer(t *testing.T) { } func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ - MinTime: 2 * time.Second, + MinTime: 5 * time.Second, }, } clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ - Time: 50 * time.Millisecond, - Timeout: 1 * time.Second, - PermitWithoutStream: true, + Time: 2 * time.Second, // Keepalive time = 2 sec. + Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. + PermitWithoutStream: true, // Run keepalive even with no RPCs. }, } server, client, cancel := setUpWithOptions(t, 0, serverConfig, normal, clientOptions) @@ -782,15 +804,17 @@ func TestKeepaliveServerEnforcementWithAbusiveClientNoRPC(t *testing.T) { } func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ - MinTime: 2 * time.Second, + MinTime: 5 * time.Second, }, } clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ - Time: 50 * time.Millisecond, - Timeout: 1 * time.Second, + Time: 2 * time.Second, // Keepalive time = 2 sec. + Timeout: 1 * time.Second, // Keepalive timeout = 1 sec. }, } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) @@ -819,16 +843,18 @@ func TestKeepaliveServerEnforcementWithAbusiveClientWithRPC(t *testing.T) { } func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ - MinTime: 100 * time.Millisecond, + MinTime: 1 * time.Second, PermitWithoutStream: true, }, } clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ - Time: 101 * time.Millisecond, - Timeout: 1 * time.Second, + Time: 2 * time.Second, + Timeout: 5 * time.Second, PermitWithoutStream: true, }, } @@ -838,7 +864,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { defer client.Close() // Give keepalive enough time. - time.Sleep(3 * time.Second) + time.Sleep(10 * time.Second) // Assert that connection is healthy. client.mu.Lock() defer client.mu.Unlock() @@ -848,15 +874,17 @@ func TestKeepaliveServerEnforcementWithObeyingClientNoRPC(t *testing.T) { } func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { + t.Parallel() + serverConfig := &ServerConfig{ KeepalivePolicy: keepalive.EnforcementPolicy{ - MinTime: 100 * time.Millisecond, + MinTime: 1 * time.Second, }, } clientOptions := ConnectOptions{ KeepaliveParams: keepalive.ClientParameters{ - Time: 101 * time.Millisecond, - Timeout: 1 * time.Second, + Time: 2 * time.Second, + Timeout: 5 * time.Second, }, } server, client, cancel := setUpWithOptions(t, 0, serverConfig, suspended, clientOptions) @@ -869,7 +897,7 @@ func TestKeepaliveServerEnforcementWithObeyingClientWithRPC(t *testing.T) { } // Give keepalive enough time. - time.Sleep(3 * time.Second) + time.Sleep(10 * time.Second) // Assert that connection is healthy. client.mu.Lock() defer client.mu.Unlock() @@ -951,18 +979,35 @@ func performOneRPC(ct ClientTransport) { func TestClientMix(t *testing.T) { s, ct, cancel := setUp(t, 0, math.MaxUint32, normal) defer cancel() + done := make(chan struct{}) + go func(s *server) { - time.Sleep(5 * time.Second) + select { + case <-done: + case <-time.After(5 * time.Second): + } s.stop() }(s) + go func(ct ClientTransport) { - <-ct.Error() + select { + case <-done: + case <-ct.Error(): + } ct.Close() }(ct) + + var wg sync.WaitGroup for i := 0; i < 1000; i++ { time.Sleep(10 * time.Millisecond) - go performOneRPC(ct) + wg.Add(1) + go func() { + performOneRPC(ct) + wg.Done() + }() } + wg.Wait() + close(done) } func TestLargeMessage(t *testing.T) { diff --git a/vet.sh b/vet.sh index 661e1e1de9b6..2bdfbc8b87e5 100755 --- a/vet.sh +++ b/vet.sh @@ -111,6 +111,7 @@ google.golang.org/grpc/balancer.go:SA1019 google.golang.org/grpc/balancer/grpclb/grpclb_remote_balancer.go:SA1019 google.golang.org/grpc/balancer/roundrobin/roundrobin_test.go:SA1019 google.golang.org/grpc/xds/internal/balancer/edsbalancer/balancergroup.go:SA1019 +google.golang.org/grpc/xds/internal/resolver/xds_resolver.go:SA1019 google.golang.org/grpc/xds/internal/balancer/xds.go:SA1019 google.golang.org/grpc/xds/internal/balancer/xds_client.go:SA1019 google.golang.org/grpc/balancer_conn_wrappers.go:SA1019 diff --git a/xds/experimental/xds_experimental.go b/xds/experimental/xds_experimental.go index 7477ea3d94c5..ff722ad54909 100644 --- a/xds/experimental/xds_experimental.go +++ b/xds/experimental/xds_experimental.go @@ -24,9 +24,12 @@ package experimental import ( "google.golang.org/grpc/balancer" + "google.golang.org/grpc/resolver" xdsbalancer "google.golang.org/grpc/xds/internal/balancer" + xdsresolver "google.golang.org/grpc/xds/internal/resolver" ) func init() { + resolver.Register(xdsresolver.NewBuilder()) balancer.Register(xdsbalancer.NewBalancerBuilder()) } diff --git a/xds/internal/balancer/xds.go b/xds/internal/balancer/xds.go index 4da30ae0025c..3f80d4111604 100644 --- a/xds/internal/balancer/xds.go +++ b/xds/internal/balancer/xds.go @@ -33,6 +33,7 @@ import ( "google.golang.org/grpc/grpclog" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/balancer/edsbalancer" "google.golang.org/grpc/xds/internal/balancer/lrs" cdspb "google.golang.org/grpc/xds/internal/proto/envoy/api/v2/cds" @@ -89,7 +90,7 @@ func (b *xdsBalancerBuilder) Name() string { } func (b *xdsBalancerBuilder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { - var cfg xdsConfig + var cfg xdsinternal.LBConfig if err := json.Unmarshal(c, &cfg); err != nil { return nil, fmt.Errorf("unable to unmarshal balancer config %s into xds config", string(c)) } @@ -130,15 +131,15 @@ type xdsBalancer struct { timer *time.Timer noSubConnAlert <-chan struct{} - client *client // may change when passed a different service config - config *xdsConfig // may change when passed a different service config + client *client // may change when passed a different service config + config *xdsinternal.LBConfig // may change when passed a different service config xdsLB edsBalancerInterface fallbackLB balancer.Balancer fallbackInitData *resolver.State // may change when HandleResolved address is called loadStore lrs.Store } -func (x *xdsBalancer) startNewXDSClient(u *xdsConfig) { +func (x *xdsBalancer) startNewXDSClient(u *xdsinternal.LBConfig) { // If the xdsBalancer is in startup stage, then we need to apply the startup timeout for the first // xdsClient to get a response from the traffic director. if x.startup { @@ -237,7 +238,7 @@ func (x *xdsBalancer) handleGRPCUpdate(update interface{}) { } } case *balancer.ClientConnState: - cfg, _ := u.BalancerConfig.(*xdsConfig) + cfg, _ := u.BalancerConfig.(*xdsinternal.LBConfig) if cfg == nil { // service config parsing failed. should never happen. return @@ -497,16 +498,16 @@ func (x *xdsBalancer) cancelFallbackAndSwitchEDSBalancerIfNecessary() { } } -func (x *xdsBalancer) buildFallBackBalancer(c *xdsConfig) { +func (x *xdsBalancer) buildFallBackBalancer(c *xdsinternal.LBConfig) { if c.FallBackPolicy == nil { - x.buildFallBackBalancer(&xdsConfig{ - FallBackPolicy: &loadBalancingConfig{ + x.buildFallBackBalancer(&xdsinternal.LBConfig{ + FallBackPolicy: &xdsinternal.LoadBalancingConfig{ Name: "round_robin", }, }) return } - // builder will always be non-nil, since when parse JSON into xdsConfig, we check whether the specified + // builder will always be non-nil, since when parse JSON into xdsinternal.LBConfig, we check whether the specified // balancer is registered or not. builder := balancer.Get(c.FallBackPolicy.Name) @@ -566,77 +567,3 @@ func createDrainedTimer() *time.Timer { } return timer } - -type xdsConfig struct { - serviceconfig.LoadBalancingConfig - BalancerName string - ChildPolicy *loadBalancingConfig - FallBackPolicy *loadBalancingConfig -} - -// When unmarshalling json to xdsConfig, we iterate through the childPolicy/fallbackPolicy lists -// and select the first LB policy which has been registered to be stored in the returned xdsConfig. -func (p *xdsConfig) UnmarshalJSON(data []byte) error { - var val map[string]json.RawMessage - if err := json.Unmarshal(data, &val); err != nil { - return err - } - for k, v := range val { - switch k { - case "balancerName": - if err := json.Unmarshal(v, &p.BalancerName); err != nil { - return err - } - case "childPolicy": - var lbcfgs []*loadBalancingConfig - if err := json.Unmarshal(v, &lbcfgs); err != nil { - return err - } - for _, lbcfg := range lbcfgs { - if balancer.Get(lbcfg.Name) != nil { - p.ChildPolicy = lbcfg - break - } - } - case "fallbackPolicy": - var lbcfgs []*loadBalancingConfig - if err := json.Unmarshal(v, &lbcfgs); err != nil { - return err - } - for _, lbcfg := range lbcfgs { - if balancer.Get(lbcfg.Name) != nil { - p.FallBackPolicy = lbcfg - break - } - } - } - } - return nil -} - -func (p *xdsConfig) MarshalJSON() ([]byte, error) { - return nil, nil -} - -type loadBalancingConfig struct { - Name string - Config json.RawMessage -} - -func (l *loadBalancingConfig) MarshalJSON() ([]byte, error) { - m := make(map[string]json.RawMessage) - m[l.Name] = l.Config - return json.Marshal(m) -} - -func (l *loadBalancingConfig) UnmarshalJSON(data []byte) error { - var cfg map[string]json.RawMessage - if err := json.Unmarshal(data, &cfg); err != nil { - return err - } - for name, config := range cfg { - l.Name = name - l.Config = config - } - return nil -} diff --git a/xds/internal/balancer/xds_lrs_test.go b/xds/internal/balancer/xds_lrs_test.go index 6fe852769ad3..77db3708247c 100644 --- a/xds/internal/balancer/xds_lrs_test.go +++ b/xds/internal/balancer/xds_lrs_test.go @@ -33,6 +33,7 @@ import ( "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" "google.golang.org/grpc/xds/internal" + xdsinternal "google.golang.org/grpc/xds/internal" basepb "google.golang.org/grpc/xds/internal/proto/envoy/api/v2/core/base" lrsgrpc "google.golang.org/grpc/xds/internal/proto/envoy/service/load_stats/v2/lrs" lrspb "google.golang.org/grpc/xds/internal/proto/envoy/service/load_stats/v2/lrs" @@ -112,9 +113,9 @@ func (s) TestXdsLoadReporting(t *testing.T) { Nanos: intervalNano, } - cfg := &xdsConfig{ + cfg := &xdsinternal.LBConfig{ BalancerName: addr, - ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, // Set this to skip cds. + ChildPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, // Set this to skip cds. } lb.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: cfg}) td.sendResp(&response{resp: testEDSRespWithoutEndpoints}) diff --git a/xds/internal/balancer/xds_test.go b/xds/internal/balancer/xds_test.go index 85bd56a16d13..5b344aac7127 100644 --- a/xds/internal/balancer/xds_test.go +++ b/xds/internal/balancer/xds_test.go @@ -31,6 +31,7 @@ import ( "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/resolver" + xdsinternal "google.golang.org/grpc/xds/internal" "google.golang.org/grpc/xds/internal/balancer/lrs" discoverypb "google.golang.org/grpc/xds/internal/proto/envoy/api/v2/discovery" edspb "google.golang.org/grpc/xds/internal/proto/envoy/api/v2/eds" @@ -62,10 +63,10 @@ const ( var ( testBalancerNameFooBar = "foo.bar" - testLBConfigFooBar = &xdsConfig{ + testLBConfigFooBar = &xdsinternal.LBConfig{ BalancerName: testBalancerNameFooBar, - ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, - FallBackPolicy: &loadBalancingConfig{Name: fakeBalancerA}, + ChildPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, + FallBackPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, } specialAddrForBalancerA = resolver.Address{Addr: "this.is.balancer.A"} @@ -178,8 +179,8 @@ type scStateChange struct { type fakeEDSBalancer struct { cc balancer.ClientConn edsChan chan *edspb.ClusterLoadAssignment - childPolicy chan *loadBalancingConfig - fallbackPolicy chan *loadBalancingConfig + childPolicy chan *xdsinternal.LoadBalancingConfig + fallbackPolicy chan *xdsinternal.LoadBalancingConfig subconnStateChange chan *scStateChange loadStore lrs.Store } @@ -199,7 +200,7 @@ func (f *fakeEDSBalancer) HandleEDSResponse(edsResp *edspb.ClusterLoadAssignment } func (f *fakeEDSBalancer) HandleChildPolicy(name string, config json.RawMessage) { - f.childPolicy <- &loadBalancingConfig{ + f.childPolicy <- &xdsinternal.LoadBalancingConfig{ Name: name, Config: config, } @@ -209,8 +210,8 @@ func newFakeEDSBalancer(cc balancer.ClientConn, loadStore lrs.Store) edsBalancer lb := &fakeEDSBalancer{ cc: cc, edsChan: make(chan *edspb.ClusterLoadAssignment, 10), - childPolicy: make(chan *loadBalancingConfig, 10), - fallbackPolicy: make(chan *loadBalancingConfig, 10), + childPolicy: make(chan *xdsinternal.LoadBalancingConfig, 10), + fallbackPolicy: make(chan *xdsinternal.LoadBalancingConfig, 10), subconnStateChange: make(chan *scStateChange, 10), loadStore: loadStore, } @@ -308,10 +309,10 @@ func (s) TestXdsBalanceHandleBalancerConfigBalancerNameUpdate(t *testing.T) { for i := 0; i < 2; i++ { addr, td, _, cleanup := setupServer(t) cleanups = append(cleanups, cleanup) - workingLBConfig := &xdsConfig{ + workingLBConfig := &xdsinternal.LBConfig{ BalancerName: addr, - ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, - FallBackPolicy: &loadBalancingConfig{Name: fakeBalancerA}, + ChildPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, + FallBackPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, } lb.UpdateClientConnState(balancer.ClientConnState{ ResolverState: resolver.State{Addresses: addrs}, @@ -364,39 +365,39 @@ func (s) TestXdsBalanceHandleBalancerConfigChildPolicyUpdate(t *testing.T) { } }() for _, test := range []struct { - cfg *xdsConfig + cfg *xdsinternal.LBConfig responseToSend *discoverypb.DiscoveryResponse - expectedChildPolicy *loadBalancingConfig + expectedChildPolicy *xdsinternal.LoadBalancingConfig }{ { - cfg: &xdsConfig{ - ChildPolicy: &loadBalancingConfig{ + cfg: &xdsinternal.LBConfig{ + ChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: fakeBalancerA, Config: json.RawMessage("{}"), }, }, responseToSend: testEDSRespWithoutEndpoints, - expectedChildPolicy: &loadBalancingConfig{ + expectedChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: string(fakeBalancerA), Config: json.RawMessage(`{}`), }, }, { - cfg: &xdsConfig{ - ChildPolicy: &loadBalancingConfig{ + cfg: &xdsinternal.LBConfig{ + ChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: fakeBalancerB, Config: json.RawMessage("{}"), }, }, - expectedChildPolicy: &loadBalancingConfig{ + expectedChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: string(fakeBalancerB), Config: json.RawMessage(`{}`), }, }, { - cfg: &xdsConfig{}, + cfg: &xdsinternal.LBConfig{}, responseToSend: testCDSResp, - expectedChildPolicy: &loadBalancingConfig{ + expectedChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: "ROUND_ROBIN", }, }, @@ -449,16 +450,16 @@ func (s) TestXdsBalanceHandleBalancerConfigFallBackUpdate(t *testing.T) { addr, td, _, cleanup := setupServer(t) - cfg := xdsConfig{ + cfg := xdsinternal.LBConfig{ BalancerName: addr, - ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, - FallBackPolicy: &loadBalancingConfig{Name: fakeBalancerA}, + ChildPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, + FallBackPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, } lb.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: &cfg}) addrs := []resolver.Address{{Addr: "1.1.1.1:10001"}, {Addr: "2.2.2.2:10002"}, {Addr: "3.3.3.3:10003"}} cfg2 := cfg - cfg2.FallBackPolicy = &loadBalancingConfig{Name: fakeBalancerB} + cfg2.FallBackPolicy = &xdsinternal.LoadBalancingConfig{Name: fakeBalancerB} lb.UpdateClientConnState(balancer.ClientConnState{ ResolverState: resolver.State{Addresses: addrs}, BalancerConfig: &cfg2, @@ -490,7 +491,7 @@ func (s) TestXdsBalanceHandleBalancerConfigFallBackUpdate(t *testing.T) { } cfg3 := cfg - cfg3.FallBackPolicy = &loadBalancingConfig{Name: fakeBalancerA} + cfg3.FallBackPolicy = &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA} lb.UpdateClientConnState(balancer.ClientConnState{ ResolverState: resolver.State{Addresses: addrs}, BalancerConfig: &cfg3, @@ -524,10 +525,10 @@ func (s) TestXdsBalancerHandlerSubConnStateChange(t *testing.T) { addr, td, _, cleanup := setupServer(t) defer cleanup() - cfg := &xdsConfig{ + cfg := &xdsinternal.LBConfig{ BalancerName: addr, - ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, - FallBackPolicy: &loadBalancingConfig{Name: fakeBalancerA}, + ChildPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, + FallBackPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, } lb.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: cfg}) @@ -602,10 +603,10 @@ func (s) TestXdsBalancerFallBackSignalFromEdsBalancer(t *testing.T) { addr, td, _, cleanup := setupServer(t) defer cleanup() - cfg := &xdsConfig{ + cfg := &xdsinternal.LBConfig{ BalancerName: addr, - ChildPolicy: &loadBalancingConfig{Name: fakeBalancerA}, - FallBackPolicy: &loadBalancingConfig{Name: fakeBalancerA}, + ChildPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, + FallBackPolicy: &xdsinternal.LoadBalancingConfig{Name: fakeBalancerA}, } lb.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: cfg}) @@ -673,12 +674,12 @@ func (s) TestXdsBalancerConfigParsingSelectingLBPolicy(t *testing.T) { if err != nil { t.Fatalf("unable to unmarshal balancer config into xds config: %v", err) } - xdsCfg := cfg.(*xdsConfig) - wantChildPolicy := &loadBalancingConfig{Name: string(fakeBalancerA), Config: json.RawMessage(`{}`)} + xdsCfg := cfg.(*xdsinternal.LBConfig) + wantChildPolicy := &xdsinternal.LoadBalancingConfig{Name: string(fakeBalancerA), Config: json.RawMessage(`{}`)} if !reflect.DeepEqual(xdsCfg.ChildPolicy, wantChildPolicy) { t.Fatalf("got child policy %v, want %v", xdsCfg.ChildPolicy, wantChildPolicy) } - wantFallbackPolicy := &loadBalancingConfig{Name: string(fakeBalancerB), Config: json.RawMessage(`{}`)} + wantFallbackPolicy := &xdsinternal.LoadBalancingConfig{Name: string(fakeBalancerB), Config: json.RawMessage(`{}`)} if !reflect.DeepEqual(xdsCfg.FallBackPolicy, wantFallbackPolicy) { t.Fatalf("got fallback policy %v, want %v", xdsCfg.FallBackPolicy, wantFallbackPolicy) } @@ -688,18 +689,18 @@ func (s) TestXdsLoadbalancingConfigParsing(t *testing.T) { tests := []struct { name string s string - want *xdsConfig + want *xdsinternal.LBConfig }{ { name: "empty", s: "{}", - want: &xdsConfig{}, + want: &xdsinternal.LBConfig{}, }, { name: "success1", s: `{"childPolicy":[{"pick_first":{}}]}`, - want: &xdsConfig{ - ChildPolicy: &loadBalancingConfig{ + want: &xdsinternal.LBConfig{ + ChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: "pick_first", Config: json.RawMessage(`{}`), }, @@ -708,8 +709,8 @@ func (s) TestXdsLoadbalancingConfigParsing(t *testing.T) { { name: "success2", s: `{"childPolicy":[{"round_robin":{}},{"pick_first":{}}]}`, - want: &xdsConfig{ - ChildPolicy: &loadBalancingConfig{ + want: &xdsinternal.LBConfig{ + ChildPolicy: &xdsinternal.LoadBalancingConfig{ Name: "round_robin", Config: json.RawMessage(`{}`), }, @@ -718,7 +719,7 @@ func (s) TestXdsLoadbalancingConfigParsing(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - var cfg xdsConfig + var cfg xdsinternal.LBConfig if err := json.Unmarshal([]byte(tt.s), &cfg); err != nil || !reflect.DeepEqual(&cfg, tt.want) { t.Errorf("test name: %s, parseFullServiceConfig() = %+v, err: %v, want %+v, ", tt.name, cfg, err, tt.want) } diff --git a/xds/internal/internal.go b/xds/internal/internal.go index 7403e3f20505..85717fd0c010 100644 --- a/xds/internal/internal.go +++ b/xds/internal/internal.go @@ -18,8 +18,11 @@ package internal import ( + "encoding/json" "fmt" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/serviceconfig" basepb "google.golang.org/grpc/xds/internal/proto/envoy/api/v2/core/base" ) @@ -48,3 +51,88 @@ func (lamk Locality) ToProto() *basepb.Locality { SubZone: lamk.SubZone, } } + +// LBConfig represents the loadBalancingConfig section of the service config +// for xDS balancers. +type LBConfig struct { + serviceconfig.LoadBalancingConfig + // BalancerName represents the load balancer to use. + BalancerName string + // ChildPolicy represents the load balancing config for the child policy. + ChildPolicy *LoadBalancingConfig + // FallBackPolicy represents the load balancing config for the fallback. + FallBackPolicy *LoadBalancingConfig +} + +// UnmarshalJSON parses the JSON-encoded byte slice in data and stores it in l. +// When unmarshalling, we iterate through the childPolicy/fallbackPolicy lists +// and select the first LB policy which has been registered. +func (l *LBConfig) UnmarshalJSON(data []byte) error { + var val map[string]json.RawMessage + if err := json.Unmarshal(data, &val); err != nil { + return err + } + for k, v := range val { + switch k { + case "balancerName": + if err := json.Unmarshal(v, &l.BalancerName); err != nil { + return err + } + case "childPolicy": + var lbcfgs []*LoadBalancingConfig + if err := json.Unmarshal(v, &lbcfgs); err != nil { + return err + } + for _, lbcfg := range lbcfgs { + if balancer.Get(lbcfg.Name) != nil { + l.ChildPolicy = lbcfg + break + } + } + case "fallbackPolicy": + var lbcfgs []*LoadBalancingConfig + if err := json.Unmarshal(v, &lbcfgs); err != nil { + return err + } + for _, lbcfg := range lbcfgs { + if balancer.Get(lbcfg.Name) != nil { + l.FallBackPolicy = lbcfg + break + } + } + } + } + return nil +} + +// MarshalJSON returns a JSON enconding of l. +func (l *LBConfig) MarshalJSON() ([]byte, error) { + return nil, nil +} + +// LoadBalancingConfig represents a single load balancing config, +// stored in JSON format. +type LoadBalancingConfig struct { + Name string + Config json.RawMessage +} + +// MarshalJSON returns a JSON enconding of l. +func (l *LoadBalancingConfig) MarshalJSON() ([]byte, error) { + m := make(map[string]json.RawMessage) + m[l.Name] = l.Config + return json.Marshal(m) +} + +// UnmarshalJSON parses the JSON-encoded byte slice in data and stores it in l. +func (l *LoadBalancingConfig) UnmarshalJSON(data []byte) error { + var cfg map[string]json.RawMessage + if err := json.Unmarshal(data, &cfg); err != nil { + return err + } + for name, config := range cfg { + l.Name = name + l.Config = config + } + return nil +} diff --git a/xds/internal/resolver/xds_resolver.go b/xds/internal/resolver/xds_resolver.go new file mode 100644 index 000000000000..a2e87738e888 --- /dev/null +++ b/xds/internal/resolver/xds_resolver.go @@ -0,0 +1,103 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package resolver implements the xds resolver. +// +// At this point, the resolver is named xds-experimental, and doesn't do very +// much at all, except for returning a hard-coded service config which selects +// the xds_experimental balancer. +package resolver + +import ( + "fmt" + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" +) + +const ( + // The JSON form of the hard-coded service config which picks the + // xds_experimental balancer with round_robin as the child policy. + jsonSC = `{ + "loadBalancingConfig":[ + { + "xds_experimental":{ + "childPolicy":[ + { + "round_robin": {} + } + ] + } + } + ] + }` + // xDS balancer name is xds_experimental while resolver scheme is + // xds-experimental since "_" is not a valid character in the URL. + xdsScheme = "xds-experimental" +) + +var ( + parseOnce sync.Once + parsedSC serviceconfig.Config +) + +// NewBuilder creates a new implementation of the resolver.Builder interface +// for the xDS resolver. +func NewBuilder() resolver.Builder { + return &xdsBuilder{} +} + +type xdsBuilder struct{} + +// Build helps implement the resolver.Builder interface. +func (b *xdsBuilder) Build(t resolver.Target, cc resolver.ClientConn, o resolver.BuildOption) (resolver.Resolver, error) { + parseOnce.Do(func() { + // The xds balancer must have been registered at this point for the service + // config to be parsed properly. + psc, err := internal.ParseServiceConfig(jsonSC) + if err != nil { + panic(fmt.Sprintf("service config %s parsing failed: %v", jsonSC, err)) + } + + var ok bool + if parsedSC, ok = psc.(*grpc.ServiceConfig); !ok { + panic(fmt.Sprintf("service config type is [%T], want [grpc.ServiceConfig]", psc)) + } + }) + + // We return a resolver which bacically does nothing. The hard-coded service + // config returned here picks the xds balancer. + cc.UpdateState(resolver.State{ServiceConfig: parsedSC}) + return &xdsResolver{}, nil +} + +// Name helps implement the resolver.Builder interface. +func (*xdsBuilder) Scheme() string { + return xdsScheme +} + +type xdsResolver struct{} + +// ResolveNow is a no-op at this point. +func (*xdsResolver) ResolveNow(o resolver.ResolveNowOption) {} + +// Close is a no-op at this point. +func (*xdsResolver) Close() {} diff --git a/xds/internal/resolver/xds_resolver_test.go b/xds/internal/resolver/xds_resolver_test.go new file mode 100644 index 000000000000..bfcbfa1404e7 --- /dev/null +++ b/xds/internal/resolver/xds_resolver_test.go @@ -0,0 +1,196 @@ +/* + * + * Copyright 2019 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package resolver + +import ( + "encoding/json" + "fmt" + "reflect" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/resolver" + "google.golang.org/grpc/serviceconfig" + xdsinternal "google.golang.org/grpc/xds/internal" +) + +// This is initialized at init time. +var fbb *fakeBalancerBuilder + +// We register a fake balancer builder and the actual xds_resolver here. We use +// the fake balancer builder to verify the service config pushed by the +// resolver. +func init() { + resolver.Register(NewBuilder()) + fbb = &fakeBalancerBuilder{ + wantLBConfig: &wrappedLBConfig{lbCfg: json.RawMessage(`{ + "childPolicy":[ + { + "round_robin": {} + } + ] + }`)}, + errCh: make(chan error), + } + balancer.Register(fbb) +} + +// testClientConn is a fake implemetation of resolver.ClientConn. All is does +// is to store the state received from the resolver locally and close the +// provided done channel. +type testClientConn struct { + done chan struct{} + gotState resolver.State +} + +func (t *testClientConn) UpdateState(s resolver.State) { + t.gotState = s + close(t.done) +} + +func (*testClientConn) NewAddress(addresses []resolver.Address) { panic("unimplemented") } +func (*testClientConn) NewServiceConfig(serviceConfig string) { panic("unimplemented") } + +// TestXDSRsolverSchemeAndAddresses creates a new xds resolver, verifies that +// it returns an empty address list and the appropriate xds-experimental +// scheme. +func TestXDSRsolverSchemeAndAddresses(t *testing.T) { + b := NewBuilder() + wantScheme := "xds-experimental" + if b.Scheme() != wantScheme { + t.Fatalf("got scheme %s, want %s", b.Scheme(), wantScheme) + } + + tcc := &testClientConn{done: make(chan struct{})} + r, err := b.Build(resolver.Target{}, tcc, resolver.BuildOption{}) + if err != nil { + t.Fatalf("xdsBuilder.Build() failed with error: %v", err) + } + defer r.Close() + + <-tcc.done + if len(tcc.gotState.Addresses) != 0 { + t.Fatalf("got address list from resolver %v, want empty list", tcc.gotState.Addresses) + } +} + +// fakeBalancer is used to verify that the xds_resolver returns the expected +// serice config. +type fakeBalancer struct { + wantLBConfig *wrappedLBConfig + errCh chan error +} + +func (*fakeBalancer) HandleSubConnStateChange(_ balancer.SubConn, _ connectivity.State) { + panic("unimplemented") +} +func (*fakeBalancer) HandleResolvedAddrs(_ []resolver.Address, _ error) { + panic("unimplemented") +} + +// UpdateClientConnState verifies that the received LBConfig matches the +// provided one, and if not, sends an error on the provided channel. +func (f *fakeBalancer) UpdateClientConnState(ccs balancer.ClientConnState) { + gotLBConfig, ok := ccs.BalancerConfig.(*wrappedLBConfig) + if !ok { + f.errCh <- fmt.Errorf("in fakeBalancer got lbConfig of type %T, want %T", ccs.BalancerConfig, &wrappedLBConfig{}) + return + } + + var gotCfg, wantCfg xdsinternal.LBConfig + if err := wantCfg.UnmarshalJSON(f.wantLBConfig.lbCfg); err != nil { + f.errCh <- fmt.Errorf("unable to unmarshal balancer config %s into xds config", string(f.wantLBConfig.lbCfg)) + return + } + if err := gotCfg.UnmarshalJSON(gotLBConfig.lbCfg); err != nil { + f.errCh <- fmt.Errorf("unable to unmarshal balancer config %s into xds config", string(gotLBConfig.lbCfg)) + return + } + if !reflect.DeepEqual(gotCfg, wantCfg) { + f.errCh <- fmt.Errorf("in fakeBalancer got lbConfig %v, want %v", gotCfg, wantCfg) + return + } + + f.errCh <- nil +} + +func (*fakeBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) { + panic("unimplemented") +} + +func (*fakeBalancer) Close() {} + +// fakeBalancerBuilder builds a fake balancer and also provides a ParseConfig +// method (which doesn't really the parse config, but just stores it as is). +type fakeBalancerBuilder struct { + wantLBConfig *wrappedLBConfig + errCh chan error +} + +func (f *fakeBalancerBuilder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + return &fakeBalancer{f.wantLBConfig, f.errCh} +} + +func (f *fakeBalancerBuilder) Name() string { + return "xds_experimental" +} + +func (f *fakeBalancerBuilder) ParseConfig(c json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + return &wrappedLBConfig{lbCfg: c}, nil +} + +// wrappedLBConfig simply wraps the provided LB config with a +// serviceconfig.LoadBalancingConfig interface. +type wrappedLBConfig struct { + serviceconfig.LoadBalancingConfig + lbCfg json.RawMessage +} + +// TestXDSRsolverServiceConfig verifies that the xds_resolver returns the +// expected service config. +// +// The following sequence of events happen in this test: +// * The xds_experimental balancer (fake) and resolver builders are initialized +// at init time. +// * We dial a dummy address here with the xds-experimental scheme. This should +// pick the xds_resolver, which should return the hard-coded service config, +// which should reach the fake balancer that we registered (because the +// service config asks for the xds balancer). +// * In the fake balancer, we verify that we receive the expected LB config. +func TestXDSRsolverServiceConfig(t *testing.T) { + xdsAddr := fmt.Sprintf("%s:///dummy", xdsScheme) + cc, err := grpc.Dial(xdsAddr, grpc.WithInsecure()) + if err != nil { + t.Fatalf("grpc.Dial(%s) failed with error: %v", xdsAddr, err) + } + defer cc.Close() + + timer := time.NewTimer(5 * time.Second) + select { + case <-timer.C: + t.Fatal("timed out waiting for service config to reach balancer") + case err := <-fbb.errCh: + if err != nil { + t.Error(err) + } + } +}