Skip to content

Commit

Permalink
fix: return return_to code if already authenticated
Browse files Browse the repository at this point in the history
  • Loading branch information
hperl committed Feb 2, 2025
1 parent 11705a5 commit aa4fce4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 21 deletions.
19 changes: 11 additions & 8 deletions selfservice/strategy/oidc/strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -365,21 +365,24 @@ func (s *Strategy) alreadyAuthenticated(ctx context.Context, w http.ResponseWrit
if _, ok := f.(*settings.Flow); ok {
// ignore this if it's a settings flow
} else if !isForced(f) {
if flowID, ok := registrationOrLoginFlowID(f); ok {
if _, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flowID); hasCode {
err := s.d.SessionTokenExchangePersister().UpdateSessionOnExchanger(ctx, flowID, sess.ID)
if err != nil {
return false, err
}
}
}
returnTo := s.d.Config().SelfServiceBrowserDefaultReturnTo(ctx)
if redirecter, ok := f.(flow.FlowWithRedirect); ok {
r, err := x.SecureRedirectTo(r, returnTo, redirecter.SecureRedirectToOpts(ctx, s.d)...)
if err == nil {
returnTo = r
}
}
if flowID, ok := registrationOrLoginFlowID(f); ok {
if codes, hasCode, _ := s.d.SessionTokenExchangePersister().CodeForFlow(ctx, flowID); hasCode {
err := s.d.SessionTokenExchangePersister().UpdateSessionOnExchanger(ctx, flowID, sess.ID)
if err != nil {
return false, err

Check warning on line 379 in selfservice/strategy/oidc/strategy.go

View check run for this annotation

Codecov / codecov/patch

selfservice/strategy/oidc/strategy.go#L379

Added line #L379 was not covered by tests
}
q := returnTo.Query()
q.Set("code", codes.ReturnToCode)
returnTo.RawQuery = q.Encode()
}
}
http.Redirect(w, r, returnTo.String(), http.StatusSeeOther)
return true, nil
}
Expand Down
46 changes: 33 additions & 13 deletions selfservice/strategy/oidc/strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,8 @@ func TestStrategy(t *testing.T) {
return res, body
}

makeAPICodeFlowRequest := func(t *testing.T, provider, action string) (returnToURL *url.URL) {
res, err := testhelpers.NewDebugClient(t).Post(action, "application/json", strings.NewReader(fmt.Sprintf(`{
makeAPICodeFlowRequest := func(t *testing.T, provider, action string, cookieJar *cookiejar.Jar) (returnToURL *url.URL) {
res, err := http.Post(action, "application/json", strings.NewReader(fmt.Sprintf(`{
"method": "oidc",
"provider": %q
}`, provider)))
Expand All @@ -212,7 +212,7 @@ func TestStrategy(t *testing.T) {
var changeLocation flow.BrowserLocationChangeRequiredError
require.NoError(t, json.NewDecoder(res.Body).Decode(&changeLocation))

res, err = testhelpers.NewClientWithCookieJar(t, nil, nil).Get(changeLocation.RedirectBrowserTo)
res, err = testhelpers.NewClientWithCookieJar(t, cookieJar, nil).Get(changeLocation.RedirectBrowserTo)
require.NoError(t, err)

returnToURL = res.Request.URL
Expand Down Expand Up @@ -839,12 +839,12 @@ func TestStrategy(t *testing.T) {
t.Run("suite=API with session token exchange code", func(t *testing.T) {
scope = []string{"openid"}

loginOrRegister := func(t *testing.T, flowID uuid.UUID, code string) {
loginOrRegister := func(t *testing.T, flowID uuid.UUID, code string, cookieJar *cookiejar.Jar) {
_, err := exchangeCodeForToken(t, sessiontokenexchange.Codes{InitCode: code})
require.Error(t, err)

action := assertFormValues(t, flowID, "valid")
returnToURL := makeAPICodeFlowRequest(t, "valid", action)
returnToURL := makeAPICodeFlowRequest(t, "valid", action, cookieJar)
returnToCode := returnToURL.Query().Get("code")
assert.NotEmpty(t, code, "code query param was empty in the return_to URL")

Expand All @@ -857,18 +857,18 @@ func TestStrategy(t *testing.T) {
assert.NotEmpty(t, codeResponse.Token)
assert.Equal(t, subject, gjson.GetBytes(codeResponse.Session.Identity.Traits, "subject").String())
}
performRegistration := func(t *testing.T) {
performRegistration := func(t *testing.T, cookieJar *cookiejar.Jar) {
f := newAPIRegistrationFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute)
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode)
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode, cookieJar)
}
performLogin := func(t *testing.T) {
performLogin := func(t *testing.T, cookieJar *cookiejar.Jar) {
f := newAPILoginFlow(t, returnTS.URL+"?return_session_token_exchange_code=true&return_to=/app_code", 1*time.Minute)
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode)
loginOrRegister(t, f.ID, f.SessionTokenExchangeCode, cookieJar)
}

for _, tc := range []struct {
name string
first, then func(*testing.T)
first, then func(*testing.T, *cookiejar.Jar)
}{{
name: "login-twice",
first: performLogin, then: performLogin,
Expand All @@ -884,10 +884,30 @@ func TestStrategy(t *testing.T) {
}} {
t.Run("case="+tc.name, func(t *testing.T) {
subject = tc.name + "[email protected]"
tc.first(t)
tc.then(t)
tc.first(t, nil)
tc.then(t, nil)
})
}

t.Run("case=should return exchange code even if already authenticated", func(t *testing.T) {
subject = "[email protected]"
jar := x.Must(cookiejar.New(nil))

t.Run("step=register and create a session", func(t *testing.T) {
returnTo := "/foo"
r := newBrowserLoginFlow(t, fmt.Sprintf("%s?return_to=%s", returnTS.URL, returnTo), time.Minute)
action := assertFormValues(t, r.ID, "valid")

res, body := makeRequestWithCookieJar(t, "valid", action, url.Values{}, jar, nil)
assert.True(t, strings.HasSuffix(res.Request.URL.String(), returnTo))
assertIdentity(t, res, body)
})

t.Run("step=perform login and get exchange code", func(t *testing.T) {
performLogin(t, jar)
})
})

t.Run("case=should use redirect_to URL on failure", func(t *testing.T) {
ctx := context.Background()
subject = "[email protected]"
Expand All @@ -905,7 +925,7 @@ func TestStrategy(t *testing.T) {
require.Error(t, err)

action := assertFormValues(t, f.ID, "valid")
returnToURL := makeAPICodeFlowRequest(t, "valid", action)
returnToURL := makeAPICodeFlowRequest(t, "valid", action, nil)
returnedFlow := returnToURL.Query().Get("flow")

require.NotEmpty(t, returnedFlow, "flow query param was empty in the return_to URL")
Expand Down

0 comments on commit aa4fce4

Please sign in to comment.