From 9a9beaf4718b59ad9701fc8faf369256dca6c2c0 Mon Sep 17 00:00:00 2001 From: Henning Perl Date: Thu, 20 Feb 2025 12:21:27 +0100 Subject: [PATCH] feat: add context param to policy --- selfservice/strategy/oidc/strategy.go | 9 +++++---- selfservice/strategy/oidc/strategy_login.go | 6 +++--- selfservice/strategy/oidc/strategy_test.go | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/selfservice/strategy/oidc/strategy.go b/selfservice/strategy/oidc/strategy.go index a0bb22eb196f..732b3264cad0 100644 --- a/selfservice/strategy/oidc/strategy.go +++ b/selfservice/strategy/oidc/strategy.go @@ -143,8 +143,9 @@ type Strategy struct { handleUnknownProviderError func(err error) error handleMethodNotAllowedError func(err error) error - conflictingIdentityPolicy func(existingIdentity, newIdentity *identity.Identity, provider Provider, claims *Claims) ConflictingIdentityVerdict + conflictingIdentityPolicy ConflictingIdentityPolicy } +type ConflictingIdentityPolicy func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, provider Provider, claims *Claims) ConflictingIdentityVerdict type AuthCodeContainer struct { FlowID string `json:"flow_id"` @@ -246,14 +247,14 @@ func WithHandleMethodNotAllowedError(handler func(error) error) NewStrategyOpt { // WithOnConflictingIdentity sets a policy handler for deciding what to do when a // new identity conflicts with an existing one during login. -func WithOnConflictingIdentity(handler func(existingIdentity, newIdentity *identity.Identity, provider Provider, claims *Claims) ConflictingIdentityVerdict) NewStrategyOpt { +func WithOnConflictingIdentity(handler ConflictingIdentityPolicy) NewStrategyOpt { return func(s *Strategy) { s.conflictingIdentityPolicy = handler } } // SetOnConflictingIdentity sets a policy handler for deciding what to do when a // new identity conflicts with an existing one during login. This should only be // called in tests. -func (s *Strategy) SetOnConflictingIdentity(t testing.TB, handler func(existingIdentity, newIdentity *identity.Identity, provider Provider, claims *Claims) ConflictingIdentityVerdict) { +func (s *Strategy) SetOnConflictingIdentity(t testing.TB, handler ConflictingIdentityPolicy) { if t == nil { panic("this should only be called in tests") } @@ -774,7 +775,7 @@ func (s *Strategy) NodeGroup() node.UiNodeGroup { return node.OpenIDConnectGroup } -func (s *Strategy) CompletedAuthenticationMethod(ctx context.Context) session.AuthenticationMethod { +func (s *Strategy) CompletedAuthenticationMethod(context.Context) session.AuthenticationMethod { return session.AuthenticationMethod{ Method: s.ID(), AAL: identity.AuthenticatorAssuranceLevel1, diff --git a/selfservice/strategy/oidc/strategy_login.go b/selfservice/strategy/oidc/strategy_login.go index 5ff14cfda524..9dd7f8d4655c 100644 --- a/selfservice/strategy/oidc/strategy_login.go +++ b/selfservice/strategy/oidc/strategy_login.go @@ -138,7 +138,7 @@ func (s *Strategy) handleConflictingIdentity(ctx context.Context, w http.Respons return ConflictingIdentityVerdictReject, nil, nil, nil } - verdict = s.conflictingIdentityPolicy(existingIdentity, newIdentity, provider, claims) + verdict = s.conflictingIdentityPolicy(ctx, existingIdentity, newIdentity, provider, claims) if verdict == ConflictingIdentityVerdictMerge { existingIdentity.SetCredentials(s.ID(), *creds) if err := s.d.PrivilegedIdentityPool().UpdateIdentity(ctx, existingIdentity); err != nil { @@ -392,11 +392,11 @@ func (s *Strategy) PopulateLoginMethodFirstFactor(r *http.Request, f *login.Flow return s.populateMethod(r, f, text.NewInfoLoginWith) } -func (s *Strategy) PopulateLoginMethodSecondFactor(r *http.Request, sr *login.Flow) error { +func (s *Strategy) PopulateLoginMethodSecondFactor(*http.Request, *login.Flow) error { return nil } -func (s *Strategy) PopulateLoginMethodSecondFactorRefresh(r *http.Request, sr *login.Flow) error { +func (s *Strategy) PopulateLoginMethodSecondFactorRefresh(*http.Request, *login.Flow) error { return nil } diff --git a/selfservice/strategy/oidc/strategy_test.go b/selfservice/strategy/oidc/strategy_test.go index 032c5c34a491..4abf3c47fd94 100644 --- a/selfservice/strategy/oidc/strategy_test.go +++ b/selfservice/strategy/oidc/strategy_test.go @@ -1712,7 +1712,7 @@ func TestStrategy(t *testing.T) { scope = []string{"openid"} reg.AllLoginStrategies().MustStrategy("oidc").(*oidc.Strategy).SetOnConflictingIdentity(t, - func(existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { + func(ctx context.Context, existingIdentity, newIdentity *identity.Identity, _ oidc.Provider, _ *oidc.Claims) oidc.ConflictingIdentityVerdict { return oidc.ConflictingIdentityVerdictMerge })