Skip to content

Commit

Permalink
Remove call to proto.Clone() in http2Server.WriteStatus. (#2842)
Browse files Browse the repository at this point in the history
* Expose a method from the internal package to get to the raw
  StatusProto wrapped by the status error, and use it from
  http2Server.WriteStatus().
* Add a helper method in internal/testutils to compare two status errors
  and update test code to use that instead of reflect.DeepEqual()
  • Loading branch information
easwars authored Jun 10, 2019
1 parent b681a11 commit a5396fd
Show file tree
Hide file tree
Showing 10 changed files with 143 additions and 15 deletions.
5 changes: 5 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ var (
// ParseServiceConfig is a function to parse JSON service configs into
// opaque data structures.
ParseServiceConfig func(sc string) (interface{}, error)
// StatusRawProto is exported by status/status.go. This func returns a
// pointer to the wrapped Status proto for a given status.Status without a
// call to proto.Clone(). The returned Status proto should not be mutated by
// the caller.
StatusRawProto interface{} // func (*status.Status) *spb.Status
)

// HealthChecker defines the signature of the client-side LB channel health checking function.
Expand Down
38 changes: 38 additions & 0 deletions internal/testutils/status_equal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package testutils

import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc/status"
)

// StatusErrEqual returns true iff both err1 and err2 wrap status.Status errors
// and their underlying status protos are equal.
func StatusErrEqual(err1, err2 error) bool {
status1, ok := status.FromError(err1)
if !ok {
return false
}
status2, ok := status.FromError(err2)
if !ok {
return false
}
return proto.Equal(status1.Proto(), status2.Proto())
}
57 changes: 57 additions & 0 deletions internal/testutils/status_equal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
*
* Copyright 2019 gRPC authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

package testutils

import (
"testing"

anypb "github.com/golang/protobuf/ptypes/any"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

var statusErr = status.ErrorProto(&spb.Status{
Code: int32(codes.DataLoss),
Message: "error for testing",
Details: []*anypb.Any{{
TypeUrl: "url",
Value: []byte{6, 0, 0, 6, 1, 3},
}},
})

func TestStatusErrEqual(t *testing.T) {
tests := []struct {
name string
err1 error
err2 error
wantEqual bool
}{
{"nil errors", nil, nil, true},
{"equal OK status", status.New(codes.OK, "").Err(), status.New(codes.OK, "").Err(), true},
{"equal status errors", statusErr, statusErr, true},
{"different status errors", statusErr, status.New(codes.OK, "").Err(), false},
}

for _, test := range tests {
if gotEqual := StatusErrEqual(test.err1, test.err2); gotEqual != test.wantEqual {
t.Errorf("%v: StatusErrEqual(%v, %v) = %v, want %v", test.name, test.err1, test.err2, gotEqual, test.wantEqual)
}
}
}
7 changes: 6 additions & 1 deletion internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ import (
"golang.org/x/net/http2"
"golang.org/x/net/http2/hpack"

spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/channelz"
"google.golang.org/grpc/internal/grpcrand"
"google.golang.org/grpc/keepalive"
Expand All @@ -55,6 +57,9 @@ var (
// ErrHeaderListSizeLimitViolation indicates that the header list size is larger
// than the limit set by peer.
ErrHeaderListSizeLimitViolation = errors.New("transport: trying to send header list size larger than the limit set by peer")
// statusRawProto is a function to get to the raw status proto wrapped in a
// status.Status without a proto.Clone().
statusRawProto = internal.StatusRawProto.(func(*status.Status) *spb.Status)
)

// http2Server implements the ServerTransport interface with HTTP2.
Expand Down Expand Up @@ -817,7 +822,7 @@ func (t *http2Server) WriteStatus(s *Stream, st *status.Status) error {
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-status", Value: strconv.Itoa(int(st.Code()))})
headerFields = append(headerFields, hpack.HeaderField{Name: "grpc-message", Value: encodeGrpcMessage(st.Message())})

if p := st.Proto(); p != nil && len(p.Details) > 0 {
if p := statusRawProto(st); p != nil && len(p.Details) > 0 {
stBytes, err := proto.Marshal(p)
if err != nil {
// TODO: return error instead, when callers are able to handle it.
Expand Down
4 changes: 2 additions & 2 deletions internal/transport/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (
"io"
"math"
"net"
"reflect"
"runtime"
"strconv"
"strings"
Expand All @@ -40,6 +39,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal/leakcheck"
"google.golang.org/grpc/internal/syscall"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
)
Expand Down Expand Up @@ -1690,7 +1690,7 @@ func TestEncodingRequiredStatus(t *testing.T) {
if _, err := s.trReader.(*transportReader).Read(p); err != io.EOF {
t.Fatalf("Read got error %v, want %v", err, io.EOF)
}
if !reflect.DeepEqual(s.Status(), encodingTestStatus) {
if !testutils.StatusErrEqual(s.Status().Err(), encodingTestStatus.Err()) {
t.Fatalf("stream with status %v, want %v", s.Status(), encodingTestStatus)
}
ct.Close()
Expand Down
7 changes: 4 additions & 3 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
protoenc "google.golang.org/grpc/encoding/proto"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/transport"
"google.golang.org/grpc/status"
perfpb "google.golang.org/grpc/test/codec_perf"
Expand Down Expand Up @@ -182,10 +183,10 @@ func (s) TestToRPCErr(t *testing.T) {
} {
err := toRPCErr(test.errIn)
if _, ok := status.FromError(err); !ok {
t.Fatalf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error(codes.Unknown, ""))
t.Errorf("toRPCErr{%v} returned type %T, want %T", test.errIn, err, status.Error)
}
if !reflect.DeepEqual(err, test.errOut) {
t.Fatalf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
if !testutils.StatusErrEqual(err, test.errOut) {
t.Errorf("toRPCErr{%v} = %v \nwant %v", test.errIn, err, test.errOut)
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions status/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,15 @@ import (
"github.com/golang/protobuf/ptypes"
spb "google.golang.org/genproto/googleapis/rpc/status"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/internal"
)

func init() {
internal.StatusRawProto = statusRawProto
}

func statusRawProto(s *Status) *spb.Status { return s.s }

// statusError is an alias of a status proto. It implements error and Status,
// and a nil statusError should never be returned by this package.
type statusError spb.Status
Expand Down
18 changes: 16 additions & 2 deletions status/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,25 @@ import (
"google.golang.org/grpc/codes"
)

// errEqual is essentially a copy of testutils.StatusErrEqual(), to avoid a
// cyclic dependency.
func errEqual(err1, err2 error) bool {
status1, ok := FromError(err1)
if !ok {
return false
}
status2, ok := FromError(err2)
if !ok {
return false
}
return proto.Equal(status1.Proto(), status2.Proto())
}

func TestErrorsWithSameParameters(t *testing.T) {
const description = "some description"
e1 := Errorf(codes.AlreadyExists, description)
e2 := Errorf(codes.AlreadyExists, description)
if e1 == e2 || !reflect.DeepEqual(e1, e2) {
if e1 == e2 || !errEqual(e1, e2) {
t.Fatalf("Errors should be equivalent but unique - e1: %v, %v e2: %p, %v", e1.(*statusError), e1, e2.(*statusError), e2)
}
}
Expand Down Expand Up @@ -156,7 +170,7 @@ func TestFromErrorImplementsInterface(t *testing.T) {
t.Fatalf("FromError(%v) = %v, %v; want <Code()=%s, Message()=%q, Err()!=nil>, true", err, s, ok, code, message)
}
pd := s.Proto().GetDetails()
if len(pd) != 1 || !reflect.DeepEqual(pd[0], details[0]) {
if len(pd) != 1 || !proto.Equal(pd[0], details[0]) {
t.Fatalf("s.Proto.GetDetails() = %v; want <Details()=%s>", pd, details)
}
}
Expand Down
5 changes: 3 additions & 2 deletions test/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/grpclog"
"google.golang.org/grpc/internal/balancerload"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/resolver"
testpb "google.golang.org/grpc/test/grpc_testing"
Expand Down Expand Up @@ -162,14 +163,14 @@ func testDoneInfo(t *testing.T, e env) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
wantErr := detailedError
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) {
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
}
if _, err := tc.UnaryCall(ctx, &testpb.SimpleRequest{}); err != nil {
t.Fatalf("TestService.UnaryCall(%v, _, _, _) = _, %v; want _, <nil>", ctx, err)
}

if len(b.doneInfo) < 1 || !reflect.DeepEqual(b.doneInfo[0].Err, wantErr) {
if len(b.doneInfo) < 1 || !testutils.StatusErrEqual(b.doneInfo[0].Err, wantErr) {
t.Fatalf("b.doneInfo = %v; want b.doneInfo[0].Err = %v", b.doneInfo, wantErr)
}
if len(b.doneInfo) < 2 || !reflect.DeepEqual(b.doneInfo[1].Trailer, testTrailerMetadata) {
Expand Down
10 changes: 5 additions & 5 deletions test/end2end_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2496,7 +2496,7 @@ func testHealthCheckOnFailure(t *testing.T, e env) {

cc := te.clientConn()
wantErr := status.Error(codes.DeadlineExceeded, "context deadline exceeded")
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) {
if _, err := healthCheck(0*time.Second, cc, "grpc.health.v1.Health"); !testutils.StatusErrEqual(err, wantErr) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.DeadlineExceeded)
}
awaitNewConnLogOutput()
Expand All @@ -2517,7 +2517,7 @@ func testHealthCheckOff(t *testing.T, e env) {
te.startServer(&testServer{security: e.security})
defer te.tearDown()
want := status.Error(codes.Unimplemented, "unknown service grpc.health.v1.Health")
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) {
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !testutils.StatusErrEqual(err, want) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
}
}
Expand Down Expand Up @@ -2791,7 +2791,7 @@ func testUnknownHandler(t *testing.T, e env, unknownHandler grpc.StreamHandler)
te.startServer(&testServer{security: e.security})
defer te.tearDown()
want := status.Error(codes.Unauthenticated, "user unauthenticated")
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !reflect.DeepEqual(err, want) {
if _, err := healthCheck(1*time.Second, te.clientConn(), ""); !testutils.StatusErrEqual(err, want) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, %v", err, want)
}
}
Expand All @@ -2818,7 +2818,7 @@ func testHealthCheckServingStatus(t *testing.T, e env) {
t.Fatalf("Got the serving status %v, want SERVING", out.Status)
}
wantErr := status.Error(codes.NotFound, "unknown service")
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !reflect.DeepEqual(err, wantErr) {
if _, err := healthCheck(1*time.Second, cc, "grpc.health.v1.Health"); !testutils.StatusErrEqual(err, wantErr) {
t.Fatalf("Health/Check(_, _) = _, %v, want _, error code %s", err, codes.NotFound)
}
hs.SetServingStatus("grpc.health.v1.Health", healthpb.HealthCheckResponse_SERVING)
Expand Down Expand Up @@ -2886,7 +2886,7 @@ func testFailedEmptyUnary(t *testing.T, e env) {

ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
wantErr := detailedError
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !reflect.DeepEqual(err, wantErr) {
if _, err := tc.EmptyCall(ctx, &testpb.Empty{}); !testutils.StatusErrEqual(err, wantErr) {
t.Fatalf("TestService/EmptyCall(_, _) = _, %v, want _, %v", err, wantErr)
}
}
Expand Down

0 comments on commit a5396fd

Please sign in to comment.