diff --git a/oidc/client.go b/oidc/client.go index 85a7af21..76330237 100644 --- a/oidc/client.go +++ b/oidc/client.go @@ -78,7 +78,7 @@ func NewClient(cfg ClientConfig) (*Client, error) { httpClient: cfg.HTTPClient, scope: cfg.Scope, redirectURL: ru.String(), - providerConfig: cfg.ProviderConfig, + providerConfig: newProviderConfigRepo(cfg.ProviderConfig), keySet: cfg.KeySet, } @@ -96,7 +96,7 @@ func NewClient(cfg ClientConfig) (*Client, error) { type Client struct { httpClient phttp.Client - providerConfig ProviderConfig + providerConfig *providerConfigRepo credentials ClientCredentials redirectURL string scope []string @@ -106,14 +106,39 @@ type Client struct { lastKeySetSync time.Time } +type providerConfigRepo struct { + mu sync.RWMutex + config ProviderConfig // do not access directly, use Get() +} + +func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo { + return &providerConfigRepo{sync.RWMutex{}, pc} +} + +// returns an error to implement ProviderConfigSetter +func (r *providerConfigRepo) Set(cfg ProviderConfig) error { + r.mu.Lock() + defer r.mu.Unlock() + r.config = cfg + return nil +} + +func (r *providerConfigRepo) Get() ProviderConfig { + r.mu.RLock() + defer r.mu.RUnlock() + return r.config +} + func (c *Client) Healthy() error { now := time.Now().UTC() - if c.providerConfig.Empty() { + cfg := c.providerConfig.Get() + + if cfg.Empty() { return errors.New("oidc client provider config empty") } - if !c.providerConfig.ExpiresAt.IsZero() && c.providerConfig.ExpiresAt.Before(now) { + if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) { return errors.New("oidc client provider config expired") } @@ -121,7 +146,8 @@ func (c *Client) Healthy() error { } func (c *Client) OAuthClient() (*oauth2.Client, error) { - authMethod, err := c.chooseAuthMethod() + cfg := c.providerConfig.Get() + authMethod, err := chooseAuthMethod(cfg) if err != nil { return nil, err } @@ -129,8 +155,8 @@ func (c *Client) OAuthClient() (*oauth2.Client, error) { ocfg := oauth2.Config{ Credentials: oauth2.ClientCredentials(c.credentials), RedirectURL: c.redirectURL, - AuthURL: c.providerConfig.AuthEndpoint, - TokenURL: c.providerConfig.TokenEndpoint, + AuthURL: cfg.AuthEndpoint, + TokenURL: cfg.TokenEndpoint, Scope: c.scope, AuthMethod: authMethod, } @@ -138,12 +164,12 @@ func (c *Client) OAuthClient() (*oauth2.Client, error) { return oauth2.NewClient(c.httpClient, ocfg) } -func (c *Client) chooseAuthMethod() (string, error) { - if len(c.providerConfig.TokenEndpointAuthMethodsSupported) == 0 { +func chooseAuthMethod(cfg ProviderConfig) (string, error) { + if len(cfg.TokenEndpointAuthMethodsSupported) == 0 { return oauth2.AuthMethodClientSecretBasic, nil } - for _, authMethod := range c.providerConfig.TokenEndpointAuthMethodsSupported { + for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported { if _, ok := supportedAuthMethods[authMethod]; ok { return authMethod, nil } @@ -153,9 +179,8 @@ func (c *Client) chooseAuthMethod() (string, error) { } func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} { - rp := &providerConfigRepo{c} r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL) - return NewProviderConfigSyncer(r, rp).Run() + return NewProviderConfigSyncer(r, c.providerConfig).Run() } func (c *Client) maybeSyncKeys() error { @@ -178,7 +203,8 @@ func (c *Client) maybeSyncKeys() error { return nil } - r := NewRemotePublicKeyRepo(c.httpClient, c.providerConfig.KeysEndpoint) + cfg := c.providerConfig.Get() + r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint) w := &clientKeyRepo{client: c} _, err := key.Sync(r, w) c.lastKeySetSync = time.Now().UTC() @@ -186,15 +212,6 @@ func (c *Client) maybeSyncKeys() error { return err } -type providerConfigRepo struct { - client *Client -} - -func (r *providerConfigRepo) Set(cfg ProviderConfig) error { - r.client.providerConfig = cfg - return nil -} - type clientKeyRepo struct { client *Client } @@ -209,7 +226,9 @@ func (r *clientKeyRepo) Set(ks key.KeySet) error { } func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) { - if !c.providerConfig.SupportsGrantType(oauth2.GrantTypeClientCreds) { + cfg := c.providerConfig.Get() + + if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) { return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds) } @@ -280,7 +299,7 @@ func (c *Client) VerifyJWT(jwt jose.JWT) error { } v := NewJWTVerifier( - c.providerConfig.Issuer, + c.providerConfig.Get().Issuer, c.credentials.ID, c.maybeSyncKeys, keysFunc) diff --git a/oidc/client_race_test.go b/oidc/client_race_test.go new file mode 100644 index 00000000..0993f3ab --- /dev/null +++ b/oidc/client_race_test.go @@ -0,0 +1,70 @@ +// This file contains tests which depend on the race detector being enabled. +// +build race + +package oidc + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +type testProvider struct { + baseURL string +} + +func (p *testProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != discoveryConfigPath { + http.NotFound(w, r) + return + } + + cfg := ProviderConfig{ + Issuer: p.baseURL, + } + json.NewEncoder(w).Encode(&cfg) +} + +// This test fails by triggering the race detector, not by calling t.Error or t.Fatal. +func TestProviderSyncRace(t *testing.T) { + + prov := &testProvider{} + + s := httptest.NewServer(prov) + defer s.Close() + prov.baseURL = s.URL + + prevValue := minimumProviderConfigSyncInterval + defer func() { minimumProviderConfigSyncInterval = prevValue }() + + // Reduce the sync interval to increase the write frequencey. + minimumProviderConfigSyncInterval = 5 * time.Millisecond + + cliCfg := ClientConfig{ + HTTPClient: http.DefaultClient, + ProviderConfig: ProviderConfig{ + Issuer: s.URL, + ExpiresAt: time.Now().Add(time.Minute), // Must expire to trigger frequent syncs. + }, + } + cli, err := NewClient(cliCfg) + if err != nil { + t.Error(err) + return + } + + // SyncProviderConfig beings a goroutine which writes to the client's provider config. + c := cli.SyncProviderConfig(s.URL) + defer func() { + // stop the background process + c <- struct{}{} + }() + + for i := 0; i < 10; i++ { + time.Sleep(5 * time.Millisecond) + // Creating an OAuth client reads from the provider config. + cli.OAuthClient() + } +} diff --git a/oidc/client_test.go b/oidc/client_test.go index e66688fc..cfa9a47e 100644 --- a/oidc/client_test.go +++ b/oidc/client_test.go @@ -64,47 +64,42 @@ func TestHealthy(t *testing.T) { now := time.Now().UTC() tests := []struct { - c *Client + p ProviderConfig h bool }{ // all ok { - c: &Client{ - providerConfig: ProviderConfig{ - Issuer: "http://example.com", - ExpiresAt: now.Add(time.Hour), - }, + p: ProviderConfig{ + Issuer: "http://example.com", + ExpiresAt: now.Add(time.Hour), }, h: true, }, // zero-value ProviderConfig.ExpiresAt { - c: &Client{ - providerConfig: ProviderConfig{ - Issuer: "http://example.com", - }, + p: ProviderConfig{ + Issuer: "http://example.com", }, h: true, }, // expired ProviderConfig { - c: &Client{ - providerConfig: ProviderConfig{ - Issuer: "http://example.com", - ExpiresAt: now.Add(time.Hour * -1), - }, + p: ProviderConfig{ + Issuer: "http://example.com", + ExpiresAt: now.Add(time.Hour * -1), }, h: false, }, // empty ProviderConfig { - c: &Client{}, + p: ProviderConfig{}, h: false, }, } for i, tt := range tests { - err := tt.c.Healthy() + c := &Client{providerConfig: newProviderConfigRepo(tt.p)} + err := c.Healthy() want := tt.h got := (err == nil) @@ -347,12 +342,10 @@ func TestChooseAuthMethod(t *testing.T) { } for i, tt := range tests { - client := Client{ - providerConfig: ProviderConfig{ - TokenEndpointAuthMethodsSupported: tt.supported, - }, + cfg := ProviderConfig{ + TokenEndpointAuthMethodsSupported: tt.supported, } - got, err := client.chooseAuthMethod() + got, err := chooseAuthMethod(cfg) if tt.err { if err == nil { t.Errorf("case %d: expected non-nil err", i) diff --git a/oidc/provider.go b/oidc/provider.go index bfdeef3f..f911e39e 100644 --- a/oidc/provider.go +++ b/oidc/provider.go @@ -25,6 +25,9 @@ const ( discoveryConfigPath = "/.well-known/openid-configuration" ) +// internally configurable for tests +var minimumProviderConfigSyncInterval = MinimumProviderConfigSyncInterval + type ProviderConfig struct { Issuer string `json:"issuer"` AuthEndpoint string `json:"authorization_endpoint"` @@ -172,8 +175,8 @@ func nextSyncAfter(exp time.Time, clock clockwork.Clock) time.Duration { t := exp.Sub(clock.Now()) / 2 if t > MaximumProviderConfigSyncInterval { t = MaximumProviderConfigSyncInterval - } else if t < MinimumProviderConfigSyncInterval { - t = MinimumProviderConfigSyncInterval + } else if t < minimumProviderConfigSyncInterval { + t = minimumProviderConfigSyncInterval } return t diff --git a/test b/test index fa16d1ba..acb94c56 100755 --- a/test +++ b/test @@ -12,6 +12,8 @@ # Invoke ./cover for HTML output COVER=${COVER:-"-cover"} +RACE=${RACE:-"-race"} + source ./build TESTABLE="http jose key oauth2 oidc" @@ -37,7 +39,7 @@ split=(${TEST// / }) TEST=${split[@]/#/github.com/coreos/go-oidc/} echo "Running tests..." -go test ${COVER} $@ ${TEST} +go test $RACE ${COVER} $@ ${TEST} echo "Checking gofmt..." fmtRes=$(gofmt -l $FMT)