Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

server/handler.go: Improve some edge cases and error handling in resultForDefaultIter. #2881

Merged
merged 1 commit into from
Mar 6, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 25 additions & 20 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package server
import (
"context"
"encoding/base64"
goerrors "errors"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -609,31 +610,32 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,

// resultForDefaultIter reads batches of rows from the iterator
// and writes results into the callback function.
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) {
func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (*sqltypes.Result, bool, error) {
defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End()

eg, ctx := ctx.NewErrgroup()

pan2err := func() {
pan2err := func(err *error) {
if recoveredPanic := recover(); recoveredPanic != nil {
returnErr = fmt.Errorf("handler caught panic: %v", recoveredPanic)
*err = goerrors.Join(*err, fmt.Errorf("handler caught panic: %v", recoveredPanic))
}
}

wg := sync.WaitGroup{}
wg.Add(2)

var r *sqltypes.Result
var processedAtLeastOneBatch bool

// Read rows off the row iterator and send them to the row channel.
iter, projs := GetDeferredProjections(iter)
var rowChan = make(chan sql.Row, 512)
eg.Go(func() error {
defer pan2err()
eg.Go(func() (err error) {
defer pan2err(&err)
defer wg.Done()
defer close(rowChan)
for {
select {
case <-ctx.Done():
return nil
return context.Cause(ctx)
default:
row, err := iter.Next(ctx)
if err == io.EOF {
Expand All @@ -651,9 +653,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
}
})

// TODO: poll for closed connections should obviously also run even if
// we're doing something with an OK result or a single row result, etc.
// This should be in the caller.
pollCtx, cancelF := ctx.NewSubContext()
eg.Go(func() error {
defer pan2err()
eg.Go(func() (err error) {
defer pan2err(&err)
return h.pollForClosedConnection(pollCtx, c)
})

Expand All @@ -676,8 +681,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s

// Reads rows from the channel, converts them to wire format,
// and calls |callback| to give them to vitess.
eg.Go(func() error {
defer pan2err()
eg.Go(func() (err error) {
defer pan2err(&err)
defer cancelF()
defer wg.Done()
for {
Expand All @@ -695,7 +700,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s

select {
case <-ctx.Done():
return nil
return context.Cause(ctx)
case row, ok := <-rowChan:
if !ok {
return nil
Expand All @@ -716,6 +721,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
ctx.GetLogger().Tracef("spooling result row %s", outputRow)
r.Rows = append(r.Rows, outputRow)
r.RowsAffected++
if !timer.Stop() {
<-timer.C
}
case <-timer.C:
// TODO: timer should probably go in its own thread, as rowChan is blocking
if h.readTimeout != 0 {
Expand All @@ -724,17 +732,14 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
return ErrRowTimeout.New()
}
}
if !timer.Stop() {
<-timer.C
}
timer.Reset(waitTime)
}
})

// Close() kills this PID in the process list,
// wait until all rows have be sent over the wire
eg.Go(func() error {
defer pan2err()
eg.Go(func() (err error) {
defer pan2err(&err)
wg.Wait()
return iter.Close(ctx)
})
Expand All @@ -745,9 +750,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s
if verboseErrorLogging {
fmt.Printf("Err: %+v", err)
}
returnErr = err
return nil, false, err
}
return
return r, processedAtLeastOneBatch, nil
}

// See https://dev.mysql.com/doc/internals/en/status-flags.html
Expand Down
Loading