Skip to content

Commit

Permalink
GH-39574: [Go] Enable PollFlightInfo in Flight RPC (#39575)
Browse files Browse the repository at this point in the history
### Rationale for this change

It's impossible to use the current bindings with PollFlightInfo. Required for apache/arrow-adbc#1457.

### What changes are included in this PR?

Add new methods that expose PollFlightInfo.

### Are these changes tested?

Yes

### Are there any user-facing changes?

Adds new methods.
* Closes: #39574

Authored-by: David Li <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
lidavidm authored and kou committed Aug 30, 2024
1 parent 14b7134 commit e9730c4
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 0 deletions.
92 changes: 92 additions & 0 deletions arrow/flight/flightsql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ func flightInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, op
return cl.getFlightInfo(ctx, desc, opts...)
}

func pollInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
if retryDescriptor != nil {
return cl.Client.PollFlightInfo(ctx, retryDescriptor, opts...)
}
desc, err := descForCommand(cmd)
if err != nil {
return nil, err
}
return cl.Client.PollFlightInfo(ctx, desc, opts...)
}

func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
desc, err := descForCommand(cmd)
if err != nil {
Expand Down Expand Up @@ -123,6 +134,14 @@ func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOpt
return flightInfoForCommand(ctx, c, &cmd, opts...)
}

// ExecutePoll idempotently starts execution of a query/checks for completion.
// To check for completion, pass the FlightDescriptor from the previous call
// to ExecutePoll as the retryDescriptor.
func (c *Client) ExecutePoll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
cmd := pb.CommandStatementQuery{Query: query}
return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
}

// GetExecuteSchema gets the schema of the result set of a query without
// executing the query itself.
func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
Expand All @@ -136,6 +155,12 @@ func (c *Client) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts
return flightInfoForCommand(ctx, c, &cmd, opts...)
}

func (c *Client) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
cmd := pb.CommandStatementSubstraitPlan{
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...)
}

func (c *Client) GetExecuteSubstraitSchema(ctx context.Context, plan SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
cmd := pb.CommandStatementSubstraitPlan{
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}}
Expand Down Expand Up @@ -606,6 +631,15 @@ func (tx *Txn) Execute(ctx context.Context, query string, opts ...grpc.CallOptio
return flightInfoForCommand(ctx, tx.c, cmd, opts...)
}

func (tx *Txn) ExecutePoll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
if !tx.txn.IsValid() {
return nil, ErrInvalidTxn
}
// The server should encode the transaction into the retry descriptor
cmd := &pb.CommandStatementQuery{Query: query, TransactionId: tx.txn}
return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
}

func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) {
if !tx.txn.IsValid() {
return nil, ErrInvalidTxn
Expand All @@ -616,6 +650,18 @@ func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts ..
return flightInfoForCommand(ctx, tx.c, cmd, opts...)
}

func (tx *Txn) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
if !tx.txn.IsValid() {
return nil, ErrInvalidTxn
}
// The server should encode the transaction into the retry descriptor
cmd := &pb.CommandStatementSubstraitPlan{
Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version},
TransactionId: tx.txn,
}
return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...)
}

func (tx *Txn) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) {
if !tx.txn.IsValid() {
return nil, ErrInvalidTxn
Expand Down Expand Up @@ -981,6 +1027,52 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption
return p.client.getFlightInfo(ctx, desc, opts...)
}

// ExecutePoll executes the prepared statement on the server and returns a PollInfo
// indicating the progress of execution.
//
// Will error if already closed.
func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) {
if p.closed {
return nil, errors.New("arrow/flightsql: prepared statement already closed")
}

cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle}

desc := retryDescriptor
var err error

if desc == nil {
desc, err = descForCommand(cmd)
if err != nil {
return nil, err
}
}

if retryDescriptor == nil {
if p.hasBindParameters() {
pstream, err := p.client.Client.DoPut(ctx, opts...)
if err != nil {
return nil, err
}

wr, err := p.writeBindParameters(pstream, desc)
if err != nil {
return nil, err
}
if err = wr.Close(); err != nil {
return nil, err
}
pstream.CloseSend()

// wait for the server to ack the result
if _, err = pstream.Recv(); err != nil && err != io.EOF {
return nil, err
}
}
}
return p.client.Client.PollFlightInfo(ctx, desc, opts...)
}

// ExecuteUpdate executes the prepared statement update query on the server
// and returns the number of rows affected. If SetParameters was called,
// the parameter bindings will be sent with the request to execute.
Expand Down
54 changes: 54 additions & 0 deletions arrow/flight/flightsql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,22 @@ func (BaseServer) RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpoi
return nil, status.Error(codes.Unimplemented, "RenewFlightEndpoint not implemented")
}

func (BaseServer) PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error) {
return nil, status.Error(codes.Unimplemented, "PollFlightInfo not implemented")
}

func (BaseServer) PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) {
return nil, status.Error(codes.Unimplemented, "PollFlightInfoStatement not implemented")
}

func (BaseServer) PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) {
return nil, status.Error(codes.Unimplemented, "PollFlightInfoSubstraitPlan not implemented")
}

func (BaseServer) PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) {
return nil, status.Error(codes.Unimplemented, "PollFlightInfoPreparedStatement not implemented")
}

func (BaseServer) EndTransaction(context.Context, ActionEndTransactionRequest) error {
return status.Error(codes.Unimplemented, "EndTransaction not implemented")
}
Expand Down Expand Up @@ -652,6 +668,14 @@ type Server interface {
CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest) (flight.CancelFlightInfoResult, error)
// RenewFlightEndpoint attempts to extend the expiration of a FlightEndpoint
RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error)
// PollFlightInfo is a generic handler for PollFlightInfo requests.
PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error)
// PollFlightInfoStatement handles polling for query execution.
PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)
// PollFlightInfoSubstraitPlan handles polling for query execution.
PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error)
// PollFlightInfoPreparedStatement handles polling for query execution.
PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error)

mustEmbedBaseServer()
}
Expand Down Expand Up @@ -729,6 +753,36 @@ func (f *flightSqlServer) GetFlightInfo(ctx context.Context, request *flight.Fli
return nil, status.Error(codes.InvalidArgument, "requested command is invalid")
}

func (f *flightSqlServer) PollFlightInfo(ctx context.Context, request *flight.FlightDescriptor) (*flight.PollInfo, error) {
var (
anycmd anypb.Any
cmd proto.Message
err error
)
// If we can't parse things, be friendly and defer to the server
// implementation. This is especially important for this method since
// the server returns a custom FlightDescriptor for future requests.
if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil {
return f.srv.PollFlightInfo(ctx, request)
}

if cmd, err = anycmd.UnmarshalNew(); err != nil {
return f.srv.PollFlightInfo(ctx, request)
}

switch cmd := cmd.(type) {
case *pb.CommandStatementQuery:
return f.srv.PollFlightInfoStatement(ctx, cmd, request)
case *pb.CommandStatementSubstraitPlan:
return f.srv.PollFlightInfoSubstraitPlan(ctx, &statementSubstraitPlan{cmd}, request)
case *pb.CommandPreparedStatementQuery:
return f.srv.PollFlightInfoPreparedStatement(ctx, cmd, request)
}
// XXX: for now we won't support the other methods

return f.srv.PollFlightInfo(ctx, request)
}

func (f *flightSqlServer) GetSchema(ctx context.Context, request *flight.FlightDescriptor) (*flight.SchemaResult, error) {
var (
anycmd anypb.Any
Expand Down
60 changes: 60 additions & 0 deletions arrow/flight/flightsql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,36 @@ func (*testServer) GetFlightInfoStatement(ctx context.Context, q flightsql.State
}, nil
}

func (*testServer) PollFlightInfo(ctx context.Context, fd *flight.FlightDescriptor) (*flight.PollInfo, error) {
return &flight.PollInfo{
Info: &flight.FlightInfo{
FlightDescriptor: fd,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: []byte{}},
}, {
Ticket: &flight.Ticket{Ticket: []byte{}},
}},
},
FlightDescriptor: nil,
}, nil
}

func (*testServer) PollFlightInfoStatement(ctx context.Context, q flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.PollInfo, error) {
ticket, err := flightsql.CreateStatementQueryTicket([]byte(q.GetQuery()))
if err != nil {
return nil, err
}
return &flight.PollInfo{
Info: &flight.FlightInfo{
FlightDescriptor: fd,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: ticket},
}},
},
FlightDescriptor: &flight.FlightDescriptor{Cmd: []byte{}},
}, nil
}

func (*testServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan flight.StreamChunk, err error) {
handle := string(ticket.GetStatementHandle())
switch handle {
Expand Down Expand Up @@ -189,6 +219,20 @@ func (s *FlightSqlServerSuite) TestExecuteChunkError() {
}
}

func (s *FlightSqlServerSuite) TestExecutePoll() {
poll, err := s.cl.ExecutePoll(context.TODO(), "1", nil)
s.NoError(err)
s.NotNil(poll)
s.NotNil(poll.GetFlightDescriptor())
s.Len(poll.GetInfo().Endpoint, 1)

poll, err = s.cl.ExecutePoll(context.TODO(), "1", poll.GetFlightDescriptor())
s.NoError(err)
s.NotNil(poll)
s.Nil(poll.GetFlightDescriptor())
s.Len(poll.GetInfo().Endpoint, 2)
}

type UnimplementedFlightSqlServerSuite struct {
suite.Suite

Expand Down Expand Up @@ -314,6 +358,22 @@ func (s *UnimplementedFlightSqlServerSuite) TestGetTypeInfo() {
s.Nil(info)
}

func (s *UnimplementedFlightSqlServerSuite) TestPoll() {
poll, err := s.cl.ExecutePoll(context.TODO(), "", nil)
st, ok := status.FromError(err)
s.True(ok)
s.Equal(codes.Unimplemented, st.Code())
s.Equal("PollFlightInfoStatement not implemented", st.Message())
s.Nil(poll)

poll, err = s.cl.ExecuteSubstraitPoll(context.TODO(), flightsql.SubstraitPlan{}, nil)
st, ok = status.FromError(err)
s.True(ok)
s.Equal(codes.Unimplemented, st.Code())
s.Equal("PollFlightInfoSubstraitPlan not implemented", st.Message())
s.Nil(poll)
}

func getTicket(cmd proto.Message) *flight.Ticket {
var anycmd anypb.Any
anycmd.MarshalFrom(cmd)
Expand Down

0 comments on commit e9730c4

Please sign in to comment.