-
Notifications
You must be signed in to change notification settings - Fork 1.5k
/
Copy pathsession.go
210 lines (177 loc) · 6.71 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Copyright © 2022 Ory Corp
// SPDX-License-Identifier: Apache-2.0
package oauth2
import (
"bytes"
"context"
"time"
jjson "github.com/go-jose/go-jose/v3/json"
"github.com/mohae/deepcopy"
"github.com/pkg/errors"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"
"github.com/ory/hydra/v2/driver/config"
"github.com/ory/hydra/v2/flow"
"github.com/ory/x/logrusx"
"github.com/ory/x/stringslice"
)
// swagger:ignore
type Session struct {
*openid.DefaultSession `json:"id_token"`
Extra map[string]interface{} `json:"extra"`
KID string `json:"kid"`
ClientID string `json:"client_id"`
ConsentChallenge string `json:"consent_challenge"`
ExcludeNotBeforeClaim bool `json:"exclude_not_before_claim"`
AllowedTopLevelClaims []string `json:"allowed_top_level_claims"`
MirrorTopLevelClaims bool `json:"mirror_top_level_claims"`
Flow *flow.Flow `json:"-"`
}
func NewSession(subject string) *Session {
ctx := context.Background()
provider := config.MustNew(ctx, logrusx.New("", ""))
return NewSessionWithCustomClaims(ctx, provider, subject)
}
func NewSessionWithCustomClaims(ctx context.Context, p *config.DefaultProvider, subject string) *Session {
allowedTopLevelClaims := p.AllowedTopLevelClaims(ctx)
mirrorTopLevelClaims := p.MirrorTopLevelClaims(ctx)
return &Session{
DefaultSession: &openid.DefaultSession{
Claims: new(jwt.IDTokenClaims),
Headers: new(jwt.Headers),
Subject: subject,
},
Extra: map[string]interface{}{},
AllowedTopLevelClaims: allowedTopLevelClaims,
MirrorTopLevelClaims: mirrorTopLevelClaims,
}
}
func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer {
// a slice of claims that are reserved and should not be overridden
reservedClaims := []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}
// remove any reserved claims from the custom claims
allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool {
return stringslice.Has(reservedClaims, s)
})
// our new extra map which will be added to the jwt
topLevelExtraWithMirrorExt := map[string]interface{}{}
// setting every allowed claim top level in jwt with respective value
for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved {
if cl, ok := s.Extra[allowedClaim]; ok {
topLevelExtraWithMirrorExt[allowedClaim] = cl
}
}
// for every other claim that was already reserved and for mirroring, add original extra under "ext"
if s.MirrorTopLevelClaims {
topLevelExtraWithMirrorExt["ext"] = s.Extra
}
claims := &jwt.JWTClaims{
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
// set our custom extra map as claims.Extra
Extra: topLevelExtraWithMirrorExt,
ExpiresAt: s.GetExpiresAt(fosite.AccessToken),
IssuedAt: time.Now(),
// No need to set the audience because that's being done by fosite automatically.
// Audience: s.Audience,
// The JTI MUST NOT BE FIXED or refreshing tokens will yield the SAME token
// JTI: s.JTI,
// These are set by the DefaultJWTStrategy
// Scope: s.Scope,
// Setting these here will cause the token to have the same iat/nbf values always
// IssuedAt: s.DefaultSession.Claims.IssuedAt,
// NotBefore: s.DefaultSession.Claims.IssuedAt,
}
if !s.ExcludeNotBeforeClaim {
claims.NotBefore = claims.IssuedAt
}
if claims.Extra == nil {
claims.Extra = map[string]interface{}{}
}
claims.Extra["client_id"] = s.ClientID
return claims
}
func (s *Session) GetJWTHeader() *jwt.Headers {
return &jwt.Headers{
Extra: map[string]interface{}{"kid": s.KID},
}
}
func (s *Session) Clone() fosite.Session {
if s == nil {
return nil
}
return deepcopy.Copy(s).(fosite.Session)
}
var keyRewrites = map[string]string{
"Extra": "extra",
"KID": "kid",
"ClientID": "client_id",
"ConsentChallenge": "consent_challenge",
"ExcludeNotBeforeClaim": "exclude_not_before_claim",
"AllowedTopLevelClaims": "allowed_top_level_claims",
"idToken.Headers.Extra": "id_token.headers.extra",
"idToken.ExpiresAt": "id_token.expires_at",
"idToken.Username": "id_token.username",
"idToken.Subject": "id_token.subject",
"idToken.Claims.JTI": "id_token.id_token_claims.jti",
"idToken.Claims.Issuer": "id_token.id_token_claims.iss",
"idToken.Claims.Subject": "id_token.id_token_claims.sub",
"idToken.Claims.Audience": "id_token.id_token_claims.aud",
"idToken.Claims.Nonce": "id_token.id_token_claims.nonce",
"idToken.Claims.ExpiresAt": "id_token.id_token_claims.exp",
"idToken.Claims.IssuedAt": "id_token.id_token_claims.iat",
"idToken.Claims.RequestedAt": "id_token.id_token_claims.rat",
"idToken.Claims.AuthTime": "id_token.id_token_claims.auth_time",
"idToken.Claims.AccessTokenHash": "id_token.id_token_claims.at_hash",
"idToken.Claims.AuthenticationContextClassReference": "id_token.id_token_claims.acr",
"idToken.Claims.AuthenticationMethodsReferences": "id_token.id_token_claims.amr",
"idToken.Claims.CodeHash": "id_token.id_token_claims.c_hash",
"idToken.Claims.Extra": "id_token.id_token_claims.ext",
}
func (s *Session) UnmarshalJSON(original []byte) (err error) {
transformed := original
originalParsed := gjson.ParseBytes(original)
for oldKey, newKey := range keyRewrites {
if !originalParsed.Get(oldKey).Exists() {
continue
}
transformed, err = sjson.SetRawBytes(transformed, newKey, []byte(originalParsed.Get(oldKey).Raw))
if err != nil {
return errors.WithStack(err)
}
}
for orig := range keyRewrites {
transformed, err = sjson.DeleteBytes(transformed, orig)
if err != nil {
return errors.WithStack(err)
}
}
if originalParsed.Get("idToken").Exists() {
transformed, err = sjson.DeleteBytes(transformed, "idToken")
if err != nil {
return errors.WithStack(err)
}
}
// https://github.com/go-jose/go-jose/issues/144
dec := jjson.NewDecoder(bytes.NewReader(transformed))
dec.SetNumberType(jjson.UnmarshalIntOrFloat)
type t Session
if err := dec.Decode((*t)(s)); err != nil {
return errors.WithStack(err)
}
return nil
}
// GetExtraClaims implements ExtraClaimsSession for Session.
// The returned value can be modified in-place.
func (s *Session) GetExtraClaims() map[string]interface{} {
if s == nil {
return nil
}
if s.Extra == nil {
s.Extra = make(map[string]interface{})
}
return s.Extra
}