Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove duplicate refresh operations #806

Merged
merged 1 commit into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error)
if err != nil {
return "", err
}
i := d.connectionInfoCache(ctx, cn, nil)
i := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN)
ci, err := i.ConnectionInfo(ctx)
if err != nil {
return "", err
Expand Down
49 changes: 44 additions & 5 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,38 @@ func TestDialerEngineVersion(t *testing.T) {
}
}

// When Auto IAM AuthN is enabled, EngineVersion should warm the cache with a
// client certificate with Auto IAM AuthN enabled.
func TestEngineVersionAvoidsDuplicateRefreshWithIAMAuthN(t *testing.T) {
inst := mock.NewFakeCSQLInstance(
"my-project", "my-region", "my-instance",
)
d := setupDialer(t, setupConfig{
testInstance: inst,
dialerOptions: []Option{
WithIAMAuthN(), WithIAMAuthNTokenSources(
mock.EmptyTokenSource{},
mock.EmptyTokenSource{},
),
},
reqs: []*mock.Request{
// There should only be two API requests
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
})

_, err := d.EngineVersion(context.Background(), inst.String())
if err != nil {
t.Fatal(err)
}

testSuccessfulDial(
context.Background(), t, d,
inst.String(),
)
}

func TestDialerUserAgent(t *testing.T) {
data, err := os.ReadFile("version.txt")
if err != nil {
Expand All @@ -420,7 +452,7 @@ func TestWarmup(t *testing.T) {
expectedCalls []*mock.Request
}{
{
desc: "warmup and dial are the same",
desc: "Warmup and Dial both use IAM AuthN",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{WithDialIAMAuthN(true)},
expectedCalls: []*mock.Request{
Expand All @@ -429,7 +461,7 @@ func TestWarmup(t *testing.T) {
},
},
{
desc: "warmup and dial are different",
desc: "Warmup uses IAM Authn, Dial does not",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{WithDialIAMAuthN(false)},
expectedCalls: []*mock.Request{
Expand All @@ -438,19 +470,26 @@ func TestWarmup(t *testing.T) {
},
},
{
desc: "warmup and default dial are different",
desc: "Warmup uses IAM AuthN, Dial uses global setting",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
d := setupDialer(t, setupConfig{
dialerOptions: []Option{
WithIAMAuthN(),
WithIAMAuthNTokenSources(
mock.EmptyTokenSource{},
mock.EmptyTokenSource{},
),
},
testInstance: inst,
reqs: test.expectedCalls,
})
Expand Down
7 changes: 6 additions & 1 deletion internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,13 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation {
nil,
)
} else {
var useIAMAuthN bool
i.mu.Lock()
useIAMAuthN = i.useIAMAuthNDial
i.mu.Unlock()
Comment on lines +397 to +399
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a lock around this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During my testing, this triggered a race condition between the read here and the write in UpdateRefresh. I wrapped it with a lock to prevent any data races in the case that a refresh is scheduled at the same time a refresh was being updated.

r.result, r.err = i.r.ConnectionInfo(
ctx, i.connName, i.key, i.useIAMAuthNDial)
ctx, i.connName, i.key, useIAMAuthN,
)
}
switch r.err {
case nil:
Expand Down
19 changes: 15 additions & 4 deletions internal/cloudsql/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,23 @@ func fetchMetadata(
return m, nil
}

var expired = time.Time{}.Add(1)

// canRefresh determines if the provided token was refreshed or if it still has
// the sentinel expiration, which means the token was provided without a
// refresh token (as with the Cloud SQL Proxy's --token flag) and therefore
// cannot be refreshed.
func canRefresh(t *oauth2.Token) bool {
return t.Expiry.Unix() != expired.Unix()
}

// refreshToken will retrieve a new token, only if a refresh token is present.
func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) {
expiredToken := &oauth2.Token{
AccessToken: tok.AccessToken,
TokenType: tok.TokenType,
RefreshToken: tok.RefreshToken,
Expiry: time.Time{}.Add(1), // Expired
Expiry: expired,
}
return oauth2.ReuseTokenSource(expiredToken, ts).Token()
}
Expand Down Expand Up @@ -217,9 +228,9 @@ func fetchEphemeralCert(
)
}
if ts != nil {
// Adjust the certificate's expiration to be the earliest of the token's
// expiration or the certificate's expiration.
if tok.Expiry.Before(clientCert.NotAfter) {
// Adjust the certificate's expiration to be the earliest of
// the token's expiration or the certificate's expiration.
if canRefresh(tok) && tok.Expiry.Before(clientCert.NotAfter) {
clientCert.NotAfter = tok.Expiry
}
}
Expand Down
35 changes: 35 additions & 0 deletions internal/cloudsql/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,41 @@ func TestRefresh(t *testing.T) {
t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiration)
}
}

// If a caller has provided a static token source that cannot be refreshed
// (e.g., when the Cloud SQL Proxy is invokved with --token), then the
// refresher cannot determine the token's expiration (without additional API
// calls), and so the refresher should use the certificate's expiration instead
// of the token's expiration which is otherwise unset.
func TestRefreshWithStaticTokenSource(t *testing.T) {
cn := testInstanceConnName()
inst := mock.NewFakeCSQLInstance(
cn.Project(), cn.Region(), cn.Name(),
)
client, cleanup, err := mock.NewSQLAdminService(
context.Background(),
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
)
if err != nil {
t.Fatalf("failed to create test SQL admin service: %s", err)
}
t.Cleanup(func() { _ = cleanup() })

ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "myaccestoken"})
r := newRefresher(nullLogger{}, client, ts, testDialerID)
ci, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true)
if err != nil {
t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err)
}
if !ci.Expiration.After(time.Now()) {
t.Fatalf(
"Connection info expiration should be in the future, got = %v",
ci.Expiration,
)
}
}

func TestRefreshRetries50xResponses(t *testing.T) {
cn := testInstanceConnName()
inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),
Expand Down
Loading