diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index b03324fd8e043..91f8c2507f218 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -273,13 +273,71 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, tracker type } +func (a *ServerWithRoles) filterSessionTracker(ctx context.Context, joinerRoles []types.Role, tracker types.SessionTracker) bool { + evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind()) + modes := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles}) + + if len(modes) == 0 { + return false + } + + // Apply RFD 45 RBAC rules to the session if it's SSH. + // This is a bit of a hack. It converts to the old legacy format + // which we don't have all data for, luckily the fields we don't have aren't made available + // to the RBAC filter anyway. + if tracker.GetKind() == types.KindSSHSession { + ruleCtx := &services.Context{User: a.context.User} + ruleCtx.SSHSession = &session.Session{ + ID: session.ID(tracker.GetSessionID()), + Namespace: apidefaults.Namespace, + Login: tracker.GetLogin(), + Created: tracker.GetCreated(), + LastActive: a.authServer.GetClock().Now(), + ServerID: tracker.GetAddress(), + ServerAddr: tracker.GetAddress(), + ServerHostname: tracker.GetHostname(), + ClusterName: tracker.GetClustername(), + } + + for _, participant := range tracker.GetParticipants() { + // We only need to fill in User here since other fields get discarded anyway. + ruleCtx.SSHSession.Parties = append(ruleCtx.SSHSession.Parties, session.Party{ + User: participant.User, + }) + } + + // Skip past it if there's a deny rule in place blocking access. + if err := a.context.Checker.CheckAccessToRule(ruleCtx, apidefaults.Namespace, types.KindSSHSession, types.VerbList, true /* silent */); err != nil { + return false + } + } + + return true +} + // GetSessionTracker returns the current state of a session tracker for an active session. func (a *ServerWithRoles) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) { - if err := a.serverAction(); err != nil { + tracker, err := a.authServer.GetSessionTracker(ctx, sessionID) + if err != nil { + return nil, trace.Wrap(err) + } + + if err := a.serverAction(); err == nil { + return tracker, nil + } + + user := a.context.User + joinerRoles, err := services.FetchRoles(user.GetRoles(), a.authServer, user.GetTraits()) + if err != nil { return nil, trace.Wrap(err) } - return a.authServer.GetSessionTracker(ctx, sessionID) + ok := a.filterSessionTracker(ctx, joinerRoles, tracker) + if !ok { + return nil, trace.NotFound("session %v not found", sessionID) + } + + return tracker, nil } // GetActiveSessionTrackers returns a list of active session trackers. @@ -289,18 +347,21 @@ func (a *ServerWithRoles) GetActiveSessionTrackers(ctx context.Context) ([]types return nil, trace.Wrap(err) } - var filteredSessions []types.SessionTracker + if err := a.serverAction(); err == nil { + return sessions, nil + } - for _, session := range sessions { - evaluator := NewSessionAccessEvaluator(session.GetHostPolicySets(), session.GetSessionKind()) - joinerRoles, err := a.authServer.GetRoles(ctx) - if err != nil { - return nil, trace.Wrap(err) - } + var filteredSessions []types.SessionTracker + user := a.context.User + joinerRoles, err := services.FetchRoles(user.GetRoles(), a.authServer, user.GetTraits()) + if err != nil { + return nil, trace.Wrap(err) + } - modes, err := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles}) - if err == nil || len(modes) > 0 { - filteredSessions = append(filteredSessions, session) + for _, sess := range sessions { + ok := a.filterSessionTracker(ctx, joinerRoles, sess) + if ok { + filteredSessions = append(filteredSessions, sess) } } diff --git a/lib/auth/session_access.go b/lib/auth/session_access.go index a271b7c9cb1b3..d7e846892e2bd 100644 --- a/lib/auth/session_access.go +++ b/lib/auth/session_access.go @@ -171,17 +171,21 @@ func (e *SessionAccessEvaluator) matchesKind(allow []string) bool { return false } -// CanJoin returns the modes a user has access to join a session with. -// If the list is empty, the user doesn't have access to join the session at all. -func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.SessionParticipantMode, error) { - supported, err := e.supportsSessionAccessControls() - if err != nil { - return nil, trace.Wrap(err) +func HasV5Role(roles []types.Role) bool { + for _, role := range roles { + if role.GetVersion() == types.V5 { + return true + } } + return false +} +// CanJoin returns the modes a user has access to join a session with. +// If the list is empty, the user doesn't have access to join the session at all. +func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) []types.SessionParticipantMode { // If we don't support session access controls, return the default mode set that was supported prior to Moderated Sessions. - if !supported { - return preAccessControlsModes(e.kind), nil + if !HasV5Role(user.Roles) { + return preAccessControlsModes(e.kind) } var modes []types.SessionParticipantMode @@ -200,7 +204,7 @@ func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.Ses } } - return modes, nil + return modes } func SliceContainsMode(s []types.SessionParticipantMode, e types.SessionParticipantMode) bool { diff --git a/lib/auth/session_access_test.go b/lib/auth/session_access_test.go index 1fe56cc008c72..ff70e01cd68fe 100644 --- a/lib/auth/session_access_test.go +++ b/lib/auth/session_access_test.go @@ -297,8 +297,7 @@ func TestSessionAccessJoin(t *testing.T) { t.Run(testCase.name, func(t *testing.T) { policy := testCase.host.GetSessionPolicySet() evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind) - result, err := evaluator.CanJoin(testCase.participant) - require.NoError(t, err) + result := evaluator.CanJoin(testCase.participant) require.Equal(t, testCase.expected, len(result) > 0) }) } diff --git a/lib/client/api.go b/lib/client/api.go index 9403411c10a66..0dec9944dcfc8 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -3597,15 +3597,3 @@ func findActiveDatabases(key *Key) ([]tlsca.RouteToDatabase, error) { } return databases, nil } - -// GetActiveSessions fetches a list of all active sessions tracked by the SessionTracker resource -// that the user has access to. -func (tc *TeleportClient) GetActiveSessions(ctx context.Context) ([]types.SessionTracker, error) { - proxy, err := tc.ConnectToProxy(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - defer proxy.Close() - return proxy.GetActiveSessions(ctx) -} diff --git a/lib/client/client.go b/lib/client/client.go index 4d0c8298f3a88..6b2d62ab11483 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -74,21 +74,6 @@ type NodeClient struct { OnMFA func() } -// GetActiveSessions returns a list of active session trackers. -func (proxy *ProxyClient) GetActiveSessions(ctx context.Context) ([]types.SessionTracker, error) { - auth, err := proxy.ConnectToCurrentCluster(ctx, false) - if err != nil { - return nil, trace.Wrap(err) - } - defer auth.Close() - sessions, err := auth.GetActiveSessionTrackers(ctx) - if err != nil { - return nil, trace.Wrap(err) - } - - return sessions, nil -} - // GetSites returns list of the "sites" (AKA teleport clusters) connected to the proxy // Each site is returned as an instance of its auth server // diff --git a/lib/kube/proxy/sess.go b/lib/kube/proxy/sess.go index 51018ce1207af..af6b7f4d5bed7 100644 --- a/lib/kube/proxy/sess.go +++ b/lib/kube/proxy/sess.go @@ -738,11 +738,7 @@ func (s *session) join(p *party) error { Roles: roles, } - modes, err := s.accessEvaluator.CanJoin(accessContext) - if err != nil { - return trace.Wrap(err) - } - + modes := s.accessEvaluator.CanJoin(accessContext) if !auth.SliceContainsMode(modes, p.Mode) { return trace.AccessDenied("insufficient permissions to join session") } diff --git a/lib/services/local/sessiontracker.go b/lib/services/local/sessiontracker.go index a6f800aab8fa3..ba010f8060331 100644 --- a/lib/services/local/sessiontracker.go +++ b/lib/services/local/sessiontracker.go @@ -145,7 +145,6 @@ func (s *sessionTracker) GetActiveSessionTrackers(ctx context.Context) ([]types. sessions = append(sessions, session) case !after && item.Expires.IsZero(): // Clear item if expiry is not set on the backend. - // We currently don't set the expiry here but we will when #11551 is merged. noExpiry = append(noExpiry, item) default: // If we take this branch, the expiry is set and the backend is responsible for cleaning up the item. diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 92447b728d97c..a0cc64b4dfd06 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -1496,11 +1496,7 @@ func (s *session) join(ch ssh.Channel, ctx *ServerContext, mode types.SessionPar Roles: roles, } - modes, err := s.access.CanJoin(accessContext) - if err != nil { - return nil, trace.Wrap(err) - } - + modes := s.access.CanJoin(accessContext) if !auth.SliceContainsMode(modes, mode) { return nil, trace.AccessDenied("insufficient permissions to join session %v", s.id) } diff --git a/tool/tsh/kube.go b/tool/tsh/kube.go index 33d397592c237..b0cdfde6a74fd 100644 --- a/tool/tsh/kube.go +++ b/tool/tsh/kube.go @@ -105,18 +105,17 @@ func newKubeJoinCommand(parent *kingpin.CmdClause) *kubeJoinCommand { } func (c *kubeJoinCommand) getSessionMeta(ctx context.Context, tc *client.TeleportClient) (types.SessionTracker, error) { - sessions, err := tc.GetActiveSessions(ctx) + proxy, err := tc.ConnectToProxy(ctx) if err != nil { return nil, trace.Wrap(err) } - for _, session := range sessions { - if session.GetSessionID() == c.session { - return session, nil - } + site, err := proxy.ConnectToCurrentCluster(ctx, false) + if err != nil { + return nil, trace.Wrap(err) } - return nil, trace.NotFound("session %q not found", c.session) + return site.GetSessionTracker(ctx, c.session) } func (c *kubeJoinCommand) run(cf *CLIConf) error { @@ -489,7 +488,17 @@ func (c *kubeSessionsCommand) run(cf *CLIConf) error { return trace.Wrap(err) } - sessions, err := tc.GetActiveSessions(cf.Context) + proxy, err := tc.ConnectToProxy(cf.Context) + if err != nil { + return trace.Wrap(err) + } + + site, err := proxy.ConnectToCurrentCluster(cf.Context, true) + if err != nil { + return trace.Wrap(err) + } + + sessions, err := site.GetActiveSessionTrackers(cf.Context) if err != nil { return trace.Wrap(err) }