From 129af1156cc18412e4e6d79aa0d0698d62f7d865 Mon Sep 17 00:00:00 2001 From: Evan Wall Date: Tue, 24 Jan 2023 09:08:07 -0500 Subject: [PATCH] sql: remove redundant session iteration Fixes #95743 Improves session/query cancelation with the following 1) Replaces session scanning by session ID with map lookup. 2) Replaces active query scanning by query ID with map lookup (session containing query to cancel is still scanned for). 3) Does not serialize entire session to get session username or id. Informs #77676 77676 was closed but some test cases incorrectly mentioned that addressing 77676 fixed them. This PR correctly fixes these test cases. Release note: None --- .../serverccl/statusccl/tenant_status_test.go | 65 +++----- pkg/server/status.go | 150 +++++++----------- pkg/sql/conn_executor.go | 29 +++- pkg/sql/conn_executor_exec.go | 4 +- pkg/sql/exec_util.go | 94 ++++------- 5 files changed, 133 insertions(+), 209 deletions(-) diff --git a/pkg/ccl/serverccl/statusccl/tenant_status_test.go b/pkg/ccl/serverccl/statusccl/tenant_status_test.go index 28bb645b3695..5859b8b4015a 100644 --- a/pkg/ccl/serverccl/statusccl/tenant_status_test.go +++ b/pkg/ccl/serverccl/statusccl/tenant_status_test.go @@ -943,30 +943,22 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T testCases := []struct { sessionID string expectedError string - - // This is a temporary assertion. We should always show the following "not found" error messages, - // regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676. - nonAdminSeesError bool }{ { - sessionID: "", - expectedError: "session ID 00000000000000000000000000000000 not found", - nonAdminSeesError: true, + sessionID: "", + expectedError: "session ID 00000000000000000000000000000000 not found", }, { - sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. - expectedError: "session ID 00000000000000000000000000000001 not found", - nonAdminSeesError: false, + sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. + expectedError: "session ID 00000000000000000000000000000001 not found", }, { - sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. - expectedError: "session ID 00000000000000000000000000000002 not found", - nonAdminSeesError: false, + sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. + expectedError: "session ID 00000000000000000000000000000002 not found", }, { - sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. - expectedError: "session ID 00000000000000000000000000000042 not found", - nonAdminSeesError: true, + sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. + expectedError: "session ID 00000000000000000000000000000042 not found", }, } @@ -982,12 +974,8 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T err = client.PostJSONChecked("/_status/cancel_session/0", &serverpb.CancelSessionRequest{ SessionID: sessionID.GetBytes(), }, &resp) - if isAdmin || testCase.nonAdminSeesError { - require.NoError(t, err) - require.Equal(t, testCase.expectedError, resp.Error) - } else { - require.Error(t, err) - } + require.NoError(t, err) + require.Equal(t, testCase.expectedError, resp.Error) }) } }) @@ -1072,36 +1060,27 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten testCases := []struct { queryID string expectedError string - - // This is a temporary assertion. We should always show the following "not found" error messages, - // regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676. - nonAdminSeesError bool }{ { queryID: "BOGUS_QUERY_ID", expectedError: "query ID 00000000000000000000000000000000 malformed: " + "could not decode BOGUS_QUERY_ID as hex: encoding/hex: invalid byte: U+004F 'O'", - nonAdminSeesError: true, }, { - queryID: "", - expectedError: "query ID 00000000000000000000000000000000 not found", - nonAdminSeesError: true, + queryID: "", + expectedError: "query ID 00000000000000000000000000000000 not found", }, { - queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. - expectedError: "query ID 00000000000000000000000000000001 not found", - nonAdminSeesError: false, + queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to. + expectedError: "query ID 00000000000000000000000000000001 not found", }, { - queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. - expectedError: "query ID 00000000000000000000000000000002 not found", - nonAdminSeesError: false, + queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to. + expectedError: "query ID 00000000000000000000000000000002 not found", }, { - queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. - expectedError: "query ID 00000000000000000000000000000042 not found", - nonAdminSeesError: true, + queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist. + expectedError: "query ID 00000000000000000000000000000042 not found", }, } @@ -1115,12 +1094,8 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten err := client.PostJSONChecked("/_status/cancel_query/0", &serverpb.CancelQueryRequest{ QueryID: testCase.queryID, }, &resp) - if isAdmin || testCase.nonAdminSeesError { - require.NoError(t, err) - require.Equal(t, testCase.expectedError, resp.Error) - } else { - require.Error(t, err) - } + require.NoError(t, err) + require.Equal(t, testCase.expectedError, resp.Error) }) } }) diff --git a/pkg/server/status.go b/pkg/server/status.go index dfc669db4947..b1e214f6edd0 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -270,90 +270,42 @@ func (b *baseStatusServer) getLocalSessions( return userSessions, nil } -type sessionFinder func(sessions []serverpb.Session) (serverpb.Session, error) - -func findSessionBySessionID(sessionID []byte) sessionFinder { - return func(sessions []serverpb.Session) (serverpb.Session, error) { - var session serverpb.Session - for _, s := range sessions { - if bytes.Equal(sessionID, s.ID) { - session = s - break - } - } - if len(session.ID) == 0 { - return session, fmt.Errorf("session ID %s not found", clusterunique.IDFromBytes(sessionID)) - } - return session, nil - } -} - -func findSessionByQueryID(queryID string) sessionFinder { - return func(sessions []serverpb.Session) (serverpb.Session, error) { - var session serverpb.Session - for _, s := range sessions { - for _, q := range s.ActiveQueries { - if queryID == q.ID { - session = s - break - } - } - } - if len(session.ID) == 0 { - return session, fmt.Errorf("query ID %s not found", queryID) - } - return session, nil - } -} - // checkCancelPrivilege returns nil if the user has the necessary cancel action // privileges for a session. This function returns a proper gRPC error status. func (b *baseStatusServer) checkCancelPrivilege( - ctx context.Context, userName username.SQLUsername, findSession sessionFinder, + ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername, ) error { ctx = propagateGatewayMetadata(ctx) ctx = b.AnnotateCtx(ctx) - // reqUser is the user who made the cancellation request. - var reqUser username.SQLUsername - { - sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) - if err != nil { - return serverError(ctx, err) - } - if userName.Undefined() || userName == sessionUser { - reqUser = sessionUser - } else { - // When CANCEL QUERY is run as a SQL statement, sessionUser is always root - // and the user who ran the statement is passed as req.Username. - if !isAdmin { - return errRequiresAdmin - } - reqUser = userName - } + + ctxUsername, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx) + if err != nil { + return serverError(ctx, err) + } + if reqUsername.Undefined() { + reqUsername = ctxUsername + } else if reqUsername != ctxUsername && !isAdmin { + // When CANCEL QUERY is run as a SQL statement, sessionUser is always root + // and the user who ran the statement is passed as req.Username. + return errRequiresAdmin } - hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUser) + hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUsername) if err != nil { return serverError(ctx, err) } if !hasAdmin { // Check if the user has permission to see the session. - session, err := findSession(b.sessionRegistry.SerializeAll()) - if err != nil { - return serverError(ctx, err) - } - - sessionUser := username.MakeSQLUsernameFromPreNormalizedString(session.Username) - if sessionUser != reqUser { + if sessionUsername != reqUsername { // Must have CANCELQUERY privilege to cancel other users' // sessions/queries. - hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUser, privilege.CANCELQUERY) + hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUsername, privilege.CANCELQUERY) if err != nil { return serverError(ctx, err) } if !hasCancelQuery { - ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUser, roleoption.CANCELQUERY) + ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUsername, roleoption.CANCELQUERY) if err != nil { return serverError(ctx, err) } @@ -362,7 +314,7 @@ func (b *baseStatusServer) checkCancelPrivilege( } } // Non-admins cannot cancel admins' sessions/queries. - isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUser) + isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUsername) if err != nil { return serverError(ctx, err) } @@ -3063,7 +3015,13 @@ func (s *statusServer) CancelSession( ctx = propagateGatewayMetadata(ctx) ctx = s.AnnotateCtx(ctx) - sessionID := clusterunique.IDFromBytes(req.SessionID) + sessionIDBytes := req.SessionID + if len(sessionIDBytes) != 16 { + return &serverpb.CancelSessionResponse{ + Error: fmt.Sprintf("session ID %v malformed", sessionIDBytes), + }, nil + } + sessionID := clusterunique.IDFromBytes(sessionIDBytes) nodeID := sessionID.GetNodeID() local := nodeID == int32(s.serverIterator.getID()) if !local { @@ -3071,8 +3029,7 @@ func (s *statusServer) CancelSession( if err != nil { if errors.Is(err, sqlinstance.NonExistentInstanceError) { return &serverpb.CancelSessionResponse{ - Canceled: false, - Error: fmt.Sprintf("session ID %s not found", sessionID), + Error: fmt.Sprintf("session ID %s not found", sessionID), }, nil } return nil, serverError(ctx, err) @@ -3085,17 +3042,21 @@ func (s *statusServer) CancelSession( return nil, status.Errorf(codes.InvalidArgument, err.Error()) } - if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionBySessionID(req.SessionID)); err != nil { + session, ok := s.sessionRegistry.GetSessionByID(sessionID) + if !ok { + return &serverpb.CancelSessionResponse{ + Error: fmt.Sprintf("session ID %s not found", sessionID), + }, nil + } + + if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil { // NB: not using serverError() here since the priv checker // already returns a proper gRPC error status. return nil, err } - r, err := s.sessionRegistry.CancelSession(req.SessionID) - if err != nil { - return nil, serverError(ctx, err) - } - return r, nil + session.CancelSession() + return &serverpb.CancelSessionResponse{Canceled: true}, nil } // CancelQuery responds to a query cancellation request, and cancels @@ -3109,8 +3070,7 @@ func (s *statusServer) CancelQuery( queryID, err := clusterunique.IDFromString(req.QueryID) if err != nil { return &serverpb.CancelQueryResponse{ - Canceled: false, - Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(), + Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(), }, nil } @@ -3122,8 +3082,7 @@ func (s *statusServer) CancelQuery( if err != nil { if errors.Is(err, sqlinstance.NonExistentInstanceError) { return &serverpb.CancelQueryResponse{ - Canceled: false, - Error: fmt.Sprintf("query ID %s not found", queryID), + Error: fmt.Sprintf("query ID %s not found", queryID), }, nil } return nil, serverError(ctx, err) @@ -3136,18 +3095,23 @@ func (s *statusServer) CancelQuery( return nil, status.Errorf(codes.InvalidArgument, err.Error()) } - if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionByQueryID(req.QueryID)); err != nil { + session, ok := s.sessionRegistry.GetSessionByQueryID(queryID) + if !ok { + return &serverpb.CancelQueryResponse{ + Error: fmt.Sprintf("query ID %s not found", queryID), + }, nil + } + + if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil { // NB: not using serverError() here since the priv checker // already returns a proper gRPC error status. return nil, err } - output := &serverpb.CancelQueryResponse{} - output.Canceled, err = s.sessionRegistry.CancelQuery(req.QueryID) - if err != nil { - output.Error = err.Error() - } - return output, nil + isCanceled := session.CancelQuery(queryID) + return &serverpb.CancelQueryResponse{ + Canceled: isCanceled, + }, nil } // CancelQueryByKey responds to a pgwire query cancellation request, and cancels @@ -3184,12 +3148,18 @@ func (s *statusServer) CancelQueryByKey( }() if local { - resp = &serverpb.CancelQueryByKeyResponse{} - resp.Canceled, err = s.sessionRegistry.CancelQueryByKey(req.CancelQueryKey) - if err != nil { - resp.Error = err.Error() + cancelQueryKey := req.CancelQueryKey + session, ok := s.sessionRegistry.GetSessionByCancelKey(cancelQueryKey) + if !ok { + return &serverpb.CancelQueryByKeyResponse{ + Error: fmt.Sprintf("session for cancel key %d not found", cancelQueryKey), + }, nil } - return resp, nil + + isCanceled := session.CancelActiveQueries() + return &serverpb.CancelQueryByKeyResponse{ + Canceled: isCanceled, + }, nil } // This request needs to be forwarded to another node. diff --git a/pkg/sql/conn_executor.go b/pkg/sql/conn_executor.go index 70c8cf522d83..cc3c035ccf3b 100644 --- a/pkg/sql/conn_executor.go +++ b/pkg/sql/conn_executor.go @@ -3137,8 +3137,16 @@ func (ex *connExecutor) initStatementResult( return nil } -// cancelQuery is part of the registrySession interface. -func (ex *connExecutor) cancelQuery(queryID clusterunique.ID) bool { +// hasQuery is part of the RegistrySession interface. +func (ex *connExecutor) hasQuery(queryID clusterunique.ID) bool { + ex.mu.RLock() + defer ex.mu.RUnlock() + _, exists := ex.mu.ActiveQueries[queryID] + return exists +} + +// CancelQuery is part of the RegistrySession interface. +func (ex *connExecutor) CancelQuery(queryID clusterunique.ID) bool { ex.mu.Lock() defer ex.mu.Unlock() if queryMeta, exists := ex.mu.ActiveQueries[queryID]; exists { @@ -3148,8 +3156,8 @@ func (ex *connExecutor) cancelQuery(queryID clusterunique.ID) bool { return false } -// cancelCurrentQueries is part of the registrySession interface. -func (ex *connExecutor) cancelCurrentQueries() bool { +// CancelActiveQueries is part of the RegistrySession interface. +func (ex *connExecutor) CancelActiveQueries() bool { ex.mu.Lock() defer ex.mu.Unlock() canceled := false @@ -3160,8 +3168,8 @@ func (ex *connExecutor) cancelCurrentQueries() bool { return canceled } -// cancelSession is part of the registrySession interface. -func (ex *connExecutor) cancelSession() { +// CancelSession is part of the RegistrySession interface. +func (ex *connExecutor) CancelSession() { if ex.onCancelSession == nil { return } @@ -3169,12 +3177,17 @@ func (ex *connExecutor) cancelSession() { ex.onCancelSession() } -// user is part of the registrySession interface. +// user is part of the RegistrySession interface. func (ex *connExecutor) user() username.SQLUsername { return ex.sessionData().User() } -// serialize is part of the registrySession interface. +// BaseSessionUser is part of the RegistrySession interface. +func (ex *connExecutor) BaseSessionUser() username.SQLUsername { + return ex.sessionDataStack.Base().SessionUser() +} + +// serialize is part of the RegistrySession interface. func (ex *connExecutor) serialize() serverpb.Session { ex.mu.RLock() defer ex.mu.RUnlock() diff --git a/pkg/sql/conn_executor_exec.go b/pkg/sql/conn_executor_exec.go index a410f3866f3b..8edb0095a068 100644 --- a/pkg/sql/conn_executor_exec.go +++ b/pkg/sql/conn_executor_exec.go @@ -149,7 +149,7 @@ func (ex *connExecutor) execStmt( // Cancel the session if the idle time exceeds the idle in session timeout. ex.mu.IdleInSessionTimeout = timeout{time.AfterFunc( ex.sessionData().IdleInSessionTimeout, - ex.cancelSession, + ex.CancelSession, )} } @@ -162,7 +162,7 @@ func (ex *connExecutor) execStmt( default: ex.mu.IdleInTransactionSessionTimeout = timeout{time.AfterFunc( ex.sessionData().IdleInTransactionSessionTimeout, - ex.cancelSession, + ex.CancelSession, )} } } diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index 1e223a6cec5b..cc75ad84f3bb 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -2068,8 +2068,8 @@ type SessionArgs struct { type SessionRegistry struct { mu struct { syncutil.RWMutex - sessionsByID map[clusterunique.ID]registrySession - sessionsByCancelKey map[pgwirecancel.BackendKeyData]registrySession + sessionsByID map[clusterunique.ID]RegistrySession + sessionsByCancelKey map[pgwirecancel.BackendKeyData]RegistrySession } } @@ -2077,31 +2077,40 @@ type SessionRegistry struct { // of sessions. func NewSessionRegistry() *SessionRegistry { r := SessionRegistry{} - r.mu.sessionsByID = make(map[clusterunique.ID]registrySession) - r.mu.sessionsByCancelKey = make(map[pgwirecancel.BackendKeyData]registrySession) + r.mu.sessionsByID = make(map[clusterunique.ID]RegistrySession) + r.mu.sessionsByCancelKey = make(map[pgwirecancel.BackendKeyData]RegistrySession) return &r } -func (r *SessionRegistry) getSessionByID(id clusterunique.ID) (registrySession, bool) { +func (r *SessionRegistry) GetSessionByID(sessionID clusterunique.ID) (RegistrySession, bool) { r.mu.RLock() defer r.mu.RUnlock() - session, ok := r.mu.sessionsByID[id] + session, ok := r.mu.sessionsByID[sessionID] return session, ok } -func (r *SessionRegistry) getSessionByCancelKey( +func (r *SessionRegistry) GetSessionByQueryID(queryID clusterunique.ID) (RegistrySession, bool) { + for _, session := range r.getSessions() { + if session.hasQuery(queryID) { + return session, true + } + } + return nil, false +} + +func (r *SessionRegistry) GetSessionByCancelKey( cancelKey pgwirecancel.BackendKeyData, -) (registrySession, bool) { +) (RegistrySession, bool) { r.mu.RLock() defer r.mu.RUnlock() session, ok := r.mu.sessionsByCancelKey[cancelKey] return session, ok } -func (r *SessionRegistry) getSessions() []registrySession { +func (r *SessionRegistry) getSessions() []RegistrySession { r.mu.RLock() defer r.mu.RUnlock() - sessions := make([]registrySession, 0, len(r.mu.sessionsByID)) + sessions := make([]RegistrySession, 0, len(r.mu.sessionsByID)) for _, session := range r.mu.sessionsByID { sessions = append(sessions, session) } @@ -2109,7 +2118,7 @@ func (r *SessionRegistry) getSessions() []registrySession { } func (r *SessionRegistry) register( - id clusterunique.ID, queryCancelKey pgwirecancel.BackendKeyData, s registrySession, + id clusterunique.ID, queryCancelKey pgwirecancel.BackendKeyData, s RegistrySession, ) { r.mu.Lock() defer r.mu.Unlock() @@ -2126,65 +2135,22 @@ func (r *SessionRegistry) deregister( delete(r.mu.sessionsByCancelKey, queryCancelKey) } -type registrySession interface { +type RegistrySession interface { user() username.SQLUsername - cancelQuery(queryID clusterunique.ID) bool - cancelCurrentQueries() bool - cancelSession() + // BaseSessionUser returns the base session's username. + BaseSessionUser() username.SQLUsername + hasQuery(queryID clusterunique.ID) bool + // CancelQuery cancels the query specified by queryID if it exists. + CancelQuery(queryID clusterunique.ID) bool + // CancelActiveQueries cancels all currently active queries. + CancelActiveQueries() bool + // CancelSession cancels the session. + CancelSession() // serialize serializes a Session into a serverpb.Session // that can be served over RPC. serialize() serverpb.Session } -// CancelQuery looks up the associated query in the session registry and cancels -// it. The caller is responsible for all permission checks. -func (r *SessionRegistry) CancelQuery(queryIDStr string) (bool, error) { - queryID, err := clusterunique.IDFromString(queryIDStr) - if err != nil { - return false, errors.Wrapf(err, "query ID %s malformed", queryID) - } - - for _, session := range r.getSessions() { - if session.cancelQuery(queryID) { - return true, nil - } - } - - return false, fmt.Errorf("query ID %s not found", queryID) -} - -// CancelQueryByKey looks up the associated query in the session registry and -// cancels it. -func (r *SessionRegistry) CancelQueryByKey( - queryCancelKey pgwirecancel.BackendKeyData, -) (canceled bool, err error) { - session, ok := r.getSessionByCancelKey(queryCancelKey) - if !ok { - return false, fmt.Errorf("session for cancel key %d not found", queryCancelKey) - } - return session.cancelCurrentQueries(), nil -} - -// CancelSession looks up the specified session in the session registry and -// cancels it. The caller is responsible for all permission checks. -func (r *SessionRegistry) CancelSession( - sessionIDBytes []byte, -) (*serverpb.CancelSessionResponse, error) { - if len(sessionIDBytes) != 16 { - return nil, errors.Errorf("invalid non-16-byte UUID %v", sessionIDBytes) - } - sessionID := clusterunique.IDFromBytes(sessionIDBytes) - - session, ok := r.getSessionByID(sessionID) - if !ok { - return &serverpb.CancelSessionResponse{ - Error: fmt.Sprintf("session ID %s not found", sessionID), - }, nil - } - session.cancelSession() - return &serverpb.CancelSessionResponse{Canceled: true}, nil -} - // SerializeAll returns a slice of all sessions in the registry converted to // serverpb.Sessions. func (r *SessionRegistry) SerializeAll() []serverpb.Session {