Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

processlist: Allow for killing the context associated with non-query operations like SetDB and Prepare. #2850

Merged
merged 2 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion processlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,40 @@ func (pl *ProcessList) EndQuery(ctx *sql.Context) {
}
}

// Registers the process and session associated with |ctx| as performing
// a long-running operation that should be able to be canceled with Kill.
//
// This is not used for Query processing --- the process is still in
// CommandSleep, it does not have a QueryPid, etc. Must always be
// bracketed with EndOperation(). Should certainly be used for any
// Handler callbacks which may access the database, like Prepare.
func (pl *ProcessList) BeginOperation(ctx *sql.Context) (*sql.Context, error) {
pl.mu.Lock()
defer pl.mu.Unlock()
id := ctx.Session.ID()
p := pl.procs[id]
if p == nil {
return nil, errors.New("internal error: connection not registered with process list")
}
if p.Kill != nil {
return nil, errors.New("internal error: attempt to begin operation on connection which was already running one")
}
newCtx, cancel := ctx.NewSubContext()
p.Kill = cancel
return newCtx, nil
}

func (pl *ProcessList) EndOperation(ctx *sql.Context) {
pl.mu.Lock()
defer pl.mu.Unlock()
id := ctx.Session.ID()
p := pl.procs[id]
if p != nil && p.Kill != nil {
p.Kill()
p.Kill = nil
}
}

// UpdateTableProgress updates the progress of the table with the given name for the
// process with the given pid.
func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) {
Expand Down Expand Up @@ -322,7 +356,11 @@ func (pl *ProcessList) Kill(connID uint32) {

p := pl.procs[connID]
if p != nil && p.Kill != nil {
logrus.Infof("kill query: pid %d", p.QueryPid)
if p.QueryPid != 0 {
logrus.Infof("kill query: pid %d", p.QueryPid)
} else {
logrus.Infof("canceling context: connID %d", connID)
}
p.Kill()
}
}
Expand Down
5 changes: 5 additions & 0 deletions server/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
defer sql.SessionCommandEnd(sess)

ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
ctx, err = s.processlist.BeginOperation(ctx)
if err != nil {
return err
}
defer s.processlist.EndOperation(ctx)
var db sql.Database
if dbName != "" {
db, err = s.getDbFunc(ctx, dbName)
Expand Down
32 changes: 23 additions & 9 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func (h *Handler) ConnectionAborted(_ *mysql.Conn, _ string) error {
}

func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
// SetDB itself handles session and processlist operation lifecycle callbacks.
err := h.sm.SetDB(c, schemaName)
if err != nil {
logrus.WithField("database", schemaName).Errorf("unable to process ComInitDB: %s", err.Error())
Expand All @@ -121,6 +122,11 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
if err != nil {
return nil, err
}
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
if err != nil {
return nil, err
}
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, err
Expand Down Expand Up @@ -166,7 +172,11 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
if err != nil {
return nil, nil, err
}

sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
if err != nil {
return nil, nil, err
}
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -201,6 +211,11 @@ func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, pars
if err != nil {
return nil, nil, err
}
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
if err != nil {
return nil, nil, err
}
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -395,6 +410,13 @@ func (h *Handler) doQuery(
if err != nil {
return "", err
}
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
// marked done until we're done spooling rows over the wire
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
if err != nil {
return remainder, err
}
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return "", err
Expand Down Expand Up @@ -439,14 +461,6 @@ func (h *Handler) doQuery(

sqlCtx.GetLogger().Tracef("beginning execution")

// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
// marked done until we're done spooling rows over the wire
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
if err != nil {
return remainder, err
}
defer sqlCtx.ProcessList.EndQuery(sqlCtx)

var schema sql.Schema
var rowIter sql.RowIter
qFlags.Set(sql.QFlagDeferProjections)
Expand Down
3 changes: 3 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ func TestServerEventListener(t *testing.T) {
require.Equal(listener.Disconnects, 2)

conn3 := newConn(3)
handler.NewConnection(conn3)
query := "SELECT ?"
_, err = handler.ComPrepare(context.Background(), conn3, query, samplePrepareData)
require.NoError(err)
Expand Down Expand Up @@ -1165,6 +1166,8 @@ func TestHandlerFoundRowsCapabilities(t *testing.T) {
),
}

handler.NewConnection(dummyConn)

tests := []struct {
name string
handler *Handler
Expand Down
15 changes: 15 additions & 0 deletions sql/processlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ type ProcessList interface {
// EndQuery transitions a previously transitioned connection from Command "Query" to Command "Sleep".
EndQuery(ctx *Context)

// BeginOperation registers and returns a SubContext for a
// long-running operation on the conneciton which does not
// change the process's Command state. This SubContext will be
// killed by a call to |Kill|, and unregistered by a call to
// |EndOperation|.
BeginOperation(ctx *Context) (*Context, error)

// EndOperation cancels and deregisters the SubContext which
// BeginOperation registered.
EndOperation(ctx *Context)

// Kill terminates all queries for a given connection id
Kill(connID uint32)

Expand Down Expand Up @@ -166,6 +177,10 @@ func (e EmptyProcessList) BeginQuery(ctx *Context, query string) (*Context, erro
return ctx, nil
}
func (e EmptyProcessList) EndQuery(ctx *Context) {}
func (e EmptyProcessList) BeginOperation(ctx *Context) (*Context, error) {
return ctx, nil
}
func (e EmptyProcessList) EndOperation(ctx *Context) {}

func (e EmptyProcessList) Kill(connID uint32) {}
func (e EmptyProcessList) Done(pid uint64) {}
Expand Down
Loading