Skip to content

Commit 872685b

Browse files
alexzhangreslveshepelyuk
authored andcommitted
feat: allow refreshing keys once via jwks URL when current jwt kid is not found
1 parent 5c7c48e commit 872685b

File tree

3 files changed

+102
-8
lines changed

3 files changed

+102
-8
lines changed

README.md

+1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ OpaDebugMode | Set the opa response in the http response body when the request i
5151
PayloadFields | The field-name in the JWT payload that are required (e.g. `exp`). Multiple field names may be specified (string array)
5252
Required | Is `Authorization` header with JWT token required for every request.
5353
Keys | Used to validate JWT signature. Multiple keys are supported. Allowed values include certificates, public keys, symmetric keys. In case the value is a valid URL, the plugin will fetch keys from the JWK endpoint.
54+
ForceRefreshKeys | Force fetching keys from JWKS service when the key of current JWT token is not found. If set false, keys will only be refreshed every 15 minutes by default.
5455
Alg | Used to verify which PKI algorithm is used in the JWT.
5556
JwksHeaders | Map used to add headers to a JWKS request (e.g. credentials for a 3rd party JWKS service).
5657
JwtHeaders | Map used to inject JWT payload fields as HTTP request headers.

jwt.go

+53-8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"net/url"
2424
"strconv"
2525
"strings"
26+
"sync"
2627
"time"
2728
)
2829

@@ -35,6 +36,7 @@ type Config struct {
3536
PayloadFields []string
3637
Required bool
3738
Keys []string
39+
ForceRefreshKeys bool
3840
Alg string
3941
OpaHeaders map[string]string
4042
JwtHeaders map[string]string
@@ -74,6 +76,9 @@ type JwtPlugin struct {
7476
opaHttpStatusField string
7577
jwtCookieKey string
7678
jwtQueryKey string
79+
80+
keysLock sync.RWMutex
81+
forceRefreshCmd chan chan<- struct{}
7782
}
7883

7984
// LogEvent contains a single log entry
@@ -189,20 +194,43 @@ func New(_ context.Context, next http.Handler, config *Config, _ string) (http.H
189194
return nil, err
190195
}
191196
if len(jwtPlugin.jwkEndpoints) > 0 {
197+
if config.ForceRefreshKeys {
198+
jwtPlugin.forceRefreshCmd = make(chan chan<- struct{})
199+
}
192200
go jwtPlugin.BackgroundRefresh()
193201
}
194202
}
195203
return jwtPlugin, nil
196204
}
197205

198206
func (jwtPlugin *JwtPlugin) BackgroundRefresh() {
207+
jwtPlugin.FetchKeys()
199208
for {
200-
jwtPlugin.FetchKeys()
201-
time.Sleep(15 * time.Minute) // 15 min
209+
select {
210+
case keysFetchedChan := <-jwtPlugin.forceRefreshCmd:
211+
jwtPlugin.FetchKeys()
212+
keysFetchedChan <- struct{}{}
213+
case <-time.After(15 * time.Minute):
214+
jwtPlugin.FetchKeys()
215+
}
216+
}
217+
}
218+
219+
func (jwtPlugin *JwtPlugin) forceRefreshKeys() (refreshed bool) {
220+
if jwtPlugin.forceRefreshCmd == nil || len(jwtPlugin.jwkEndpoints) == 0 {
221+
return
202222
}
223+
refreshedCh := make(chan struct{})
224+
jwtPlugin.forceRefreshCmd <- refreshedCh
225+
<-refreshedCh
226+
refreshed = true
227+
return
203228
}
204229

205230
func (jwtPlugin *JwtPlugin) ParseKeys(certificates []string) error {
231+
jwtPlugin.keysLock.Lock()
232+
defer jwtPlugin.keysLock.Unlock()
233+
206234
for _, certificate := range certificates {
207235
if block, rest := pem.Decode([]byte(certificate)); block != nil {
208236
if len(rest) > 0 {
@@ -235,6 +263,7 @@ func (jwtPlugin *JwtPlugin) ParseKeys(certificates []string) error {
235263
func (jwtPlugin *JwtPlugin) FetchKeys() {
236264
logInfo(fmt.Sprintf("FetchKeys - #%d jwkEndpoints to fetch", len(jwtPlugin.jwkEndpoints))).
237265
print()
266+
fetchedKeys := map[string]interface{}{}
238267
for _, u := range jwtPlugin.jwkEndpoints {
239268
req, err := http.NewRequest("GET", u.String(), nil)
240269
if err != nil {
@@ -281,7 +310,7 @@ func (jwtPlugin *JwtPlugin) FetchKeys() {
281310
ptr := new(rsa.PublicKey)
282311
ptr.N = new(big.Int).SetBytes(nBytes)
283312
ptr.E = int(new(big.Int).SetBytes(eBytes).Uint64())
284-
jwtPlugin.keys[key.Kid] = ptr
313+
fetchedKeys[key.Kid] = ptr
285314
}
286315
case "EC":
287316
{
@@ -323,7 +352,7 @@ func (jwtPlugin *JwtPlugin) FetchKeys() {
323352
ptr.Curve = crv
324353
ptr.X = new(big.Int).SetBytes(xBytes)
325354
ptr.Y = new(big.Int).SetBytes(yBytes)
326-
jwtPlugin.keys[key.Kid] = ptr
355+
fetchedKeys[key.Kid] = ptr
327356
}
328357
case "oct":
329358
{
@@ -337,11 +366,18 @@ func (jwtPlugin *JwtPlugin) FetchKeys() {
337366
break
338367
}
339368
}
340-
jwtPlugin.keys[key.Kid] = kBytes
369+
fetchedKeys[key.Kid] = kBytes
341370
}
342371
}
343372
}
344373
}
374+
375+
jwtPlugin.keysLock.Lock()
376+
defer jwtPlugin.keysLock.Unlock()
377+
378+
for k, v := range fetchedKeys {
379+
jwtPlugin.keys[k] = v
380+
}
345381
}
346382

347383
func (jwtPlugin *JwtPlugin) ServeHTTP(rw http.ResponseWriter, request *http.Request) {
@@ -373,7 +409,7 @@ func (jwtPlugin *JwtPlugin) CheckToken(request *http.Request, rw http.ResponseWr
373409
if jwtToken != nil {
374410
sub = fmt.Sprint(jwtToken.Payload["sub"])
375411
// only verify jwt tokens if keys are configured
376-
if len(jwtPlugin.keys) > 0 || len(jwtPlugin.jwkEndpoints) > 0 {
412+
if len(jwtPlugin.getKeysSync()) > 0 || len(jwtPlugin.jwkEndpoints) > 0 {
377413
if err = jwtPlugin.VerifyToken(jwtToken); err != nil {
378414
logError(fmt.Sprintf("Token is invalid - err: %s", err.Error())).
379415
withSub(sub).
@@ -550,6 +586,12 @@ func (jwtPlugin *JwtPlugin) remoteAddr(req *http.Request) Network {
550586
}
551587
}
552588

589+
func (jwtPlugin *JwtPlugin) getKeysSync() map[string]interface{} {
590+
jwtPlugin.keysLock.RLock()
591+
defer jwtPlugin.keysLock.RUnlock()
592+
return jwtPlugin.keys
593+
}
594+
553595
func (jwtPlugin *JwtPlugin) VerifyToken(jwtToken *JWT) error {
554596
for _, h := range jwtToken.Header.Crit {
555597
if _, ok := supportedHeaderNames[h]; !ok {
@@ -564,11 +606,14 @@ func (jwtPlugin *JwtPlugin) VerifyToken(jwtToken *JWT) error {
564606
if jwtPlugin.alg != "" && jwtToken.Header.Alg != jwtPlugin.alg {
565607
return fmt.Errorf("incorrect alg, expected %s got %s", jwtPlugin.alg, jwtToken.Header.Alg)
566608
}
567-
key, ok := jwtPlugin.keys[jwtToken.Header.Kid]
609+
key, ok := jwtPlugin.getKeysSync()[jwtToken.Header.Kid]
610+
if !ok && jwtPlugin.forceRefreshKeys() {
611+
key, ok = jwtPlugin.getKeysSync()[jwtToken.Header.Kid]
612+
}
568613
if ok {
569614
return a.verify(key, a.hash, jwtToken.Plaintext, jwtToken.Signature)
570615
} else {
571-
for _, key := range jwtPlugin.keys {
616+
for _, key := range jwtPlugin.getKeysSync() {
572617
err := a.verify(key, a.hash, jwtToken.Plaintext, jwtToken.Signature)
573618
if err == nil {
574619
return nil

jwt_test.go

+48
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,54 @@ func TestNewJWKEndpoint(t *testing.T) {
618618
}
619619
}
620620

621+
func TestForceRefreshKeys(t *testing.T) {
622+
keys := `{"keys":[{"kty":"oct","kid":"57bd26a0-6209-4a93-a688-f8752be5d191","k":"eW91ci01MTItYml0LXNlY3JldA","alg":"HS512"}]}`
623+
token := "Bearer eyJhbGciOiJIUzUxMiIsInR5cCI6IkpXVCIsImNyaXQiOlsia2lkIl0sImtpZCI6IjU3YmQyNmEwLTYyMDktNGE5My1hNjg4LWY4NzUyYmU1ZDE5MSJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTUxNjIzOTAyMn0.573ixRAw4I4XUFJwJGpv5dHNOGaexX5zTtF0nOQTWuU2_JyZjD-7cuMPxQUHOv8RR0kQrS0uVdo_N1lzTCPFnA"
624+
jwksCalledCounter := 0
625+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
626+
defer func() { jwksCalledCounter++ }()
627+
w.WriteHeader(http.StatusOK)
628+
if jwksCalledCounter == 0 {
629+
fmt.Fprintln(w, `{"keys":[]}`)
630+
return
631+
}
632+
_, _ = fmt.Fprintln(w, keys)
633+
}))
634+
defer ts.Close()
635+
cfg := Config{
636+
Keys: []string{ts.URL},
637+
ForceRefreshKeys: true,
638+
}
639+
ctx := context.Background()
640+
nextCalled := false
641+
next := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { nextCalled = true })
642+
opa, err := New(ctx, next, &cfg, "test-traefik-jwt-plugin")
643+
if err != nil {
644+
t.Fatal(err)
645+
}
646+
647+
recorder := httptest.NewRecorder()
648+
649+
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://localhost", nil)
650+
if err != nil {
651+
t.Fatal(err)
652+
}
653+
req.Header.Add("Authorization", token)
654+
655+
opa.ServeHTTP(recorder, req)
656+
657+
resp := recorder.Result()
658+
if resp.StatusCode != http.StatusOK {
659+
t.Fatalf("Expected status code %d, received %d", http.StatusOK, resp.StatusCode)
660+
}
661+
if !nextCalled {
662+
t.Fatalf("next.ServeHTTP was called: %t, expected: %t", nextCalled, true)
663+
}
664+
if jwksCalledCounter != 2 {
665+
t.Fatalf("jwks was called: %d times, expected: %d", jwksCalledCounter, 2)
666+
}
667+
}
668+
621669
func TestIssue3(t *testing.T) {
622670
cfg := Config{
623671
JwtHeaders: map[string]string{"Subject": "sub", "User": "preferred_username"},

0 commit comments

Comments
 (0)