@@ -32,6 +32,7 @@ import (
32
32
"path"
33
33
"regexp"
34
34
"strings"
35
+ "sync"
35
36
"time"
36
37
37
38
gooidc "github.com/coreos/go-oidc/v3/oidc"
@@ -69,16 +70,23 @@ type OIDCStateStorage interface {
69
70
}
70
71
71
72
type Cache struct {
72
- OidcState map [ string ] * OIDCState
73
+ OidcState sync. Map
73
74
}
74
75
75
76
func (c * Cache ) GetOIDCState (key string ) (* OIDCState , error ) {
76
- state := c .OidcState [key ]
77
+ value , exists := c .OidcState .Load (key )
78
+ if ! exists {
79
+ return nil , ErrCacheMiss
80
+ }
81
+ state , ok := value .(* OIDCState )
82
+ if ! ok || state == nil {
83
+ return nil , ErrInvalidState
84
+ }
77
85
return state , nil
78
86
}
79
87
80
88
func (c * Cache ) SetOIDCState (key string , state * OIDCState ) error {
81
- c .OidcState [ key ] = state
89
+ c .OidcState . Store ( key , state )
82
90
return nil
83
91
}
84
92
@@ -287,12 +295,15 @@ func (a *ClientApp) generateAppState(returnURL string) string {
287
295
}
288
296
289
297
var ErrCacheMiss = errors .New ("cache: key is missing" )
298
+ var ErrInvalidState = errors .New ("invalid app state" )
290
299
291
300
func (a * ClientApp ) verifyAppState (state string ) (* OIDCState , error ) {
292
301
res , err := a .cache .GetOIDCState (state )
293
302
if err != nil {
294
- if err == ErrCacheMiss {
303
+ if errors . Is ( err , ErrCacheMiss ) {
295
304
return nil , fmt .Errorf ("unknown app state %s" , state )
305
+ } else if errors .Is (err , ErrInvalidState ) {
306
+ return nil , fmt .Errorf ("invalid app state %s" , state )
296
307
} else {
297
308
return nil , fmt .Errorf ("failed to verify app state %s: %v" , state , err )
298
309
}
0 commit comments