Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v9] Backport #12584 #12832

Merged
merged 2 commits into from
May 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 73 additions & 12 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,71 @@ func (a *ServerWithRoles) CreateSessionTracker(ctx context.Context, tracker type

}

func (a *ServerWithRoles) filterSessionTracker(ctx context.Context, joinerRoles []types.Role, tracker types.SessionTracker) bool {
evaluator := NewSessionAccessEvaluator(tracker.GetHostPolicySets(), tracker.GetSessionKind())
modes := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles})

if len(modes) == 0 {
return false
}

// Apply RFD 45 RBAC rules to the session if it's SSH.
// This is a bit of a hack. It converts to the old legacy format
// which we don't have all data for, luckily the fields we don't have aren't made available
// to the RBAC filter anyway.
if tracker.GetKind() == types.KindSSHSession {
ruleCtx := &services.Context{User: a.context.User}
ruleCtx.SSHSession = &session.Session{
ID: session.ID(tracker.GetSessionID()),
Namespace: apidefaults.Namespace,
Login: tracker.GetLogin(),
Created: tracker.GetCreated(),
LastActive: a.authServer.GetClock().Now(),
ServerID: tracker.GetAddress(),
ServerAddr: tracker.GetAddress(),
ServerHostname: tracker.GetHostname(),
ClusterName: tracker.GetClustername(),
}

for _, participant := range tracker.GetParticipants() {
// We only need to fill in User here since other fields get discarded anyway.
ruleCtx.SSHSession.Parties = append(ruleCtx.SSHSession.Parties, session.Party{
User: participant.User,
})
}

// Skip past it if there's a deny rule in place blocking access.
if err := a.context.Checker.CheckAccessToRule(ruleCtx, apidefaults.Namespace, types.KindSSHSession, types.VerbList, true /* silent */); err != nil {
return false
}
}

return true
}

// GetSessionTracker returns the current state of a session tracker for an active session.
func (a *ServerWithRoles) GetSessionTracker(ctx context.Context, sessionID string) (types.SessionTracker, error) {
if err := a.serverAction(); err != nil {
tracker, err := a.authServer.GetSessionTracker(ctx, sessionID)
if err != nil {
return nil, trace.Wrap(err)
}

if err := a.serverAction(); err == nil {
return tracker, nil
}

user := a.context.User
joinerRoles, err := services.FetchRoles(user.GetRoles(), a.authServer, user.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}

return a.authServer.GetSessionTracker(ctx, sessionID)
ok := a.filterSessionTracker(ctx, joinerRoles, tracker)
if !ok {
return nil, trace.NotFound("session %v not found", sessionID)
}

return tracker, nil
}

// GetActiveSessionTrackers returns a list of active session trackers.
Expand All @@ -289,18 +347,21 @@ func (a *ServerWithRoles) GetActiveSessionTrackers(ctx context.Context) ([]types
return nil, trace.Wrap(err)
}

var filteredSessions []types.SessionTracker
if err := a.serverAction(); err == nil {
return sessions, nil
}

for _, session := range sessions {
evaluator := NewSessionAccessEvaluator(session.GetHostPolicySets(), session.GetSessionKind())
joinerRoles, err := a.authServer.GetRoles(ctx)
if err != nil {
return nil, trace.Wrap(err)
}
var filteredSessions []types.SessionTracker
user := a.context.User
joinerRoles, err := services.FetchRoles(user.GetRoles(), a.authServer, user.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}

modes, err := evaluator.CanJoin(SessionAccessContext{Roles: joinerRoles})
if err == nil || len(modes) > 0 {
filteredSessions = append(filteredSessions, session)
for _, sess := range sessions {
ok := a.filterSessionTracker(ctx, joinerRoles, sess)
if ok {
filteredSessions = append(filteredSessions, sess)
}
}

Expand Down
22 changes: 13 additions & 9 deletions lib/auth/session_access.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,21 @@ func (e *SessionAccessEvaluator) matchesKind(allow []string) bool {
return false
}

// CanJoin returns the modes a user has access to join a session with.
// If the list is empty, the user doesn't have access to join the session at all.
func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.SessionParticipantMode, error) {
supported, err := e.supportsSessionAccessControls()
if err != nil {
return nil, trace.Wrap(err)
func HasV5Role(roles []types.Role) bool {
for _, role := range roles {
if role.GetVersion() == types.V5 {
return true
}
}
return false
}

// CanJoin returns the modes a user has access to join a session with.
// If the list is empty, the user doesn't have access to join the session at all.
func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) []types.SessionParticipantMode {
// If we don't support session access controls, return the default mode set that was supported prior to Moderated Sessions.
if !supported {
return preAccessControlsModes(e.kind), nil
if !HasV5Role(user.Roles) {
return preAccessControlsModes(e.kind)
}

var modes []types.SessionParticipantMode
Expand All @@ -200,7 +204,7 @@ func (e *SessionAccessEvaluator) CanJoin(user SessionAccessContext) ([]types.Ses
}
}

return modes, nil
return modes
}

func SliceContainsMode(s []types.SessionParticipantMode, e types.SessionParticipantMode) bool {
Expand Down
3 changes: 1 addition & 2 deletions lib/auth/session_access_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,8 +297,7 @@ func TestSessionAccessJoin(t *testing.T) {
t.Run(testCase.name, func(t *testing.T) {
policy := testCase.host.GetSessionPolicySet()
evaluator := NewSessionAccessEvaluator([]*types.SessionTrackerPolicySet{&policy}, testCase.sessionKind)
result, err := evaluator.CanJoin(testCase.participant)
require.NoError(t, err)
result := evaluator.CanJoin(testCase.participant)
require.Equal(t, testCase.expected, len(result) > 0)
})
}
Expand Down
12 changes: 0 additions & 12 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -3596,15 +3596,3 @@ func findActiveDatabases(key *Key) ([]tlsca.RouteToDatabase, error) {
}
return databases, nil
}

// GetActiveSessions fetches a list of all active sessions tracked by the SessionTracker resource
// that the user has access to.
func (tc *TeleportClient) GetActiveSessions(ctx context.Context) ([]types.SessionTracker, error) {
proxy, err := tc.ConnectToProxy(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

defer proxy.Close()
return proxy.GetActiveSessions(ctx)
}
15 changes: 0 additions & 15 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,6 @@ type NodeClient struct {
OnMFA func()
}

// GetActiveSessions returns a list of active session trackers.
func (proxy *ProxyClient) GetActiveSessions(ctx context.Context) ([]types.SessionTracker, error) {
auth, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return nil, trace.Wrap(err)
}
defer auth.Close()
sessions, err := auth.GetActiveSessionTrackers(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

return sessions, nil
}

// GetSites returns list of the "sites" (AKA teleport clusters) connected to the proxy
// Each site is returned as an instance of its auth server
//
Expand Down
6 changes: 1 addition & 5 deletions lib/kube/proxy/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -738,11 +738,7 @@ func (s *session) join(p *party) error {
Roles: roles,
}

modes, err := s.accessEvaluator.CanJoin(accessContext)
if err != nil {
return trace.Wrap(err)
}

modes := s.accessEvaluator.CanJoin(accessContext)
if !auth.SliceContainsMode(modes, p.Mode) {
return trace.AccessDenied("insufficient permissions to join session")
}
Expand Down
1 change: 0 additions & 1 deletion lib/services/local/sessiontracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ func (s *sessionTracker) GetActiveSessionTrackers(ctx context.Context) ([]types.
sessions = append(sessions, session)
case !after && item.Expires.IsZero():
// Clear item if expiry is not set on the backend.
// We currently don't set the expiry here but we will when #11551 is merged.
noExpiry = append(noExpiry, item)
default:
// If we take this branch, the expiry is set and the backend is responsible for cleaning up the item.
Expand Down
6 changes: 1 addition & 5 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -1496,11 +1496,7 @@ func (s *session) join(ch ssh.Channel, ctx *ServerContext, mode types.SessionPar
Roles: roles,
}

modes, err := s.access.CanJoin(accessContext)
if err != nil {
return nil, trace.Wrap(err)
}

modes := s.access.CanJoin(accessContext)
if !auth.SliceContainsMode(modes, mode) {
return nil, trace.AccessDenied("insufficient permissions to join session %v", s.id)
}
Expand Down
23 changes: 16 additions & 7 deletions tool/tsh/kube.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,17 @@ func newKubeJoinCommand(parent *kingpin.CmdClause) *kubeJoinCommand {
}

func (c *kubeJoinCommand) getSessionMeta(ctx context.Context, tc *client.TeleportClient) (types.SessionTracker, error) {
sessions, err := tc.GetActiveSessions(ctx)
proxy, err := tc.ConnectToProxy(ctx)
if err != nil {
return nil, trace.Wrap(err)
}

for _, session := range sessions {
if session.GetSessionID() == c.session {
return session, nil
}
site, err := proxy.ConnectToCurrentCluster(ctx, false)
if err != nil {
return nil, trace.Wrap(err)
}

return nil, trace.NotFound("session %q not found", c.session)
return site.GetSessionTracker(ctx, c.session)
}

func (c *kubeJoinCommand) run(cf *CLIConf) error {
Expand Down Expand Up @@ -489,7 +488,17 @@ func (c *kubeSessionsCommand) run(cf *CLIConf) error {
return trace.Wrap(err)
}

sessions, err := tc.GetActiveSessions(cf.Context)
proxy, err := tc.ConnectToProxy(cf.Context)
if err != nil {
return trace.Wrap(err)
}

site, err := proxy.ConnectToCurrentCluster(cf.Context, true)
if err != nil {
return trace.Wrap(err)
}

sessions, err := site.GetActiveSessionTrackers(cf.Context)
if err != nil {
return trace.Wrap(err)
}
Expand Down