diff --git a/server/handler.go b/server/handler.go index 67d672777b..35e0718d23 100644 --- a/server/handler.go +++ b/server/handler.go @@ -17,6 +17,7 @@ package server import ( "context" "encoding/base64" + goerrors "errors" "fmt" "io" "net" @@ -609,31 +610,32 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, // resultForDefaultIter reads batches of rows from the iterator // and writes results into the callback function. -func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (r *sqltypes.Result, processedAtLeastOneBatch bool, returnErr error) { +func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema sql.Schema, iter sql.RowIter, callback func(*sqltypes.Result, bool) error, resultFields []*querypb.Field, more bool, buf *sql.ByteBuffer) (*sqltypes.Result, bool, error) { defer trace.StartRegion(ctx, "Handler.resultForDefaultIter").End() eg, ctx := ctx.NewErrgroup() - - pan2err := func() { + pan2err := func(err *error) { if recoveredPanic := recover(); recoveredPanic != nil { - returnErr = fmt.Errorf("handler caught panic: %v", recoveredPanic) + *err = goerrors.Join(*err, fmt.Errorf("handler caught panic: %v", recoveredPanic)) } } - wg := sync.WaitGroup{} wg.Add(2) + var r *sqltypes.Result + var processedAtLeastOneBatch bool + // Read rows off the row iterator and send them to the row channel. iter, projs := GetDeferredProjections(iter) var rowChan = make(chan sql.Row, 512) - eg.Go(func() error { - defer pan2err() + eg.Go(func() (err error) { + defer pan2err(&err) defer wg.Done() defer close(rowChan) for { select { case <-ctx.Done(): - return nil + return context.Cause(ctx) default: row, err := iter.Next(ctx) if err == io.EOF { @@ -651,9 +653,12 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s } }) + // TODO: poll for closed connections should obviously also run even if + // we're doing something with an OK result or a single row result, etc. + // This should be in the caller. pollCtx, cancelF := ctx.NewSubContext() - eg.Go(func() error { - defer pan2err() + eg.Go(func() (err error) { + defer pan2err(&err) return h.pollForClosedConnection(pollCtx, c) }) @@ -676,8 +681,8 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s // Reads rows from the channel, converts them to wire format, // and calls |callback| to give them to vitess. - eg.Go(func() error { - defer pan2err() + eg.Go(func() (err error) { + defer pan2err(&err) defer cancelF() defer wg.Done() for { @@ -695,7 +700,7 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s select { case <-ctx.Done(): - return nil + return context.Cause(ctx) case row, ok := <-rowChan: if !ok { return nil @@ -716,6 +721,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s ctx.GetLogger().Tracef("spooling result row %s", outputRow) r.Rows = append(r.Rows, outputRow) r.RowsAffected++ + if !timer.Stop() { + <-timer.C + } case <-timer.C: // TODO: timer should probably go in its own thread, as rowChan is blocking if h.readTimeout != 0 { @@ -724,17 +732,14 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s return ErrRowTimeout.New() } } - if !timer.Stop() { - <-timer.C - } timer.Reset(waitTime) } }) // Close() kills this PID in the process list, // wait until all rows have be sent over the wire - eg.Go(func() error { - defer pan2err() + eg.Go(func() (err error) { + defer pan2err(&err) wg.Wait() return iter.Close(ctx) }) @@ -745,9 +750,9 @@ func (h *Handler) resultForDefaultIter(ctx *sql.Context, c *mysql.Conn, schema s if verboseErrorLogging { fmt.Printf("Err: %+v", err) } - returnErr = err + return nil, false, err } - return + return r, processedAtLeastOneBatch, nil } // See https://dev.mysql.com/doc/internals/en/status-flags.html