Skip to content

Commit

Permalink
Add gRPC stream error interceptor (#4019)
Browse files Browse the repository at this point in the history
* Add gRPC stream error interceptor & UT
  • Loading branch information
wxing1292 authored Mar 14, 2023
1 parent ed47a16 commit ddd0911
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 0 deletions.
117 changes: 117 additions & 0 deletions common/rpc/interceptor/stream_error.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package interceptor

import (
"context"
"io"

"github.com/gogo/status"
"go.temporal.io/api/serviceerror"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)

type (
ClientStreamErrorInterceptor struct {
grpc.ClientStream
}
)

var _ grpc.ClientStream = (*ClientStreamErrorInterceptor)(nil)

func NewClientStreamErrorInterceptor(
clientStream grpc.ClientStream,
) *ClientStreamErrorInterceptor {
return &ClientStreamErrorInterceptor{
ClientStream: clientStream,
}
}

func (c *ClientStreamErrorInterceptor) CloseSend() error {
return errorConvert(c.ClientStream.CloseSend())
}

func (c *ClientStreamErrorInterceptor) SendMsg(m interface{}) error {
return errorConvert(c.ClientStream.SendMsg(m))
}

func (c *ClientStreamErrorInterceptor) RecvMsg(m interface{}) error {
return errorConvert(c.ClientStream.RecvMsg(m))
}

func StreamErrorInterceptor(
ctx context.Context,
desc *grpc.StreamDesc,
cc *grpc.ClientConn,
method string,
streamer grpc.Streamer,
opts ...grpc.CallOption,
) (grpc.ClientStream, error) {
clientStream, err := streamer(ctx, desc, cc, method, opts...)
if err != nil {
return nil, errorConvert(err)
}
return NewClientStreamErrorInterceptor(clientStream), nil
}

func errorConvert(err error) error {
switch err {
case nil:
return nil
case io.EOF:
return io.EOF
default:
return FromStatus(status.Convert(err))
}
}

// FromStatus converts gogo gRPC Status to service error.
func FromStatus(st *status.Status) error {
if st == nil {
return nil
}

switch st.Code() {
case codes.OK:
return nil
case codes.DeadlineExceeded:
return serviceerror.NewDeadlineExceeded(st.Message())
case codes.Canceled:
return serviceerror.NewCanceled(st.Message())
case codes.InvalidArgument:
return serviceerror.NewInvalidArgument(st.Message())
case codes.FailedPrecondition:
return serviceerror.NewFailedPrecondition(st.Message())
case codes.Unavailable:
return serviceerror.NewUnavailable(st.Message())
case codes.Internal:
return serviceerror.NewInternal(st.Message())
case codes.Unknown:
return serviceerror.NewInternal(st.Message())
default:
return serviceerror.NewInternal(st.Message())
}
}
75 changes: 75 additions & 0 deletions common/rpc/interceptor/stream_error_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// The MIT License
//
// Copyright (c) 2020 Temporal Technologies Inc. All rights reserved.
//
// Copyright (c) 2020 Uber Technologies, Inc.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
// THE SOFTWARE.

package interceptor

import (
"io"
"testing"

"github.com/gogo/status"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"go.temporal.io/api/serviceerror"
"google.golang.org/grpc/codes"
)

type (
streamErrorSuite struct {
*require.Assertions
suite.Suite
}
)

func TestStreamErrorSuite(t *testing.T) {
s := new(streamErrorSuite)
suite.Run(t, s)
}

func (s *streamErrorSuite) SetupSuite() {
s.Assertions = require.New(s.T())
}

func (s *streamErrorSuite) TearDownSuite() {
}

func (s *streamErrorSuite) SetupTest() {
}

func (s *streamErrorSuite) TearDownTest() {
}

func (s *streamErrorSuite) TestErrorConversion() {
s.Equal(nil, errorConvert(nil))
s.Equal(io.EOF, errorConvert(io.EOF))

s.IsType(nil, errorConvert(status.Error(codes.OK, "")))
s.IsType(&serviceerror.DeadlineExceeded{}, errorConvert(status.Error(codes.DeadlineExceeded, "")))
s.IsType(&serviceerror.Canceled{}, errorConvert(status.Error(codes.Canceled, "")))
s.IsType(&serviceerror.InvalidArgument{}, errorConvert(status.Error(codes.InvalidArgument, "")))
s.IsType(&serviceerror.FailedPrecondition{}, errorConvert(status.Error(codes.FailedPrecondition, "")))
s.IsType(&serviceerror.Unavailable{}, errorConvert(status.Error(codes.Unavailable, "")))
s.IsType(&serviceerror.Internal{}, errorConvert(status.Error(codes.Internal, "")))
s.IsType(&serviceerror.Internal{}, errorConvert(status.Error(codes.Unknown, "")))
}

0 comments on commit ddd0911

Please sign in to comment.