Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: remove redundant session iteration #95745

Merged
merged 1 commit into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 20 additions & 45 deletions pkg/ccl/serverccl/statusccl/tenant_status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -943,30 +943,22 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T
testCases := []struct {
sessionID string
expectedError string

// This is a temporary assertion. We should always show the following "not found" error messages,
// regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676.
nonAdminSeesError bool
}{
{
sessionID: "",
expectedError: "session ID 00000000000000000000000000000000 not found",
nonAdminSeesError: true,
sessionID: "",
expectedError: "session ID 00000000000000000000000000000000 not found",
},
{
sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "session ID 00000000000000000000000000000001 not found",
nonAdminSeesError: false,
sessionID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "session ID 00000000000000000000000000000001 not found",
},
{
sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "session ID 00000000000000000000000000000002 not found",
nonAdminSeesError: false,
sessionID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "session ID 00000000000000000000000000000002 not found",
},
{
sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "session ID 00000000000000000000000000000042 not found",
nonAdminSeesError: true,
sessionID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "session ID 00000000000000000000000000000042 not found",
},
}

Expand All @@ -982,12 +974,8 @@ func testTenantStatusCancelSessionErrorMessages(t *testing.T, helper serverccl.T
err = client.PostJSONChecked("/_status/cancel_session/0", &serverpb.CancelSessionRequest{
SessionID: sessionID.GetBytes(),
}, &resp)
if isAdmin || testCase.nonAdminSeesError {
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
} else {
require.Error(t, err)
}
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
})
}
})
Expand Down Expand Up @@ -1072,36 +1060,27 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten
testCases := []struct {
queryID string
expectedError string

// This is a temporary assertion. We should always show the following "not found" error messages,
// regardless of admin status, but our current behavior is slightly broken and will be fixed in #77676.
nonAdminSeesError bool
}{
{
queryID: "BOGUS_QUERY_ID",
expectedError: "query ID 00000000000000000000000000000000 malformed: " +
"could not decode BOGUS_QUERY_ID as hex: encoding/hex: invalid byte: U+004F 'O'",
nonAdminSeesError: true,
},
{
queryID: "",
expectedError: "query ID 00000000000000000000000000000000 not found",
nonAdminSeesError: true,
queryID: "",
expectedError: "query ID 00000000000000000000000000000000 not found",
},
{
queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "query ID 00000000000000000000000000000001 not found",
nonAdminSeesError: false,
queryID: "01", // This query ID claims to have SQL instance ID 1, different from the one we're talking to.
expectedError: "query ID 00000000000000000000000000000001 not found",
},
{
queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "query ID 00000000000000000000000000000002 not found",
nonAdminSeesError: false,
queryID: "02", // This query ID claims to have SQL instance ID 2, the instance we're talking to.
expectedError: "query ID 00000000000000000000000000000002 not found",
},
{
queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "query ID 00000000000000000000000000000042 not found",
nonAdminSeesError: true,
queryID: "42", // This query ID claims to have SQL instance ID 42, which does not exist.
expectedError: "query ID 00000000000000000000000000000042 not found",
},
}

Expand All @@ -1115,12 +1094,8 @@ func testTenantStatusCancelQueryErrorMessages(t *testing.T, helper serverccl.Ten
err := client.PostJSONChecked("/_status/cancel_query/0", &serverpb.CancelQueryRequest{
QueryID: testCase.queryID,
}, &resp)
if isAdmin || testCase.nonAdminSeesError {
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
} else {
require.Error(t, err)
}
require.NoError(t, err)
require.Equal(t, testCase.expectedError, resp.Error)
})
}
})
Expand Down
150 changes: 60 additions & 90 deletions pkg/server/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,90 +270,42 @@ func (b *baseStatusServer) getLocalSessions(
return userSessions, nil
}

type sessionFinder func(sessions []serverpb.Session) (serverpb.Session, error)

func findSessionBySessionID(sessionID []byte) sessionFinder {
return func(sessions []serverpb.Session) (serverpb.Session, error) {
var session serverpb.Session
for _, s := range sessions {
if bytes.Equal(sessionID, s.ID) {
session = s
break
}
}
if len(session.ID) == 0 {
return session, fmt.Errorf("session ID %s not found", clusterunique.IDFromBytes(sessionID))
}
return session, nil
}
}

func findSessionByQueryID(queryID string) sessionFinder {
return func(sessions []serverpb.Session) (serverpb.Session, error) {
var session serverpb.Session
for _, s := range sessions {
for _, q := range s.ActiveQueries {
if queryID == q.ID {
session = s
break
}
}
}
if len(session.ID) == 0 {
return session, fmt.Errorf("query ID %s not found", queryID)
}
return session, nil
}
}

// checkCancelPrivilege returns nil if the user has the necessary cancel action
// privileges for a session. This function returns a proper gRPC error status.
func (b *baseStatusServer) checkCancelPrivilege(
ctx context.Context, userName username.SQLUsername, findSession sessionFinder,
ctx context.Context, reqUsername username.SQLUsername, sessionUsername username.SQLUsername,
) error {
ctx = propagateGatewayMetadata(ctx)
ctx = b.AnnotateCtx(ctx)
// reqUser is the user who made the cancellation request.
var reqUser username.SQLUsername
{
sessionUser, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx)
if err != nil {
return serverError(ctx, err)
}
if userName.Undefined() || userName == sessionUser {
reqUser = sessionUser
} else {
// When CANCEL QUERY is run as a SQL statement, sessionUser is always root
// and the user who ran the statement is passed as req.Username.
if !isAdmin {
return errRequiresAdmin
}
reqUser = userName
}

ctxUsername, isAdmin, err := b.privilegeChecker.getUserAndRole(ctx)
if err != nil {
return serverError(ctx, err)
}
if reqUsername.Undefined() {
reqUsername = ctxUsername
} else if reqUsername != ctxUsername && !isAdmin {
// When CANCEL QUERY is run as a SQL statement, sessionUser is always root
// and the user who ran the statement is passed as req.Username.
return errRequiresAdmin
}

hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUser)
hasAdmin, err := b.privilegeChecker.hasAdminRole(ctx, reqUsername)
if err != nil {
return serverError(ctx, err)
}

if !hasAdmin {
// Check if the user has permission to see the session.
session, err := findSession(b.sessionRegistry.SerializeAll())
if err != nil {
return serverError(ctx, err)
}

sessionUser := username.MakeSQLUsernameFromPreNormalizedString(session.Username)
if sessionUser != reqUser {
if sessionUsername != reqUsername {
// Must have CANCELQUERY privilege to cancel other users'
// sessions/queries.
hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUser, privilege.CANCELQUERY)
hasCancelQuery, err := b.privilegeChecker.hasGlobalPrivilege(ctx, reqUsername, privilege.CANCELQUERY)
if err != nil {
return serverError(ctx, err)
}
if !hasCancelQuery {
ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUser, roleoption.CANCELQUERY)
ok, err := b.privilegeChecker.hasRoleOption(ctx, reqUsername, roleoption.CANCELQUERY)
if err != nil {
return serverError(ctx, err)
}
Expand All @@ -362,7 +314,7 @@ func (b *baseStatusServer) checkCancelPrivilege(
}
}
// Non-admins cannot cancel admins' sessions/queries.
isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUser)
isAdminSession, err := b.privilegeChecker.hasAdminRole(ctx, sessionUsername)
if err != nil {
return serverError(ctx, err)
}
Expand Down Expand Up @@ -3063,16 +3015,21 @@ func (s *statusServer) CancelSession(
ctx = propagateGatewayMetadata(ctx)
ctx = s.AnnotateCtx(ctx)

sessionID := clusterunique.IDFromBytes(req.SessionID)
sessionIDBytes := req.SessionID
if len(sessionIDBytes) != 16 {
return &serverpb.CancelSessionResponse{
Error: fmt.Sprintf("session ID %v malformed", sessionIDBytes),
}, nil
}
sessionID := clusterunique.IDFromBytes(sessionIDBytes)
nodeID := sessionID.GetNodeID()
local := nodeID == int32(s.serverIterator.getID())
if !local {
status, err := s.dialNode(ctx, roachpb.NodeID(nodeID))
if err != nil {
if errors.Is(err, sqlinstance.NonExistentInstanceError) {
return &serverpb.CancelSessionResponse{
Canceled: false,
Error: fmt.Sprintf("session ID %s not found", sessionID),
Error: fmt.Sprintf("session ID %s not found", sessionID),
}, nil
}
return nil, serverError(ctx, err)
Expand All @@ -3085,17 +3042,21 @@ func (s *statusServer) CancelSession(
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionBySessionID(req.SessionID)); err != nil {
session, ok := s.sessionRegistry.GetSessionByID(sessionID)
if !ok {
return &serverpb.CancelSessionResponse{
Error: fmt.Sprintf("session ID %s not found", sessionID),
}, nil
}

if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil {
// NB: not using serverError() here since the priv checker
// already returns a proper gRPC error status.
return nil, err
}

r, err := s.sessionRegistry.CancelSession(req.SessionID)
if err != nil {
return nil, serverError(ctx, err)
}
return r, nil
session.CancelSession()
return &serverpb.CancelSessionResponse{Canceled: true}, nil
}

// CancelQuery responds to a query cancellation request, and cancels
Expand All @@ -3109,8 +3070,7 @@ func (s *statusServer) CancelQuery(
queryID, err := clusterunique.IDFromString(req.QueryID)
if err != nil {
return &serverpb.CancelQueryResponse{
Canceled: false,
Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(),
Error: errors.Wrapf(err, "query ID %s malformed", queryID).Error(),
}, nil
}

Expand All @@ -3122,8 +3082,7 @@ func (s *statusServer) CancelQuery(
if err != nil {
if errors.Is(err, sqlinstance.NonExistentInstanceError) {
return &serverpb.CancelQueryResponse{
Canceled: false,
Error: fmt.Sprintf("query ID %s not found", queryID),
Error: fmt.Sprintf("query ID %s not found", queryID),
}, nil
}
return nil, serverError(ctx, err)
Expand All @@ -3136,18 +3095,23 @@ func (s *statusServer) CancelQuery(
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

if err := s.checkCancelPrivilege(ctx, reqUsername, findSessionByQueryID(req.QueryID)); err != nil {
session, ok := s.sessionRegistry.GetSessionByQueryID(queryID)
if !ok {
return &serverpb.CancelQueryResponse{
Error: fmt.Sprintf("query ID %s not found", queryID),
}, nil
}

if err := s.checkCancelPrivilege(ctx, reqUsername, session.BaseSessionUser()); err != nil {
// NB: not using serverError() here since the priv checker
// already returns a proper gRPC error status.
return nil, err
}

output := &serverpb.CancelQueryResponse{}
output.Canceled, err = s.sessionRegistry.CancelQuery(req.QueryID)
if err != nil {
output.Error = err.Error()
}
return output, nil
isCanceled := session.CancelQuery(queryID)
return &serverpb.CancelQueryResponse{
Canceled: isCanceled,
}, nil
}

// CancelQueryByKey responds to a pgwire query cancellation request, and cancels
Expand Down Expand Up @@ -3184,12 +3148,18 @@ func (s *statusServer) CancelQueryByKey(
}()

if local {
resp = &serverpb.CancelQueryByKeyResponse{}
resp.Canceled, err = s.sessionRegistry.CancelQueryByKey(req.CancelQueryKey)
if err != nil {
resp.Error = err.Error()
cancelQueryKey := req.CancelQueryKey
session, ok := s.sessionRegistry.GetSessionByCancelKey(cancelQueryKey)
if !ok {
return &serverpb.CancelQueryByKeyResponse{
Error: fmt.Sprintf("session for cancel key %d not found", cancelQueryKey),
}, nil
}
return resp, nil

isCanceled := session.CancelActiveQueries()
return &serverpb.CancelQueryByKeyResponse{
Canceled: isCanceled,
}, nil
}

// This request needs to be forwarded to another node.
Expand Down
Loading