Skip to content

Commit

Permalink
oidc: fix race condition in provider config syncing
Browse files Browse the repository at this point in the history
Add a RWMutex around Client's providerConfig field. Syncing the
provider config writes to the field, while numerous other actions
read from it.

Fixes coreos#17
  • Loading branch information
ericchiang committed Dec 9, 2015
1 parent 562ae81 commit 58dee99
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 47 deletions.
67 changes: 43 additions & 24 deletions oidc/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ func NewClient(cfg ClientConfig) (*Client, error) {
httpClient: cfg.HTTPClient,
scope: cfg.Scope,
redirectURL: ru.String(),
providerConfig: cfg.ProviderConfig,
providerConfig: newProviderConfigRepo(cfg.ProviderConfig),
keySet: cfg.KeySet,
}

Expand All @@ -96,7 +96,7 @@ func NewClient(cfg ClientConfig) (*Client, error) {

type Client struct {
httpClient phttp.Client
providerConfig ProviderConfig
providerConfig *providerConfigRepo
credentials ClientCredentials
redirectURL string
scope []string
Expand All @@ -106,44 +106,70 @@ type Client struct {
lastKeySetSync time.Time
}

type providerConfigRepo struct {
mu sync.RWMutex
config ProviderConfig // do not access directly, use Get()
}

func newProviderConfigRepo(pc ProviderConfig) *providerConfigRepo {
return &providerConfigRepo{sync.RWMutex{}, pc}
}

// returns an error to implement ProviderConfigSetter
func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
r.mu.Lock()
defer r.mu.Unlock()
r.config = cfg
return nil
}

func (r *providerConfigRepo) Get() ProviderConfig {
r.mu.RLock()
defer r.mu.RUnlock()
return r.config
}

func (c *Client) Healthy() error {
now := time.Now().UTC()

if c.providerConfig.Empty() {
cfg := c.providerConfig.Get()

if cfg.Empty() {
return errors.New("oidc client provider config empty")
}

if !c.providerConfig.ExpiresAt.IsZero() && c.providerConfig.ExpiresAt.Before(now) {
if !cfg.ExpiresAt.IsZero() && cfg.ExpiresAt.Before(now) {
return errors.New("oidc client provider config expired")
}

return nil
}

func (c *Client) OAuthClient() (*oauth2.Client, error) {
authMethod, err := c.chooseAuthMethod()
cfg := c.providerConfig.Get()
authMethod, err := chooseAuthMethod(cfg)
if err != nil {
return nil, err
}

ocfg := oauth2.Config{
Credentials: oauth2.ClientCredentials(c.credentials),
RedirectURL: c.redirectURL,
AuthURL: c.providerConfig.AuthEndpoint,
TokenURL: c.providerConfig.TokenEndpoint,
AuthURL: cfg.AuthEndpoint,
TokenURL: cfg.TokenEndpoint,
Scope: c.scope,
AuthMethod: authMethod,
}

return oauth2.NewClient(c.httpClient, ocfg)
}

func (c *Client) chooseAuthMethod() (string, error) {
if len(c.providerConfig.TokenEndpointAuthMethodsSupported) == 0 {
func chooseAuthMethod(cfg ProviderConfig) (string, error) {
if len(cfg.TokenEndpointAuthMethodsSupported) == 0 {
return oauth2.AuthMethodClientSecretBasic, nil
}

for _, authMethod := range c.providerConfig.TokenEndpointAuthMethodsSupported {
for _, authMethod := range cfg.TokenEndpointAuthMethodsSupported {
if _, ok := supportedAuthMethods[authMethod]; ok {
return authMethod, nil
}
Expand All @@ -153,9 +179,8 @@ func (c *Client) chooseAuthMethod() (string, error) {
}

func (c *Client) SyncProviderConfig(discoveryURL string) chan struct{} {
rp := &providerConfigRepo{c}
r := NewHTTPProviderConfigGetter(c.httpClient, discoveryURL)
return NewProviderConfigSyncer(r, rp).Run()
return NewProviderConfigSyncer(r, c.providerConfig).Run()
}

func (c *Client) maybeSyncKeys() error {
Expand All @@ -178,23 +203,15 @@ func (c *Client) maybeSyncKeys() error {
return nil
}

r := NewRemotePublicKeyRepo(c.httpClient, c.providerConfig.KeysEndpoint)
cfg := c.providerConfig.Get()
r := NewRemotePublicKeyRepo(c.httpClient, cfg.KeysEndpoint)
w := &clientKeyRepo{client: c}
_, err := key.Sync(r, w)
c.lastKeySetSync = time.Now().UTC()

return err
}

type providerConfigRepo struct {
client *Client
}

func (r *providerConfigRepo) Set(cfg ProviderConfig) error {
r.client.providerConfig = cfg
return nil
}

type clientKeyRepo struct {
client *Client
}
Expand All @@ -209,7 +226,9 @@ func (r *clientKeyRepo) Set(ks key.KeySet) error {
}

func (c *Client) ClientCredsToken(scope []string) (jose.JWT, error) {
if !c.providerConfig.SupportsGrantType(oauth2.GrantTypeClientCreds) {
cfg := c.providerConfig.Get()

if !cfg.SupportsGrantType(oauth2.GrantTypeClientCreds) {
return jose.JWT{}, fmt.Errorf("%v grant type is not supported", oauth2.GrantTypeClientCreds)
}

Expand Down Expand Up @@ -280,7 +299,7 @@ func (c *Client) VerifyJWT(jwt jose.JWT) error {
}

v := NewJWTVerifier(
c.providerConfig.Issuer,
c.providerConfig.Get().Issuer,
c.credentials.ID,
c.maybeSyncKeys, keysFunc)

Expand Down
61 changes: 61 additions & 0 deletions oidc/client_race_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// This file contains tests which depend on the race detector being enabled.
// +build race

package oidc

import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
)

type testProvider struct {
baseURL string
}

func (p *testProvider) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != discoveryConfigPath {
http.NotFound(w, r)
return
}

cfg := ProviderConfig{
Issuer: p.baseURL,
}
json.NewEncoder(w).Encode(&cfg)
}

// If this test fails by triggering the race detector, not by calling t.Error or t.Fatal.
func TestProviderSyncRace(t *testing.T) {
prov := &testProvider{}

s := httptest.NewServer(prov)
defer s.Close()
prov.baseURL = s.URL

cliCfg := ClientConfig{
HTTPClient: http.DefaultClient,
ProviderConfig: ProviderConfig{Issuer: s.URL},
}
cli, err := NewClient(cliCfg)
if err != nil {
t.Error(err)
return
}

// SyncProviderConfig beings a goroutine which writes to the client's provider config.
c := cli.SyncProviderConfig(s.URL)
defer func() {
// stop the background process
c <- struct{}{}
}()

// Iterate serveral times to increase the likelyhood of triggering the race detector.
for i := 0; i < 100; i++ {
time.Sleep(5 * time.Millisecond)
// Creating an OAuth client reads from the provider config.
cli.OAuthClient()
}
}
37 changes: 15 additions & 22 deletions oidc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,47 +64,42 @@ func TestHealthy(t *testing.T) {
now := time.Now().UTC()

tests := []struct {
c *Client
p ProviderConfig
h bool
}{
// all ok
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour),
},
p: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour),
},
h: true,
},
// zero-value ProviderConfig.ExpiresAt
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
},
p: ProviderConfig{
Issuer: "http://example.com",
},
h: true,
},
// expired ProviderConfig
{
c: &Client{
providerConfig: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour * -1),
},
p: ProviderConfig{
Issuer: "http://example.com",
ExpiresAt: now.Add(time.Hour * -1),
},
h: false,
},
// empty ProviderConfig
{
c: &Client{},
p: ProviderConfig{},
h: false,
},
}

for i, tt := range tests {
err := tt.c.Healthy()
c := &Client{providerConfig: newProviderConfigRepo(tt.p)}
err := c.Healthy()
want := tt.h
got := (err == nil)

Expand Down Expand Up @@ -347,12 +342,10 @@ func TestChooseAuthMethod(t *testing.T) {
}

for i, tt := range tests {
client := Client{
providerConfig: ProviderConfig{
TokenEndpointAuthMethodsSupported: tt.supported,
},
cfg := ProviderConfig{
TokenEndpointAuthMethodsSupported: tt.supported,
}
got, err := client.chooseAuthMethod()
got, err := chooseAuthMethod(cfg)
if tt.err {
if err == nil {
t.Errorf("case %d: expected non-nil err", i)
Expand Down
4 changes: 3 additions & 1 deletion test
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# Invoke ./cover for HTML output
COVER=${COVER:-"-cover"}

RACE=${RACE:-"-race"}

source ./build

TESTABLE="http jose key oauth2 oidc"
Expand All @@ -37,7 +39,7 @@ split=(${TEST// / })
TEST=${split[@]/#/github.com/coreos/go-oidc/}

echo "Running tests..."
go test ${COVER} $@ ${TEST}
go test $RACE ${COVER} $@ ${TEST}

echo "Checking gofmt..."
fmtRes=$(gofmt -l $FMT)
Expand Down

0 comments on commit 58dee99

Please sign in to comment.