From beb36052af2221d7ff238edc4c98c733cac2999d Mon Sep 17 00:00:00 2001 From: Eno Compton Date: Tue, 21 May 2024 11:40:32 -0600 Subject: [PATCH] fix: remove duplicate refresh operations (#806) When callers (like the Cloud SQL Proxy) warmup the background refresh with EngineVersion, the initial refresh operation starts with IAM Authn disabled. Then when a caller connects with IAM authentication, the existing refresh is invalidated (because it doesn't include the client's OAuth2 token). In effect, two refresh operations are completed when the dialer could have just run one initially. This commit ensures calls to EngineVersion respect the dialer's global IAM authentication setting. If IAM authentication is enabled at the dialer level, then EngineVersion will ensure the refresh operation uses IAM authentication. And only one refresh operation occurs between warmup and first connection. In cases where a dialer is initialized without IAM authentication, but then a call to dial requests IAM authentication, a second refresh is unavoidable. This seems to be an uncommon enough use case that it is an acceptable tradeoff given how IAM authentication is tightly coupled with the client certificate refresh. Separately, when Cloud SQL Proxy invocations start the Proxy with the --token flag in combination with IAM authentication, the underlying token does not have a corresponding refresh token and so cannot be refreshed. As a result, when the double calls occur (as described above), there is a third refresh attempt started because the token has a missing expiration field (there is only AccessToken, no RefreshToken, or Expiry). The dialer sees the missing expiration as an expired client certificate and immediately starts a new refresh. But because the dialer has already consumed two attempts, the rate limiter (2 initial attempts, then 30s/attempt) forces the client to wait 30s before connecting. This commit ensures that situation will not happen by using the correct client certificate expiration and not the invalid token expiration. For cases where the Cloud SQL Proxy configures the dialer with the --token flag (and no refresh token), the dialer will always default to using the client certificate's expiration. This means the refresh cycle will fail once the token expires. Fixes #771 --- dialer.go | 2 +- dialer_test.go | 49 +++++++++++++++++++++++++++---- internal/cloudsql/instance.go | 7 ++++- internal/cloudsql/refresh.go | 19 +++++++++--- internal/cloudsql/refresh_test.go | 35 ++++++++++++++++++++++ 5 files changed, 101 insertions(+), 11 deletions(-) diff --git a/dialer.go b/dialer.go index 7f0bbbd1..c93dbade 100644 --- a/dialer.go +++ b/dialer.go @@ -389,7 +389,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) if err != nil { return "", err } - i := d.connectionInfoCache(ctx, cn, nil) + i := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN) ci, err := i.ConnectionInfo(ctx) if err != nil { return "", err diff --git a/dialer_test.go b/dialer_test.go index 9d508025..051e9b92 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -396,6 +396,38 @@ func TestDialerEngineVersion(t *testing.T) { } } +// When Auto IAM AuthN is enabled, EngineVersion should warm the cache with a +// client certificate with Auto IAM AuthN enabled. +func TestEngineVersionAvoidsDuplicateRefreshWithIAMAuthN(t *testing.T) { + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + d := setupDialer(t, setupConfig{ + testInstance: inst, + dialerOptions: []Option{ + WithIAMAuthN(), WithIAMAuthNTokenSources( + mock.EmptyTokenSource{}, + mock.EmptyTokenSource{}, + ), + }, + reqs: []*mock.Request{ + // There should only be two API requests + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + }) + + _, err := d.EngineVersion(context.Background(), inst.String()) + if err != nil { + t.Fatal(err) + } + + testSuccessfulDial( + context.Background(), t, d, + inst.String(), + ) +} + func TestDialerUserAgent(t *testing.T) { data, err := os.ReadFile("version.txt") if err != nil { @@ -420,7 +452,7 @@ func TestWarmup(t *testing.T) { expectedCalls []*mock.Request }{ { - desc: "warmup and dial are the same", + desc: "Warmup and Dial both use IAM AuthN", warmupOpts: []DialOption{WithDialIAMAuthN(true)}, dialOpts: []DialOption{WithDialIAMAuthN(true)}, expectedCalls: []*mock.Request{ @@ -429,7 +461,7 @@ func TestWarmup(t *testing.T) { }, }, { - desc: "warmup and dial are different", + desc: "Warmup uses IAM Authn, Dial does not", warmupOpts: []DialOption{WithDialIAMAuthN(true)}, dialOpts: []DialOption{WithDialIAMAuthN(false)}, expectedCalls: []*mock.Request{ @@ -438,12 +470,12 @@ func TestWarmup(t *testing.T) { }, }, { - desc: "warmup and default dial are different", + desc: "Warmup uses IAM AuthN, Dial uses global setting", warmupOpts: []DialOption{WithDialIAMAuthN(true)}, dialOpts: []DialOption{}, expectedCalls: []*mock.Request{ - mock.InstanceGetSuccess(inst, 2), - mock.CreateEphemeralSuccess(inst, 2), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), }, }, } @@ -451,6 +483,13 @@ func TestWarmup(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { d := setupDialer(t, setupConfig{ + dialerOptions: []Option{ + WithIAMAuthN(), + WithIAMAuthNTokenSources( + mock.EmptyTokenSource{}, + mock.EmptyTokenSource{}, + ), + }, testInstance: inst, reqs: test.expectedCalls, }) diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index 78620dec..bebce43a 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -393,8 +393,13 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation { nil, ) } else { + var useIAMAuthN bool + i.mu.Lock() + useIAMAuthN = i.useIAMAuthNDial + i.mu.Unlock() r.result, r.err = i.r.ConnectionInfo( - ctx, i.connName, i.key, i.useIAMAuthNDial) + ctx, i.connName, i.key, useIAMAuthN, + ) } switch r.err { case nil: diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index c2056561..f13874f2 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -132,12 +132,23 @@ func fetchMetadata( return m, nil } +var expired = time.Time{}.Add(1) + +// canRefresh determines if the provided token was refreshed or if it still has +// the sentinel expiration, which means the token was provided without a +// refresh token (as with the Cloud SQL Proxy's --token flag) and therefore +// cannot be refreshed. +func canRefresh(t *oauth2.Token) bool { + return t.Expiry.Unix() != expired.Unix() +} + +// refreshToken will retrieve a new token, only if a refresh token is present. func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) { expiredToken := &oauth2.Token{ AccessToken: tok.AccessToken, TokenType: tok.TokenType, RefreshToken: tok.RefreshToken, - Expiry: time.Time{}.Add(1), // Expired + Expiry: expired, } return oauth2.ReuseTokenSource(expiredToken, ts).Token() } @@ -217,9 +228,9 @@ func fetchEphemeralCert( ) } if ts != nil { - // Adjust the certificate's expiration to be the earliest of the token's - // expiration or the certificate's expiration. - if tok.Expiry.Before(clientCert.NotAfter) { + // Adjust the certificate's expiration to be the earliest of + // the token's expiration or the certificate's expiration. + if canRefresh(tok) && tok.Expiry.Before(clientCert.NotAfter) { clientCert.NotAfter = tok.Expiry } } diff --git a/internal/cloudsql/refresh_test.go b/internal/cloudsql/refresh_test.go index 7959bef4..2daadf5a 100644 --- a/internal/cloudsql/refresh_test.go +++ b/internal/cloudsql/refresh_test.go @@ -96,6 +96,41 @@ func TestRefresh(t *testing.T) { t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiration) } } + +// If a caller has provided a static token source that cannot be refreshed +// (e.g., when the Cloud SQL Proxy is invokved with --token), then the +// refresher cannot determine the token's expiration (without additional API +// calls), and so the refresher should use the certificate's expiration instead +// of the token's expiration which is otherwise unset. +func TestRefreshWithStaticTokenSource(t *testing.T) { + cn := testInstanceConnName() + inst := mock.NewFakeCSQLInstance( + cn.Project(), cn.Region(), cn.Name(), + ) + client, cleanup, err := mock.NewSQLAdminService( + context.Background(), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + if err != nil { + t.Fatalf("failed to create test SQL admin service: %s", err) + } + t.Cleanup(func() { _ = cleanup() }) + + ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "myaccestoken"}) + r := newRefresher(nullLogger{}, client, ts, testDialerID) + ci, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true) + if err != nil { + t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err) + } + if !ci.Expiration.After(time.Now()) { + t.Fatalf( + "Connection info expiration should be in the future, got = %v", + ci.Expiration, + ) + } +} + func TestRefreshRetries50xResponses(t *testing.T) { cn := testInstanceConnName() inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),