Skip to content

Commit

Permalink
test(go/adbc/driver/flightsql): test for errors during polling (apach…
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm authored Feb 13, 2024
1 parent 34b0866 commit d462c51
Showing 1 changed file with 58 additions and 4 deletions.
62 changes: 58 additions & 4 deletions go/adbc/driver/flightsql/flightsql_adbc_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,9 @@ func (ts *ExecuteSchemaTests) TestQuery() {
type IncrementalQuery struct {
query string
nextIndex int
// if set, then return an error in the next poll and unset
// for testing the client's error handling
unavailable bool
}

type IncrementalPollTestServer struct {
Expand All @@ -451,6 +454,10 @@ type IncrementalPollTestServer struct {
testCases map[string]IncrementalPollTestCase
}

var unavailableCase = IncrementalPollTestCase{
progress: []int{1, 1},
}

func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()
Expand Down Expand Up @@ -478,27 +485,46 @@ func (srv *IncrementalPollTestServer) PollFlightInfo(ctx context.Context, desc *

testCase, ok := srv.testCases[query.query]
if !ok {
return nil, status.Errorf(codes.Unimplemented, fmt.Sprintf("Invalid case %s", query.query))
if query.query == "unavailable" {
testCase = unavailableCase
} else {
return nil, status.Errorf(codes.Unimplemented, fmt.Sprintf("Invalid case %s", query.query))
}
}

if testCase.differentRetryDescriptor && progress != int64(query.nextIndex) {
return nil, status.Errorf(codes.InvalidArgument, fmt.Sprintf("Used wrong retry descriptor, expected %d but got %d", query.nextIndex, progress))
}

if query.unavailable {
query.unavailable = false
return nil, status.Errorf(codes.Unavailable, "Server temporarily unavailable")
}

return srv.MakePollInfo(&testCase, query, queryId)
}

func (srv *IncrementalPollTestServer) PollFlightInfoStatement(ctx context.Context, query flightsql.StatementQuery, desc *flight.FlightDescriptor) (*flight.PollInfo, error) {
srv.mu.Lock()
defer srv.mu.Unlock()

queryId := uuid.New().String()

if query.GetQuery() == "unavailable" {
srv.queries[queryId] = &IncrementalQuery{
query: query.GetQuery(),
nextIndex: 0,
unavailable: true,
}

return srv.MakePollInfo(&unavailableCase, srv.queries[queryId], queryId)
}

testCase, ok := srv.testCases[query.GetQuery()]
if !ok {
return nil, status.Errorf(codes.Unimplemented, fmt.Sprintf("Invalid case %s", query.GetQuery()))
}

srv.mu.Lock()
defer srv.mu.Unlock()

srv.queries[queryId] = &IncrementalQuery{
query: query.GetQuery(),
nextIndex: 0,
Expand Down Expand Up @@ -701,6 +727,34 @@ func (ts *IncrementalPollTests) TestOptionValue() {
ts.Equal(adbc.StatusInvalidArgument, adbcErr.Code)
}

func (ts *IncrementalPollTests) TestUnavailable() {
// An error from the server should not tear down all the state. We
// should be able to retry the request.
ctx := context.Background()
stmt, err := ts.cnxn.NewStatement()
ts.NoError(err)
defer stmt.Close()

ts.NoError(stmt.SetOption(adbc.OptionKeyIncremental, adbc.OptionValueEnabled))

ts.NoError(stmt.SetSqlQuery("unavailable"))
_, partitions, _, err := stmt.ExecutePartitions(ctx)
ts.NoError(err)
ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)

_, partitions, _, err = stmt.ExecutePartitions(ctx)
ts.ErrorContains(err, "Server temporarily unavailable")
ts.Equal(uint64(0), partitions.NumPartitions)

_, partitions, _, err = stmt.ExecutePartitions(ctx)
ts.NoError(err)
ts.Equalf(uint64(1), partitions.NumPartitions, "%#v", partitions)

_, partitions, _, err = stmt.ExecutePartitions(ctx)
ts.NoError(err)
ts.Equal(uint64(0), partitions.NumPartitions)
}

func (ts *IncrementalPollTests) RunOneTestCase(ctx context.Context, stmt adbc.Statement, name string, testCase *IncrementalPollTestCase) {
opts := stmt.(adbc.GetSetOptions)

Expand Down

0 comments on commit d462c51

Please sign in to comment.