Skip to content

Commit

Permalink
fix: ensure that auto_link_credentials markers are being properly ove…
Browse files Browse the repository at this point in the history
…rwritten
  • Loading branch information
hperl committed Feb 25, 2025
1 parent 6eeeaa8 commit c9340e7
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 46 deletions.
11 changes: 3 additions & 8 deletions identity/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -820,17 +820,12 @@ func TestHandler(t *testing.T) {
}`

res := send(t, ts, "POST", "/identities", http.StatusCreated, json.RawMessage(payload))
stateChangedAt := sqlxx.NullTime(res.Get("state_changed_at").Time())

i.Traits = []byte(res.Get("traits").Raw)
i.ID = x.ParseUUID(res.Get("id").String())
i.StateChangedAt = &stateChangedAt
assert.NotEmpty(t, res.Get("id").String())

i, err := reg.Persister().GetIdentityConfidential(context.Background(), i.ID)
require.NoError(t, err)
identRes := send(t, adminTS, "GET", fmt.Sprintf("/identities/%s?include_credential=oidc", i.ID), http.StatusOK, nil)

require.True(t, gjson.GetBytes(i.Credentials[identity.CredentialsTypeOIDC].Config, "providers.0.use_auto_link").Bool())
assert.True(t, identRes.Get("credentials.oidc.config.providers.0.use_auto_link").Bool())
assert.False(t, identRes.Get("credentials.oidc.config.providers.0.organization").Exists())
})
}
})
Expand Down
15 changes: 12 additions & 3 deletions identity/identity.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,9 +489,18 @@ func (i *Identity) WithDeclassifiedCredentials(ctx context.Context, c cipher.Pro
return false
}

toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), v.Get("organization").String())
if err != nil {
return false
if org := v.Get("organization").String(); org != "" {
toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.organization", i), org)
if err != nil {
return false
}
}

if useAutoLink := v.Get("use_auto_link").Bool(); useAutoLink {
toPublish.Config, err = sjson.SetBytes(toPublish.Config, fmt.Sprintf("providers.%d.use_auto_link", i), useAutoLink)
if err != nil {
return false
}
}

i++
Expand Down
1 change: 1 addition & 0 deletions internal/client-go/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e h1:bRhVy7zSSasaqNksaRZiA5EEI+Ei4I1nO5Jh72wfHlg=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4 h1:YUO/7uOKsKeq9UokNS62b8FYywz3ker1l1vDZRCRefw=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
Expand Down
3 changes: 2 additions & 1 deletion selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ func (s *Strategy) linkCredentials(ctx context.Context, i *identity.Identity, to
} else {
creds.Identifiers = append(creds.Identifiers, identity.OIDCUniqueID(provider, subject))
conf.Providers = append(conf.Providers, identity.CredentialsOIDCProvider{
Subject: subject, Provider: provider,
Subject: subject,
Provider: provider,
InitialAccessToken: tokens.GetAccessToken(),
InitialRefreshToken: tokens.GetRefreshToken(),
InitialIDToken: tokens.GetIDToken(),
Expand Down
31 changes: 29 additions & 2 deletions selfservice/strategy/oidc/strategy_login.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,35 @@ func (s *Strategy) handleConflictingIdentity(ctx context.Context, w http.Respons

verdict = s.conflictingIdentityPolicy(ctx, existingIdentity, newIdentity, provider, claims)
if verdict == ConflictingIdentityVerdictMerge {
existingIdentity.SetCredentials(s.ID(), *creds)
if err := s.d.PrivilegedIdentityPool().UpdateIdentity(ctx, existingIdentity); err != nil {
var conf identity.CredentialsOIDC
if err = json.Unmarshal(existingIdentity.Credentials[s.ID()].Config, &conf); err != nil {
return ConflictingIdentityVerdictUnknown, nil, nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, newIdentity.Traits, err)
}
// If there exists a provider in the existing identity for the same provider, we
// need to merge the providers, otherwise we just add the new provider.
var providerWasUpdated bool
newProvider := identity.CredentialsOIDCProvider{
Subject: claims.Subject,
Provider: provider.Config().ID,
InitialIDToken: token.GetIDToken(),
InitialAccessToken: token.GetAccessToken(),
InitialRefreshToken: token.GetRefreshToken(),
Organization: provider.Config().OrganizationID,
}
for i, p := range conf.Providers {
if p.Provider == newProvider.Provider {
conf.Providers[i] = newProvider
providerWasUpdated = true
break
}
}
if !providerWasUpdated {
conf.Providers = append(conf.Providers, newProvider)
}
if err = existingIdentity.SetCredentialsWithConfig(s.ID(), existingIdentity.Credentials[s.ID()], conf); err != nil {
return ConflictingIdentityVerdictUnknown, nil, nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, newIdentity.Traits, err)
}
if err = s.d.PrivilegedIdentityPool().UpdateIdentity(ctx, existingIdentity); err != nil {
return ConflictingIdentityVerdictUnknown, nil, nil, s.HandleError(ctx, w, r, loginFlow, provider.Config().ID, newIdentity.Traits, err)
}
}
Expand Down
103 changes: 71 additions & 32 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,24 +83,6 @@ func TestStrategy(t *testing.T) {
ts, _ := testhelpers.NewKratosServerWithRouters(t, reg, routerP, routerA)
invalid := newOIDCProvider(t, ts, remotePublic, remoteAdmin, "invalid-issuer")

//onConflictingIdentityPolicy := func(existingIdentity, newIdentity *identity.Identity) oidc.ConflictingIdentityVerdict {
// return oidc.ConflictingIdentityVerdictReject
//}
//oidcStrategy := oidc.NewStrategy(reg, oidc.WithOnConflictingIdentity(onConflictingIdentityPolicy))
//
//reg = reg.WithSelfserviceStrategies(t, []any{
// password.NewStrategy(reg),
// oidcStrategy,
// profile.NewStrategy(reg),
// code.NewStrategy(reg),
// link.NewStrategy(reg),
// totp.NewStrategy(reg),
// passkey.NewStrategy(reg),
// webauthn.NewStrategy(reg),
// lookup.NewStrategy(reg),
// idfirst.NewStrategy(reg),
//}).(*driver.RegistryDefault)

orgID := uuidx.NewV4()
viperSetProviderConfig(
t,
Expand Down Expand Up @@ -1707,36 +1689,93 @@ func TestStrategy(t *testing.T) {
})
})

t.Run("case=should automatically link credential if policy says so", func(t *testing.T) {
subject = "[email protected]"
scope = []string{"openid"}
t.Run("suite=auto link policy", func(t *testing.T) {

t.Run("case=should automatically link credential if policy says so", func(t *testing.T) {
subject = "[email protected]"
scope = []string{"openid"}

reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t,
func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict {
return oidc.ConflictingIdentityVerdictMerge
reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t,
func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict {
return oidc.ConflictingIdentityVerdictMerge
})

var i *identity.Identity
t.Run("step=create identity in org without credentials", func(t *testing.T) {
i = identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
i.Traits = identity.Traits(`{"subject":"` + subject + `"}`)
i.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{
Type: identity.CredentialsTypePassword,
Identifiers: []string{subject},
Config: sqlxx.JSONRawMessage(`{}`),
})
i.OrganizationID = uuid.NullUUID{orgID, true}
i.VerifiableAddresses = []identity.VerifiableAddress{{Value: subject, Via: "email", Verified: true}}
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, i))
})

var i *identity.Identity
t.Run("step=create identity in org without credentials", func(t *testing.T) {
t.Run("step=log in with OIDC", func(t *testing.T) {
loginFlow := newLoginFlow(t, returnTS.URL, time.Minute, flow.TypeBrowser)
loginFlow.OrganizationID = i.OrganizationID
require.NoError(t, reg.LoginFlowPersister().UpdateLoginFlow(ctx, loginFlow))
client := testhelpers.NewClientWithCookieJar(t, nil, nil)

res, body := loginWithOIDC(t, client, loginFlow.ID, "valid")
checkCredentialsLinked(res, body, i.ID, "valid")
})
})

t.Run("case=should remove use_auto_link credential if policy says so", func(t *testing.T) {
subject = "[email protected]"
scope = []string{"openid"}

reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t,
func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict {
return oidc.ConflictingIdentityVerdictMerge
})

var i *identity.Identity
//t.Run("step=create identity with use_auto_link", func(t *testing.T) {
i = identity.NewIdentity(config.DefaultIdentityTraitsSchemaID)
i.Traits = identity.Traits(`{"subject":"` + subject + `"}`)
i.SetCredentials(identity.CredentialsTypePassword, identity.Credentials{
Type: identity.CredentialsTypePassword, Identifiers: []string{subject},
Config: sqlxx.JSONRawMessage(`{}`),
Type: identity.CredentialsTypePassword,
Identifiers: []string{subject},
Config: sqlxx.JSONRawMessage(`{}`),
})
i.SetCredentials(identity.CredentialsTypeOIDC, identity.Credentials{
Type: identity.CredentialsTypeOIDC,
Identifiers: []string{subject},
Config: sqlxx.JSONRawMessage(`{"providers": [{
"subject": "",
"provider": "valid",
"use_auto_link": true
},{
"subject": "",
"provider": "other",
"use_auto_link": true
}]}`),
})
i.OrganizationID = uuid.NullUUID{orgID, true}
i.VerifiableAddresses = []identity.VerifiableAddress{{Value: subject, Via: "email", Verified: true}}
require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentity(ctx, i))
})
//})

t.Run("step=log in with OIDC", func(t *testing.T) {
//t.Run("step=log in with OIDC", func(t *testing.T) {
loginFlow := newLoginFlow(t, returnTS.URL, time.Minute, flow.TypeBrowser)
loginFlow.OrganizationID = i.OrganizationID
require.NoError(t, reg.LoginFlowPersister().UpdateLoginFlow(ctx, loginFlow))
client := testhelpers.NewClientWithCookieJar(t, nil, nil)

res, body := loginWithOIDC(t, client, loginFlow.ID, "valid")
checkCredentialsLinked(res, body, i.ID, "valid")
//})

//t.Run("step=should remove use_auto_link", func(t *testing.T) {
var err error
i, err = reg.PrivilegedIdentityPool().GetIdentityConfidential(ctx, i.ID)
require.NoError(t, err)
assert.False(t, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.0.use_auto_link").Bool())
assert.True(t, gjson.GetBytes(i.Credentials["oidc"].Config, "providers.1.use_auto_link").Bool())
//})
})
})

Expand Down

0 comments on commit c9340e7

Please sign in to comment.