Skip to content

Commit

Permalink
[web] Add ability to switchback to default roles/expiry (#6639)
Browse files Browse the repository at this point in the history
- Preserve login time with WebSession when user first creates a web session to derive
"default" expiry when user wants to switch back
- Change the signature of ExtendWebSession to accept a
NewWebSessionRequest struct that contains session information
- Create renewSessionRequest object to read from web request for endpoint renewSession
- Endpoint now also returns SessionExpires time that is used as countdown in UI
  • Loading branch information
kimlisa authored Apr 29, 2021
1 parent 8357e75 commit 73f40b3
Show file tree
Hide file tree
Showing 17 changed files with 957 additions and 770 deletions.
16 changes: 16 additions & 0 deletions api/types/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ type WebSession interface {
GetBearerTokenExpiryTime() time.Time
// GetExpiryTime - absolute time when web session expires
GetExpiryTime() time.Time
// GetLoginTime returns the time this user recently logged in.
GetLoginTime() time.Time
// SetLoginTime sets when this user logged in.
SetLoginTime(time.Time)
// WithoutSecrets returns copy of the web session but without private keys
WithoutSecrets() WebSession
// CheckAndSetDefaults checks and set default values for any missing fields.
Expand Down Expand Up @@ -252,6 +256,16 @@ func (ws *WebSessionV2) GetExpiryTime() time.Time {
return ws.Spec.Expires
}

// GetLoginTime returns the time this user recently logged in.
func (ws *WebSessionV2) GetLoginTime() time.Time {
return ws.Spec.LoginTime
}

// SetLoginTime sets when this user logged in.
func (ws *WebSessionV2) SetLoginTime(loginTime time.Time) {
ws.Spec.LoginTime = loginTime
}

// GetAppSessionRequest contains the parameters to request an application
// web session.
type GetAppSessionRequest struct {
Expand Down Expand Up @@ -490,6 +504,8 @@ type NewWebSessionRequest struct {
// SessionTTL optionally specifies the session time-to-live.
// If left unspecified, the default certificate duration is used.
SessionTTL time.Duration
// LoginTime is the time that this user recently logged in.
LoginTime time.Time
}

// Check validates the request.
Expand Down
1,365 changes: 704 additions & 661 deletions api/types/types.pb.go

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions api/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1627,6 +1627,12 @@ message WebSessionSpecV2 {
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "expires"
];
// LoginTime is the time this user recently logged in.
google.protobuf.Timestamp LoginTime = 8 [
(gogoproto.stdtime) = true,
(gogoproto.nullable) = false,
(gogoproto.jsontag) = "login_time"
];
}

// WebSessionFilter encodes cache watch parameters for filtering web sessions.
Expand Down
26 changes: 20 additions & 6 deletions lib/auth/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -760,25 +760,39 @@ func (s *APIServer) u2fSignRequest(auth ClientI, w http.ResponseWriter, r *http.
return u2fSignReq, nil
}

type createWebSessionReq struct {
PrevSessionID string `json:"prev_session_id"`
type WebSessionReq struct {
// User is the user name associated with the session id.
User string `json:"user"`
// PrevSessionID is the id of current session.
PrevSessionID string `json:"prev_session_id"`
// AccessRequestID is an optional field that holds the id of an approved access request.
AccessRequestID string `json:"access_request_id"`
// Switchback is a flag to indicate if user is wanting to switchback from an assumed role
// back to their default role.
Switchback bool `json:"switchback"`
}

func (s *APIServer) createWebSession(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
var req *createWebSessionReq
var req WebSessionReq
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
user := p.ByName("user")

// DELETE IN 8.0: proxy v5 sends request with no user field.
// And since proxy v6, request will come with user field set, so grabbing user
// by param is not required.
if req.User == "" {
req.User = p.ByName("user")
}

if req.PrevSessionID != "" {
sess, err := auth.ExtendWebSession(user, req.PrevSessionID, req.AccessRequestID)
sess, err := auth.ExtendWebSession(req)
if err != nil {
return nil, trace.Wrap(err)
}
return sess, nil
}
sess, err := auth.CreateWebSession(user)
sess, err := auth.CreateWebSession(req.User)
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
76 changes: 59 additions & 17 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -952,13 +952,18 @@ func (a *Server) CheckU2FSignResponse(ctx context.Context, user string, response
return a.checkU2F(ctx, user, *response, a.Identity)
}

// ExtendWebSession creates a new web session for a user based on a valid previous session.
// Additional roles are appended to initial roles if there is an approved access request.
// The new session expiration time will not exceed the expiration time of the old session.
func (a *Server) ExtendWebSession(user, prevSessionID, accessRequestID string, identity tlsca.Identity) (services.WebSession, error) {
// ExtendWebSession creates a new web session for a user based on a valid previous (current) session.
//
// If there is an approved access request, additional roles are appended to the roles that were
// extracted from identity. The new session expiration time will not exceed the expiration time
// of the previous session.
//
// If there is a switchback request, the roles will switchback to user's default roles and
// the expiration time is derived from users recently logged in time.
func (a *Server) ExtendWebSession(req WebSessionReq, identity tlsca.Identity) (services.WebSession, error) {
prevSession, err := a.GetWebSession(context.TODO(), types.GetWebSessionRequest{
User: user,
SessionID: prevSessionID,
User: req.User,
SessionID: req.PrevSessionID,
})
if err != nil {
return nil, trace.Wrap(err)
Expand All @@ -977,8 +982,8 @@ func (a *Server) ExtendWebSession(user, prevSessionID, accessRequestID string, i
return nil, trace.Wrap(err)
}

if accessRequestID != "" {
newRoles, requestExpiry, err := a.getRolesAndExpiryFromAccessRequest(user, accessRequestID)
if req.AccessRequestID != "" {
newRoles, requestExpiry, err := a.getRolesAndExpiryFromAccessRequest(req.User, req.AccessRequestID)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -992,9 +997,33 @@ func (a *Server) ExtendWebSession(user, prevSessionID, accessRequestID string, i
}
}

if req.Switchback {
if prevSession.GetLoginTime().IsZero() {
return nil, trace.BadParameter("Unable to switchback, log in time was not recorded.")
}

// Get default/static roles.
user, err := a.GetUser(req.User, false)
if err != nil {
return nil, trace.Wrap(err, "failed to switchback")
}

// Calculate expiry time.
roleSet, err := services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits())
if err != nil {
return nil, trace.Wrap(err)
}

sessionTTL := roleSet.AdjustSessionTTL(defaults.CertDuration)

// Set default roles and expiration.
expiresAt = prevSession.GetLoginTime().UTC().Add(sessionTTL)
roles = user.GetRoles()
}

sessionTTL := utils.ToTTL(a.clock, expiresAt)
sess, err := a.NewWebSession(types.NewWebSessionRequest{
User: user,
User: req.User,
Roles: roles,
Traits: traits,
SessionTTL: sessionTTL,
Expand All @@ -1003,7 +1032,10 @@ func (a *Server) ExtendWebSession(user, prevSessionID, accessRequestID string, i
return nil, trace.Wrap(err)
}

if err := a.upsertWebSession(context.TODO(), user, sess); err != nil {
// Keep preserving the login time.
sess.SetLoginTime(prevSession.GetLoginTime())

if err := a.upsertWebSession(context.TODO(), req.User, sess); err != nil {
return nil, trace.Wrap(err)
}

Expand Down Expand Up @@ -1058,9 +1090,10 @@ func (a *Server) CreateWebSession(user string) (services.WebSession, error) {
return nil, trace.Wrap(err)
}
sess, err := a.NewWebSession(types.NewWebSessionRequest{
User: user,
Roles: u.GetRoles(),
Traits: u.GetTraits(),
User: user,
Roles: u.GetRoles(),
Traits: u.GetTraits(),
LoginTime: a.clock.Now().UTC(),
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -1610,15 +1643,24 @@ func (a *Server) NewWebSession(req types.NewWebSessionRequest) (services.WebSess
return nil, trace.Wrap(err)
}
bearerTokenTTL := utils.MinTTL(sessionTTL, BearerTokenTTL)
return services.NewWebSession(token, services.KindWebSession, services.KindWebSession, services.WebSessionSpecV2{

startTime := a.clock.Now()
if !req.LoginTime.IsZero() {
startTime = req.LoginTime
}

sessionSpec := services.WebSessionSpecV2{
User: req.User,
Priv: priv,
Pub: certs.ssh,
TLSCert: certs.tls,
Expires: a.clock.Now().UTC().Add(sessionTTL),
Expires: startTime.UTC().Add(sessionTTL),
BearerToken: bearerToken,
BearerTokenExpires: a.clock.Now().UTC().Add(bearerTokenTTL),
}), nil
BearerTokenExpires: startTime.UTC().Add(bearerTokenTTL),
LoginTime: req.LoginTime,
}

return services.NewWebSession(token, services.KindWebSession, services.KindWebSession, sessionSpec), nil
}

// GetWebSessionInfo returns the web session specified with sessionID for the given user.
Expand Down
6 changes: 3 additions & 3 deletions lib/auth/auth_with_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -825,11 +825,11 @@ func (a *ServerWithRoles) CreateWebSession(user string) (services.WebSession, er
// ExtendWebSession creates a new web session for a user based on a valid previous session.
// Additional roles are appended to initial roles if there is an approved access request.
// The new session expiration time will not exceed the expiration time of the old session.
func (a *ServerWithRoles) ExtendWebSession(user, prevSessionID, accessRequestID string) (services.WebSession, error) {
if err := a.currentUserAction(user); err != nil {
func (a *ServerWithRoles) ExtendWebSession(req WebSessionReq) (services.WebSession, error) {
if err := a.currentUserAction(req.User); err != nil {
return nil, trace.Wrap(err)
}
return a.authServer.ExtendWebSession(user, prevSessionID, accessRequestID, a.context.Identity.GetIdentity())
return a.authServer.ExtendWebSession(req, a.context.Identity.GetIdentity())
}

// GetWebSessionInfo returns the web session for the given user specified with sid.
Expand Down
13 changes: 4 additions & 9 deletions lib/auth/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -1170,14 +1170,9 @@ func (c *Client) GetMFAAuthenticateChallenge(user string, password []byte) (*MFA

// ExtendWebSession creates a new web session for a user based on another
// valid web session
func (c *Client) ExtendWebSession(user string, prevSessionID string, accessRequestID string) (services.WebSession, error) {
func (c *Client) ExtendWebSession(req WebSessionReq) (services.WebSession, error) {
out, err := c.PostJSON(
c.Endpoint("users", user, "web", "sessions"),
createWebSessionReq{
PrevSessionID: prevSessionID,
AccessRequestID: accessRequestID,
},
)
c.Endpoint("users", req.User, "web", "sessions"), req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -1188,7 +1183,7 @@ func (c *Client) ExtendWebSession(user string, prevSessionID string, accessReque
func (c *Client) CreateWebSession(user string) (services.WebSession, error) {
out, err := c.PostJSON(
c.Endpoint("users", user, "web", "sessions"),
createWebSessionReq{},
WebSessionReq{User: user},
)
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -2442,7 +2437,7 @@ type WebService interface {
GetWebSessionInfo(ctx context.Context, user, sessionID string) (types.WebSession, error)
// ExtendWebSession creates a new web session for a user based on another
// valid web session
ExtendWebSession(user, prevSessionID, accessRequestID string) (types.WebSession, error)
ExtendWebSession(req WebSessionReq) (types.WebSession, error)
// CreateWebSession creates a new web session for a user
CreateWebSession(user string) (types.WebSession, error)

Expand Down
44 changes: 1 addition & 43 deletions lib/auth/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ func (a *Server) validateGithubAuthCallback(q url.Values) (*githubAuthResponse,
Roles: user.GetRoles(),
Traits: user.GetTraits(),
SessionTTL: params.sessionTTL,
LoginTime: a.clock.Now().UTC(),
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -308,49 +309,6 @@ func (a *Server) validateGithubAuthCallback(q url.Values) (*githubAuthResponse,
return re, nil
}

func (a *Server) createWebSession(ctx context.Context, req types.NewWebSessionRequest) (services.WebSession, error) {
// It's safe to extract the roles and traits directly from services.User
// because this occurs during the user creation process and services.User
// is not fetched from the backend.
session, err := a.NewWebSession(req)
if err != nil {
return nil, trace.Wrap(err)
}

err = a.upsertWebSession(ctx, req.User, session)
if err != nil {
return nil, trace.Wrap(err)
}

return session, nil
}

func (a *Server) createSessionCert(user services.User, sessionTTL time.Duration, publicKey []byte, compatibility, routeToCluster, kubernetesCluster string) ([]byte, []byte, error) {
// It's safe to extract the roles and traits directly from services.User
// because this occurs during the user creation process and services.User
// is not fetched from the backend.
checker, err := services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits())
if err != nil {
return nil, nil, trace.Wrap(err)
}

certs, err := a.generateUserCert(certRequest{
user: user,
ttl: sessionTTL,
publicKey: publicKey,
compatibility: compatibility,
checker: checker,
traits: user.GetTraits(),
routeToCluster: routeToCluster,
kubernetesCluster: kubernetesCluster,
})
if err != nil {
return nil, nil, trace.Wrap(err)
}

return certs.ssh, certs.tls, nil
}

// createUserParams is a set of parameters used to create a user for an
// external identity provider.
type createUserParams struct {
Expand Down
7 changes: 4 additions & 3 deletions lib/auth/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,9 +424,10 @@ func (s *Server) createUserWebSession(ctx context.Context, user services.User) (
// It's safe to extract the roles and traits directly from services.User as this method
// is only used for local accounts.
return s.createWebSession(ctx, types.NewWebSessionRequest{
User: user.GetName(),
Roles: user.GetRoles(),
Traits: user.GetTraits(),
User: user.GetName(),
Roles: user.GetRoles(),
Traits: user.GetTraits(),
LoginTime: s.clock.Now().UTC(),
})
}

Expand Down
1 change: 1 addition & 0 deletions lib/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,7 @@ func (a *Server) validateOIDCAuthCallback(q url.Values) (*oidcAuthResponse, erro
Roles: user.GetRoles(),
Traits: user.GetTraits(),
SessionTTL: params.sessionTTL,
LoginTime: a.clock.Now().UTC(),
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
1 change: 1 addition & 0 deletions lib/auth/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ func (a *Server) validateSAMLResponse(samlResponse string) (*samlAuthResponse, e
Roles: user.GetRoles(),
Traits: user.GetTraits(),
SessionTTL: params.sessionTTL,
LoginTime: a.clock.Now().UTC(),
})
if err != nil {
return re, trace.Wrap(err)
Expand Down
Loading

0 comments on commit 73f40b3

Please sign in to comment.