Skip to content

Commit

Permalink
[v9] Check for unimplemented error during stream receive in Client.Ge…
Browse files Browse the repository at this point in the history
…tAccessRequests (#13490)

* test + fix

* Update api/client/client_test.go

Co-authored-by: Alan Parra <[email protected]>

* fix error handling

* remove first check

* Backport NewClient

Co-authored-by: Alan Parra <[email protected]>
  • Loading branch information
xacrimon and codingllama authored Jun 14, 2022
1 parent a823d2d commit 1a22034
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 11 deletions.
16 changes: 9 additions & 7 deletions api/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -748,21 +748,23 @@ func (c *Client) GetBotUsers(ctx context.Context) ([]types.User, error) {
func (c *Client) GetAccessRequests(ctx context.Context, filter types.AccessRequestFilter) ([]types.AccessRequest, error) {
stream, err := c.grpc.GetAccessRequestsV2(ctx, &filter, c.callOpts...)
if err != nil {
err := trail.FromGRPC(err)
if trace.IsNotImplemented(err) {
return c.getAccessRequestsLegacy(ctx, filter)
}

return nil, err
return nil, trail.FromGRPC(err)
}

var reqs []types.AccessRequest
for {
req, err := stream.Recv()
if err == io.EOF {
break
}

if err != nil {
return nil, trail.FromGRPC(err)
err := trail.FromGRPC(err)
if trace.IsNotImplemented(err) {
return c.getAccessRequestsLegacy(ctx, filter)
}

return nil, err
}
reqs = append(reqs, req)
}
Expand Down
71 changes: 67 additions & 4 deletions api/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ import (

// mockServer mocks an Auth Server.
type mockServer struct {
addr string
grpc *grpc.Server
*proto.UnimplementedAuthServiceServer
}

func newMockServer() *mockServer {
func newMockServer(addr string) *mockServer {
m := &mockServer{
addr: addr,
grpc: grpc.NewServer(),
UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{},
}
Expand All @@ -59,10 +61,24 @@ func startMockServer(t *testing.T) string {
l, err := net.Listen("tcp", "")
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, l.Close()) })
go newMockServer().grpc.Serve(l)
go newMockServer(l.Addr().String()).grpc.Serve(l)
return l.Addr().String()
}

func (m *mockServer) NewClient(ctx context.Context) (*Client, error) {
cfg := Config{
Addrs: []string{m.addr},
Credentials: []Credentials{
&mockInsecureTLSCredentials{},
},
DialOpts: []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
},
}

return New(ctx, cfg)
}

func (m *mockServer) Ping(ctx context.Context, req *proto.PingRequest) (*proto.PingResponse, error) {
return &proto.PingResponse{}, nil
}
Expand Down Expand Up @@ -386,7 +402,7 @@ func TestNewDialBackground(t *testing.T) {
require.Error(t, err)

// Start the server and wait for the client connection to be ready.
go newMockServer().grpc.Serve(l)
go newMockServer(l.Addr().String()).grpc.Serve(l)
require.NoError(t, clt.waitForConnectionReady(ctx))

// requests to the server should succeed.
Expand Down Expand Up @@ -424,7 +440,7 @@ func TestWaitForConnectionReady(t *testing.T) {
require.Error(t, clt.waitForConnectionReady(cancelCtx))

// WaitForConnectionReady should return nil if the server is open to connections.
go newMockServer().grpc.Serve(l)
go newMockServer(l.Addr().String()).grpc.Serve(l)
require.NoError(t, clt.waitForConnectionReady(ctx))

// WaitForConnectionReady should return an error if the grpc connection is closed.
Expand Down Expand Up @@ -713,3 +729,50 @@ func TestSetOIDCRedirectURLBackwardsCompatibility(t *testing.T) {
require.Equal(t, 1, len(connectorsResp[0].GetRedirectURLs()))
require.Equal(t, "one.example.com", connectorsResp[0].GetRedirectURLs()[0])
}

type mockAccessRequestServer struct {
*mockServer
}

func (g *mockAccessRequestServer) GetAccessRequests(ctx context.Context, f *types.AccessRequestFilter) (*proto.AccessRequests, error) {
req, err := types.NewAccessRequest("foo", "bob", "admin")
if err != nil {
return nil, trace.Wrap(err)
}

return &proto.AccessRequests{
AccessRequests: []*types.AccessRequestV3{req.(*types.AccessRequestV3)},
}, nil
}

// TestAccessRequestDowngrade tests that the client will downgrade to the non stream API for fetching access requests
// if the stream API is not available.
func TestAccessRequestDowngrade(t *testing.T) {
ctx := context.Background()
l, err := net.Listen("tcp", "")
require.NoError(t, err)

m := &mockAccessRequestServer{
&mockServer{
addr: l.Addr().String(),
grpc: grpc.NewServer(),
UnimplementedAuthServiceServer: &proto.UnimplementedAuthServiceServer{},
},
}
proto.RegisterAuthServiceServer(m.grpc, m)
t.Cleanup(m.grpc.Stop)

remoteErr := make(chan error)
go func() {
remoteErr <- m.grpc.Serve(l)
}()

clt, err := m.NewClient(ctx)
require.NoError(t, err)

items, err := clt.GetAccessRequests(ctx, types.AccessRequestFilter{})
require.NoError(t, err)
require.Len(t, items, 1)
m.grpc.Stop()
require.NoError(t, <-remoteErr)
}

0 comments on commit 1a22034

Please sign in to comment.