From 58dee99000f04044170dad7bcd2640dee6763344 Mon Sep 17 00:00:00 2001 From: Eric Chiang Date: Tue, 8 Dec 2015 16:25:35 -0800 Subject: [PATCH] oidc: fix race condition in provider config syncing Add a RWMutex around Client's providerConfig field. Syncing the provider config writes to the field, while numerous other actions read from it. Fixes #17 --- oidc/client.go | 67 ++++++++++++++++++++++++++-------------- oidc/client_race_test.go | 61 ++++++++++++++++++++++++++++++++++++ oidc/client_test.go | 37 +++++++++------------- test | 4 ++- 4 files changed, 122 insertions(+), 47 deletions(-) create mode 100644 oidc/client_race_test.go diff --git a/oidc/client.go b/oidc/client.go index 85a7af21..76330237 100644 --- a/oidc/client.go +++ b/oidc/client.go @@ -78,7 +78,7 @@ func NewClient(cfg ClientConfig) (*Client, error) { httpClient: cfg.HTTPClient, scope: cfg.Scope, redirectURL: ru.String(), - providerConfig: cfg.ProviderConfig, + providerConfig: newProviderConfigRepo(cfg.ProviderConfig), keySet: cfg.KeySet, } @@ -96,7 +96,7 @@ func NewClient(cfg ClientConfig) (*Client, error) { type Client struct { httpClient phttp.Client - providerConfig ProviderConfig + providerConfig *providerConfigRepo credentials ClientCredentials redirectURL string scope []string @@ -106,14 +106,39 @@ type Client struct { lastKeySetSync time.Time } +type providerConfigRepo struct { + mu sync.RWMutex + config ProviderConfig // do not access directly, use Get() +} + +func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo { + return &providerConfigRepo{sync.RWMutex{}, pc} +} + +// returns an error to implement ProviderConfigSetter +func (r *providerConfigRepo) Set(cfg ProviderConfig) error { + r.mu.Lock() + defer r.mu.Unlock() + r.config = cfg + return nil +} + +func (r *providerConfigRepo) Get() ProviderConfig { + r.mu.RLock() + defer r.mu.RUnlock() + return r.config +} + func (c *Client) Healthy() error { now := time.Now().UTC() - if c.providerConfig.Empty() { + cfg := c.providerConfig.Get() + + if cfg.Empty() { return errors.New("oidc client provider config empty") } - if !c.providerConfig.ExpiresAt.IsZero() && c.providerConfig.ExpiresAt.Before(now) { + if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) { return errors.New("oidc client provider config expired") } @@ -121,7 +146,8 @@ func (c *Client) Healthy() error { } func (c *Client) OAuthClient() (*oauth2.Client, error) { - authMethod, err := c.chooseAuthMethod() + cfg := c.providerConfig.Get() + authMethod, err := chooseAuthMethod(cfg) if err != nil { return nil, err } @@ -129,8 +155,8 @@ func (c *Client) OAuthClient() (*oauth2.Client, error) { ocfg := oauth2.Config{ Credentials: oauth2.ClientCredentials(c.credentials), RedirectURL: c.redirectURL, - AuthURL: c.providerConfig.AuthEndpoint, - TokenURL: c.providerConfig.TokenEndpoint, + AuthURL: cfg.AuthEndpoint, + TokenURL: cfg.TokenEndpoint, Scope: c.scope, AuthMethod: authMethod, } @@ -138,12 +164,12 @@ func (c *Client) OAuthClient() (*oauth2.Client, error) { return oauth2.NewClient(c.httpClient, ocfg) } -func (c *Client) chooseAuthMethod() (string, error) { - if len(c.providerConfig.TokenEndpointAuthMethodsSupported) == 0 { +func chooseAuthMethod(cfg ProviderConfig) (string, error) { + if len(cfg.TokenEndpointAuthMethodsSupported) == 0 { return oauth2.AuthMethodClientSecretBasic, nil } - for _, authMethod := range c.providerConfig.TokenEndpointAuthMethodsSupported { + for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported { if _, ok := supportedAuthMethods[authMethod]; ok { return authMethod, nil } @@ -153,9 +179,8 @@ func (c *Client) chooseAuthMethod() (string, error) { } func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} { - rp := &providerConfigRepo{c} r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL) - return NewProviderConfigSyncer(r, rp).Run() + return NewProviderConfigSyncer(r, c.providerConfig).Run() } func (c *Client) maybeSyncKeys() error { @@ -178,7 +203,8 @@ func (c *Client) maybeSyncKeys() error { return nil } - r := NewRemotePublicKeyRepo(c.httpClient, c.providerConfig.KeysEndpoint) + cfg := c.providerConfig.Get() + r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint) w := &clientKeyRepo{client: c} _, err := key.Sync(r, w) c.lastKeySetSync = time.Now().UTC() @@ -186,15 +212,6 @@ func (c *Client) maybeSyncKeys() error { return err } -type providerConfigRepo struct { - client *Client -} - -func (r *providerConfigRepo) Set(cfg ProviderConfig) error { - r.client.providerConfig = cfg - return nil -} - type clientKeyRepo struct { client *Client } @@ -209,7 +226,9 @@ func (r *clientKeyRepo) Set(ks key.KeySet) error { } func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) { - if !c.providerConfig.SupportsGrantType(oauth2.GrantTypeClientCreds) { + cfg := c.providerConfig.Get() + + if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) { return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds) } @@ -280,7 +299,7 @@ func (c *Client) VerifyJWT(jwt jose.JWT) error { } v := NewJWTVerifier( - c.providerConfig.Issuer, + c.providerConfig.Get().Issuer, c.credentials.ID, c.maybeSyncKeys, keysFunc) diff --git a/oidc/client_race_test.go b/oidc/client_race_test.go new file mode 100644 index 00000000..5ba9b8e7 --- /dev/null +++ b/oidc/client_race_test.go @@ -0,0 +1,61 @@ +// This file contains tests which depend on the race detector being enabled. +// +build race + +package oidc + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +type testProvider struct { + baseURL string +} + +func (p *testProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != discoveryConfigPath { + http.NotFound(w, r) + return + } + + cfg := ProviderConfig{ + Issuer: p.baseURL, + } + json.NewEncoder(w).Encode(&cfg) +} + +// If this test fails by triggering the race detector, not by calling t.Error or t.Fatal. +func TestProviderSyncRace(t *testing.T) { + prov := &testProvider{} + + s := httptest.NewServer(prov) + defer s.Close() + prov.baseURL = s.URL + + cliCfg := ClientConfig{ + HTTPClient: http.DefaultClient, + ProviderConfig: ProviderConfig{Issuer: s.URL}, + } + cli, err := NewClient(cliCfg) + if err != nil { + t.Error(err) + return + } + + // SyncProviderConfig beings a goroutine which writes to the client's provider config. + c := cli.SyncProviderConfig(s.URL) + defer func() { + // stop the background process + c <- struct{}{} + }() + + // Iterate serveral times to increase the likelyhood of triggering the race detector. + for i := 0; i < 100; i++ { + time.Sleep(5 * time.Millisecond) + // Creating an OAuth client reads from the provider config. + cli.OAuthClient() + } +} diff --git a/oidc/client_test.go b/oidc/client_test.go index e66688fc..cfa9a47e 100644 --- a/oidc/client_test.go +++ b/oidc/client_test.go @@ -64,47 +64,42 @@ func TestHealthy(t *testing.T) { now := time.Now().UTC() tests := []struct { - c *Client + p ProviderConfig h bool }{ // all ok { - c: &Client{ - providerConfig: ProviderConfig{ - Issuer: "http://example.com", - ExpiresAt: now.Add(time.Hour), - }, + p: ProviderConfig{ + Issuer: "http://example.com", + ExpiresAt: now.Add(time.Hour), }, h: true, }, // zero-value ProviderConfig.ExpiresAt { - c: &Client{ - providerConfig: ProviderConfig{ - Issuer: "http://example.com", - }, + p: ProviderConfig{ + Issuer: "http://example.com", }, h: true, }, // expired ProviderConfig { - c: &Client{ - providerConfig: ProviderConfig{ - Issuer: "http://example.com", - ExpiresAt: now.Add(time.Hour * -1), - }, + p: ProviderConfig{ + Issuer: "http://example.com", + ExpiresAt: now.Add(time.Hour * -1), }, h: false, }, // empty ProviderConfig { - c: &Client{}, + p: ProviderConfig{}, h: false, }, } for i, tt := range tests { - err := tt.c.Healthy() + c := &Client{providerConfig: newProviderConfigRepo(tt.p)} + err := c.Healthy() want := tt.h got := (err == nil) @@ -347,12 +342,10 @@ func TestChooseAuthMethod(t *testing.T) { } for i, tt := range tests { - client := Client{ - providerConfig: ProviderConfig{ - TokenEndpointAuthMethodsSupported: tt.supported, - }, + cfg := ProviderConfig{ + TokenEndpointAuthMethodsSupported: tt.supported, } - got, err := client.chooseAuthMethod() + got, err := chooseAuthMethod(cfg) if tt.err { if err == nil { t.Errorf("case %d: expected non-nil err", i) diff --git a/test b/test index fa16d1ba..acb94c56 100755 --- a/test +++ b/test @@ -12,6 +12,8 @@ # Invoke ./cover for HTML output COVER=${COVER:-"-cover"} +RACE=${RACE:-"-race"} + source ./build TESTABLE="http jose key oauth2 oidc" @@ -37,7 +39,7 @@ split=(${TEST// / }) TEST=${split[@]/#/github.com/coreos/go-oidc/} echo "Running tests..." -go test ${COVER} $@ ${TEST} +go test $RACE ${COVER} $@ ${TEST} echo "Checking gofmt..." fmtRes=$(gofmt -l $FMT)