diff --git a/lib/auth/github.go b/lib/auth/github.go index ae21b8d110156..02080f5ded612 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -150,55 +150,48 @@ func (s *AuthServer) validateGithubAuthCallback(q url.Values) (*GithubAuthRespon if err != nil { return nil, trace.Wrap(err) } - err = s.createGithubUser(connector, *claims) + + // Calculate (figure out name, roles, traits, session TTL) of user and + // create the user in the backend. + params, err := s.calculateGithubUser(connector, claims, req) if err != nil { return nil, trace.Wrap(err) } + user, err := s.createGithubUser(params) + if err != nil { + return nil, trace.Wrap(err) + } + + // Auth was successful, return session, certificate, etc. to caller. response := &GithubAuthResponse{ + Req: *req, Identity: services.ExternalIdentity{ - ConnectorID: connector.GetName(), - Username: claims.Username, + ConnectorID: params.connectorName, + Username: params.username, }, - Req: *req, - } - user, err := s.Identity.GetUserByGithubIdentity(response.Identity) - if err != nil { - return nil, trace.Wrap(err) + Username: user.GetName(), } response.Username = user.GetName() - roles, err := services.FetchRoles(user.GetRoles(), s.Access, user.GetTraits()) - if err != nil { - return nil, trace.Wrap(err) - } + + // If the request is coming from a browser, create a web session. if req.CreateWebSession { - session, err := s.NewWebSession(user.GetName()) - if err != nil { - return nil, trace.Wrap(err) - } - sessionTTL := roles.AdjustSessionTTL(defaults.OAuth2TTL) - bearerTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) - session.SetExpiryTime(s.clock.Now().UTC().Add(sessionTTL)) - session.SetBearerTokenExpiryTime(s.clock.Now().UTC().Add(bearerTTL)) - err = s.UpsertWebSession(user.GetName(), session) + session, err := s.createWebSession(user, params.sessionTTL) if err != nil { return nil, trace.Wrap(err) } + response.Session = session } + + // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - certTTL := utils.MinTTL(defaults.OAuth2TTL, req.CertTTL) - certs, err := s.generateUserCert(certRequest{ - user: user, - roles: roles, - ttl: certTTL, - publicKey: req.PublicKey, - compatibility: req.Compatibility, - }) + sshCert, tlsCert, err := s.createSessionCert(user, params.sessionTTL, req.PublicKey, req.Compatibility) if err != nil { return nil, trace.Wrap(err) } - response.Cert = certs.ssh - response.TLSCert = certs.tls + + response.Cert = sshCert + response.TLSCert = tlsCert // Return the host CA for this cluster only. authority, err := s.GetCertAuthority(services.CertAuthID{ @@ -210,61 +203,157 @@ func (s *AuthServer) validateGithubAuthCallback(q url.Values) (*GithubAuthRespon } response.HostSigners = append(response.HostSigners, authority) } + return response, nil } -func (s *AuthServer) createGithubUser(connector services.GithubConnector, claims services.GithubClaims) error { - logins, kubeGroups := connector.MapClaims(claims) - if len(logins) == 0 { - return trace.BadParameter( +func (s *AuthServer) createWebSession(user services.User, sessionTTL time.Duration) (services.WebSession, error) { + session, err := s.NewWebSession(user.GetName()) + if err != nil { + return nil, trace.Wrap(err) + } + + // Session expiry time is the same as the user expiry time. + session.SetExpiryTime(s.clock.Now().UTC().Add(sessionTTL)) + + // Bearer tokens expire quicker than the overall session time and need to be refreshed. + bearerTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) + session.SetBearerTokenExpiryTime(s.clock.Now().UTC().Add(bearerTTL)) + + err = s.UpsertWebSession(user.GetName(), session) + if err != nil { + return nil, trace.Wrap(err) + } + + return session, nil +} + +func (s *AuthServer) createSessionCert(user services.User, sessionTTL time.Duration, publicKey []byte, compatibility string) ([]byte, []byte, error) { + roles, err := services.FetchRoles(user.GetRoles(), s.Access, user.GetTraits()) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + certs, err := s.generateUserCert(certRequest{ + user: user, + roles: roles, + ttl: sessionTTL, + publicKey: publicKey, + compatibility: compatibility, + }) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return certs.ssh, certs.tls, nil +} + +// createUserParams is a set of parameters used to create a user for an +// external identity provider. +type createUserParams struct { + // connectorName is the name of the connector for the identity provider. + connectorName string + + // username is the Teleport user name . + username string + + // logins is the list of *nix logins. + logins []string + + // kubeGroups is the list of Kubernetes this user belongs to. + kubeGroups []string + + // roles is the list of roles this user is assigned to. + roles []string + + // traits is the list of traits for this user. + traits map[string][]string + + // sessionTTL is how long this session will last. + sessionTTL time.Duration +} + +func (s *AuthServer) calculateGithubUser(connector services.GithubConnector, claims *services.GithubClaims, request *services.GithubAuthRequest) (*createUserParams, error) { + p := createUserParams{ + connectorName: connector.GetName(), + username: claims.Username, + } + + // Calculate logins, kubegroups, roles, and traits. + p.logins, p.kubeGroups = connector.MapClaims(*claims) + if len(p.logins) == 0 { + return nil, trace.BadParameter( "user %q does not belong to any teams configured in %q connector", claims.Username, connector.GetName()) } + p.roles = modules.GetModules().RolesFromLogins(p.logins) + p.traits = modules.GetModules().TraitsFromLogins(p.logins, p.kubeGroups) + + // Pick smaller for role: session TTL from role or requested TTL. + roles, err := services.FetchRoles(p.roles, s.Access, p.traits) + if err != nil { + return nil, trace.Wrap(err) + } + roleTTL := roles.AdjustSessionTTL(defaults.MaxCertDuration) + p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + + return &p, nil +} + +func (s *AuthServer) createGithubUser(p *createUserParams) (services.User, error) { + log.WithFields(logrus.Fields{trace.Component: "github"}).Debugf( "Generating dynamic identity %v/%v with logins: %v.", - connector.GetName(), claims.Username, logins) + p.connectorName, p.username, p.logins) + + expires := s.GetClock().Now().UTC().Add(p.sessionTTL) + user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{ Kind: services.KindUser, Version: services.V2, Metadata: services.Metadata{ - Name: claims.Username, + Name: p.username, Namespace: defaults.Namespace, + Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: modules.GetModules().RolesFromLogins(logins), - Traits: modules.GetModules().TraitsFromLogins(logins, kubeGroups), - Expires: s.clock.Now().UTC().Add(defaults.OAuth2TTL), + Roles: p.roles, + Traits: p.traits, GithubIdentities: []services.ExternalIdentity{{ - ConnectorID: connector.GetName(), - Username: claims.Username, + ConnectorID: p.connectorName, + Username: p.username, }}, CreatedBy: services.CreatedBy{ User: services.UserRef{Name: "system"}, - Time: time.Now().UTC(), + Time: s.GetClock().Now().UTC(), Connector: &services.ConnectorRef{ Type: teleport.ConnectorGithub, - ID: connector.GetName(), - Identity: claims.Username, + ID: p.connectorName, + Identity: p.username, }, }, }, }) - existingUser, err := s.GetUser(claims.Username) + if err != nil { + return nil, trace.Wrap(err) + } + + existingUser, err := s.GetUser(p.username) if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if existingUser != nil { ref := user.GetCreatedBy().Connector if !ref.IsSameProvider(existingUser.GetCreatedBy().Connector) { - return trace.AlreadyExists("user %q already exists and is not Github user", + return nil, trace.AlreadyExists("user %q already exists and is not Github user", existingUser.GetName()) } } err = s.UpsertUser(user) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + return user, nil } // populateGithubClaims retrieves information about user and its team diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index dc48374e1da46..fc20ad8e0fce3 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -17,18 +17,56 @@ limitations under the License. package auth import ( + "context" + "fmt" + "time" + + authority "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" - check "gopkg.in/check.v1" + "github.com/jonboulle/clockwork" + "gopkg.in/check.v1" ) -type GithubSuite struct{} +type GithubSuite struct { + a *AuthServer + b backend.Backend + c clockwork.FakeClock +} +var _ = fmt.Printf var _ = check.Suite(&GithubSuite{}) func (s *GithubSuite) SetUpSuite(c *check.C) { + var err error + utils.InitLoggerForTests() + + s.c = clockwork.NewFakeClockAt(time.Now()) + + s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: s.c, + }) + c.Assert(err, check.IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, check.IsNil) + + authConfig := &InitConfig{ + ClusterName: clusterName, + Backend: s.b, + Authority: authority.New(), + SkipPeriodicOperations: true, + } + s.a, err = NewAuthServer(authConfig) + c.Assert(err, check.IsNil) } func (s *GithubSuite) TestPopulateClaims(c *check.C) { @@ -43,6 +81,27 @@ func (s *GithubSuite) TestPopulateClaims(c *check.C) { }) } +func (s *GithubSuite) TestCreateGithubUser(c *check.C) { + // Create GitHub user with 1 minute expiry. + _, err := s.a.createGithubUser(&createUserParams{ + connectorName: "github", + username: "foo", + logins: []string{"foo"}, + roles: []string{"admin"}, + sessionTTL: 1 * time.Minute, + }) + c.Assert(err, check.IsNil) + + // Within that 1 minute period the user should still exist. + _, err = s.a.GetUser("foo") + c.Assert(err, check.IsNil) + + // Advance time 2 minutes, the user should be gone. + s.c.Advance(2 * time.Minute) + _, err = s.a.GetUser("foo") + c.Assert(err, check.NotNil) +} + type testGithubAPIClient struct{} func (c *testGithubAPIClient) getUser() (*userResponse, error) { diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 37d24286ca7a7..5cfbf0b6174fd 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -53,6 +53,8 @@ type TestAuthServerConfig struct { AcceptedUsage []string // CipherSuites is the list of ciphers that the server supports. CipherSuites []uint16 + // Clock is used to control time in tests. + Clock clockwork.FakeClock } // CheckAndSetDefaults checks and sets defaults @@ -63,6 +65,9 @@ func (cfg *TestAuthServerConfig) CheckAndSetDefaults() error { if cfg.Dir == "" { return trace.BadParameter("missing parameter Dir") } + if cfg.Clock == nil { + cfg.Clock = clockwork.NewFakeClockAt(time.Now()) + } if len(cfg.CipherSuites) == 0 { cfg.CipherSuites = utils.DefaultCipherSuites() } @@ -107,7 +112,11 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { TestAuthServerConfig: cfg, } var err error - srv.Backend, err = lite.NewWithConfig(context.TODO(), lite.Config{Path: cfg.Dir, PollStreamPeriod: 100 * time.Millisecond}) + srv.Backend, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: cfg.Dir, + PollStreamPeriod: 100 * time.Millisecond, + Clock: cfg.Clock, + }) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index 7611012ee0e52..e8c2578a6f936 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -22,7 +22,6 @@ import ( "io/ioutil" "net/http" "net/url" - "time" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" @@ -202,66 +201,55 @@ func (a *AuthServer) validateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, } log.Debugf("OIDC user %q expires at: %v.", ident.Email, ident.ExpiresAt) - response := &OIDCAuthResponse{ - Identity: services.ExternalIdentity{ConnectorID: connector.GetName(), Username: ident.Email}, - Req: *req, + if len(connector.GetClaimsToRoles()) == 0 { + return nil, trace.BadParameter("no claims to roles mapping, check connector documentation") } - log.Debugf("Applying %v OIDC claims to roles mappings.", len(connector.GetClaimsToRoles())) - if len(connector.GetClaimsToRoles()) != 0 { - if err := a.createOIDCUser(connector, ident, claims); err != nil { - return nil, trace.Wrap(err) - } - } - if !req.CheckUser { - return response, nil - } - - user, err := a.Identity.GetUserByOIDCIdentity(services.ExternalIdentity{ - ConnectorID: req.ConnectorID, Username: ident.Email}) + // Calculate (figure out name, roles, traits, session TTL) of user and + // create the user in the backend. + params, err := a.calculateOIDCUser(connector, claims, ident, req) if err != nil { return nil, trace.Wrap(err) } - response.Username = user.GetName() - - var roles services.RoleSet - roles, err = services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits()) + user, err := a.createOIDCUser(params) if err != nil { return nil, trace.Wrap(err) } - sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, ident.ExpiresAt)) - bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) + // Auth was successful, return session, certificate, etc. to caller. + response := &OIDCAuthResponse{ + Req: *req, + Identity: services.ExternalIdentity{ + ConnectorID: params.connectorName, + Username: params.username, + }, + Username: user.GetName(), + } + + if !req.CheckUser { + return response, nil + } + + // If the request is coming from a browser, create a web session. if req.CreateWebSession { - sess, err := a.NewWebSession(user.GetName()) + session, err := a.createWebSession(user, params.sessionTTL) if err != nil { return nil, trace.Wrap(err) } - // session will expire based on identity TTL and allowed session TTL - sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL)) - // bearer token will expire based on the expected session renewal - sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL)) - if err := a.UpsertWebSession(user.GetName(), sess); err != nil { - return nil, trace.Wrap(err) - } - response.Session = sess + + response.Session = session } + // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - certTTL := utils.MinTTL(utils.ToTTL(a.clock, ident.ExpiresAt), req.CertTTL) - certs, err := a.generateUserCert(certRequest{ - user: user, - roles: roles, - ttl: certTTL, - publicKey: req.PublicKey, - compatibility: req.Compatibility, - }) + sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, req.PublicKey, req.Compatibility) if err != nil { return nil, trace.Wrap(err) } - response.Cert = certs.ssh - response.TLSCert = certs.tls + + response.Cert = sshCert + response.TLSCert = tlsCert // Return the host CA for this cluster only. authority, err := a.GetCertAuthority(services.CertAuthID{ @@ -273,6 +261,7 @@ func (a *AuthServer) validateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, } response.HostSigners = append(response.HostSigners, authority) } + return response, nil } @@ -325,52 +314,72 @@ func claimsToTraitMap(claims jose.Claims) map[string][]string { return traits } -func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oidc.Identity, claims jose.Claims) error { - roles, err := a.buildOIDCRoles(connector, claims) +func (a *AuthServer) calculateOIDCUser(connector services.OIDCConnector, claims jose.Claims, ident *oidc.Identity, request *services.OIDCAuthRequest) (*createUserParams, error) { + var err error + + p := createUserParams{ + connectorName: connector.GetName(), + username: ident.Email, + } + + p.roles, err = a.buildOIDCRoles(connector, claims) + if err != nil { + return nil, trace.Wrap(err) + } + p.traits = claimsToTraitMap(claims) + + // Pick smaller for role: session TTL from role or requested TTL. + roles, err := services.FetchRoles(p.roles, a.Access, p.traits) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + roleTTL := roles.AdjustSessionTTL(defaults.MaxCertDuration) + p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + + return &p, nil +} - traits := claimsToTraitMap(claims) +func (a *AuthServer) createOIDCUser(p *createUserParams) (services.User, error) { + expires := a.GetClock().Now().UTC().Add(p.sessionTTL) - log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v.", connector.GetName(), ident.Email, roles) + log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v.", p.connectorName, p.username, p.roles) user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{ Kind: services.KindUser, Version: services.V2, Metadata: services.Metadata{ - Name: ident.Email, + Name: p.username, Namespace: defaults.Namespace, + Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: roles, - Traits: traits, - Expires: ident.ExpiresAt, + Roles: p.roles, + Traits: p.traits, OIDCIdentities: []services.ExternalIdentity{ - { - ConnectorID: connector.GetName(), - Username: ident.Email, + services.ExternalIdentity{ + ConnectorID: p.connectorName, + Username: p.username, }, }, CreatedBy: services.CreatedBy{ User: services.UserRef{Name: "system"}, - Time: time.Now().UTC(), + Time: a.clock.Now().UTC(), Connector: &services.ConnectorRef{ Type: teleport.ConnectorOIDC, - ID: connector.GetName(), - Identity: ident.Email, + ID: p.connectorName, + Identity: p.username, }, }, }, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } // Get the user to check if it already exists or not. - existingUser, err := a.GetUser(ident.Email) + existingUser, err := a.GetUser(p.username) if err != nil { if !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -380,7 +389,7 @@ func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oid // If the exisiting user is a local user, fail and advise how to fix the problem. if connectorRef == nil { - return trace.AlreadyExists("local user with name '%v' already exists. Either change "+ + return nil, trace.AlreadyExists("local user with name '%v' already exists. Either change "+ "email in OIDC identity or remove local user and try again.", existingUser.GetName()) } @@ -391,10 +400,10 @@ func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oid // Upsert the new user creating or updating whatever is in the database. err = a.UpsertUser(user) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + return user, nil } // claimsFromIDToken extracts claims from the ID token. diff --git a/lib/auth/oidc_test.go b/lib/auth/oidc_test.go new file mode 100644 index 0000000000000..d853de3872275 --- /dev/null +++ b/lib/auth/oidc_test.go @@ -0,0 +1,91 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "context" + "fmt" + "time" + + authority "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/jonboulle/clockwork" + "gopkg.in/check.v1" +) + +type OIDCSuite struct { + a *AuthServer + b backend.Backend + c clockwork.FakeClock +} + +var _ = fmt.Printf +var _ = check.Suite(&OIDCSuite{}) + +func (s *OIDCSuite) SetUpSuite(c *check.C) { + var err error + + utils.InitLoggerForTests() + + s.c = clockwork.NewFakeClockAt(time.Now()) + + s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: s.c, + }) + c.Assert(err, check.IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, check.IsNil) + + authConfig := &InitConfig{ + ClusterName: clusterName, + Backend: s.b, + Authority: authority.New(), + SkipPeriodicOperations: true, + } + s.a, err = NewAuthServer(authConfig) + c.Assert(err, check.IsNil) +} + +func (s *OIDCSuite) TestCreateOIDCUser(c *check.C) { + // Create OIDC user with 1 minute expiry. + _, err := s.a.createOIDCUser(&createUserParams{ + connectorName: "oidcService", + username: "foo@example.com", + logins: []string{"foo"}, + roles: []string{"admin"}, + sessionTTL: 1 * time.Minute, + }) + c.Assert(err, check.IsNil) + + // Within that 1 minute period the user should still exist. + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.IsNil) + + // Advance time 2 minutes, the user should be gone. + s.c.Advance(2 * time.Minute) + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.NotNil) +} diff --git a/lib/auth/password_test.go b/lib/auth/password_test.go index c0709f2a95896..ab53b859f336b 100644 --- a/lib/auth/password_test.go +++ b/lib/auth/password_test.go @@ -43,8 +43,8 @@ type PasswordSuite struct { a *AuthServer } -var _ = Suite(&PasswordSuite{}) var _ = fmt.Printf +var _ = Suite(&PasswordSuite{}) func (s *PasswordSuite) SetUpSuite(c *C) { utils.InitLoggerForTests() diff --git a/lib/auth/saml.go b/lib/auth/saml.go index 897a468c6790b..30f1175da35f1 100644 --- a/lib/auth/saml.go +++ b/lib/auth/saml.go @@ -1,3 +1,19 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package auth import ( @@ -5,7 +21,6 @@ import ( "compress/flate" "encoding/base64" "io/ioutil" - "time" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" @@ -105,47 +120,75 @@ func assertionsToTraitMap(assertionInfo saml2.AssertionInfo) map[string][]string return traits } -func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) error { - roles, err := a.buildSAMLRoles(connector, assertionInfo) +func (a *AuthServer) calculateSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, request *services.SAMLAuthRequest) (*createUserParams, error) { + var err error + + p := createUserParams{ + connectorName: connector.GetName(), + username: assertionInfo.NameID, + } + + p.roles, err = a.buildSAMLRoles(connector, assertionInfo) + if err != nil { + return nil, trace.Wrap(err) + } + p.traits = assertionsToTraitMap(assertionInfo) + + // Pick smaller for role: session TTL from role or requested TTL. + roles, err := services.FetchRoles(p.roles, a.Access, p.traits) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + roleTTL := roles.AdjustSessionTTL(defaults.MaxCertDuration) + p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + + return &p, nil +} + +func (a *AuthServer) createSAMLUser(p *createUserParams) (services.User, error) { + expires := a.GetClock().Now().UTC().Add(p.sessionTTL) - traits := assertionsToTraitMap(assertionInfo) + log.Debugf("Generating dynamic SAML identity %v/%v with roles: %v.", p.connectorName, p.username, p.roles) - log.Debugf("Generating dynamic SAML identity %v/%v with roles: %v.", connector.GetName(), assertionInfo.NameID, roles) user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{ Kind: services.KindUser, Version: services.V2, Metadata: services.Metadata{ - Name: assertionInfo.NameID, + Name: p.username, Namespace: defaults.Namespace, + Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: roles, - Traits: traits, - Expires: expiresAt, - SAMLIdentities: []services.ExternalIdentity{{ConnectorID: connector.GetName(), Username: assertionInfo.NameID}}, + Roles: p.roles, + Traits: p.traits, + SAMLIdentities: []services.ExternalIdentity{ + services.ExternalIdentity{ + ConnectorID: p.connectorName, + Username: p.username, + }, + }, CreatedBy: services.CreatedBy{ - User: services.UserRef{Name: "system"}, - Time: time.Now().UTC(), + User: services.UserRef{ + Name: "system", + }, + Time: a.clock.Now().UTC(), Connector: &services.ConnectorRef{ Type: teleport.ConnectorSAML, - ID: connector.GetName(), - Identity: assertionInfo.NameID, + ID: p.connectorName, + Identity: p.username, }, }, }, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } // Get the user to check if it already exists or not. - existingUser, err := a.GetUser(assertionInfo.NameID) + existingUser, err := a.GetUser(p.username) if err != nil { if !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -155,7 +198,7 @@ func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionI // If the exisiting user is a local user, fail and advise how to fix the problem. if connectorRef == nil { - return trace.AlreadyExists("local user with name '%v' already exists. Either change "+ + return nil, trace.AlreadyExists("local user with name '%v' already exists. Either change "+ "NameID in assertion or remove local user and try again.", existingUser.GetName()) } @@ -166,10 +209,10 @@ func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionI // Upsert the new user creating or updating whatever is in the database. err = a.UpsertUser(user) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + return user, nil } func parseSAMLInResponseTo(response string) (string, error) { @@ -293,67 +336,51 @@ func (a *AuthServer) validateSAMLResponse(samlResponse string) (*SAMLAuthRespons } log.Debugf("SAML assertion warnings: %+v.", assertionInfo.WarningInfo) - log.Debugf("Applying %v SAML attribute to roles mappings.", len(connector.GetAttributesToRoles())) if len(connector.GetAttributesToRoles()) == 0 { - return nil, trace.BadParameter("SAML does not support binding to local users") - } - // TODO(klizhentas) use SessionNotOnOrAfter to calculate expiration time - expiresAt := a.clock.Now().Add(defaults.CertDuration) - if err := a.createSAMLUser(connector, *assertionInfo, expiresAt); err != nil { - return nil, trace.Wrap(err) + return nil, trace.BadParameter("no attributes to roles mapping, check connector documentation") } + log.Debugf("Applying %v SAML attribute to roles mappings.", len(connector.GetAttributesToRoles())) - identity := services.ExternalIdentity{ - ConnectorID: request.ConnectorID, - Username: assertionInfo.NameID, + // Calculate (figure out name, roles, traits, session TTL) of user and + // create the user in the backend. + params, err := a.calculateSAMLUser(connector, *assertionInfo, request) + if err != nil { + return nil, trace.Wrap(err) } - user, err := a.Identity.GetUserBySAMLIdentity(identity) + user, err := a.createSAMLUser(params) if err != nil { return nil, trace.Wrap(err) } + + // Auth was successful, return session, certificate, etc. to caller. response := &SAMLAuthResponse{ - Req: *request, - Identity: identity, + Req: *request, + Identity: services.ExternalIdentity{ + ConnectorID: params.connectorName, + Username: params.username, + }, Username: user.GetName(), } - var roles services.RoleSet - roles, err = services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits()) - if err != nil { - return nil, trace.Wrap(err) - } - sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, expiresAt)) - bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) - + // If the request is coming from a browser, create a web session. if request.CreateWebSession { - sess, err := a.NewWebSession(user.GetName()) + session, err := a.createWebSession(user, params.sessionTTL) if err != nil { return nil, trace.Wrap(err) } - // session will expire based on identity TTL and allowed session TTL - sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL)) - // bearer token will expire based on the expected session renewal - sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL)) - if err := a.UpsertWebSession(user.GetName(), sess); err != nil { - return nil, trace.Wrap(err) - } - response.Session = sess + + response.Session = session } + // If a public key was provided, sign it and return a certificate. if len(request.PublicKey) != 0 { - certTTL := utils.MinTTL(sessionTTL, request.CertTTL) - certs, err := a.generateUserCert(certRequest{ - user: user, - roles: roles, - ttl: certTTL, - publicKey: request.PublicKey, - compatibility: request.Compatibility, - }) + sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, request.PublicKey, request.Compatibility) if err != nil { return nil, trace.Wrap(err) } - response.Cert = certs.ssh - response.TLSCert = certs.tls + + response.Cert = sshCert + response.TLSCert = tlsCert // Return the host CA for this cluster only. authority, err := a.GetCertAuthority(services.CertAuthID{ @@ -365,5 +392,6 @@ func (a *AuthServer) validateSAMLResponse(samlResponse string) (*SAMLAuthRespons } response.HostSigners = append(response.HostSigners, authority) } + return response, nil } diff --git a/lib/auth/saml_test.go b/lib/auth/saml_test.go new file mode 100644 index 0000000000000..34192950254ab --- /dev/null +++ b/lib/auth/saml_test.go @@ -0,0 +1,91 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "context" + "fmt" + "time" + + authority "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/jonboulle/clockwork" + "gopkg.in/check.v1" +) + +type SAMLSuite struct { + a *AuthServer + b backend.Backend + c clockwork.FakeClock +} + +var _ = fmt.Printf +var _ = check.Suite(&SAMLSuite{}) + +func (s *SAMLSuite) SetUpSuite(c *check.C) { + var err error + + utils.InitLoggerForTests() + + s.c = clockwork.NewFakeClockAt(time.Now()) + + s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: s.c, + }) + c.Assert(err, check.IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, check.IsNil) + + authConfig := &InitConfig{ + ClusterName: clusterName, + Backend: s.b, + Authority: authority.New(), + SkipPeriodicOperations: true, + } + s.a, err = NewAuthServer(authConfig) + c.Assert(err, check.IsNil) +} + +func (s *SAMLSuite) TestCreateSAMLUser(c *check.C) { + // Create SAML user with 1 minute expiry. + _, err := s.a.createSAMLUser(&createUserParams{ + connectorName: "samlService", + username: "foo@example.com", + logins: []string{"foo"}, + roles: []string{"admin"}, + sessionTTL: 1 * time.Minute, + }) + c.Assert(err, check.IsNil) + + // Within that 1 minute period the user should still exist. + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.IsNil) + + // Advance time 2 minutes, the user should be gone. + s.c.Advance(2 * time.Minute) + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.NotNil) +} diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 4f4072438aa67..f5aba776ab285 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1917,6 +1917,17 @@ waitLoop: log.Debugf("Skipping stale event %v, latest object version is %v", event.Resource.GetResourceID(), ca.GetResourceID()) continue waitLoop } + + eca, ok := event.Resource.(*services.CertAuthorityV2) + if !ok { + log.Debugf("Wrong resource type, expected *services.CertAuthorityV2 got %T.", event.Resource) + continue waitLoop + } + if eca.GetRotation().State != services.RotationStateInProgress { + log.Debugf("Skipping CA, wrong rotation state: %v.", eca.GetRotation().State) + continue waitLoop + } + fixtures.DeepCompare(c, ca, event.Resource) break waitLoop } diff --git a/lib/services/local/services_test.go b/lib/services/local/services_test.go index f40464b856b66..9955310f8facf 100644 --- a/lib/services/local/services_test.go +++ b/lib/services/local/services_test.go @@ -18,6 +18,7 @@ package local import ( "context" + "fmt" "testing" "time" @@ -26,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/services/suite" "github.com/gravitational/teleport/lib/utils" + "github.com/jonboulle/clockwork" "gopkg.in/check.v1" ) @@ -36,6 +38,7 @@ type ServicesSuite struct { suite *suite.ServicesTestSuite } +var _ = fmt.Printf var _ = check.Suite(&ServicesSuite{}) func (s *ServicesSuite) SetUpSuite(c *check.C) { @@ -45,18 +48,25 @@ func (s *ServicesSuite) SetUpSuite(c *check.C) { func (s *ServicesSuite) SetUpTest(c *check.C) { var err error - s.bk, err = lite.NewWithConfig(context.TODO(), lite.Config{Path: c.MkDir(), PollStreamPeriod: 200 * time.Millisecond}) + clock := clockwork.NewFakeClockAt(time.Now()) + + s.bk, err = lite.NewWithConfig(context.TODO(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: clock, + }) c.Assert(err, check.IsNil) - suite := &suite.ServicesTestSuite{} - suite.CAS = NewCAService(s.bk) - suite.PresenceS = NewPresenceService(s.bk) - suite.ProvisioningS = NewProvisioningService(s.bk) - suite.WebS = NewIdentityService(s.bk) - suite.Access = NewAccessService(s.bk) - suite.EventsS = NewEventsService(s.bk) - suite.ChangesC = make(chan interface{}) - s.suite = suite + s.suite = &suite.ServicesTestSuite{ + CAS: NewCAService(s.bk), + PresenceS: NewPresenceService(s.bk), + ProvisioningS: NewProvisioningService(s.bk), + WebS: NewIdentityService(s.bk), + Access: NewAccessService(s.bk), + EventsS: NewEventsService(s.bk), + ChangesC: make(chan interface{}), + Clock: clock, + } } func (s *ServicesSuite) TearDownTest(c *check.C) { @@ -79,6 +89,10 @@ func (s *ServicesSuite) TestUsersCRUD(c *check.C) { s.suite.UsersCRUD(c) } +func (s *ServicesSuite) TestUsersExpiry(c *check.C) { + s.suite.UsersExpiry(c) +} + func (s *ServicesSuite) TestLoginAttempts(c *check.C) { s.suite.LoginAttempts(c) } diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index b2ab2854028e2..78be8527b1c6b 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -23,6 +23,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/base64" + "fmt" "sort" "time" @@ -41,6 +42,8 @@ import ( "gopkg.in/check.v1" ) +var _ = fmt.Printf + // NewTestCA returns new test authority with a test key as a public and // signing key func NewTestCA(caType services.CertAuthType, clusterName string, privateKeys ...[]byte) *services.CertAuthorityV2 { @@ -93,6 +96,7 @@ type ServicesTestSuite struct { ConfigS services.ClusterConfiguration EventsS services.Events ChangesC chan interface{} + Clock clockwork.FakeClock } func (s *ServicesTestSuite) collectChanges(c *check.C, expected int) []interface{} { @@ -184,6 +188,33 @@ func (s *ServicesTestSuite) UsersCRUD(c *check.C) { fixtures.ExpectBadParameter(c, err) } +func (s *ServicesTestSuite) UsersExpiry(c *check.C) { + expiresAt := s.Clock.Now().Add(1 * time.Minute) + + err := s.WebS.UpsertUser(&services.UserV2{ + Kind: services.KindUser, + Version: services.V2, + Metadata: services.Metadata{ + Name: "foo", + Namespace: defaults.Namespace, + Expires: &expiresAt, + }, + Spec: services.UserSpecV2{}, + }) + c.Assert(err, check.IsNil) + + // Make sure the user exists. + u, err := s.WebS.GetUser("foo") + c.Assert(err, check.IsNil) + c.Assert(u.GetName(), check.Equals, "foo") + + s.Clock.Advance(2 * time.Minute) + + // Make sure the user is now gone. + u, err = s.WebS.GetUser("foo") + c.Assert(err, check.NotNil) +} + func (s *ServicesTestSuite) LoginAttempts(c *check.C) { user := newUser("user1", []string{"admin", "user"}) c.Assert(s.WebS.UpsertUser(user), check.IsNil) @@ -560,8 +591,8 @@ func (s *ServicesTestSuite) SAMLCRUD(c *check.C) { Namespace: defaults.Namespace, }, Spec: services.SAMLConnectorSpecV2{ - Issuer: "http://example.com", - SSO: "https://example.com/saml/sso", + Issuer: "http://example.com", + SSO: "https://example.com/saml/sso", AssertionConsumerService: "https://localhost/acs", Audience: "https://localhost/aud", ServiceProviderIssuer: "https://localhost/iss", diff --git a/lib/services/user.go b/lib/services/user.go index d5e6204e5e079..9ac51bb9458d2 100644 --- a/lib/services/user.go +++ b/lib/services/user.go @@ -1,3 +1,19 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package services import ( @@ -402,12 +418,10 @@ func (u *UserV2) Equals(other User) bool { return true } -// Expiry returns expiry time for temporary users +// Expiry returns expiry time for temporary users. Prefer expires from +// metadata, if it does not exist, fall back to expires in spec. func (u *UserV2) Expiry() time.Time { - if u.Metadata.Expires == nil { - return time.Time{} - } - if !u.Metadata.Expires.IsZero() { + if u.Metadata.Expires != nil && !u.Metadata.Expires.IsZero() { return *u.Metadata.Expires } return u.Spec.Expires diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index cd79942812724..1dcfb3b32afd0 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -128,6 +128,7 @@ func (s *WebSuite) SetUpTest(c *C) { authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ ClusterName: "localhost", Dir: c.MkDir(), + Clock: clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)), }) c.Assert(err, IsNil) s.server, err = authServer.NewTestTLSServer()