From a56acc14d12e530687b6064e707ea7cdc5589580 Mon Sep 17 00:00:00 2001 From: Russell Jones Date: Wed, 6 Feb 2019 02:12:40 +0000 Subject: [PATCH 1/3] Moved expires to resource metadata for services.Users. Moved expiry field from spec to metadata for services.Users and updated expiry check to prefer metadata and fallback to spec if not found. Added test coverage. --- lib/auth/github.go | 11 +-- lib/auth/github_test.go | 77 ++++++++++++++++++- lib/auth/helpers.go | 11 ++- lib/auth/oidc.go | 6 +- lib/auth/oidc_test.go | 110 +++++++++++++++++++++++++++ lib/auth/password_test.go | 2 +- lib/auth/saml.go | 18 ++++- lib/auth/saml_test.go | 112 ++++++++++++++++++++++++++++ lib/services/local/services_test.go | 34 ++++++--- lib/services/suite/suite.go | 35 ++++++++- lib/services/user.go | 24 ++++-- lib/web/apiserver_test.go | 1 + 12 files changed, 411 insertions(+), 30 deletions(-) create mode 100644 lib/auth/oidc_test.go create mode 100644 lib/auth/saml_test.go diff --git a/lib/auth/github.go b/lib/auth/github.go index ae21b8d110156..5cbbc1d7572c6 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -150,7 +150,8 @@ func (s *AuthServer) validateGithubAuthCallback(q url.Values) (*GithubAuthRespon if err != nil { return nil, trace.Wrap(err) } - err = s.createGithubUser(connector, *claims) + expires := s.clock.Now().UTC().Add(defaults.OAuth2TTL) + err = s.createGithubUser(connector, *claims, expires) if err != nil { return nil, trace.Wrap(err) } @@ -213,7 +214,7 @@ func (s *AuthServer) validateGithubAuthCallback(q url.Values) (*GithubAuthRespon return response, nil } -func (s *AuthServer) createGithubUser(connector services.GithubConnector, claims services.GithubClaims) error { +func (s *AuthServer) createGithubUser(connector services.GithubConnector, claims services.GithubClaims, expires time.Time) error { logins, kubeGroups := connector.MapClaims(claims) if len(logins) == 0 { return trace.BadParameter( @@ -229,11 +230,11 @@ func (s *AuthServer) createGithubUser(connector services.GithubConnector, claims Metadata: services.Metadata{ Name: claims.Username, Namespace: defaults.Namespace, + Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: modules.GetModules().RolesFromLogins(logins), - Traits: modules.GetModules().TraitsFromLogins(logins, kubeGroups), - Expires: s.clock.Now().UTC().Add(defaults.OAuth2TTL), + Roles: modules.GetModules().RolesFromLogins(logins), + Traits: modules.GetModules().TraitsFromLogins(logins, kubeGroups), GithubIdentities: []services.ExternalIdentity{{ ConnectorID: connector.GetName(), Username: claims.Username, diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index dc48374e1da46..f71d01e0b0b31 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -17,18 +17,56 @@ limitations under the License. package auth import ( + "context" + "fmt" + "time" + + authority "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" - check "gopkg.in/check.v1" + "github.com/jonboulle/clockwork" + "gopkg.in/check.v1" ) -type GithubSuite struct{} +type GithubSuite struct { + a *AuthServer + b backend.Backend + c clockwork.FakeClock +} +var _ = fmt.Printf var _ = check.Suite(&GithubSuite{}) func (s *GithubSuite) SetUpSuite(c *check.C) { + var err error + utils.InitLoggerForTests() + + s.c = clockwork.NewFakeClockAt(time.Now()) + + s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: s.c, + }) + c.Assert(err, check.IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, check.IsNil) + + authConfig := &InitConfig{ + ClusterName: clusterName, + Backend: s.b, + Authority: authority.New(), + SkipPeriodicOperations: true, + } + s.a, err = NewAuthServer(authConfig) + c.Assert(err, check.IsNil) } func (s *GithubSuite) TestPopulateClaims(c *check.C) { @@ -43,6 +81,41 @@ func (s *GithubSuite) TestPopulateClaims(c *check.C) { }) } +func (s *GithubSuite) TestCreateGithubUser(c *check.C) { + connector := services.NewGithubConnector("github", services.GithubConnectorSpecV3{ + ClientID: "fakeClientID", + ClientSecret: "fakeClientSecret", + RedirectURL: "https://www.example.com", + TeamsToLogins: []services.TeamMapping{ + services.TeamMapping{ + Organization: "fakeOrg", + Team: "fakeTeam", + Logins: []string{"foo"}, + }, + }, + }) + + claims := services.GithubClaims{ + Username: "foo", + OrganizationToTeams: map[string][]string{ + "fakeOrg": []string{"fakeTeam"}, + }, + } + + // Create GitHub user with 1 minute expiry. + err := s.a.createGithubUser(connector, claims, s.c.Now().Add(1*time.Minute)) + c.Assert(err, check.IsNil) + + // Within that 1 minute period the user should still exist. + _, err = s.a.GetUser("foo") + c.Assert(err, check.IsNil) + + // Advance time 2 minutes, the user should be gone. + s.c.Advance(2 * time.Minute) + _, err = s.a.GetUser("foo") + c.Assert(err, check.NotNil) +} + type testGithubAPIClient struct{} func (c *testGithubAPIClient) getUser() (*userResponse, error) { diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 37d24286ca7a7..5cfbf0b6174fd 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -53,6 +53,8 @@ type TestAuthServerConfig struct { AcceptedUsage []string // CipherSuites is the list of ciphers that the server supports. CipherSuites []uint16 + // Clock is used to control time in tests. + Clock clockwork.FakeClock } // CheckAndSetDefaults checks and sets defaults @@ -63,6 +65,9 @@ func (cfg *TestAuthServerConfig) CheckAndSetDefaults() error { if cfg.Dir == "" { return trace.BadParameter("missing parameter Dir") } + if cfg.Clock == nil { + cfg.Clock = clockwork.NewFakeClockAt(time.Now()) + } if len(cfg.CipherSuites) == 0 { cfg.CipherSuites = utils.DefaultCipherSuites() } @@ -107,7 +112,11 @@ func NewTestAuthServer(cfg TestAuthServerConfig) (*TestAuthServer, error) { TestAuthServerConfig: cfg, } var err error - srv.Backend, err = lite.NewWithConfig(context.TODO(), lite.Config{Path: cfg.Dir, PollStreamPeriod: 100 * time.Millisecond}) + srv.Backend, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: cfg.Dir, + PollStreamPeriod: 100 * time.Millisecond, + Clock: cfg.Clock, + }) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index 7611012ee0e52..3d6b9cc97a472 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -340,11 +340,11 @@ func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oid Metadata: services.Metadata{ Name: ident.Email, Namespace: defaults.Namespace, + Expires: &ident.ExpiresAt, }, Spec: services.UserSpecV2{ - Roles: roles, - Traits: traits, - Expires: ident.ExpiresAt, + Roles: roles, + Traits: traits, OIDCIdentities: []services.ExternalIdentity{ { ConnectorID: connector.GetName(), diff --git a/lib/auth/oidc_test.go b/lib/auth/oidc_test.go new file mode 100644 index 0000000000000..8c95fb8403af6 --- /dev/null +++ b/lib/auth/oidc_test.go @@ -0,0 +1,110 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "context" + "fmt" + "time" + + authority "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/coreos/go-oidc/oidc" + "github.com/jonboulle/clockwork" + "gopkg.in/check.v1" +) + +type OIDCSuite struct { + a *AuthServer + b backend.Backend + c clockwork.FakeClock +} + +var _ = fmt.Printf +var _ = check.Suite(&OIDCSuite{}) + +func (s *OIDCSuite) SetUpSuite(c *check.C) { + var err error + + utils.InitLoggerForTests() + + s.c = clockwork.NewFakeClockAt(time.Now()) + + s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: s.c, + }) + c.Assert(err, check.IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, check.IsNil) + + authConfig := &InitConfig{ + ClusterName: clusterName, + Backend: s.b, + Authority: authority.New(), + SkipPeriodicOperations: true, + } + s.a, err = NewAuthServer(authConfig) + c.Assert(err, check.IsNil) +} + +func (s *OIDCSuite) TestCreateOIDCUser(c *check.C) { + connector := services.NewOIDCConnector("oidcService", services.OIDCConnectorSpecV2{ + IssuerURL: "https://www.example.com", + ClientID: "fakeClientID", + ClientSecret: "fakeClientSecret", + RedirectURL: "https://www.example.com/redirect", + Scope: []string{"profile", "email"}, + ClaimsToRoles: []services.ClaimMapping{ + services.ClaimMapping{ + Claim: "email", + Value: "foo@example.com", + Roles: []string{"admin"}, + }, + }, + }) + + ident := &oidc.Identity{ + Email: "foo@example.com", + ExpiresAt: s.c.Now().Add(1 * time.Minute), + } + + claims := map[string]interface{}{ + "email": "foo@example.com", + } + + // Create OIDC user with 1 minute expiry. + err := s.a.createOIDCUser(connector, ident, claims) + c.Assert(err, check.IsNil) + + // Within that 1 minute period the user should still exist. + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.IsNil) + + // Advance time 2 minutes, the user should be gone. + s.c.Advance(2 * time.Minute) + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.NotNil) +} diff --git a/lib/auth/password_test.go b/lib/auth/password_test.go index c0709f2a95896..ab53b859f336b 100644 --- a/lib/auth/password_test.go +++ b/lib/auth/password_test.go @@ -43,8 +43,8 @@ type PasswordSuite struct { a *AuthServer } -var _ = Suite(&PasswordSuite{}) var _ = fmt.Printf +var _ = Suite(&PasswordSuite{}) func (s *PasswordSuite) SetUpSuite(c *C) { utils.InitLoggerForTests() diff --git a/lib/auth/saml.go b/lib/auth/saml.go index 897a468c6790b..88d676967e5fc 100644 --- a/lib/auth/saml.go +++ b/lib/auth/saml.go @@ -1,3 +1,19 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package auth import ( @@ -120,11 +136,11 @@ func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionI Metadata: services.Metadata{ Name: assertionInfo.NameID, Namespace: defaults.Namespace, + Expires: &expiresAt, }, Spec: services.UserSpecV2{ Roles: roles, Traits: traits, - Expires: expiresAt, SAMLIdentities: []services.ExternalIdentity{{ConnectorID: connector.GetName(), Username: assertionInfo.NameID}}, CreatedBy: services.CreatedBy{ User: services.UserRef{Name: "system"}, diff --git a/lib/auth/saml_test.go b/lib/auth/saml_test.go new file mode 100644 index 0000000000000..6134b92799ac2 --- /dev/null +++ b/lib/auth/saml_test.go @@ -0,0 +1,112 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "context" + "fmt" + "time" + + authority "github.com/gravitational/teleport/lib/auth/testauthority" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/backend/lite" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/utils" + + "github.com/jonboulle/clockwork" + saml2 "github.com/russellhaering/gosaml2" + "github.com/russellhaering/gosaml2/types" + "gopkg.in/check.v1" +) + +type SAMLSuite struct { + a *AuthServer + b backend.Backend + c clockwork.FakeClock +} + +var _ = fmt.Printf +var _ = check.Suite(&SAMLSuite{}) + +func (s *SAMLSuite) SetUpSuite(c *check.C) { + var err error + + utils.InitLoggerForTests() + + s.c = clockwork.NewFakeClockAt(time.Now()) + + s.b, err = lite.NewWithConfig(context.Background(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: s.c, + }) + c.Assert(err, check.IsNil) + + clusterName, err := services.NewClusterName(services.ClusterNameSpecV2{ + ClusterName: "me.localhost", + }) + c.Assert(err, check.IsNil) + + authConfig := &InitConfig{ + ClusterName: clusterName, + Backend: s.b, + Authority: authority.New(), + SkipPeriodicOperations: true, + } + s.a, err = NewAuthServer(authConfig) + c.Assert(err, check.IsNil) +} + +func (s *SAMLSuite) TestCreateSAMLUser(c *check.C) { + connector := services.NewSAMLConnector("samlService", services.SAMLConnectorSpecV2{ + AssertionConsumerService: "https://www.example.com", + AttributesToRoles: []services.AttributeMapping{ + services.AttributeMapping{ + Name: "groups", + Value: "everyone", + Roles: []string{"admin"}, + }, + }, + }) + + assertionInfo := saml2.AssertionInfo{ + NameID: "foo@example.com", + Values: map[string]types.Attribute{ + "groups": types.Attribute{ + Name: "groups", + Values: []types.AttributeValue{ + types.AttributeValue{ + Value: "everyone", + }, + }, + }, + }, + } + + // Create SAML user with 1 minute expiry. + err := s.a.createSAMLUser(connector, assertionInfo, s.c.Now().Add(1*time.Minute)) + c.Assert(err, check.IsNil) + + // Within that 1 minute period the user should still exist. + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.IsNil) + + // Advance time 2 minutes, the user should be gone. + s.c.Advance(2 * time.Minute) + _, err = s.a.GetUser("foo@example.com") + c.Assert(err, check.NotNil) +} diff --git a/lib/services/local/services_test.go b/lib/services/local/services_test.go index f40464b856b66..9955310f8facf 100644 --- a/lib/services/local/services_test.go +++ b/lib/services/local/services_test.go @@ -18,6 +18,7 @@ package local import ( "context" + "fmt" "testing" "time" @@ -26,6 +27,7 @@ import ( "github.com/gravitational/teleport/lib/services/suite" "github.com/gravitational/teleport/lib/utils" + "github.com/jonboulle/clockwork" "gopkg.in/check.v1" ) @@ -36,6 +38,7 @@ type ServicesSuite struct { suite *suite.ServicesTestSuite } +var _ = fmt.Printf var _ = check.Suite(&ServicesSuite{}) func (s *ServicesSuite) SetUpSuite(c *check.C) { @@ -45,18 +48,25 @@ func (s *ServicesSuite) SetUpSuite(c *check.C) { func (s *ServicesSuite) SetUpTest(c *check.C) { var err error - s.bk, err = lite.NewWithConfig(context.TODO(), lite.Config{Path: c.MkDir(), PollStreamPeriod: 200 * time.Millisecond}) + clock := clockwork.NewFakeClockAt(time.Now()) + + s.bk, err = lite.NewWithConfig(context.TODO(), lite.Config{ + Path: c.MkDir(), + PollStreamPeriod: 200 * time.Millisecond, + Clock: clock, + }) c.Assert(err, check.IsNil) - suite := &suite.ServicesTestSuite{} - suite.CAS = NewCAService(s.bk) - suite.PresenceS = NewPresenceService(s.bk) - suite.ProvisioningS = NewProvisioningService(s.bk) - suite.WebS = NewIdentityService(s.bk) - suite.Access = NewAccessService(s.bk) - suite.EventsS = NewEventsService(s.bk) - suite.ChangesC = make(chan interface{}) - s.suite = suite + s.suite = &suite.ServicesTestSuite{ + CAS: NewCAService(s.bk), + PresenceS: NewPresenceService(s.bk), + ProvisioningS: NewProvisioningService(s.bk), + WebS: NewIdentityService(s.bk), + Access: NewAccessService(s.bk), + EventsS: NewEventsService(s.bk), + ChangesC: make(chan interface{}), + Clock: clock, + } } func (s *ServicesSuite) TearDownTest(c *check.C) { @@ -79,6 +89,10 @@ func (s *ServicesSuite) TestUsersCRUD(c *check.C) { s.suite.UsersCRUD(c) } +func (s *ServicesSuite) TestUsersExpiry(c *check.C) { + s.suite.UsersExpiry(c) +} + func (s *ServicesSuite) TestLoginAttempts(c *check.C) { s.suite.LoginAttempts(c) } diff --git a/lib/services/suite/suite.go b/lib/services/suite/suite.go index b2ab2854028e2..78be8527b1c6b 100644 --- a/lib/services/suite/suite.go +++ b/lib/services/suite/suite.go @@ -23,6 +23,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/base64" + "fmt" "sort" "time" @@ -41,6 +42,8 @@ import ( "gopkg.in/check.v1" ) +var _ = fmt.Printf + // NewTestCA returns new test authority with a test key as a public and // signing key func NewTestCA(caType services.CertAuthType, clusterName string, privateKeys ...[]byte) *services.CertAuthorityV2 { @@ -93,6 +96,7 @@ type ServicesTestSuite struct { ConfigS services.ClusterConfiguration EventsS services.Events ChangesC chan interface{} + Clock clockwork.FakeClock } func (s *ServicesTestSuite) collectChanges(c *check.C, expected int) []interface{} { @@ -184,6 +188,33 @@ func (s *ServicesTestSuite) UsersCRUD(c *check.C) { fixtures.ExpectBadParameter(c, err) } +func (s *ServicesTestSuite) UsersExpiry(c *check.C) { + expiresAt := s.Clock.Now().Add(1 * time.Minute) + + err := s.WebS.UpsertUser(&services.UserV2{ + Kind: services.KindUser, + Version: services.V2, + Metadata: services.Metadata{ + Name: "foo", + Namespace: defaults.Namespace, + Expires: &expiresAt, + }, + Spec: services.UserSpecV2{}, + }) + c.Assert(err, check.IsNil) + + // Make sure the user exists. + u, err := s.WebS.GetUser("foo") + c.Assert(err, check.IsNil) + c.Assert(u.GetName(), check.Equals, "foo") + + s.Clock.Advance(2 * time.Minute) + + // Make sure the user is now gone. + u, err = s.WebS.GetUser("foo") + c.Assert(err, check.NotNil) +} + func (s *ServicesTestSuite) LoginAttempts(c *check.C) { user := newUser("user1", []string{"admin", "user"}) c.Assert(s.WebS.UpsertUser(user), check.IsNil) @@ -560,8 +591,8 @@ func (s *ServicesTestSuite) SAMLCRUD(c *check.C) { Namespace: defaults.Namespace, }, Spec: services.SAMLConnectorSpecV2{ - Issuer: "http://example.com", - SSO: "https://example.com/saml/sso", + Issuer: "http://example.com", + SSO: "https://example.com/saml/sso", AssertionConsumerService: "https://localhost/acs", Audience: "https://localhost/aud", ServiceProviderIssuer: "https://localhost/iss", diff --git a/lib/services/user.go b/lib/services/user.go index d5e6204e5e079..9ac51bb9458d2 100644 --- a/lib/services/user.go +++ b/lib/services/user.go @@ -1,3 +1,19 @@ +/* +Copyright 2019 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + package services import ( @@ -402,12 +418,10 @@ func (u *UserV2) Equals(other User) bool { return true } -// Expiry returns expiry time for temporary users +// Expiry returns expiry time for temporary users. Prefer expires from +// metadata, if it does not exist, fall back to expires in spec. func (u *UserV2) Expiry() time.Time { - if u.Metadata.Expires == nil { - return time.Time{} - } - if !u.Metadata.Expires.IsZero() { + if u.Metadata.Expires != nil && !u.Metadata.Expires.IsZero() { return *u.Metadata.Expires } return u.Spec.Expires diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index cd79942812724..1dcfb3b32afd0 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -128,6 +128,7 @@ func (s *WebSuite) SetUpTest(c *C) { authServer, err := auth.NewTestAuthServer(auth.TestAuthServerConfig{ ClusterName: "localhost", Dir: c.MkDir(), + Clock: clockwork.NewFakeClockAt(time.Date(2017, 05, 10, 18, 53, 0, 0, time.UTC)), }) c.Assert(err, IsNil) s.server, err = authServer.NewTestTLSServer() From 9f8b1332c1ceadb518ff70d9f3ac9ad18c546a70 Mon Sep 17 00:00:00 2001 From: Russell Jones Date: Mon, 18 Feb 2019 13:18:54 -0800 Subject: [PATCH 2/3] Fixed TestStreamEvents. --- lib/auth/tls_test.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lib/auth/tls_test.go b/lib/auth/tls_test.go index 4f4072438aa67..f5aba776ab285 100644 --- a/lib/auth/tls_test.go +++ b/lib/auth/tls_test.go @@ -1917,6 +1917,17 @@ waitLoop: log.Debugf("Skipping stale event %v, latest object version is %v", event.Resource.GetResourceID(), ca.GetResourceID()) continue waitLoop } + + eca, ok := event.Resource.(*services.CertAuthorityV2) + if !ok { + log.Debugf("Wrong resource type, expected *services.CertAuthorityV2 got %T.", event.Resource) + continue waitLoop + } + if eca.GetRotation().State != services.RotationStateInProgress { + log.Debugf("Skipping CA, wrong rotation state: %v.", eca.GetRotation().State) + continue waitLoop + } + fixtures.DeepCompare(c, ca, event.Resource) break waitLoop } From 581356ae38008a97ea5e1161c73f52f9e8909fa4 Mon Sep 17 00:00:00 2001 From: Russell Jones Date: Mon, 18 Feb 2019 17:42:26 -0800 Subject: [PATCH 3/3] Consolidate external identity user creation code. --- lib/auth/github.go | 188 +++++++++++++++++++++++++++++----------- lib/auth/github_test.go | 28 ++---- lib/auth/oidc.go | 133 +++++++++++++++------------- lib/auth/oidc_test.go | 33 ++----- lib/auth/saml.go | 134 +++++++++++++++------------- lib/auth/saml_test.go | 35 ++------ 6 files changed, 303 insertions(+), 248 deletions(-) diff --git a/lib/auth/github.go b/lib/auth/github.go index 5cbbc1d7572c6..02080f5ded612 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -150,56 +150,48 @@ func (s *AuthServer) validateGithubAuthCallback(q url.Values) (*GithubAuthRespon if err != nil { return nil, trace.Wrap(err) } - expires := s.clock.Now().UTC().Add(defaults.OAuth2TTL) - err = s.createGithubUser(connector, *claims, expires) + + // Calculate (figure out name, roles, traits, session TTL) of user and + // create the user in the backend. + params, err := s.calculateGithubUser(connector, claims, req) if err != nil { return nil, trace.Wrap(err) } + user, err := s.createGithubUser(params) + if err != nil { + return nil, trace.Wrap(err) + } + + // Auth was successful, return session, certificate, etc. to caller. response := &GithubAuthResponse{ + Req: *req, Identity: services.ExternalIdentity{ - ConnectorID: connector.GetName(), - Username: claims.Username, + ConnectorID: params.connectorName, + Username: params.username, }, - Req: *req, - } - user, err := s.Identity.GetUserByGithubIdentity(response.Identity) - if err != nil { - return nil, trace.Wrap(err) + Username: user.GetName(), } response.Username = user.GetName() - roles, err := services.FetchRoles(user.GetRoles(), s.Access, user.GetTraits()) - if err != nil { - return nil, trace.Wrap(err) - } + + // If the request is coming from a browser, create a web session. if req.CreateWebSession { - session, err := s.NewWebSession(user.GetName()) - if err != nil { - return nil, trace.Wrap(err) - } - sessionTTL := roles.AdjustSessionTTL(defaults.OAuth2TTL) - bearerTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) - session.SetExpiryTime(s.clock.Now().UTC().Add(sessionTTL)) - session.SetBearerTokenExpiryTime(s.clock.Now().UTC().Add(bearerTTL)) - err = s.UpsertWebSession(user.GetName(), session) + session, err := s.createWebSession(user, params.sessionTTL) if err != nil { return nil, trace.Wrap(err) } + response.Session = session } + + // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - certTTL := utils.MinTTL(defaults.OAuth2TTL, req.CertTTL) - certs, err := s.generateUserCert(certRequest{ - user: user, - roles: roles, - ttl: certTTL, - publicKey: req.PublicKey, - compatibility: req.Compatibility, - }) + sshCert, tlsCert, err := s.createSessionCert(user, params.sessionTTL, req.PublicKey, req.Compatibility) if err != nil { return nil, trace.Wrap(err) } - response.Cert = certs.ssh - response.TLSCert = certs.tls + + response.Cert = sshCert + response.TLSCert = tlsCert // Return the host CA for this cluster only. authority, err := s.GetCertAuthority(services.CertAuthID{ @@ -211,61 +203,157 @@ func (s *AuthServer) validateGithubAuthCallback(q url.Values) (*GithubAuthRespon } response.HostSigners = append(response.HostSigners, authority) } + return response, nil } -func (s *AuthServer) createGithubUser(connector services.GithubConnector, claims services.GithubClaims, expires time.Time) error { - logins, kubeGroups := connector.MapClaims(claims) - if len(logins) == 0 { - return trace.BadParameter( +func (s *AuthServer) createWebSession(user services.User, sessionTTL time.Duration) (services.WebSession, error) { + session, err := s.NewWebSession(user.GetName()) + if err != nil { + return nil, trace.Wrap(err) + } + + // Session expiry time is the same as the user expiry time. + session.SetExpiryTime(s.clock.Now().UTC().Add(sessionTTL)) + + // Bearer tokens expire quicker than the overall session time and need to be refreshed. + bearerTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) + session.SetBearerTokenExpiryTime(s.clock.Now().UTC().Add(bearerTTL)) + + err = s.UpsertWebSession(user.GetName(), session) + if err != nil { + return nil, trace.Wrap(err) + } + + return session, nil +} + +func (s *AuthServer) createSessionCert(user services.User, sessionTTL time.Duration, publicKey []byte, compatibility string) ([]byte, []byte, error) { + roles, err := services.FetchRoles(user.GetRoles(), s.Access, user.GetTraits()) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + certs, err := s.generateUserCert(certRequest{ + user: user, + roles: roles, + ttl: sessionTTL, + publicKey: publicKey, + compatibility: compatibility, + }) + if err != nil { + return nil, nil, trace.Wrap(err) + } + + return certs.ssh, certs.tls, nil +} + +// createUserParams is a set of parameters used to create a user for an +// external identity provider. +type createUserParams struct { + // connectorName is the name of the connector for the identity provider. + connectorName string + + // username is the Teleport user name . + username string + + // logins is the list of *nix logins. + logins []string + + // kubeGroups is the list of Kubernetes this user belongs to. + kubeGroups []string + + // roles is the list of roles this user is assigned to. + roles []string + + // traits is the list of traits for this user. + traits map[string][]string + + // sessionTTL is how long this session will last. + sessionTTL time.Duration +} + +func (s *AuthServer) calculateGithubUser(connector services.GithubConnector, claims *services.GithubClaims, request *services.GithubAuthRequest) (*createUserParams, error) { + p := createUserParams{ + connectorName: connector.GetName(), + username: claims.Username, + } + + // Calculate logins, kubegroups, roles, and traits. + p.logins, p.kubeGroups = connector.MapClaims(*claims) + if len(p.logins) == 0 { + return nil, trace.BadParameter( "user %q does not belong to any teams configured in %q connector", claims.Username, connector.GetName()) } + p.roles = modules.GetModules().RolesFromLogins(p.logins) + p.traits = modules.GetModules().TraitsFromLogins(p.logins, p.kubeGroups) + + // Pick smaller for role: session TTL from role or requested TTL. + roles, err := services.FetchRoles(p.roles, s.Access, p.traits) + if err != nil { + return nil, trace.Wrap(err) + } + roleTTL := roles.AdjustSessionTTL(defaults.MaxCertDuration) + p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + + return &p, nil +} + +func (s *AuthServer) createGithubUser(p *createUserParams) (services.User, error) { + log.WithFields(logrus.Fields{trace.Component: "github"}).Debugf( "Generating dynamic identity %v/%v with logins: %v.", - connector.GetName(), claims.Username, logins) + p.connectorName, p.username, p.logins) + + expires := s.GetClock().Now().UTC().Add(p.sessionTTL) + user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{ Kind: services.KindUser, Version: services.V2, Metadata: services.Metadata{ - Name: claims.Username, + Name: p.username, Namespace: defaults.Namespace, Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: modules.GetModules().RolesFromLogins(logins), - Traits: modules.GetModules().TraitsFromLogins(logins, kubeGroups), + Roles: p.roles, + Traits: p.traits, GithubIdentities: []services.ExternalIdentity{{ - ConnectorID: connector.GetName(), - Username: claims.Username, + ConnectorID: p.connectorName, + Username: p.username, }}, CreatedBy: services.CreatedBy{ User: services.UserRef{Name: "system"}, - Time: time.Now().UTC(), + Time: s.GetClock().Now().UTC(), Connector: &services.ConnectorRef{ Type: teleport.ConnectorGithub, - ID: connector.GetName(), - Identity: claims.Username, + ID: p.connectorName, + Identity: p.username, }, }, }, }) - existingUser, err := s.GetUser(claims.Username) + if err != nil { + return nil, trace.Wrap(err) + } + + existingUser, err := s.GetUser(p.username) if err != nil && !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if existingUser != nil { ref := user.GetCreatedBy().Connector if !ref.IsSameProvider(existingUser.GetCreatedBy().Connector) { - return trace.AlreadyExists("user %q already exists and is not Github user", + return nil, trace.AlreadyExists("user %q already exists and is not Github user", existingUser.GetName()) } } err = s.UpsertUser(user) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + return user, nil } // populateGithubClaims retrieves information about user and its team diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index f71d01e0b0b31..fc20ad8e0fce3 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -82,28 +82,14 @@ func (s *GithubSuite) TestPopulateClaims(c *check.C) { } func (s *GithubSuite) TestCreateGithubUser(c *check.C) { - connector := services.NewGithubConnector("github", services.GithubConnectorSpecV3{ - ClientID: "fakeClientID", - ClientSecret: "fakeClientSecret", - RedirectURL: "https://www.example.com", - TeamsToLogins: []services.TeamMapping{ - services.TeamMapping{ - Organization: "fakeOrg", - Team: "fakeTeam", - Logins: []string{"foo"}, - }, - }, - }) - - claims := services.GithubClaims{ - Username: "foo", - OrganizationToTeams: map[string][]string{ - "fakeOrg": []string{"fakeTeam"}, - }, - } - // Create GitHub user with 1 minute expiry. - err := s.a.createGithubUser(connector, claims, s.c.Now().Add(1*time.Minute)) + _, err := s.a.createGithubUser(&createUserParams{ + connectorName: "github", + username: "foo", + logins: []string{"foo"}, + roles: []string{"admin"}, + sessionTTL: 1 * time.Minute, + }) c.Assert(err, check.IsNil) // Within that 1 minute period the user should still exist. diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index 3d6b9cc97a472..e8c2578a6f936 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -22,7 +22,6 @@ import ( "io/ioutil" "net/http" "net/url" - "time" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" @@ -202,66 +201,55 @@ func (a *AuthServer) validateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, } log.Debugf("OIDC user %q expires at: %v.", ident.Email, ident.ExpiresAt) - response := &OIDCAuthResponse{ - Identity: services.ExternalIdentity{ConnectorID: connector.GetName(), Username: ident.Email}, - Req: *req, + if len(connector.GetClaimsToRoles()) == 0 { + return nil, trace.BadParameter("no claims to roles mapping, check connector documentation") } - log.Debugf("Applying %v OIDC claims to roles mappings.", len(connector.GetClaimsToRoles())) - if len(connector.GetClaimsToRoles()) != 0 { - if err := a.createOIDCUser(connector, ident, claims); err != nil { - return nil, trace.Wrap(err) - } - } - if !req.CheckUser { - return response, nil - } - - user, err := a.Identity.GetUserByOIDCIdentity(services.ExternalIdentity{ - ConnectorID: req.ConnectorID, Username: ident.Email}) + // Calculate (figure out name, roles, traits, session TTL) of user and + // create the user in the backend. + params, err := a.calculateOIDCUser(connector, claims, ident, req) if err != nil { return nil, trace.Wrap(err) } - response.Username = user.GetName() - - var roles services.RoleSet - roles, err = services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits()) + user, err := a.createOIDCUser(params) if err != nil { return nil, trace.Wrap(err) } - sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, ident.ExpiresAt)) - bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) + // Auth was successful, return session, certificate, etc. to caller. + response := &OIDCAuthResponse{ + Req: *req, + Identity: services.ExternalIdentity{ + ConnectorID: params.connectorName, + Username: params.username, + }, + Username: user.GetName(), + } + + if !req.CheckUser { + return response, nil + } + + // If the request is coming from a browser, create a web session. if req.CreateWebSession { - sess, err := a.NewWebSession(user.GetName()) + session, err := a.createWebSession(user, params.sessionTTL) if err != nil { return nil, trace.Wrap(err) } - // session will expire based on identity TTL and allowed session TTL - sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL)) - // bearer token will expire based on the expected session renewal - sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL)) - if err := a.UpsertWebSession(user.GetName(), sess); err != nil { - return nil, trace.Wrap(err) - } - response.Session = sess + + response.Session = session } + // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - certTTL := utils.MinTTL(utils.ToTTL(a.clock, ident.ExpiresAt), req.CertTTL) - certs, err := a.generateUserCert(certRequest{ - user: user, - roles: roles, - ttl: certTTL, - publicKey: req.PublicKey, - compatibility: req.Compatibility, - }) + sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, req.PublicKey, req.Compatibility) if err != nil { return nil, trace.Wrap(err) } - response.Cert = certs.ssh - response.TLSCert = certs.tls + + response.Cert = sshCert + response.TLSCert = tlsCert // Return the host CA for this cluster only. authority, err := a.GetCertAuthority(services.CertAuthID{ @@ -273,6 +261,7 @@ func (a *AuthServer) validateOIDCAuthCallback(q url.Values) (*OIDCAuthResponse, } response.HostSigners = append(response.HostSigners, authority) } + return response, nil } @@ -325,52 +314,72 @@ func claimsToTraitMap(claims jose.Claims) map[string][]string { return traits } -func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oidc.Identity, claims jose.Claims) error { - roles, err := a.buildOIDCRoles(connector, claims) +func (a *AuthServer) calculateOIDCUser(connector services.OIDCConnector, claims jose.Claims, ident *oidc.Identity, request *services.OIDCAuthRequest) (*createUserParams, error) { + var err error + + p := createUserParams{ + connectorName: connector.GetName(), + username: ident.Email, + } + + p.roles, err = a.buildOIDCRoles(connector, claims) + if err != nil { + return nil, trace.Wrap(err) + } + p.traits = claimsToTraitMap(claims) + + // Pick smaller for role: session TTL from role or requested TTL. + roles, err := services.FetchRoles(p.roles, a.Access, p.traits) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + roleTTL := roles.AdjustSessionTTL(defaults.MaxCertDuration) + p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + + return &p, nil +} - traits := claimsToTraitMap(claims) +func (a *AuthServer) createOIDCUser(p *createUserParams) (services.User, error) { + expires := a.GetClock().Now().UTC().Add(p.sessionTTL) - log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v.", connector.GetName(), ident.Email, roles) + log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v.", p.connectorName, p.username, p.roles) user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{ Kind: services.KindUser, Version: services.V2, Metadata: services.Metadata{ - Name: ident.Email, + Name: p.username, Namespace: defaults.Namespace, - Expires: &ident.ExpiresAt, + Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: roles, - Traits: traits, + Roles: p.roles, + Traits: p.traits, OIDCIdentities: []services.ExternalIdentity{ - { - ConnectorID: connector.GetName(), - Username: ident.Email, + services.ExternalIdentity{ + ConnectorID: p.connectorName, + Username: p.username, }, }, CreatedBy: services.CreatedBy{ User: services.UserRef{Name: "system"}, - Time: time.Now().UTC(), + Time: a.clock.Now().UTC(), Connector: &services.ConnectorRef{ Type: teleport.ConnectorOIDC, - ID: connector.GetName(), - Identity: ident.Email, + ID: p.connectorName, + Identity: p.username, }, }, }, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } // Get the user to check if it already exists or not. - existingUser, err := a.GetUser(ident.Email) + existingUser, err := a.GetUser(p.username) if err != nil { if !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -380,7 +389,7 @@ func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oid // If the exisiting user is a local user, fail and advise how to fix the problem. if connectorRef == nil { - return trace.AlreadyExists("local user with name '%v' already exists. Either change "+ + return nil, trace.AlreadyExists("local user with name '%v' already exists. Either change "+ "email in OIDC identity or remove local user and try again.", existingUser.GetName()) } @@ -391,10 +400,10 @@ func (a *AuthServer) createOIDCUser(connector services.OIDCConnector, ident *oid // Upsert the new user creating or updating whatever is in the database. err = a.UpsertUser(user) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + return user, nil } // claimsFromIDToken extracts claims from the ID token. diff --git a/lib/auth/oidc_test.go b/lib/auth/oidc_test.go index 8c95fb8403af6..d853de3872275 100644 --- a/lib/auth/oidc_test.go +++ b/lib/auth/oidc_test.go @@ -27,7 +27,6 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" - "github.com/coreos/go-oidc/oidc" "github.com/jonboulle/clockwork" "gopkg.in/check.v1" ) @@ -71,32 +70,14 @@ func (s *OIDCSuite) SetUpSuite(c *check.C) { } func (s *OIDCSuite) TestCreateOIDCUser(c *check.C) { - connector := services.NewOIDCConnector("oidcService", services.OIDCConnectorSpecV2{ - IssuerURL: "https://www.example.com", - ClientID: "fakeClientID", - ClientSecret: "fakeClientSecret", - RedirectURL: "https://www.example.com/redirect", - Scope: []string{"profile", "email"}, - ClaimsToRoles: []services.ClaimMapping{ - services.ClaimMapping{ - Claim: "email", - Value: "foo@example.com", - Roles: []string{"admin"}, - }, - }, - }) - - ident := &oidc.Identity{ - Email: "foo@example.com", - ExpiresAt: s.c.Now().Add(1 * time.Minute), - } - - claims := map[string]interface{}{ - "email": "foo@example.com", - } - // Create OIDC user with 1 minute expiry. - err := s.a.createOIDCUser(connector, ident, claims) + _, err := s.a.createOIDCUser(&createUserParams{ + connectorName: "oidcService", + username: "foo@example.com", + logins: []string{"foo"}, + roles: []string{"admin"}, + sessionTTL: 1 * time.Minute, + }) c.Assert(err, check.IsNil) // Within that 1 minute period the user should still exist. diff --git a/lib/auth/saml.go b/lib/auth/saml.go index 88d676967e5fc..30f1175da35f1 100644 --- a/lib/auth/saml.go +++ b/lib/auth/saml.go @@ -21,7 +21,6 @@ import ( "compress/flate" "encoding/base64" "io/ioutil" - "time" "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" @@ -121,47 +120,75 @@ func assertionsToTraitMap(assertionInfo saml2.AssertionInfo) map[string][]string return traits } -func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, expiresAt time.Time) error { - roles, err := a.buildSAMLRoles(connector, assertionInfo) +func (a *AuthServer) calculateSAMLUser(connector services.SAMLConnector, assertionInfo saml2.AssertionInfo, request *services.SAMLAuthRequest) (*createUserParams, error) { + var err error + + p := createUserParams{ + connectorName: connector.GetName(), + username: assertionInfo.NameID, + } + + p.roles, err = a.buildSAMLRoles(connector, assertionInfo) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } + p.traits = assertionsToTraitMap(assertionInfo) + + // Pick smaller for role: session TTL from role or requested TTL. + roles, err := services.FetchRoles(p.roles, a.Access, p.traits) + if err != nil { + return nil, trace.Wrap(err) + } + roleTTL := roles.AdjustSessionTTL(defaults.MaxCertDuration) + p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + + return &p, nil +} + +func (a *AuthServer) createSAMLUser(p *createUserParams) (services.User, error) { + expires := a.GetClock().Now().UTC().Add(p.sessionTTL) - traits := assertionsToTraitMap(assertionInfo) + log.Debugf("Generating dynamic SAML identity %v/%v with roles: %v.", p.connectorName, p.username, p.roles) - log.Debugf("Generating dynamic SAML identity %v/%v with roles: %v.", connector.GetName(), assertionInfo.NameID, roles) user, err := services.GetUserMarshaler().GenerateUser(&services.UserV2{ Kind: services.KindUser, Version: services.V2, Metadata: services.Metadata{ - Name: assertionInfo.NameID, + Name: p.username, Namespace: defaults.Namespace, - Expires: &expiresAt, + Expires: &expires, }, Spec: services.UserSpecV2{ - Roles: roles, - Traits: traits, - SAMLIdentities: []services.ExternalIdentity{{ConnectorID: connector.GetName(), Username: assertionInfo.NameID}}, + Roles: p.roles, + Traits: p.traits, + SAMLIdentities: []services.ExternalIdentity{ + services.ExternalIdentity{ + ConnectorID: p.connectorName, + Username: p.username, + }, + }, CreatedBy: services.CreatedBy{ - User: services.UserRef{Name: "system"}, - Time: time.Now().UTC(), + User: services.UserRef{ + Name: "system", + }, + Time: a.clock.Now().UTC(), Connector: &services.ConnectorRef{ Type: teleport.ConnectorSAML, - ID: connector.GetName(), - Identity: assertionInfo.NameID, + ID: p.connectorName, + Identity: p.username, }, }, }, }) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } // Get the user to check if it already exists or not. - existingUser, err := a.GetUser(assertionInfo.NameID) + existingUser, err := a.GetUser(p.username) if err != nil { if !trace.IsNotFound(err) { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } @@ -171,7 +198,7 @@ func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionI // If the exisiting user is a local user, fail and advise how to fix the problem. if connectorRef == nil { - return trace.AlreadyExists("local user with name '%v' already exists. Either change "+ + return nil, trace.AlreadyExists("local user with name '%v' already exists. Either change "+ "NameID in assertion or remove local user and try again.", existingUser.GetName()) } @@ -182,10 +209,10 @@ func (a *AuthServer) createSAMLUser(connector services.SAMLConnector, assertionI // Upsert the new user creating or updating whatever is in the database. err = a.UpsertUser(user) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + return user, nil } func parseSAMLInResponseTo(response string) (string, error) { @@ -309,67 +336,51 @@ func (a *AuthServer) validateSAMLResponse(samlResponse string) (*SAMLAuthRespons } log.Debugf("SAML assertion warnings: %+v.", assertionInfo.WarningInfo) - log.Debugf("Applying %v SAML attribute to roles mappings.", len(connector.GetAttributesToRoles())) if len(connector.GetAttributesToRoles()) == 0 { - return nil, trace.BadParameter("SAML does not support binding to local users") - } - // TODO(klizhentas) use SessionNotOnOrAfter to calculate expiration time - expiresAt := a.clock.Now().Add(defaults.CertDuration) - if err := a.createSAMLUser(connector, *assertionInfo, expiresAt); err != nil { - return nil, trace.Wrap(err) + return nil, trace.BadParameter("no attributes to roles mapping, check connector documentation") } + log.Debugf("Applying %v SAML attribute to roles mappings.", len(connector.GetAttributesToRoles())) - identity := services.ExternalIdentity{ - ConnectorID: request.ConnectorID, - Username: assertionInfo.NameID, + // Calculate (figure out name, roles, traits, session TTL) of user and + // create the user in the backend. + params, err := a.calculateSAMLUser(connector, *assertionInfo, request) + if err != nil { + return nil, trace.Wrap(err) } - user, err := a.Identity.GetUserBySAMLIdentity(identity) + user, err := a.createSAMLUser(params) if err != nil { return nil, trace.Wrap(err) } + + // Auth was successful, return session, certificate, etc. to caller. response := &SAMLAuthResponse{ - Req: *request, - Identity: identity, + Req: *request, + Identity: services.ExternalIdentity{ + ConnectorID: params.connectorName, + Username: params.username, + }, Username: user.GetName(), } - var roles services.RoleSet - roles, err = services.FetchRoles(user.GetRoles(), a.Access, user.GetTraits()) - if err != nil { - return nil, trace.Wrap(err) - } - sessionTTL := roles.AdjustSessionTTL(utils.ToTTL(a.clock, expiresAt)) - bearerTokenTTL := utils.MinTTL(BearerTokenTTL, sessionTTL) - + // If the request is coming from a browser, create a web session. if request.CreateWebSession { - sess, err := a.NewWebSession(user.GetName()) + session, err := a.createWebSession(user, params.sessionTTL) if err != nil { return nil, trace.Wrap(err) } - // session will expire based on identity TTL and allowed session TTL - sess.SetExpiryTime(a.clock.Now().UTC().Add(sessionTTL)) - // bearer token will expire based on the expected session renewal - sess.SetBearerTokenExpiryTime(a.clock.Now().UTC().Add(bearerTokenTTL)) - if err := a.UpsertWebSession(user.GetName(), sess); err != nil { - return nil, trace.Wrap(err) - } - response.Session = sess + + response.Session = session } + // If a public key was provided, sign it and return a certificate. if len(request.PublicKey) != 0 { - certTTL := utils.MinTTL(sessionTTL, request.CertTTL) - certs, err := a.generateUserCert(certRequest{ - user: user, - roles: roles, - ttl: certTTL, - publicKey: request.PublicKey, - compatibility: request.Compatibility, - }) + sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, request.PublicKey, request.Compatibility) if err != nil { return nil, trace.Wrap(err) } - response.Cert = certs.ssh - response.TLSCert = certs.tls + + response.Cert = sshCert + response.TLSCert = tlsCert // Return the host CA for this cluster only. authority, err := a.GetCertAuthority(services.CertAuthID{ @@ -381,5 +392,6 @@ func (a *AuthServer) validateSAMLResponse(samlResponse string) (*SAMLAuthRespons } response.HostSigners = append(response.HostSigners, authority) } + return response, nil } diff --git a/lib/auth/saml_test.go b/lib/auth/saml_test.go index 6134b92799ac2..34192950254ab 100644 --- a/lib/auth/saml_test.go +++ b/lib/auth/saml_test.go @@ -28,8 +28,6 @@ import ( "github.com/gravitational/teleport/lib/utils" "github.com/jonboulle/clockwork" - saml2 "github.com/russellhaering/gosaml2" - "github.com/russellhaering/gosaml2/types" "gopkg.in/check.v1" ) @@ -72,33 +70,14 @@ func (s *SAMLSuite) SetUpSuite(c *check.C) { } func (s *SAMLSuite) TestCreateSAMLUser(c *check.C) { - connector := services.NewSAMLConnector("samlService", services.SAMLConnectorSpecV2{ - AssertionConsumerService: "https://www.example.com", - AttributesToRoles: []services.AttributeMapping{ - services.AttributeMapping{ - Name: "groups", - Value: "everyone", - Roles: []string{"admin"}, - }, - }, - }) - - assertionInfo := saml2.AssertionInfo{ - NameID: "foo@example.com", - Values: map[string]types.Attribute{ - "groups": types.Attribute{ - Name: "groups", - Values: []types.AttributeValue{ - types.AttributeValue{ - Value: "everyone", - }, - }, - }, - }, - } - // Create SAML user with 1 minute expiry. - err := s.a.createSAMLUser(connector, assertionInfo, s.c.Now().Add(1*time.Minute)) + _, err := s.a.createSAMLUser(&createUserParams{ + connectorName: "samlService", + username: "foo@example.com", + logins: []string{"foo"}, + roles: []string{"admin"}, + sessionTTL: 1 * time.Minute, + }) c.Assert(err, check.IsNil) // Within that 1 minute period the user should still exist.