diff --git a/dialer.go b/dialer.go index 7f0bbbd1..c93dbade 100644 --- a/dialer.go +++ b/dialer.go @@ -389,7 +389,7 @@ func (d *Dialer) EngineVersion(ctx context.Context, icn string) (string, error) if err != nil { return "", err } - i := d.connectionInfoCache(ctx, cn, nil) + i := d.connectionInfoCache(ctx, cn, &d.defaultDialConfig.useIAMAuthN) ci, err := i.ConnectionInfo(ctx) if err != nil { return "", err diff --git a/dialer_test.go b/dialer_test.go index 9d508025..051e9b92 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -396,6 +396,38 @@ func TestDialerEngineVersion(t *testing.T) { } } +// When Auto IAM AuthN is enabled, EngineVersion should warm the cache with a +// client certificate with Auto IAM AuthN enabled. +func TestEngineVersionAvoidsDuplicateRefreshWithIAMAuthN(t *testing.T) { + inst := mock.NewFakeCSQLInstance( + "my-project", "my-region", "my-instance", + ) + d := setupDialer(t, setupConfig{ + testInstance: inst, + dialerOptions: []Option{ + WithIAMAuthN(), WithIAMAuthNTokenSources( + mock.EmptyTokenSource{}, + mock.EmptyTokenSource{}, + ), + }, + reqs: []*mock.Request{ + // There should only be two API requests + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + }, + }) + + _, err := d.EngineVersion(context.Background(), inst.String()) + if err != nil { + t.Fatal(err) + } + + testSuccessfulDial( + context.Background(), t, d, + inst.String(), + ) +} + func TestDialerUserAgent(t *testing.T) { data, err := os.ReadFile("version.txt") if err != nil { @@ -420,7 +452,7 @@ func TestWarmup(t *testing.T) { expectedCalls []*mock.Request }{ { - desc: "warmup and dial are the same", + desc: "Warmup and Dial both use IAM AuthN", warmupOpts: []DialOption{WithDialIAMAuthN(true)}, dialOpts: []DialOption{WithDialIAMAuthN(true)}, expectedCalls: []*mock.Request{ @@ -429,7 +461,7 @@ func TestWarmup(t *testing.T) { }, }, { - desc: "warmup and dial are different", + desc: "Warmup uses IAM Authn, Dial does not", warmupOpts: []DialOption{WithDialIAMAuthN(true)}, dialOpts: []DialOption{WithDialIAMAuthN(false)}, expectedCalls: []*mock.Request{ @@ -438,12 +470,12 @@ func TestWarmup(t *testing.T) { }, }, { - desc: "warmup and default dial are different", + desc: "Warmup uses IAM AuthN, Dial uses global setting", warmupOpts: []DialOption{WithDialIAMAuthN(true)}, dialOpts: []DialOption{}, expectedCalls: []*mock.Request{ - mock.InstanceGetSuccess(inst, 2), - mock.CreateEphemeralSuccess(inst, 2), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), }, }, } @@ -451,6 +483,13 @@ func TestWarmup(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { d := setupDialer(t, setupConfig{ + dialerOptions: []Option{ + WithIAMAuthN(), + WithIAMAuthNTokenSources( + mock.EmptyTokenSource{}, + mock.EmptyTokenSource{}, + ), + }, testInstance: inst, reqs: test.expectedCalls, }) diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index 78620dec..bebce43a 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -393,8 +393,13 @@ func (i *RefreshAheadCache) scheduleRefresh(d time.Duration) *refreshOperation { nil, ) } else { + var useIAMAuthN bool + i.mu.Lock() + useIAMAuthN = i.useIAMAuthNDial + i.mu.Unlock() r.result, r.err = i.r.ConnectionInfo( - ctx, i.connName, i.key, i.useIAMAuthNDial) + ctx, i.connName, i.key, useIAMAuthN, + ) } switch r.err { case nil: diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index c2056561..f13874f2 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -132,12 +132,23 @@ func fetchMetadata( return m, nil } +var expired = time.Time{}.Add(1) + +// canRefresh determines if the provided token was refreshed or if it still has +// the sentinel expiration, which means the token was provided without a +// refresh token (as with the Cloud SQL Proxy's --token flag) and therefore +// cannot be refreshed. +func canRefresh(t *oauth2.Token) bool { + return t.Expiry.Unix() != expired.Unix() +} + +// refreshToken will retrieve a new token, only if a refresh token is present. func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) { expiredToken := &oauth2.Token{ AccessToken: tok.AccessToken, TokenType: tok.TokenType, RefreshToken: tok.RefreshToken, - Expiry: time.Time{}.Add(1), // Expired + Expiry: expired, } return oauth2.ReuseTokenSource(expiredToken, ts).Token() } @@ -217,9 +228,9 @@ func fetchEphemeralCert( ) } if ts != nil { - // Adjust the certificate's expiration to be the earliest of the token's - // expiration or the certificate's expiration. - if tok.Expiry.Before(clientCert.NotAfter) { + // Adjust the certificate's expiration to be the earliest of + // the token's expiration or the certificate's expiration. + if canRefresh(tok) && tok.Expiry.Before(clientCert.NotAfter) { clientCert.NotAfter = tok.Expiry } } diff --git a/internal/cloudsql/refresh_test.go b/internal/cloudsql/refresh_test.go index 7959bef4..2daadf5a 100644 --- a/internal/cloudsql/refresh_test.go +++ b/internal/cloudsql/refresh_test.go @@ -96,6 +96,41 @@ func TestRefresh(t *testing.T) { t.Fatalf("expiry mismatch, want = %v, got = %v", wantExpiry, rr.Expiration) } } + +// If a caller has provided a static token source that cannot be refreshed +// (e.g., when the Cloud SQL Proxy is invokved with --token), then the +// refresher cannot determine the token's expiration (without additional API +// calls), and so the refresher should use the certificate's expiration instead +// of the token's expiration which is otherwise unset. +func TestRefreshWithStaticTokenSource(t *testing.T) { + cn := testInstanceConnName() + inst := mock.NewFakeCSQLInstance( + cn.Project(), cn.Region(), cn.Name(), + ) + client, cleanup, err := mock.NewSQLAdminService( + context.Background(), + mock.InstanceGetSuccess(inst, 1), + mock.CreateEphemeralSuccess(inst, 1), + ) + if err != nil { + t.Fatalf("failed to create test SQL admin service: %s", err) + } + t.Cleanup(func() { _ = cleanup() }) + + ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "myaccestoken"}) + r := newRefresher(nullLogger{}, client, ts, testDialerID) + ci, err := r.ConnectionInfo(context.Background(), cn, RSAKey, true) + if err != nil { + t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err) + } + if !ci.Expiration.After(time.Now()) { + t.Fatalf( + "Connection info expiration should be in the future, got = %v", + ci.Expiration, + ) + } +} + func TestRefreshRetries50xResponses(t *testing.T) { cn := testInstanceConnName() inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(),