Skip to content

Commit 5fd7354

Browse files
committed
feat: revoke consent by session id. trigger back channel logout.
1 parent 924be24 commit 5fd7354

File tree

8 files changed

+178
-32
lines changed

8 files changed

+178
-32
lines changed

client/client.go

+9
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,15 @@ type Client struct {
209209
RegistrationClientURI string `json:"registration_client_uri,omitempty" db:"-"`
210210
}
211211

212+
type AuthenticatedClient struct {
213+
ClientID string `json:"client_id" db:"id"`
214+
FrontChannelLogoutURI string `json:"frontchannel_logout_uri,omitempty" db:"frontchannel_logout_uri"`
215+
FrontChannelLogoutSessionRequired bool `json:"frontchannel_logout_session_required,omitempty" db:"frontchannel_logout_session_required"`
216+
BackChannelLogoutURI string `json:"backchannel_logout_uri,omitempty" db:"backchannel_logout_uri"`
217+
BackChannelLogoutSessionRequired bool `json:"backchannel_logout_session_required,omitempty" db:"backchannel_logout_session_required"`
218+
LoginSessionID string `json:"login_session_id,omitempty" db:"login_session_id"`
219+
}
220+
212221
func (Client) TableName() string {
213222
return "hydra_client"
214223
}

consent/handler.go

+43-6
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ func (h *Handler) SetRoutes(admin *x.RouterAdmin) {
102102
func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
103103
subject := r.URL.Query().Get("subject")
104104
client := r.URL.Query().Get("client")
105+
loginSessionId := r.URL.Query().Get("login_session_id")
106+
triggerBackChannelLogout := r.URL.Query().Get("trigger_backchannel_logout")
107+
105108
allClients := r.URL.Query().Get("all") == "true"
106109
if subject == "" {
107110
h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter 'subject' is not defined but should have been.`)))
@@ -110,14 +113,48 @@ func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, p
110113

111114
switch {
112115
case len(client) > 0:
113-
if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) {
114-
h.r.Writer().WriteError(w, r, err)
115-
return
116+
if len(loginSessionId) > 0 {
117+
if triggerBackChannelLogout == "true" {
118+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClientSession(r.Context(), r, subject, client, loginSessionId); err != nil {
119+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
120+
}
121+
}
122+
if err := h.r.ConsentManager().RevokeSubjectClientLoginSessionConsentSession(r.Context(), subject, client, loginSessionId); err != nil && !errors.Is(err, x.ErrNotFound) {
123+
h.r.Writer().WriteError(w, r, err)
124+
return
125+
}
126+
} else {
127+
if triggerBackChannelLogout == "true" {
128+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClient(r.Context(), r, subject, client); err != nil {
129+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
130+
}
131+
}
132+
if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) {
133+
h.r.Writer().WriteError(w, r, err)
134+
return
135+
}
116136
}
117137
case allClients:
118-
if err := h.r.ConsentManager().RevokeSubjectConsentSession(r.Context(), subject); err != nil && !errors.Is(err, x.ErrNotFound) {
119-
h.r.Writer().WriteError(w, r, err)
120-
return
138+
if len(loginSessionId) > 0 {
139+
if triggerBackChannelLogout == "true" {
140+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutBySession(r.Context(), r, subject, loginSessionId); err != nil {
141+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
142+
}
143+
}
144+
if err := h.r.ConsentManager().RevokeLoginSessionConsentSession(r.Context(), loginSessionId); err != nil && !errors.Is(err, x.ErrNotFound) {
145+
h.r.Writer().WriteError(w, r, err)
146+
return
147+
}
148+
} else {
149+
if triggerBackChannelLogout == "true" {
150+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutBySubject(r.Context(), r, subject); err != nil {
151+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
152+
}
153+
}
154+
if err := h.r.ConsentManager().RevokeSubjectConsentSession(r.Context(), subject); err != nil && !errors.Is(err, x.ErrNotFound) {
155+
h.r.Writer().WriteError(w, r, err)
156+
return
157+
}
121158
}
122159
default:
123160
h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter both 'client' and 'all' is not defined but one of them should have been.`)))

consent/manager.go

+5-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ type Manager interface {
4242
GetConsentRequest(ctx context.Context, challenge string) (*ConsentRequest, error)
4343
HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error)
4444
RevokeSubjectConsentSession(ctx context.Context, user string) error
45+
RevokeLoginSessionConsentSession(ctx context.Context, loginSessionId string) error
4546
RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error
47+
RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error
4648

4749
VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error)
4850
FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]HandledConsentRequest, error)
@@ -64,8 +66,9 @@ type Manager interface {
6466
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
6567
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
6668

67-
ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
68-
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
69+
ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error)
70+
ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error)
71+
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error)
6972

7073
CreateLogoutRequest(ctx context.Context, request *LogoutRequest) error
7174
GetLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error)

consent/manager_test_helpers.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
736736
}
737737

738738
for _, ls := range sessions {
739-
check := func(t *testing.T, expected map[string][]client.Client, actual []client.Client) {
739+
check := func(t *testing.T, expected map[string][]client.Client, actual []client.AuthenticatedClient) {
740740
es, ok := expected[ls.ID]
741741
if !ok {
742742
require.Len(t, actual, 0)
@@ -747,25 +747,25 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
747747
for _, e := range es {
748748
var found bool
749749
for _, a := range actual {
750-
if e.OutfacingID == a.OutfacingID {
750+
if e.OutfacingID == a.ClientID {
751751
found = true
752752
}
753-
assert.Equal(t, e.OutfacingID, a.OutfacingID)
753+
assert.Equal(t, e.OutfacingID, a.ClientID)
754754
assert.Equal(t, e.FrontChannelLogoutURI, a.FrontChannelLogoutURI)
755755
assert.Equal(t, e.BackChannelLogoutURI, a.BackChannelLogoutURI)
756756
}
757757
require.True(t, found)
758758
}
759759
}
760760

761-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
762-
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
761+
t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
762+
actual, err := m.ListUserSessionAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
763763
require.NoError(t, err)
764764
check(t, frontChannels, actual)
765765
})
766766

767-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
768-
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
767+
t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
768+
actual, err := m.ListUserSessionAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
769769
require.NoError(t, err)
770770
check(t, backChannels, actual)
771771
})

consent/strategy.go

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package consent
2222

2323
import (
24+
"context"
2425
"net/http"
2526

2627
"github.com/ory/fosite"
@@ -31,4 +32,8 @@ var _ Strategy = new(DefaultStrategy)
3132
type Strategy interface {
3233
HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*HandledConsentRequest, error)
3334
HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*LogoutResult, error)
35+
ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error
36+
ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error
37+
ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error
38+
ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error
3439
}

consent/strategy_default.go

+44-8
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request,
632632
}
633633

634634
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
635-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
635+
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
636636
if err != nil {
637637
return nil, err
638638
}
@@ -653,12 +653,49 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
653653
return urls, nil
654654
}
655655

656-
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error {
657-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
656+
func (s *DefaultStrategy) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error {
657+
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
658658
if err != nil {
659659
return err
660660
}
661+
return s.executeBackChannelLogout(ctx, r, clients)
662+
}
661663

664+
func (s *DefaultStrategy) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error {
665+
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
666+
if err != nil {
667+
return err
668+
}
669+
for i := len(clients) - 1; i >= 0; i-- {
670+
if clients[i].ClientID != client {
671+
clients = append(clients[:i], clients[i+1:]...)
672+
}
673+
}
674+
return s.executeBackChannelLogout(ctx, r, clients)
675+
}
676+
677+
func (s *DefaultStrategy) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error {
678+
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject)
679+
if err != nil {
680+
return err
681+
}
682+
for i := len(clients) - 1; i >= 0; i-- {
683+
if clients[i].ClientID != client {
684+
clients = append(clients[:i], clients[i+1:]...)
685+
}
686+
}
687+
return s.executeBackChannelLogout(ctx, r, clients)
688+
}
689+
690+
func (s *DefaultStrategy) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error {
691+
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject)
692+
if err != nil {
693+
return err
694+
}
695+
return s.executeBackChannelLogout(ctx, r, clients)
696+
}
697+
698+
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, clients []client.AuthenticatedClient) error {
662699
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
663700
if err != nil {
664701
return err
@@ -678,22 +715,21 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.
678715
//
679716
// s.r.ConsentManager().GetForcedObfuscatedLoginSession(context.Background(), subject, <missing>)
680717
// sub := s.obfuscateSubjectIdentifier(c, subject, )
681-
682718
t, _, err := s.r.OpenIDJWTStrategy().Generate(ctx, jwtgo.MapClaims{
683719
"iss": s.c.IssuerURL().String(),
684-
"aud": []string{c.OutfacingID},
720+
"aud": []string{c.ClientID},
685721
"iat": time.Now().UTC().Unix(),
686722
"jti": uuid.New(),
687723
"events": map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}},
688-
"sid": sid,
724+
"sid": c.LoginSessionID,
689725
}, &jwt.Headers{
690726
Extra: map[string]interface{}{"kid": openIDKeyID},
691727
})
692728
if err != nil {
693729
return err
694730
}
695731

696-
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.OutfacingID, token: t})
732+
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.ClientID, token: t})
697733
}
698734

699735
var wg sync.WaitGroup
@@ -964,7 +1000,7 @@ func (s *DefaultStrategy) completeLogout(w http.ResponseWriter, r *http.Request)
9641000
return nil, err
9651001
}
9661002

967-
if err := s.executeBackChannelLogout(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
1003+
if err := s.ExecuteBackChannelLogoutBySession(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
9681004
return nil, err
9691005
}
9701006

oauth2/oauth2_helper_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package oauth2_test
2222

2323
import (
24+
"context"
2425
"net/http"
2526
"time"
2627

@@ -62,3 +63,19 @@ func (c *consentMock) HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r
6263
func (c *consentMock) HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*consent.LogoutResult, error) {
6364
panic("not implemented")
6465
}
66+
67+
func (c *consentMock) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error {
68+
panic("not implemented")
69+
}
70+
71+
func (c *consentMock) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error {
72+
panic("not implemented")
73+
}
74+
75+
func (c *consentMock) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error {
76+
panic("not implemented")
77+
}
78+
79+
func (c *consentMock) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error {
80+
panic("not implemented")
81+
}

persistence/sql/persister_consent.go

+48-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"fmt"
77
"time"
88

9+
"github.com/ory/hydra/client"
10+
911
"github.com/gobuffalo/pop/v6"
1012

1113
"github.com/ory/x/sqlxx"
@@ -15,7 +17,6 @@ import (
1517
"github.com/pkg/errors"
1618

1719
"github.com/ory/fosite"
18-
"github.com/ory/hydra/client"
1920
"github.com/ory/hydra/consent"
2021
"github.com/ory/hydra/x"
2122
"github.com/ory/x/sqlcon"
@@ -27,10 +28,18 @@ func (p *Persister) RevokeSubjectConsentSession(ctx context.Context, user string
2728
return p.transaction(ctx, p.revokeConsentSession("r.subject = ?", user))
2829
}
2930

31+
func (p *Persister) RevokeLoginSessionConsentSession(ctx context.Context, loginSessionId string) error {
32+
return p.transaction(ctx, p.revokeConsentSession("r.login_session_id = ?", loginSessionId))
33+
}
34+
3035
func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error {
3136
return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ?", user, client))
3237
}
3338

39+
func (p *Persister) RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error {
40+
return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ? AND r.login_session_id = ?", user, client, loginSessionId))
41+
}
42+
3443
func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
3544
return func(ctx context.Context, c *pop.Connection) error {
3645
hrs := make([]*consent.HandledConsentRequest, 0)
@@ -364,20 +373,24 @@ func (p *Persister) resolveHandledConsentRequests(ctx context.Context, requests
364373
return result, nil
365374
}
366375

367-
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
368-
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
376+
func (p *Persister) ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) {
377+
return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "front")
369378
}
370379

371-
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
372-
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
380+
func (p *Persister) ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) {
381+
return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "back")
373382
}
374383

375-
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.Client, error) {
376-
var cs []client.Client
377-
return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
384+
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error) {
385+
return p.listUserAuthenticatedClients(ctx, subject, "back")
386+
}
387+
388+
func (p *Persister) listUserSessionAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.AuthenticatedClient, error) {
389+
var cs []client.AuthenticatedClient
390+
err := p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
378391
if err := c.RawQuery(
379392
/* #nosec G201 - channel can either be "front" or "back" */
380-
fmt.Sprintf(`SELECT DISTINCT c.* FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`,
393+
fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`,
381394
channel,
382395
channel,
383396
),
@@ -387,6 +400,32 @@ func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, s
387400
return sqlcon.HandleError(err)
388401
}
389402

403+
return nil
404+
})
405+
if err != nil {
406+
return nil, err
407+
}
408+
409+
for i := range cs {
410+
cs[i].LoginSessionID = sid
411+
}
412+
return cs, err
413+
}
414+
415+
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, channel string) ([]client.AuthenticatedClient, error) {
416+
var cs []client.AuthenticatedClient
417+
return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
418+
if err := c.RawQuery(
419+
/* #nosec G201 - channel can either be "front" or "back" */
420+
fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required, r.login_session_id FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL`,
421+
channel,
422+
channel,
423+
),
424+
subject,
425+
).All(&cs); err != nil {
426+
return sqlcon.HandleError(err)
427+
}
428+
390429
return nil
391430
})
392431
}

0 commit comments

Comments
 (0)