From c9340e7960235e00af7931273dbc965f12ee2809 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Tue, 25 Feb 2025 12:36:10 +0100 Subject: [PATCH] fix: ensure that auto_link_credentials markers are being properly overwritten --- identity/handler_test.go | 11 +-- identity/identity.go | 15 ++- internal/client-go/go.sum | 1 + selfservice/strategy/oidc/strategy.go | 3 +- selfservice/strategy/oidc/strategy_login.go | 31 +++++- selfservice/strategy/oidc/strategy_test.go | 103 ++++++++++++++------ 6 files changed, 118 insertions(+), 46 deletions(-) diff --git a/identity/handler_test.go b/identity/handler_test.go index bb0a651d5c8f..3c0a12c55d11 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -820,17 +820,12 @@ func TestHandler(t *testing.T) { }` res := send(t, ts, "POST", "/identities", http.StatusCreated, json.RawMessage(payload)) - stateChangedAt := sqlxx.NullTime(res.Get("state_changed_at").Time()) - - i.Traits = []byte(res.Get("traits").Raw) i.ID = x.ParseUUID(res.Get("id").String()) - i.StateChangedAt = &stateChangedAt - assert.NotEmpty(t, res.Get("id").String()) - i, err := reg.Persister().GetIdentityConfidential(context.Background(), i.ID) - require.NoError(t, err) + identRes := send(t, adminTS, "GET", fmt.Sprintf("/identities/%s?include_credential=oidc", i.ID), http.StatusOK, nil) - require.True(t, gjson.GetBytes(i.Credentials[identity.CredentialsTypeOIDC].Config, "providers.0.use_auto_link").Bool()) + assert.True(t, identRes.Get("credentials.oidc.config.providers.0.use_auto_link").Bool()) + assert.False(t, identRes.Get("credentials.oidc.config.providers.0.organization").Exists()) }) } }) diff --git a/identity/identity.go b/identity/identity.go index d21cadb36ab3..433a7993d7e9 100644 --- a/identity/identity.go +++ b/identity/identity.go @@ -489,9 +489,18 @@ func (i *Identity) WithDeclassifiedCredentials(ctx context.Context, c cipher.Pro return false } - toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), v.Get("organization").String()) - if err != nil { - return false + if org := v.Get("organization").String(); org != "" { + toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), org) + if err != nil { + return false + } + } + + if useAutoLink := v.Get("use_auto_link").Bool(); useAutoLink { + toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.use_auto_link", i), useAutoLink) + if err != nil { + return false + } } i++ diff --git a/internal/client-go/go.sum b/internal/client-go/go.sum index c966c8ddfd0d..6cc3f5911d11 100644 --- a/internal/client-go/go.sum +++ b/internal/client-go/go.sum @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg= golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw= golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index f1285f800f68..a635f7046667 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -844,7 +844,8 @@ func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, to } else { creds.Identifiers = append(creds.Identifiers, identity.OIDCUniqueID(provider, subject)) conf.Providers = append(conf.Providers, identity.CredentialsOIDCProvider{ - Subject: subject, Provider: provider, + Subject: subject, + Provider: provider, InitialAccessToken: tokens.GetAccessToken(), InitialRefreshToken: tokens.GetRefreshToken(), InitialIDToken: tokens.GetIDToken(), diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 9dd7f8d4655c..1d97d1197e64 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -140,8 +140,35 @@ func (s *Strategy) handleConflictingIdentity(ctx context.Context, w http.Respons verdict = s.conflictingIdentityPolicy(ctx, existingIdentity, newIdentity, provider, claims) if verdict == ConflictingIdentityVerdictMerge { - existingIdentity.SetCredentials(s.ID(), *creds) - if err := s.d.PrivilegedIdentityPool().UpdateIdentity(ctx, existingIdentity); err != nil { + var conf identity.CredentialsOIDC + if err = json.Unmarshal(existingIdentity.Credentials[s.ID()].Config, &conf); err != nil { + return ConflictingIdentityVerdictUnknown, nil, nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, newIdentity.Traits, err) + } + // If there exists a provider in the existing identity for the same provider, we + // need to merge the providers, otherwise we just add the new provider. + var providerWasUpdated bool + newProvider := identity.CredentialsOIDCProvider{ + Subject: claims.Subject, + Provider: provider.Config().ID, + InitialIDToken: token.GetIDToken(), + InitialAccessToken: token.GetAccessToken(), + InitialRefreshToken: token.GetRefreshToken(), + Organization: provider.Config().OrganizationID, + } + for i, p := range conf.Providers { + if p.Provider == newProvider.Provider { + conf.Providers[i] = newProvider + providerWasUpdated = true + break + } + } + if !providerWasUpdated { + conf.Providers = append(conf.Providers, newProvider) + } + if err = existingIdentity.SetCredentialsWithConfig(s.ID(), existingIdentity.Credentials[s.ID()], conf); err != nil { + return ConflictingIdentityVerdictUnknown, nil, nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, newIdentity.Traits, err) + } + if err = s.d.PrivilegedIdentityPool().UpdateIdentity(ctx, existingIdentity); err != nil { return ConflictingIdentityVerdictUnknown, nil, nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, newIdentity.Traits, err) } } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index d8b194f98480..317425c16581 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -83,24 +83,6 @@ func TestStrategy(t *testing.T) { ts, _ := testhelpers.NewKratosServerWithRouters(t, reg, routerP, routerA) invalid := newOIDCProvider(t, ts, remotePublic, remoteAdmin, "invalid-issuer") - //onConflictingIdentityPolicy := func(existingIdentity, newIdentity *identity.Identity) oidc.ConflictingIdentityVerdict { - // return oidc.ConflictingIdentityVerdictReject - //} - //oidcStrategy := oidc.NewStrategy(reg, oidc.WithOnConflictingIdentity(onConflictingIdentityPolicy)) - // - //reg = reg.WithSelfserviceStrategies(t, []any{ - // password.NewStrategy(reg), - // oidcStrategy, - // profile.NewStrategy(reg), - // code.NewStrategy(reg), - // link.NewStrategy(reg), - // totp.NewStrategy(reg), - // passkey.NewStrategy(reg), - // webauthn.NewStrategy(reg), - // lookup.NewStrategy(reg), - // idfirst.NewStrategy(reg), - //}).(*driver.RegistryDefault) - orgID := uuidx.NewV4() viperSetProviderConfig( t, @@ -1707,36 +1689,93 @@ func TestStrategy(t *testing.T) { }) }) - t.Run("case=should automatically link credential if policy says so", func(t *testing.T) { - subject = "user-in-org@ory.sh" - scope = []string{"openid"} + t.Run("suite=auto link policy", func(t *testing.T) { + + t.Run("case=should automatically link credential if policy says so", func(t *testing.T) { + subject = "user-in-org@ory.sh" + scope = []string{"openid"} - reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t, - func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { - return oidc.ConflictingIdentityVerdictMerge + reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t, + func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { + return oidc.ConflictingIdentityVerdictMerge + }) + + var i *identity.Identity + t.Run("step=create identity in org without credentials", func(t *testing.T) { + i = identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.Traits = identity.Traits(`{"subject":"` + subject + `"}`) + i.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{ + Type: identity.CredentialsTypePassword, + Identifiers: []string{subject}, + Config: sqlxx.JSONRawMessage(`{}`), + }) + i.OrganizationID = uuid.NullUUID{orgID, true} + i.VerifiableAddresses = []identity.VerifiableAddress{{Value: subject, Via: "email", Verified: true}} + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, i)) }) - var i *identity.Identity - t.Run("step=create identity in org without credentials", func(t *testing.T) { + t.Run("step=log in with OIDC", func(t *testing.T) { + loginFlow := newLoginFlow(t, returnTS.URL, time.Minute, flow.TypeBrowser) + loginFlow.OrganizationID = i.OrganizationID + require.NoError(t, reg.LoginFlowPersister().UpdateLoginFlow(ctx, loginFlow)) + client := testhelpers.NewClientWithCookieJar(t, nil, nil) + + res, body := loginWithOIDC(t, client, loginFlow.ID, "valid") + checkCredentialsLinked(res, body, i.ID, "valid") + }) + }) + + t.Run("case=should remove use_auto_link credential if policy says so", func(t *testing.T) { + subject = "user-with-use-auto-link@ory.sh" + scope = []string{"openid"} + + reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t, + func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { + return oidc.ConflictingIdentityVerdictMerge + }) + + var i *identity.Identity + //t.Run("step=create identity with use_auto_link", func(t *testing.T) { i = identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) i.Traits = identity.Traits(`{"subject":"` + subject + `"}`) i.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{ - Type: identity.CredentialsTypePassword, Identifiers: []string{subject}, - Config: sqlxx.JSONRawMessage(`{}`), + Type: identity.CredentialsTypePassword, + Identifiers: []string{subject}, + Config: sqlxx.JSONRawMessage(`{}`), + }) + i.SetCredentials(identity.CredentialsTypeOIDC, identity.Credentials{ + Type: identity.CredentialsTypeOIDC, + Identifiers: []string{subject}, + Config: sqlxx.JSONRawMessage(`{"providers": [{ + "subject": "", + "provider": "valid", + "use_auto_link": true +},{ + "subject": "", + "provider": "other", + "use_auto_link": true +}]}`), }) - i.OrganizationID = uuid.NullUUID{orgID, true} i.VerifiableAddresses = []identity.VerifiableAddress{{Value: subject, Via: "email", Verified: true}} require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, i)) - }) + //}) - t.Run("step=log in with OIDC", func(t *testing.T) { + //t.Run("step=log in with OIDC", func(t *testing.T) { loginFlow := newLoginFlow(t, returnTS.URL, time.Minute, flow.TypeBrowser) - loginFlow.OrganizationID = i.OrganizationID require.NoError(t, reg.LoginFlowPersister().UpdateLoginFlow(ctx, loginFlow)) client := testhelpers.NewClientWithCookieJar(t, nil, nil) res, body := loginWithOIDC(t, client, loginFlow.ID, "valid") checkCredentialsLinked(res, body, i.ID, "valid") + //}) + + //t.Run("step=should remove use_auto_link", func(t *testing.T) { + var err error + i, err = reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, i.ID) + require.NoError(t, err) + assert.False(t, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.0.use_auto_link").Bool()) + assert.True(t, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.1.use_auto_link").Bool()) + //}) }) })