diff --git a/pkg/server/BUILD.bazel b/pkg/server/BUILD.bazel index 0d749648a0b6..eabbcecd8c9e 100644 --- a/pkg/server/BUILD.bazel +++ b/pkg/server/BUILD.bazel @@ -281,6 +281,7 @@ go_library( "//pkg/util/tracing/tracingpb", "//pkg/util/tracing/tracingservicepb", "//pkg/util/tracing/tracingui", + "//pkg/util/uint128", "//pkg/util/uuid", "@com_github_cenkalti_backoff//:backoff", "@com_github_cockroachdb_apd_v3//:apd", diff --git a/pkg/server/status.go b/pkg/server/status.go index 8b0ed83b6204..4f666009efba 100644 --- a/pkg/server/status.go +++ b/pkg/server/status.go @@ -77,6 +77,7 @@ import ( "github.com/cockroachdb/cockroach/pkg/util/syncutil" "github.com/cockroachdb/cockroach/pkg/util/timeutil" "github.com/cockroachdb/cockroach/pkg/util/tracing/tracingpb" + "github.com/cockroachdb/cockroach/pkg/util/uint128" "github.com/cockroachdb/cockroach/pkg/util/uuid" "github.com/cockroachdb/errors" gwruntime "github.com/grpc-ecosystem/grpc-gateway/runtime" @@ -212,30 +213,27 @@ func (b *baseStatusServer) getLocalSessions( showAll := reqUsername.Undefined() showInternal := SQLStatsShowInternal.Get(&b.st.SV) || req.IncludeInternal - // In order to avoid duplicate sessions showing up as both open and closed, - // we lock the session registry to prevent any changes to it while we - // serialize the sessions from the session registry and the closed session - // cache. - b.sessionRegistry.Lock() - sessions := b.sessionRegistry.SerializeAllLocked() - + sessions := b.sessionRegistry.SerializeAll() var closedSessions []serverpb.Session + var closedSessionIDs map[uint128.Uint128]struct{} if !req.ExcludeClosedSessions { closedSessions = b.closedSessionCache.GetSerializedSessions() + closedSessionIDs = make(map[uint128.Uint128]struct{}, len(closedSessions)) + for _, closedSession := range closedSessions { + closedSessionIDs[uint128.FromBytes(closedSession.ID)] = struct{}{} + } } - b.sessionRegistry.Unlock() - - userSessions := make([]serverpb.Session, 0) - sessions = append(sessions, closedSessions...) reqUserNameNormalized := reqUsername.Normalized() - for _, session := range sessions { + + userSessions := make([]serverpb.Session, 0, len(sessions)+len(closedSessions)) + addUserSession := func(session serverpb.Session) { // We filter based on the session name instead of the executor type because we // may want to surface certain internal sessions, such as those executed by // the SQL over HTTP api, as non-internal. if (reqUserNameNormalized != session.Username && !showAll) || (!showInternal && isInternalAppName(session.ApplicationName)) { - continue + return } if !isAdmin && hasViewActivityRedacted && (reqUserNameNormalized != session.Username) { @@ -247,9 +245,22 @@ func (b *baseStatusServer) getLocalSessions( } session.LastActiveQuery = session.LastActiveQueryNoConstants } - userSessions = append(userSessions, session) } + for _, session := range sessions { + // The same session can appear as both open and closed because reading the + // open and closed sessions is not synchronized. Prefer the closed session + // over the open one if the same session appears as both because it was + // closed in between reading the open sessions and reading the closed ones. + _, ok := closedSessionIDs[uint128.FromBytes(session.ID)] + if ok { + continue + } + addUserSession(session) + } + for _, session := range closedSessions { + addUserSession(session) + } sort.Slice(userSessions, func(i, j int) bool { return userSessions[i].Start.Before(userSessions[j].Start) diff --git a/pkg/sql/conn_executor_test.go b/pkg/sql/conn_executor_test.go index 8ff35e6ab657..cce78bbf26fa 100644 --- a/pkg/sql/conn_executor_test.go +++ b/pkg/sql/conn_executor_test.go @@ -635,7 +635,7 @@ func TestQueryProgress(t *testing.T) { // stalled ch as expected. defer func() { select { - case <-stalled: //stalled was closed as expected. + case <-stalled: // stalled was closed as expected. default: panic("expected stalled to have been closed during execution") } diff --git a/pkg/sql/exec_util.go b/pkg/sql/exec_util.go index aeb7b41e9236..7af816dc5bdd 100644 --- a/pkg/sql/exec_util.go +++ b/pkg/sql/exec_util.go @@ -1299,7 +1299,7 @@ type ExecutorConfig struct { // RootMemoryMonitor is the root memory monitor of the entire server. Do not // use this for normal purposes. It is to be used to establish any new - // root-level memory accounts that are not related to a user sessions. + // root-level memory accounts that are not related to a user session. RootMemoryMonitor *mon.BytesMonitor // CompactEngineSpanFunc is used to inform a storage engine of the need to @@ -2062,36 +2062,64 @@ type SessionArgs struct { // SessionRegistry stores a set of all sessions on this node. // Use register() and deregister() to modify this registry. type SessionRegistry struct { - syncutil.Mutex - sessions map[clusterunique.ID]registrySession - sessionsByCancelKey map[pgwirecancel.BackendKeyData]registrySession + mu struct { + syncutil.RWMutex + sessionsByID map[clusterunique.ID]registrySession + sessionsByCancelKey map[pgwirecancel.BackendKeyData]registrySession + } } // NewSessionRegistry creates a new SessionRegistry with an empty set // of sessions. func NewSessionRegistry() *SessionRegistry { - return &SessionRegistry{ - sessions: make(map[clusterunique.ID]registrySession), - sessionsByCancelKey: make(map[pgwirecancel.BackendKeyData]registrySession), + r := SessionRegistry{} + r.mu.sessionsByID = make(map[clusterunique.ID]registrySession) + r.mu.sessionsByCancelKey = make(map[pgwirecancel.BackendKeyData]registrySession) + return &r +} + +func (r *SessionRegistry) getSessionByID(id clusterunique.ID) (registrySession, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + session, ok := r.mu.sessionsByID[id] + return session, ok +} + +func (r *SessionRegistry) getSessionByCancelKey( + cancelKey pgwirecancel.BackendKeyData, +) (registrySession, bool) { + r.mu.RLock() + defer r.mu.RUnlock() + session, ok := r.mu.sessionsByCancelKey[cancelKey] + return session, ok +} + +func (r *SessionRegistry) getSessions() []registrySession { + r.mu.RLock() + defer r.mu.RUnlock() + sessions := make([]registrySession, 0, len(r.mu.sessionsByID)) + for _, session := range r.mu.sessionsByID { + sessions = append(sessions, session) } + return sessions } func (r *SessionRegistry) register( id clusterunique.ID, queryCancelKey pgwirecancel.BackendKeyData, s registrySession, ) { - r.Lock() - defer r.Unlock() - r.sessions[id] = s - r.sessionsByCancelKey[queryCancelKey] = s + r.mu.Lock() + defer r.mu.Unlock() + r.mu.sessionsByID[id] = s + r.mu.sessionsByCancelKey[queryCancelKey] = s } func (r *SessionRegistry) deregister( id clusterunique.ID, queryCancelKey pgwirecancel.BackendKeyData, ) { - r.Lock() - defer r.Unlock() - delete(r.sessions, id) - delete(r.sessionsByCancelKey, queryCancelKey) + r.mu.Lock() + defer r.mu.Unlock() + delete(r.mu.sessionsByID, id) + delete(r.mu.sessionsByCancelKey, queryCancelKey) } type registrySession interface { @@ -2112,10 +2140,7 @@ func (r *SessionRegistry) CancelQuery(queryIDStr string) (bool, error) { return false, errors.Wrapf(err, "query ID %s malformed", queryID) } - r.Lock() - defer r.Unlock() - - for _, session := range r.sessions { + for _, session := range r.getSessions() { if session.cancelQuery(queryID) { return true, nil } @@ -2129,15 +2154,11 @@ func (r *SessionRegistry) CancelQuery(queryIDStr string) (bool, error) { func (r *SessionRegistry) CancelQueryByKey( queryCancelKey pgwirecancel.BackendKeyData, ) (canceled bool, err error) { - r.Lock() - defer r.Unlock() - if session, ok := r.sessionsByCancelKey[queryCancelKey]; ok { - if session.cancelCurrentQueries() { - return true, nil - } - return false, nil + session, ok := r.getSessionByCancelKey(queryCancelKey) + if !ok { + return false, fmt.Errorf("session for cancel key %d not found", queryCancelKey) } - return false, fmt.Errorf("session for cancel key %d not found", queryCancelKey) + return session.cancelCurrentQueries(), nil } // CancelSession looks up the specified session in the session registry and @@ -2150,37 +2171,24 @@ func (r *SessionRegistry) CancelSession( } sessionID := clusterunique.IDFromBytes(sessionIDBytes) - r.Lock() - defer r.Unlock() - - for id, session := range r.sessions { - if id == sessionID { - session.cancelSession() - return &serverpb.CancelSessionResponse{Canceled: true}, nil - } + session, ok := r.getSessionByID(sessionID) + if !ok { + return &serverpb.CancelSessionResponse{ + Error: fmt.Sprintf("session ID %s not found", sessionID), + }, nil } - - return &serverpb.CancelSessionResponse{ - Error: fmt.Sprintf("session ID %s not found", sessionID), - }, nil + session.cancelSession() + return &serverpb.CancelSessionResponse{Canceled: true}, nil } -// SerializeAll returns a slice of all sessions in the registry, converted to serverpb.Sessions. +// SerializeAll returns a slice of all sessions in the registry converted to +// serverpb.Sessions. func (r *SessionRegistry) SerializeAll() []serverpb.Session { - r.Lock() - defer r.Unlock() - - return r.SerializeAllLocked() -} - -// SerializeAllLocked is like SerializeAll but assumes SessionRegistry's mutex is locked. -func (r *SessionRegistry) SerializeAllLocked() []serverpb.Session { - response := make([]serverpb.Session, 0, len(r.sessions)) - - for _, s := range r.sessions { + sessions := r.getSessions() + response := make([]serverpb.Session, 0, len(sessions)) + for _, s := range sessions { response = append(response, s.serialize()) } - return response }