Skip to content

Commit db0de9f

Browse files
committed
feat: revoke consent by session id. trigger back channel logout.
1 parent d687366 commit db0de9f

File tree

8 files changed

+158
-29
lines changed

8 files changed

+158
-29
lines changed

client/client.go

+9
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,15 @@ type Client struct {
200200
Metadata sqlxx.JSONRawMessage `json:"metadata,omitempty" db:"metadata"`
201201
}
202202

203+
type AuthenticatedClient struct {
204+
ClientID string `json:"client_id" db:"id"`
205+
FrontChannelLogoutURI string `json:"frontchannel_logout_uri,omitempty" db:"frontchannel_logout_uri"`
206+
FrontChannelLogoutSessionRequired bool `json:"frontchannel_logout_session_required,omitempty" db:"frontchannel_logout_session_required"`
207+
BackChannelLogoutURI string `json:"backchannel_logout_uri,omitempty" db:"backchannel_logout_uri"`
208+
BackChannelLogoutSessionRequired bool `json:"backchannel_logout_session_required,omitempty" db:"backchannel_logout_session_required"`
209+
LoginSessionID string `json:"login_session_id,omitempty" db:"login_session_id"`
210+
}
211+
203212
func (Client) TableName() string {
204213
return "hydra_client"
205214
}

consent/handler.go

+28-3
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,9 @@ func (h *Handler) SetRoutes(admin *x.RouterAdmin) {
102102
func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
103103
subject := r.URL.Query().Get("subject")
104104
client := r.URL.Query().Get("client")
105+
loginSessionId := r.URL.Query().Get("login_session_id")
106+
triggerBackChannelLogout := r.URL.Query().Get("trigger_backchannel_logout")
107+
105108
allClients := r.URL.Query().Get("all") == "true"
106109
if subject == "" {
107110
h.r.Writer().WriteError(w, r, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(`Query parameter 'subject' is not defined but should have been.`)))
@@ -110,11 +113,33 @@ func (h *Handler) DeleteConsentSession(w http.ResponseWriter, r *http.Request, p
110113

111114
switch {
112115
case len(client) > 0:
113-
if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) {
114-
h.r.Writer().WriteError(w, r, err)
115-
return
116+
if len(loginSessionId) > 0 {
117+
if triggerBackChannelLogout == "true" {
118+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClientSession(r.Context(), r, subject, client, loginSessionId); err != nil {
119+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
120+
}
121+
}
122+
if err := h.r.ConsentManager().RevokeSubjectClientLoginSessionConsentSession(r.Context(), subject, client, loginSessionId); err != nil && !errors.Is(err, x.ErrNotFound) {
123+
h.r.Writer().WriteError(w, r, err)
124+
return
125+
}
126+
} else {
127+
if triggerBackChannelLogout == "true" {
128+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutByClient(r.Context(), r, subject, client); err != nil {
129+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
130+
}
131+
}
132+
if err := h.r.ConsentManager().RevokeSubjectClientConsentSession(r.Context(), subject, client); err != nil && !errors.Is(err, x.ErrNotFound) {
133+
h.r.Writer().WriteError(w, r, err)
134+
return
135+
}
116136
}
117137
case allClients:
138+
if triggerBackChannelLogout == "true" {
139+
if err := h.r.ConsentStrategy().ExecuteBackChannelLogoutBySubject(r.Context(), r, subject); err != nil {
140+
h.r.Logger().WithError(err).Warn("Unable to execute back channel logout")
141+
}
142+
}
118143
if err := h.r.ConsentManager().RevokeSubjectConsentSession(r.Context(), subject); err != nil && !errors.Is(err, x.ErrNotFound) {
119144
h.r.Writer().WriteError(w, r, err)
120145
return

consent/manager.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ type Manager interface {
4343
HandleConsentRequest(ctx context.Context, challenge string, r *HandledConsentRequest) (*ConsentRequest, error)
4444
RevokeSubjectConsentSession(ctx context.Context, user string) error
4545
RevokeSubjectClientConsentSession(ctx context.Context, user, client string) error
46+
RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error
4647

4748
VerifyAndInvalidateConsentRequest(ctx context.Context, verifier string) (*HandledConsentRequest, error)
4849
FindGrantedAndRememberedConsentRequests(ctx context.Context, client, user string) ([]HandledConsentRequest, error)
@@ -64,8 +65,9 @@ type Manager interface {
6465
CreateForcedObfuscatedLoginSession(ctx context.Context, session *ForcedObfuscatedLoginSession) error
6566
GetForcedObfuscatedLoginSession(ctx context.Context, client, obfuscated string) (*ForcedObfuscatedLoginSession, error)
6667

67-
ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
68-
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error)
68+
ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error)
69+
ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error)
70+
ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error)
6971

7072
CreateLogoutRequest(ctx context.Context, request *LogoutRequest) error
7173
GetLogoutRequest(ctx context.Context, challenge string) (*LogoutRequest, error)

consent/manager_test_helpers.go

+7-7
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
736736
}
737737

738738
for _, ls := range sessions {
739-
check := func(t *testing.T, expected map[string][]client.Client, actual []client.Client) {
739+
check := func(t *testing.T, expected map[string][]client.Client, actual []client.AuthenticatedClient) {
740740
es, ok := expected[ls.ID]
741741
if !ok {
742742
require.Len(t, actual, 0)
@@ -747,25 +747,25 @@ func ManagerTests(m Manager, clientManager client.Manager, fositeManager x.Fosit
747747
for _, e := range es {
748748
var found bool
749749
for _, a := range actual {
750-
if e.OutfacingID == a.OutfacingID {
750+
if e.OutfacingID == a.ClientID {
751751
found = true
752752
}
753-
assert.Equal(t, e.OutfacingID, a.OutfacingID)
753+
assert.Equal(t, e.OutfacingID, a.ClientID)
754754
assert.Equal(t, e.FrontChannelLogoutURI, a.FrontChannelLogoutURI)
755755
assert.Equal(t, e.BackChannelLogoutURI, a.BackChannelLogoutURI)
756756
}
757757
require.True(t, found)
758758
}
759759
}
760760

761-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
762-
actual, err := m.ListUserAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
761+
t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithFrontChannelLogout/session=%s/subject=%s", ls.ID, ls.Subject), func(t *testing.T) {
762+
actual, err := m.ListUserSessionAuthenticatedClientsWithFrontChannelLogout(context.Background(), ls.Subject, ls.ID)
763763
require.NoError(t, err)
764764
check(t, frontChannels, actual)
765765
})
766766

767-
t.Run(fmt.Sprintf("method=ListUserAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
768-
actual, err := m.ListUserAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
767+
t.Run(fmt.Sprintf("method=ListUserSessionAuthenticatedClientsWithBackChannelLogout/session=%s", ls.ID), func(t *testing.T) {
768+
actual, err := m.ListUserSessionAuthenticatedClientsWithBackChannelLogout(context.Background(), ls.Subject, ls.ID)
769769
require.NoError(t, err)
770770
check(t, backChannels, actual)
771771
})

consent/strategy.go

+5
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package consent
2222

2323
import (
24+
"context"
2425
"net/http"
2526

2627
"github.com/ory/fosite"
@@ -31,4 +32,8 @@ var _ Strategy = new(DefaultStrategy)
3132
type Strategy interface {
3233
HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r *http.Request, req fosite.AuthorizeRequester) (*HandledConsentRequest, error)
3334
HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*LogoutResult, error)
35+
ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error
36+
ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error
37+
ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error
38+
ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error
3439
}

consent/strategy_default.go

+44-8
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ func (s *DefaultStrategy) verifyConsent(w http.ResponseWriter, r *http.Request,
632632
}
633633

634634
func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, subject, sid string) ([]string, error) {
635-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
635+
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx, subject, sid)
636636
if err != nil {
637637
return nil, err
638638
}
@@ -653,12 +653,49 @@ func (s *DefaultStrategy) generateFrontChannelLogoutURLs(ctx context.Context, su
653653
return urls, nil
654654
}
655655

656-
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, subject, sid string) error {
657-
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
656+
func (s *DefaultStrategy) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error {
657+
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
658658
if err != nil {
659659
return err
660660
}
661+
return s.executeBackChannelLogout(ctx, r, clients)
662+
}
661663

664+
func (s *DefaultStrategy) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error {
665+
clients, err := s.r.ConsentManager().ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx, subject, sid)
666+
if err != nil {
667+
return err
668+
}
669+
for i := len(clients) - 1; i >= 0; i-- {
670+
if clients[i].ClientID != client {
671+
clients = append(clients[:i], clients[i+1:]...)
672+
}
673+
}
674+
return s.executeBackChannelLogout(ctx, r, clients)
675+
}
676+
677+
func (s *DefaultStrategy) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error {
678+
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject)
679+
if err != nil {
680+
return err
681+
}
682+
for i := len(clients) - 1; i >= 0; i-- {
683+
if clients[i].ClientID != client {
684+
clients = append(clients[:i], clients[i+1:]...)
685+
}
686+
}
687+
return s.executeBackChannelLogout(ctx, r, clients)
688+
}
689+
690+
func (s *DefaultStrategy) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error {
691+
clients, err := s.r.ConsentManager().ListUserAuthenticatedClientsWithBackChannelLogout(ctx, subject)
692+
if err != nil {
693+
return err
694+
}
695+
return s.executeBackChannelLogout(ctx, r, clients)
696+
}
697+
698+
func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.Request, clients []client.AuthenticatedClient) error {
662699
openIDKeyID, err := s.r.OpenIDJWTStrategy().GetPublicKeyID(ctx)
663700
if err != nil {
664701
return err
@@ -678,22 +715,21 @@ func (s *DefaultStrategy) executeBackChannelLogout(ctx context.Context, r *http.
678715
//
679716
// s.r.ConsentManager().GetForcedObfuscatedLoginSession(context.Background(), subject, <missing>)
680717
// sub := s.obfuscateSubjectIdentifier(c, subject, )
681-
682718
t, _, err := s.r.OpenIDJWTStrategy().Generate(ctx, jwtgo.MapClaims{
683719
"iss": s.c.IssuerURL().String(),
684-
"aud": []string{c.OutfacingID},
720+
"aud": []string{c.ClientID},
685721
"iat": time.Now().UTC().Unix(),
686722
"jti": uuid.New(),
687723
"events": map[string]struct{}{"http://schemas.openid.net/event/backchannel-logout": {}},
688-
"sid": sid,
724+
"sid": c.LoginSessionID,
689725
}, &jwt.Headers{
690726
Extra: map[string]interface{}{"kid": openIDKeyID},
691727
})
692728
if err != nil {
693729
return err
694730
}
695731

696-
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.OutfacingID, token: t})
732+
tasks = append(tasks, task{url: c.BackChannelLogoutURI, clientID: c.ClientID, token: t})
697733
}
698734

699735
var wg sync.WaitGroup
@@ -964,7 +1000,7 @@ func (s *DefaultStrategy) completeLogout(w http.ResponseWriter, r *http.Request)
9641000
return nil, err
9651001
}
9661002

967-
if err := s.executeBackChannelLogout(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
1003+
if err := s.ExecuteBackChannelLogoutBySession(r.Context(), r, lr.Subject, lr.SessionID); err != nil {
9681004
return nil, err
9691005
}
9701006

oauth2/oauth2_helper_test.go

+17
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
package oauth2_test
2222

2323
import (
24+
"context"
2425
"net/http"
2526
"time"
2627

@@ -62,3 +63,19 @@ func (c *consentMock) HandleOAuth2AuthorizationRequest(w http.ResponseWriter, r
6263
func (c *consentMock) HandleOpenIDConnectLogout(w http.ResponseWriter, r *http.Request) (*consent.LogoutResult, error) {
6364
panic("not implemented")
6465
}
66+
67+
func (c *consentMock) ExecuteBackChannelLogoutBySession(ctx context.Context, r *http.Request, subject, sid string) error {
68+
panic("not implemented")
69+
}
70+
71+
func (c *consentMock) ExecuteBackChannelLogoutByClientSession(ctx context.Context, r *http.Request, subject, client, sid string) error {
72+
panic("not implemented")
73+
}
74+
75+
func (c *consentMock) ExecuteBackChannelLogoutByClient(ctx context.Context, r *http.Request, subject, client string) error {
76+
panic("not implemented")
77+
}
78+
79+
func (c *consentMock) ExecuteBackChannelLogoutBySubject(ctx context.Context, r *http.Request, subject string) error {
80+
panic("not implemented")
81+
}

persistence/sql/persister_consent.go

+44-9
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"fmt"
77
"time"
88

9+
"github.com/ory/hydra/client"
10+
911
"github.com/ory/x/sqlxx"
1012

1113
"github.com/ory/x/errorsx"
@@ -14,7 +16,6 @@ import (
1416
"github.com/pkg/errors"
1517

1618
"github.com/ory/fosite"
17-
"github.com/ory/hydra/client"
1819
"github.com/ory/hydra/consent"
1920
"github.com/ory/hydra/x"
2021
"github.com/ory/x/sqlcon"
@@ -30,6 +31,10 @@ func (p *Persister) RevokeSubjectClientConsentSession(ctx context.Context, user,
3031
return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ?", user, client))
3132
}
3233

34+
func (p *Persister) RevokeSubjectClientLoginSessionConsentSession(ctx context.Context, user, client, loginSessionId string) error {
35+
return p.transaction(ctx, p.revokeConsentSession("r.subject = ? AND r.client_id = ? AND r.login_session_id = ?", user, client, loginSessionId))
36+
}
37+
3338
func (p *Persister) revokeConsentSession(whereStmt string, whereArgs ...interface{}) func(context.Context, *pop.Connection) error {
3439
return func(ctx context.Context, c *pop.Connection) error {
3540
hrs := make([]*consent.HandledConsentRequest, 0)
@@ -363,20 +368,24 @@ func (p *Persister) resolveHandledConsentRequests(ctx context.Context, requests
363368
return result, nil
364369
}
365370

366-
func (p *Persister) ListUserAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
367-
return p.listUserAuthenticatedClients(ctx, subject, sid, "front")
371+
func (p *Persister) ListUserSessionAuthenticatedClientsWithFrontChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) {
372+
return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "front")
368373
}
369374

370-
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.Client, error) {
371-
return p.listUserAuthenticatedClients(ctx, subject, sid, "back")
375+
func (p *Persister) ListUserSessionAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject, sid string) ([]client.AuthenticatedClient, error) {
376+
return p.listUserSessionAuthenticatedClients(ctx, subject, sid, "back")
372377
}
373378

374-
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.Client, error) {
375-
var cs []client.Client
376-
return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
379+
func (p *Persister) ListUserAuthenticatedClientsWithBackChannelLogout(ctx context.Context, subject string) ([]client.AuthenticatedClient, error) {
380+
return p.listUserAuthenticatedClients(ctx, subject, "back")
381+
}
382+
383+
func (p *Persister) listUserSessionAuthenticatedClients(ctx context.Context, subject, sid, channel string) ([]client.AuthenticatedClient, error) {
384+
var cs []client.AuthenticatedClient
385+
err := p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
377386
if err := c.RawQuery(
378387
/* #nosec G201 - channel can either be "front" or "back" */
379-
fmt.Sprintf(`SELECT DISTINCT c.* FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`,
388+
fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL AND r.login_session_id = ?`,
380389
channel,
381390
channel,
382391
),
@@ -386,6 +395,32 @@ func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, s
386395
return sqlcon.HandleError(err)
387396
}
388397

398+
return nil
399+
})
400+
if err != nil {
401+
return nil, err
402+
}
403+
404+
for i := range cs {
405+
cs[i].LoginSessionID = sid
406+
}
407+
return cs, err
408+
}
409+
410+
func (p *Persister) listUserAuthenticatedClients(ctx context.Context, subject, channel string) ([]client.AuthenticatedClient, error) {
411+
var cs []client.AuthenticatedClient
412+
return cs, p.transaction(ctx, func(ctx context.Context, c *pop.Connection) error {
413+
if err := c.RawQuery(
414+
/* #nosec G201 - channel can either be "front" or "back" */
415+
fmt.Sprintf(`SELECT DISTINCT c.id, c.frontchannel_logout_uri, c.frontchannel_logout_session_required, c.backchannel_logout_uri, c.backchannel_logout_session_required, r.login_session_id FROM hydra_client as c JOIN hydra_oauth2_consent_request as r ON (c.id = r.client_id) WHERE r.subject=? AND c.%schannel_logout_uri!='' AND c.%schannel_logout_uri IS NOT NULL`,
416+
channel,
417+
channel,
418+
),
419+
subject,
420+
).All(&cs); err != nil {
421+
return sqlcon.HandleError(err)
422+
}
423+
389424
return nil
390425
})
391426
}

0 commit comments

Comments
 (0)