Skip to content

Commit

Permalink
Merge pull request #2850 from dolthub/aaron/processlist-kill-on-prepare
Browse files Browse the repository at this point in the history
processlist: Allow for killing the context associated with non-query operations like SetDB and Prepare.
  • Loading branch information
reltuk authored Feb 14, 2025
2 parents b3a4c87 + 6887d52 commit 1e0e5e8
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 10 deletions.
40 changes: 39 additions & 1 deletion processlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,40 @@ func (pl *ProcessList) EndQuery(ctx *sql.Context) {
}
}

// Registers the process and session associated with |ctx| as performing
// a long-running operation that should be able to be canceled with Kill.
//
// This is not used for Query processing --- the process is still in
// CommandSleep, it does not have a QueryPid, etc. Must always be
// bracketed with EndOperation(). Should certainly be used for any
// Handler callbacks which may access the database, like Prepare.
func (pl *ProcessList) BeginOperation(ctx *sql.Context) (*sql.Context, error) {
pl.mu.Lock()
defer pl.mu.Unlock()
id := ctx.Session.ID()
p := pl.procs[id]
if p == nil {
return nil, errors.New("internal error: connection not registered with process list")
}
if p.Kill != nil {
return nil, errors.New("internal error: attempt to begin operation on connection which was already running one")
}
newCtx, cancel := ctx.NewSubContext()
p.Kill = cancel
return newCtx, nil
}

func (pl *ProcessList) EndOperation(ctx *sql.Context) {
pl.mu.Lock()
defer pl.mu.Unlock()
id := ctx.Session.ID()
p := pl.procs[id]
if p != nil && p.Kill != nil {
p.Kill()
p.Kill = nil
}
}

// UpdateTableProgress updates the progress of the table with the given name for the
// process with the given pid.
func (pl *ProcessList) UpdateTableProgress(pid uint64, name string, delta int64) {
Expand Down Expand Up @@ -322,7 +356,11 @@ func (pl *ProcessList) Kill(connID uint32) {

p := pl.procs[connID]
if p != nil && p.Kill != nil {
logrus.Infof("kill query: pid %d", p.QueryPid)
if p.QueryPid != 0 {
logrus.Infof("kill query: pid %d", p.QueryPid)
} else {
logrus.Infof("canceling context: connID %d", connID)
}
p.Kill()
}
}
Expand Down
53 changes: 53 additions & 0 deletions processlist_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,59 @@ func TestKillConnection(t *testing.T) {
require.False(t, killed[2])
}

func TestBeginEndOperation(t *testing.T) {
knownSession := sql.NewBaseSessionWithClientServer("", sql.Client{}, 1)
unknownSession := sql.NewBaseSessionWithClientServer("", sql.Client{}, 2)

pl := NewProcessList()
pl.AddConnection(1, "")

// Begining an operation with an unknown connection returns an error.
ctx := sql.NewContext(context.Background(), sql.WithSession(unknownSession))
_, err := pl.BeginOperation(ctx)
require.Error(t, err)

// Can begin and end operation before connection is ready.
ctx = sql.NewContext(context.Background(), sql.WithSession(knownSession))
subCtx, err := pl.BeginOperation(ctx)
require.NoError(t, err)
pl.EndOperation(subCtx)

// Can begin and end operation across the connection ready boundary.
subCtx, err = pl.BeginOperation(ctx)
require.NoError(t, err)
pl.ConnectionReady(knownSession)
pl.EndOperation(subCtx)

// Ending the operation cancels the subcontext.
subCtx, err = pl.BeginOperation(ctx)
require.NoError(t, err)
done := make(chan struct{})
context.AfterFunc(subCtx, func() {
close(done)
})
pl.EndOperation(subCtx)
<-done

// Kill on the connection cancels the subcontext.
subCtx, err = pl.BeginOperation(ctx)
require.NoError(t, err)
done = make(chan struct{})
context.AfterFunc(subCtx, func() {
close(done)
})
pl.Kill(1)
<-done
pl.EndOperation(subCtx)

// Beginning an operation while one is outstanding errors.
subCtx, err = pl.BeginOperation(ctx)
require.NoError(t, err)
_, err = pl.BeginOperation(ctx)
require.Error(t, err)
pl.EndOperation(subCtx)
}

// TestSlowQueryTracking tests that processes that take longer than @@long_query_time increment the
// Slow_queries status variable.
func TestSlowQueryTracking(t *testing.T) {
Expand Down
5 changes: 5 additions & 0 deletions server/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,11 @@ func (s *SessionManager) SetDB(conn *mysql.Conn, dbName string) error {
defer sql.SessionCommandEnd(sess)

ctx := sql.NewContext(context.Background(), sql.WithSession(sess))
ctx, err = s.processlist.BeginOperation(ctx)
if err != nil {
return err
}
defer s.processlist.EndOperation(ctx)
var db sql.Database
if dbName != "" {
db, err = s.getDbFunc(ctx, dbName)
Expand Down
32 changes: 23 additions & 9 deletions server/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ func (h *Handler) ConnectionAborted(_ *mysql.Conn, _ string) error {
}

func (h *Handler) ComInitDB(c *mysql.Conn, schemaName string) error {
// SetDB itself handles session and processlist operation lifecycle callbacks.
err := h.sm.SetDB(c, schemaName)
if err != nil {
logrus.WithField("database", schemaName).Errorf("unable to process ComInitDB: %s", err.Error())
Expand All @@ -121,6 +122,11 @@ func (h *Handler) ComPrepare(ctx context.Context, c *mysql.Conn, query string, p
if err != nil {
return nil, err
}
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
if err != nil {
return nil, err
}
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, err
Expand Down Expand Up @@ -166,7 +172,11 @@ func (h *Handler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, query str
if err != nil {
return nil, nil, err
}

sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
if err != nil {
return nil, nil, err
}
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -201,6 +211,11 @@ func (h *Handler) ComBind(ctx context.Context, c *mysql.Conn, query string, pars
if err != nil {
return nil, nil, err
}
sqlCtx, err = sqlCtx.ProcessList.BeginOperation(sqlCtx)
if err != nil {
return nil, nil, err
}
defer sqlCtx.ProcessList.EndOperation(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -395,6 +410,13 @@ func (h *Handler) doQuery(
if err != nil {
return "", err
}
// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
// marked done until we're done spooling rows over the wire
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
if err != nil {
return remainder, err
}
defer sqlCtx.ProcessList.EndQuery(sqlCtx)
err = sql.SessionCommandBegin(sqlCtx.Session)
if err != nil {
return "", err
Expand Down Expand Up @@ -439,14 +461,6 @@ func (h *Handler) doQuery(

sqlCtx.GetLogger().Tracef("beginning execution")

// TODO: it would be nice to put this logic in the engine, not the handler, but we don't want the process to be
// marked done until we're done spooling rows over the wire
sqlCtx, err = sqlCtx.ProcessList.BeginQuery(sqlCtx, query)
if err != nil {
return remainder, err
}
defer sqlCtx.ProcessList.EndQuery(sqlCtx)

var schema sql.Schema
var rowIter sql.RowIter
qFlags.Set(sql.QFlagDeferProjections)
Expand Down
3 changes: 3 additions & 0 deletions server/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ func TestServerEventListener(t *testing.T) {
require.Equal(listener.Disconnects, 2)

conn3 := newConn(3)
handler.NewConnection(conn3)
query := "SELECT ?"
_, err = handler.ComPrepare(context.Background(), conn3, query, samplePrepareData)
require.NoError(err)
Expand Down Expand Up @@ -1165,6 +1166,8 @@ func TestHandlerFoundRowsCapabilities(t *testing.T) {
),
}

handler.NewConnection(dummyConn)

tests := []struct {
name string
handler *Handler
Expand Down
15 changes: 15 additions & 0 deletions sql/processlist.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,17 @@ type ProcessList interface {
// EndQuery transitions a previously transitioned connection from Command "Query" to Command "Sleep".
EndQuery(ctx *Context)

// BeginOperation registers and returns a SubContext for a
// long-running operation on the conneciton which does not
// change the process's Command state. This SubContext will be
// killed by a call to |Kill|, and unregistered by a call to
// |EndOperation|.
BeginOperation(ctx *Context) (*Context, error)

// EndOperation cancels and deregisters the SubContext which
// BeginOperation registered.
EndOperation(ctx *Context)

// Kill terminates all queries for a given connection id
Kill(connID uint32)

Expand Down Expand Up @@ -166,6 +177,10 @@ func (e EmptyProcessList) BeginQuery(ctx *Context, query string) (*Context, erro
return ctx, nil
}
func (e EmptyProcessList) EndQuery(ctx *Context) {}
func (e EmptyProcessList) BeginOperation(ctx *Context) (*Context, error) {
return ctx, nil
}
func (e EmptyProcessList) EndOperation(ctx *Context) {}

func (e EmptyProcessList) Kill(connID uint32) {}
func (e EmptyProcessList) Done(pid uint64) {}
Expand Down

0 comments on commit 1e0e5e8

Please sign in to comment.