diff --git a/go/arrow/flight/flightsql/client.go b/go/arrow/flight/flightsql/client.go index c0c7e2cf20a28..89784b483b01b 100644 --- a/go/arrow/flight/flightsql/client.go +++ b/go/arrow/flight/flightsql/client.go @@ -82,6 +82,17 @@ func flightInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, op return cl.getFlightInfo(ctx, desc, opts...) } +func pollInfoForCommand(ctx context.Context, cl *Client, cmd proto.Message, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { + if retryDescriptor != nil { + return cl.Client.PollFlightInfo(ctx, retryDescriptor, opts...) + } + desc, err := descForCommand(cmd) + if err != nil { + return nil, err + } + return cl.Client.PollFlightInfo(ctx, desc, opts...) +} + func schemaForCommand(ctx context.Context, cl *Client, cmd proto.Message, opts ...grpc.CallOption) (*flight.SchemaResult, error) { desc, err := descForCommand(cmd) if err != nil { @@ -123,6 +134,14 @@ func (c *Client) Execute(ctx context.Context, query string, opts ...grpc.CallOpt return flightInfoForCommand(ctx, c, &cmd, opts...) } +// ExecutePoll idempotently starts execution of a query/checks for completion. +// To check for completion, pass the FlightDescriptor from the previous call +// to ExecutePoll as the retryDescriptor. +func (c *Client) ExecutePoll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { + cmd := pb.CommandStatementQuery{Query: query} + return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...) +} + // GetExecuteSchema gets the schema of the result set of a query without // executing the query itself. func (c *Client) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { @@ -136,6 +155,12 @@ func (c *Client) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts return flightInfoForCommand(ctx, c, &cmd, opts...) } +func (c *Client) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { + cmd := pb.CommandStatementSubstraitPlan{ + Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}} + return pollInfoForCommand(ctx, c, &cmd, retryDescriptor, opts...) +} + func (c *Client) GetExecuteSubstraitSchema(ctx context.Context, plan SubstraitPlan, opts ...grpc.CallOption) (*flight.SchemaResult, error) { cmd := pb.CommandStatementSubstraitPlan{ Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}} @@ -606,6 +631,15 @@ func (tx *Txn) Execute(ctx context.Context, query string, opts ...grpc.CallOptio return flightInfoForCommand(ctx, tx.c, cmd, opts...) } +func (tx *Txn) ExecutePoll(ctx context.Context, query string, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { + if !tx.txn.IsValid() { + return nil, ErrInvalidTxn + } + // The server should encode the transaction into the retry descriptor + cmd := &pb.CommandStatementQuery{Query: query, TransactionId: tx.txn} + return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...) +} + func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts ...grpc.CallOption) (*flight.FlightInfo, error) { if !tx.txn.IsValid() { return nil, ErrInvalidTxn @@ -616,6 +650,18 @@ func (tx *Txn) ExecuteSubstrait(ctx context.Context, plan SubstraitPlan, opts .. return flightInfoForCommand(ctx, tx.c, cmd, opts...) } +func (tx *Txn) ExecuteSubstraitPoll(ctx context.Context, plan SubstraitPlan, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { + if !tx.txn.IsValid() { + return nil, ErrInvalidTxn + } + // The server should encode the transaction into the retry descriptor + cmd := &pb.CommandStatementSubstraitPlan{ + Plan: &pb.SubstraitPlan{Plan: plan.Plan, Version: plan.Version}, + TransactionId: tx.txn, + } + return pollInfoForCommand(ctx, tx.c, cmd, retryDescriptor, opts...) +} + func (tx *Txn) GetExecuteSchema(ctx context.Context, query string, opts ...grpc.CallOption) (*flight.SchemaResult, error) { if !tx.txn.IsValid() { return nil, ErrInvalidTxn @@ -981,6 +1027,52 @@ func (p *PreparedStatement) Execute(ctx context.Context, opts ...grpc.CallOption return p.client.getFlightInfo(ctx, desc, opts...) } +// ExecutePoll executes the prepared statement on the server and returns a PollInfo +// indicating the progress of execution. +// +// Will error if already closed. +func (p *PreparedStatement) ExecutePoll(ctx context.Context, retryDescriptor *flight.FlightDescriptor, opts ...grpc.CallOption) (*flight.PollInfo, error) { + if p.closed { + return nil, errors.New("arrow/flightsql: prepared statement already closed") + } + + cmd := &pb.CommandPreparedStatementQuery{PreparedStatementHandle: p.handle} + + desc := retryDescriptor + var err error + + if desc == nil { + desc, err = descForCommand(cmd) + if err != nil { + return nil, err + } + } + + if retryDescriptor == nil { + if p.hasBindParameters() { + pstream, err := p.client.Client.DoPut(ctx, opts...) + if err != nil { + return nil, err + } + + wr, err := p.writeBindParameters(pstream, desc) + if err != nil { + return nil, err + } + if err = wr.Close(); err != nil { + return nil, err + } + pstream.CloseSend() + + // wait for the server to ack the result + if _, err = pstream.Recv(); err != nil && err != io.EOF { + return nil, err + } + } + } + return p.client.Client.PollFlightInfo(ctx, desc, opts...) +} + // ExecuteUpdate executes the prepared statement update query on the server // and returns the number of rows affected. If SetParameters was called, // the parameter bindings will be sent with the request to execute. diff --git a/go/arrow/flight/flightsql/server.go b/go/arrow/flight/flightsql/server.go index 5b1764707c298..2ec02e2829962 100644 --- a/go/arrow/flight/flightsql/server.go +++ b/go/arrow/flight/flightsql/server.go @@ -524,6 +524,22 @@ func (BaseServer) RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpoi return nil, status.Error(codes.Unimplemented, "RenewFlightEndpoint not implemented") } +func (BaseServer) PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error) { + return nil, status.Error(codes.Unimplemented, "PollFlightInfo not implemented") +} + +func (BaseServer) PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) { + return nil, status.Error(codes.Unimplemented, "PollFlightInfoStatement not implemented") +} + +func (BaseServer) PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) { + return nil, status.Error(codes.Unimplemented, "PollFlightInfoSubstraitPlan not implemented") +} + +func (BaseServer) PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) { + return nil, status.Error(codes.Unimplemented, "PollFlightInfoPreparedStatement not implemented") +} + func (BaseServer) EndTransaction(context.Context, ActionEndTransactionRequest) error { return status.Error(codes.Unimplemented, "EndTransaction not implemented") } @@ -652,6 +668,14 @@ type Server interface { CancelFlightInfo(context.Context, *flight.CancelFlightInfoRequest) (flight.CancelFlightInfoResult, error) // RenewFlightEndpoint attempts to extend the expiration of a FlightEndpoint RenewFlightEndpoint(context.Context, *flight.RenewFlightEndpointRequest) (*flight.FlightEndpoint, error) + // PollFlightInfo is a generic handler for PollFlightInfo requests. + PollFlightInfo(context.Context, *flight.FlightDescriptor) (*flight.PollInfo, error) + // PollFlightInfoStatement handles polling for query execution. + PollFlightInfoStatement(context.Context, StatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) + // PollFlightInfoSubstraitPlan handles polling for query execution. + PollFlightInfoSubstraitPlan(context.Context, StatementSubstraitPlan, *flight.FlightDescriptor) (*flight.PollInfo, error) + // PollFlightInfoPreparedStatement handles polling for query execution. + PollFlightInfoPreparedStatement(context.Context, PreparedStatementQuery, *flight.FlightDescriptor) (*flight.PollInfo, error) mustEmbedBaseServer() } @@ -729,6 +753,36 @@ func (f *flightSqlServer) GetFlightInfo(ctx context.Context, request *flight.Fli return nil, status.Error(codes.InvalidArgument, "requested command is invalid") } +func (f *flightSqlServer) PollFlightInfo(ctx context.Context, request *flight.FlightDescriptor) (*flight.PollInfo, error) { + var ( + anycmd anypb.Any + cmd proto.Message + err error + ) + // If we can't parse things, be friendly and defer to the server + // implementation. This is especially important for this method since + // the server returns a custom FlightDescriptor for future requests. + if err = proto.Unmarshal(request.Cmd, &anycmd); err != nil { + return f.srv.PollFlightInfo(ctx, request) + } + + if cmd, err = anycmd.UnmarshalNew(); err != nil { + return f.srv.PollFlightInfo(ctx, request) + } + + switch cmd := cmd.(type) { + case *pb.CommandStatementQuery: + return f.srv.PollFlightInfoStatement(ctx, cmd, request) + case *pb.CommandStatementSubstraitPlan: + return f.srv.PollFlightInfoSubstraitPlan(ctx, &statementSubstraitPlan{cmd}, request) + case *pb.CommandPreparedStatementQuery: + return f.srv.PollFlightInfoPreparedStatement(ctx, cmd, request) + } + // XXX: for now we won't support the other methods + + return f.srv.PollFlightInfo(ctx, request) +} + func (f *flightSqlServer) GetSchema(ctx context.Context, request *flight.FlightDescriptor) (*flight.SchemaResult, error) { var ( anycmd anypb.Any diff --git a/go/arrow/flight/flightsql/server_test.go b/go/arrow/flight/flightsql/server_test.go index e444da4aaf4a2..956a1714c671c 100644 --- a/go/arrow/flight/flightsql/server_test.go +++ b/go/arrow/flight/flightsql/server_test.go @@ -56,6 +56,36 @@ func (*testServer) GetFlightInfoStatement(ctx context.Context, q flightsql.State }, nil } +func (*testServer) PollFlightInfo(ctx context.Context, fd *flight.FlightDescriptor) (*flight.PollInfo, error) { + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + FlightDescriptor: fd, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{Ticket: []byte{}}, + }, { + Ticket: &flight.Ticket{Ticket: []byte{}}, + }}, + }, + FlightDescriptor: nil, + }, nil +} + +func (*testServer) PollFlightInfoStatement(ctx context.Context, q flightsql.StatementQuery, fd *flight.FlightDescriptor) (*flight.PollInfo, error) { + ticket, err := flightsql.CreateStatementQueryTicket([]byte(q.GetQuery())) + if err != nil { + return nil, err + } + return &flight.PollInfo{ + Info: &flight.FlightInfo{ + FlightDescriptor: fd, + Endpoint: []*flight.FlightEndpoint{{ + Ticket: &flight.Ticket{Ticket: ticket}, + }}, + }, + FlightDescriptor: &flight.FlightDescriptor{Cmd: []byte{}}, + }, nil +} + func (*testServer) DoGetStatement(ctx context.Context, ticket flightsql.StatementQueryTicket) (sc *arrow.Schema, cc <-chan flight.StreamChunk, err error) { handle := string(ticket.GetStatementHandle()) switch handle { @@ -189,6 +219,20 @@ func (s *FlightSqlServerSuite) TestExecuteChunkError() { } } +func (s *FlightSqlServerSuite) TestExecutePoll() { + poll, err := s.cl.ExecutePoll(context.TODO(), "1", nil) + s.NoError(err) + s.NotNil(poll) + s.NotNil(poll.GetFlightDescriptor()) + s.Len(poll.GetInfo().Endpoint, 1) + + poll, err = s.cl.ExecutePoll(context.TODO(), "1", poll.GetFlightDescriptor()) + s.NoError(err) + s.NotNil(poll) + s.Nil(poll.GetFlightDescriptor()) + s.Len(poll.GetInfo().Endpoint, 2) +} + type UnimplementedFlightSqlServerSuite struct { suite.Suite @@ -314,6 +358,22 @@ func (s *UnimplementedFlightSqlServerSuite) TestGetTypeInfo() { s.Nil(info) } +func (s *UnimplementedFlightSqlServerSuite) TestPoll() { + poll, err := s.cl.ExecutePoll(context.TODO(), "", nil) + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal("PollFlightInfoStatement not implemented", st.Message()) + s.Nil(poll) + + poll, err = s.cl.ExecuteSubstraitPoll(context.TODO(), flightsql.SubstraitPlan{}, nil) + st, ok = status.FromError(err) + s.True(ok) + s.Equal(codes.Unimplemented, st.Code()) + s.Equal("PollFlightInfoSubstraitPlan not implemented", st.Message()) + s.Nil(poll) +} + func getTicket(cmd proto.Message) *flight.Ticket { var anycmd anypb.Any anycmd.MarshalFrom(cmd)