diff --git a/dialer.go b/dialer.go index 269c9a15..e20c0641 100644 --- a/dialer.go +++ b/dialer.go @@ -24,11 +24,15 @@ import ( "fmt" "io" "net" + "os" "strings" "sync" "sync/atomic" "time" + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/credentials" + "cloud.google.com/go/auth/httptransport" "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" @@ -36,8 +40,6 @@ import ( "cloud.google.com/go/cloudsqlconn/internal/trace" "github.com/google/uuid" "golang.org/x/net/proxy" - "golang.org/x/oauth2" - "golang.org/x/oauth2/google" "google.golang.org/api/option" sqladmin "google.golang.org/api/sqladmin/v1beta4" ) @@ -50,6 +52,12 @@ const ( // iamLoginScope is the OAuth2 scope used for tokens embedded in the ephemeral // certificate. iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login" + // universeDomainEnvVar is the environment variable for setting the default + // service domain for a given Cloud universe. + universeDomainEnvVar = "GOOGLE_CLOUD_UNIVERSE_DOMAIN" + // defaultUniverseDomain is the default value for universe domain. + // Universe domain is the default service domain for a given Cloud universe. + defaultUniverseDomain = "googleapis.com" ) var ( @@ -117,6 +125,25 @@ type cacheKey struct { name string } +// getClientUniverseDomain returns the default service domain for a given Cloud +// universe, with the following precedence: +// +// 1. A non-empty option.WithUniverseDomain or similar client option. +// 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN. +// 3. The default value "googleapis.com". +// +// This is the universe domain configured for the client, which will be compared +// to the universe domain that is separately configured for the credentials. +func (c *dialerConfig) getClientUniverseDomain() string { + if c.clientUniverseDomain != "" { + return c.clientUniverseDomain + } + if envUD := os.Getenv(universeDomainEnvVar); envUD != "" { + return envUD + } + return defaultUniverseDomain +} + // A Dialer is used to create connections to Cloud SQL instances. // // Use NewDialer to initialize a Dialer. @@ -150,8 +177,8 @@ type Dialer struct { // network. By default, it is golang.org/x/net/proxy#Dial. dialFunc func(cxt context.Context, network, addr string) (net.Conn, error) - // iamTokenSource supplies the OAuth2 token used for IAM DB Authn. - iamTokenSource oauth2.TokenSource + // iamTokenProvider supplies the OAuth2 token used for IAM DB Authn. + iamTokenProvider auth.TokenProvider // resolver converts instance names into DNS names. resolver instance.ConnectionNameResolver @@ -174,12 +201,11 @@ func (nullLogger) Debugf(_ context.Context, _ string, _ ...interface{}) {} // RSA keypair is generated will be faster. func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { cfg := &dialerConfig{ - refreshTimeout: cloudsql.RefreshTimeout, - dialFunc: proxy.Dial, - logger: nullLogger{}, - useragents: []string{userAgent}, - serviceUniverse: "googleapis.com", - failoverPeriod: cloudsql.FailoverPeriod, + refreshTimeout: cloudsql.RefreshTimeout, + dialFunc: proxy.Dial, + logger: nullLogger{}, + useragents: []string{userAgent}, + failoverPeriod: cloudsql.FailoverPeriod, } for _, opt := range opts { opt(cfg) @@ -197,40 +223,41 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { // Add this to the end to make sure it's not overridden cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " "))) - // If callers have not provided a token source, either explicitly with - // WithTokenSource or implicitly with WithCredentialsJSON etc., then use the - // default token source. + // If callers have not provided a credential source, either explicitly with + // WithTokenSource or implicitly with WithCredentialsJSON etc., then use + // default credentials if !cfg.setCredentials { - c, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope) + c, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{sqladmin.SqlserviceAdminScope}, + }) if err != nil { return nil, fmt.Errorf("failed to create default credentials: %v", err) } - ud, err := c.GetUniverseDomain() - if err != nil { - return nil, fmt.Errorf("failed to get universe domain: %v", err) - } - cfg.credentialsUniverse = ud - cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource)) - scoped, err := google.DefaultTokenSource(ctx, iamLoginScope) + cfg.authCredentials = c + // create second set of credentials, scoped for IAM AuthN login only + scoped, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{iamLoginScope}, + }) if err != nil { - return nil, fmt.Errorf("failed to create scoped token source: %v", err) + return nil, fmt.Errorf("failed to create scoped credentials: %v", err) } - cfg.iamLoginTokenSource = scoped - } - - if cfg.setUniverseDomain && cfg.setAdminAPIEndpoint { - return nil, errors.New( - "can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " + - "use WithAdminAPIEndpoint (it already contains the universe domain)", - ) + cfg.iamLoginTokenProvider = scoped.TokenProvider } - if cfg.credentialsUniverse != "" && cfg.serviceUniverse != "" { - if cfg.credentialsUniverse != cfg.serviceUniverse { - return nil, fmt.Errorf( - "the configured service universe domain (%s) does not match the credential universe domain (%s)", - cfg.serviceUniverse, cfg.credentialsUniverse, - ) + // For all credential paths, use auth library's built-in + // httptransport.NewClient + if cfg.authCredentials != nil { + authClient, err := httptransport.NewClient(&httptransport.Options{ + Credentials: cfg.authCredentials, + UniverseDomain: cfg.getClientUniverseDomain(), + }) + if err != nil { + return nil, fmt.Errorf("failed to create auth client: %v", err) + } + // If callers have not provided an HTTPClient explicitly with + // WithHTTPClient, then use auth client + if !cfg.setHTTPClient { + cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithHTTPClient(authClient)) } } @@ -273,7 +300,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) { logger: cfg.logger, defaultDialConfig: dc, dialerID: uuid.New().String(), - iamTokenSource: cfg.iamLoginTokenSource, + iamTokenProvider: cfg.iamLoginTokenProvider, dialFunc: cfg.dialFunc, resolver: r, failoverPeriod: cfg.failoverPeriod, @@ -636,7 +663,7 @@ func (d *Dialer) connectionInfoCache( cn, d.logger, d.sqladmin, rsaKey, - d.refreshTimeout, d.iamTokenSource, + d.refreshTimeout, d.iamTokenProvider, d.dialerID, useIAMAuthNDial, ) } else { @@ -644,7 +671,7 @@ func (d *Dialer) connectionInfoCache( cn, d.logger, d.sqladmin, rsaKey, - d.refreshTimeout, d.iamTokenSource, + d.refreshTimeout, d.iamTokenProvider, d.dialerID, useIAMAuthNDial, ) } diff --git a/dialer_test.go b/dialer_test.go index 9da624d8..4143bac9 100644 --- a/dialer_test.go +++ b/dialer_test.go @@ -280,72 +280,6 @@ func TestSQLServerFailsOnIAMAuthN(t *testing.T) { } } -func TestUniverseDomain(t *testing.T) { - tcs := []struct { - desc string - opts Option - }{ - { - desc: "When universe domain matches GDU", - opts: WithOptions( - WithUniverseDomain("googleapis.com"), - WithCredentialsJSON(fakeServiceAccount("")), - ), - }, - { - desc: "When TPC universe matches TPC credential domain", - opts: WithOptions( - WithUniverseDomain("test-universe.test"), - WithCredentialsJSON(fakeServiceAccount("test-universe.test")), - ), - }, - } - - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - _, err := NewDialer(context.Background(), tc.opts) - if err != nil { - t.Fatalf("NewDialer failed with error = %v", err) - } - }) - } -} - -func TestUniverseDomainErrors(t *testing.T) { - tcs := []struct { - desc string - opts Option - }{ - { - desc: "When universe domain does not match ADC credentials from GDU", - opts: WithOptions(WithUniverseDomain("test-universe.test")), - }, - { - desc: "When GDU does not match credential domain", - opts: WithOptions(WithCredentialsJSON( - fakeServiceAccount("test-universe.test"), - )), - }, - { - desc: "WithUniverseDomain used alongside WithAdminAPIEndpoint", - opts: WithOptions( - WithUniverseDomain("googleapis.com"), - WithAdminAPIEndpoint("https://sqladmin.googleapis.com"), - ), - }, - } - - for _, tc := range tcs { - t.Run(tc.desc, func(t *testing.T) { - _, err := NewDialer(context.Background(), tc.opts) - t.Log(err) - if err == nil { - t.Fatalf("Wanted universe domain mismatch, want error, got nil") - } - }) - } -} - func TestDialerWithCustomDialFunc(t *testing.T) { inst := mock.NewFakeCSQLInstance("proj", "region", "inst", mock.WithEngineVersion("SQLSERVER"), diff --git a/go.mod b/go.mod index 22fefcbc..17fde12b 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,8 @@ module cloud.google.com/go/cloudsqlconn go 1.22 require ( + cloud.google.com/go/auth v0.13.0 + cloud.google.com/go/auth/oauth2adapt v0.2.6 github.com/go-sql-driver/mysql v1.8.1 github.com/google/uuid v1.6.0 github.com/jackc/pgx/v4 v4.18.3 @@ -18,8 +20,6 @@ require ( ) require ( - cloud.google.com/go/auth v0.13.0 // indirect - cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect cloud.google.com/go/compute/metadata v0.6.0 // indirect filippo.io/edwards25519 v1.1.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect diff --git a/internal/cloudsql/instance.go b/internal/cloudsql/instance.go index 6a34570e..dea9a335 100644 --- a/internal/cloudsql/instance.go +++ b/internal/cloudsql/instance.go @@ -23,10 +23,10 @@ import ( "sync" "time" + "cloud.google.com/go/auth" "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" - "golang.org/x/oauth2" "golang.org/x/time/rate" sqladmin "google.golang.org/api/sqladmin/v1beta4" ) @@ -129,7 +129,7 @@ func NewRefreshAheadCache( client *sqladmin.Service, key *rsa.PrivateKey, refreshTimeout time.Duration, - ts oauth2.TokenSource, + tp auth.TokenProvider, dialerID string, useIAMAuthNDial bool, ) *RefreshAheadCache { @@ -142,7 +142,7 @@ func NewRefreshAheadCache( l, client, key, - ts, + tp, dialerID, ), refreshTimeout: refreshTimeout, diff --git a/internal/cloudsql/lazy.go b/internal/cloudsql/lazy.go index 5b65b3b9..d4feed77 100644 --- a/internal/cloudsql/lazy.go +++ b/internal/cloudsql/lazy.go @@ -20,9 +20,9 @@ import ( "sync" "time" + "cloud.google.com/go/auth" "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/instance" - "golang.org/x/oauth2" sqladmin "google.golang.org/api/sqladmin/v1beta4" ) @@ -45,7 +45,7 @@ func NewLazyRefreshCache( client *sqladmin.Service, key *rsa.PrivateKey, _ time.Duration, - ts oauth2.TokenSource, + tp auth.TokenProvider, dialerID string, useIAMAuthNDial bool, ) *LazyRefreshCache { @@ -56,7 +56,7 @@ func NewLazyRefreshCache( l, client, key, - ts, + tp, dialerID, ), useIAMAuthNDial: useIAMAuthNDial, diff --git a/internal/cloudsql/lazy_test.go b/internal/cloudsql/lazy_test.go index 269e90c9..6a2baec0 100644 --- a/internal/cloudsql/lazy_test.go +++ b/internal/cloudsql/lazy_test.go @@ -20,9 +20,9 @@ import ( "testing" "time" + "cloud.google.com/go/auth" "cloud.google.com/go/cloudsqlconn/instance" "cloud.google.com/go/cloudsqlconn/internal/mock" - "golang.org/x/oauth2" ) func TestLazyRefreshCacheConnectionInfo(t *testing.T) { @@ -95,20 +95,20 @@ func TestLazyRefreshCacheForceRefresh(t *testing.T) { } } -// spyTokenSource is a non-threadsafe spy for tracking token source usage -type spyTokenSource struct { +// spyTokenProvider is a non-threadsafe spy for tracking token provider usage +type spyTokenProvider struct { mu sync.Mutex count int } -func (s *spyTokenSource) Token() (*oauth2.Token, error) { +func (s *spyTokenProvider) Token(context.Context) (*auth.Token, error) { s.mu.Lock() defer s.mu.Unlock() s.count++ - return &oauth2.Token{}, nil + return &auth.Token{}, nil } -func (s *spyTokenSource) callCount() int { +func (s *spyTokenProvider) callCount() int { s.mu.Lock() defer s.mu.Unlock() return s.count @@ -131,7 +131,7 @@ func TestLazyRefreshCacheUpdateRefresh(t *testing.T) { } }() - spy := &spyTokenSource{} + spy := &spyTokenProvider{} c := NewLazyRefreshCache( testInstanceConnName(), nullLogger{}, client, RSAKey, 30*time.Second, spy, "", false, // disable IAM AuthN at first @@ -143,7 +143,7 @@ func TestLazyRefreshCacheUpdateRefresh(t *testing.T) { } if got := spy.callCount(); got != 0 { - t.Fatal("oauth2.TokenSource was called, but should not have been") + t.Fatal("auth.TokenProvider was called, but should not have been") } c.UpdateRefresh(ptr(true)) @@ -153,12 +153,12 @@ func TestLazyRefreshCacheUpdateRefresh(t *testing.T) { t.Fatal(err) } - // Q: Why should the token source be called twice? + // Q: Why should the token provider be called twice? // A: Because the refresh code retrieves a token first (1 call) and then // refreshes it (1 call) for a total of 2 calls. if got, want := spy.callCount(), 2; got != want { t.Fatalf( - "oauth2.TokenSource call count, got = %v, want = %v", + "auth.TokenProvider call count, got = %v, want = %v", got, want, ) } diff --git a/internal/cloudsql/refresh.go b/internal/cloudsql/refresh.go index 152d99a4..9b415e2c 100644 --- a/internal/cloudsql/refresh.go +++ b/internal/cloudsql/refresh.go @@ -22,13 +22,12 @@ import ( "encoding/pem" "fmt" "strings" - "time" + "cloud.google.com/go/auth" "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" "cloud.google.com/go/cloudsqlconn/internal/trace" - "golang.org/x/oauth2" sqladmin "google.golang.org/api/sqladmin/v1beta4" ) @@ -140,27 +139,6 @@ func fetchMetadata( return m, nil } -var expired = time.Time{}.Add(1) - -// canRefresh determines if the provided token was refreshed or if it still has -// the sentinel expiration, which means the token was provided without a -// refresh token (as with the Cloud SQL Proxy's --token flag) and therefore -// cannot be refreshed. -func canRefresh(t *oauth2.Token) bool { - return t.Expiry.Unix() != expired.Unix() -} - -// refreshToken will retrieve a new token, only if a refresh token is present. -func refreshToken(ts oauth2.TokenSource, tok *oauth2.Token) (*oauth2.Token, error) { - expiredToken := &oauth2.Token{ - AccessToken: tok.AccessToken, - TokenType: tok.TokenType, - RefreshToken: tok.RefreshToken, - Expiry: expired, - } - return oauth2.ReuseTokenSource(expiredToken, ts).Token() -} - // fetchEphemeralCert uses the Cloud SQL Admin API's createEphemeral method to // create a signed TLS certificate that authorized to connect via the Cloud SQL // instance's serverside proxy. The cert if valid for approximately one hour. @@ -169,7 +147,7 @@ func fetchEphemeralCert( client *sqladmin.Service, inst instance.ConnName, key *rsa.PrivateKey, - ts oauth2.TokenSource, + tp auth.TokenProvider, ) (c tls.Certificate, err error) { var end trace.EndSpanFunc ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.FetchEphemeralCert") @@ -182,10 +160,10 @@ func fetchEphemeralCert( req := sqladmin.GenerateEphemeralCertRequest{ PublicKey: string(pem.EncodeToMemory(&pem.Block{Bytes: clientPubKey, Type: "RSA PUBLIC KEY"})), } - var tok *oauth2.Token - if ts != nil { + var tok *auth.Token + if tp != nil { var tokErr error - tok, tokErr = ts.Token() + tok, tokErr = tp.Token(ctx) if tokErr != nil { return tls.Certificate{}, errtype.NewRefreshError( "failed to retrieve Oauth2 token", @@ -193,17 +171,15 @@ func fetchEphemeralCert( tokErr, ) } - // Always refresh the token to ensure its expiration is far enough in - // the future. - tok, tokErr = refreshToken(ts, tok) + tok, tokErr = tp.Token(ctx) if tokErr != nil { return tls.Certificate{}, errtype.NewRefreshError( - "failed to refresh Oauth2 token", + "failed to get Oauth2 token", inst.String(), tokErr, ) } - req.AccessToken = tok.AccessToken + req.AccessToken = tok.Value } resp, err := retry50x(ctx, func(ctx2 context.Context) (*sqladmin.GenerateEphemeralCertResponse, error) { return client.Connect.GenerateEphemeralCert( @@ -235,10 +211,10 @@ func fetchEphemeralCert( nil, ) } - if ts != nil { + if tp != nil { // Adjust the certificate's expiration to be the earliest of // the token's expiration or the certificate's expiration. - if canRefresh(tok) && tok.Expiry.Before(clientCert.NotAfter) { + if tok.Expiry.Before(clientCert.NotAfter) { clientCert.NotAfter = tok.Expiry } } @@ -256,7 +232,7 @@ func newAdminAPIClient( l debug.ContextLogger, svc *sqladmin.Service, key *rsa.PrivateKey, - ts oauth2.TokenSource, + tp auth.TokenProvider, dialerID string, ) adminAPIClient { return adminAPIClient{ @@ -264,7 +240,7 @@ func newAdminAPIClient( logger: l, key: key, client: svc, - ts: ts, + tp: tp, } } @@ -277,8 +253,8 @@ type adminAPIClient struct { // key is used to generate the client certificate key *rsa.PrivateKey client *sqladmin.Service - // ts is the TokenSource used for IAM DB AuthN. - ts oauth2.TokenSource + // tp is the TokenProvider used for IAM DB AuthN. + tp auth.TokenProvider } // ConnectionInfo immediately performs a full refresh operation using the Cloud @@ -316,11 +292,11 @@ func (c adminAPIClient) ConnectionInfo( ecC := make(chan ecRes, 1) go func() { defer close(ecC) - var iamTS oauth2.TokenSource + var iamTP auth.TokenProvider if iamAuthNDial { - iamTS = c.ts + iamTP = c.tp } - ec, err := fetchEphemeralCert(ctx, c.client, cn, c.key, iamTS) + ec, err := fetchEphemeralCert(ctx, c.client, cn, c.key, iamTP) ecC <- ecRes{ec, err} }() diff --git a/internal/cloudsql/refresh_test.go b/internal/cloudsql/refresh_test.go index 7d5e75b4..6d8f397e 100644 --- a/internal/cloudsql/refresh_test.go +++ b/internal/cloudsql/refresh_test.go @@ -25,9 +25,9 @@ import ( "testing" "time" + "cloud.google.com/go/auth" "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/internal/mock" - "golang.org/x/oauth2" ) const testDialerID = "some-dialer-id" @@ -140,40 +140,6 @@ func TestRefreshForCASInstances(t *testing.T) { } } -// If a caller has provided a static token source that cannot be refreshed -// (e.g., when the Cloud SQL Proxy is invokved with --token), then the -// refresher cannot determine the token's expiration (without additional API -// calls), and so the refresher should use the certificate's expiration instead -// of the token's expiration which is otherwise unset. -func TestRefreshWithStaticTokenSource(t *testing.T) { - cn := testInstanceConnName() - inst := mock.NewFakeCSQLInstance( - cn.Project(), cn.Region(), cn.Name(), - ) - client, cleanup, err := mock.NewSQLAdminService( - context.Background(), - mock.InstanceGetSuccess(inst, 1), - mock.CreateEphemeralSuccess(inst, 1), - ) - if err != nil { - t.Fatalf("failed to create test SQL admin service: %s", err) - } - t.Cleanup(func() { _ = cleanup() }) - - ts := oauth2.StaticTokenSource(&oauth2.Token{AccessToken: "myaccestoken"}) - r := newAdminAPIClient(nullLogger{}, client, RSAKey, ts, testDialerID) - ci, err := r.ConnectionInfo(context.Background(), cn, true) - if err != nil { - t.Fatalf("PerformRefresh unexpectedly failed with error: %v", err) - } - if !ci.Expiration.After(time.Now()) { - t.Fatalf( - "Connection info expiration should be in the future, got = %v", - ci.Expiration, - ) - } -} - func TestRefreshRetries50xResponses(t *testing.T) { cn := testInstanceConnName() inst := mock.NewFakeCSQLInstance(cn.Project(), cn.Region(), cn.Name(), @@ -238,17 +204,17 @@ func TestRefreshFailsFast(t *testing.T) { } type tokenResp struct { - tok *oauth2.Token + tok *auth.Token err error } -type fakeTokenSource struct { +type fakeTokenProvider struct { responses []tokenResp mu sync.Mutex ct int } -func (f *fakeTokenSource) Token() (*oauth2.Token, error) { +func (f *fakeTokenProvider) Token(context.Context) (*auth.Token, error) { f.mu.Lock() defer f.mu.Unlock() resp := f.responses[f.ct] @@ -256,7 +222,7 @@ func (f *fakeTokenSource) Token() (*oauth2.Token, error) { return resp.tok, resp.err } -func (f *fakeTokenSource) count() int { +func (f *fakeTokenProvider) count() int { f.mu.Lock() defer f.mu.Unlock() return f.ct @@ -274,16 +240,16 @@ func TestRefreshAdjustsCertExpiry(t *testing.T) { { desc: "when the token's expiration comes BEFORE the cert", resps: []tokenResp{ - {tok: &oauth2.Token{}}, - {tok: &oauth2.Token{Expiry: t1}}, + {tok: &auth.Token{}}, + {tok: &auth.Token{Expiry: t1}}, }, wantExpiry: t1, }, { desc: "when the token's expiration comes AFTER the cert", resps: []tokenResp{ - {tok: &oauth2.Token{}}, - {tok: &oauth2.Token{Expiry: t2}}, + {tok: &auth.Token{}}, + {tok: &auth.Token{Expiry: t2}}, }, wantExpiry: certExpiry, }, @@ -303,8 +269,8 @@ func TestRefreshAdjustsCertExpiry(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - ts := &fakeTokenSource{responses: tc.resps} - r := newAdminAPIClient(nullLogger{}, client, RSAKey, ts, testDialerID) + tp := &fakeTokenProvider{responses: tc.resps} + r := newAdminAPIClient(nullLogger{}, client, RSAKey, tp, testDialerID) rr, err := r.ConnectionInfo(context.Background(), cn, true) if err != nil { t.Fatalf("want no error, got = %v", err) @@ -330,7 +296,7 @@ func TestRefreshWithIAMAuthErrors(t *testing.T) { { desc: "when refreshing a token fails", resps: []tokenResp{ - {tok: &oauth2.Token{}, err: nil}, + {tok: &auth.Token{}, err: nil}, {tok: nil, err: errors.New("refresh failed")}, }, wantCount: 2, @@ -349,13 +315,13 @@ func TestRefreshWithIAMAuthErrors(t *testing.T) { for _, tc := range tcs { t.Run(tc.desc, func(t *testing.T) { - ts := &fakeTokenSource{responses: tc.resps} - r := newAdminAPIClient(nullLogger{}, client, RSAKey, ts, testDialerID) + tp := &fakeTokenProvider{responses: tc.resps} + r := newAdminAPIClient(nullLogger{}, client, RSAKey, tp, testDialerID) _, err := r.ConnectionInfo(context.Background(), cn, true) if err == nil { t.Fatalf("expected get failed error, got = %v", err) } - if count := ts.count(); count != tc.wantCount { + if count := tp.count(); count != tc.wantCount { t.Fatalf("expected fake token source to be called %v time, got = %v", tc.wantCount, count) } }) diff --git a/options.go b/options.go index a719eca9..1a89303b 100644 --- a/options.go +++ b/options.go @@ -22,12 +22,14 @@ import ( "os" "time" + "cloud.google.com/go/auth" + "cloud.google.com/go/auth/credentials" + "cloud.google.com/go/auth/oauth2adapt" "cloud.google.com/go/cloudsqlconn/debug" "cloud.google.com/go/cloudsqlconn/errtype" "cloud.google.com/go/cloudsqlconn/instance" "cloud.google.com/go/cloudsqlconn/internal/cloudsql" "golang.org/x/oauth2" - "golang.org/x/oauth2/google" apiopt "google.golang.org/api/option" sqladmin "google.golang.org/api/sqladmin/v1beta4" ) @@ -44,13 +46,13 @@ type dialerConfig struct { useIAMAuthN bool logger debug.ContextLogger lazyRefresh bool - iamLoginTokenSource oauth2.TokenSource + clientUniverseDomain string + authCredentials *auth.Credentials + iamLoginTokenProvider auth.TokenProvider useragents []string - credentialsUniverse string - serviceUniverse string setAdminAPIEndpoint bool - setUniverseDomain bool setCredentials bool + setHTTPClient bool setTokenSource bool setIAMAuthNTokenSource bool resolver instance.ConnectionNameResolver @@ -87,26 +89,25 @@ func WithCredentialsFile(filename string) Option { // or refresh token JSON credentials to be used as the basis for authentication. func WithCredentialsJSON(b []byte) Option { return func(d *dialerConfig) { - c, err := google.CredentialsFromJSON(context.Background(), b, sqladmin.SqlserviceAdminScope) + c, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{sqladmin.SqlserviceAdminScope}, + CredentialsJSON: b, + }) if err != nil { d.err = errtype.NewConfigError(err.Error(), "n/a") return } - ud, err := c.GetUniverseDomain() - if err != nil { - d.err = errtype.NewConfigError(err.Error(), "n/a") - return - } - d.credentialsUniverse = ud - d.sqladminOpts = append(d.sqladminOpts, apiopt.WithCredentials(c)) - + d.authCredentials = c // Create another set of credentials scoped to login only - scoped, err := google.CredentialsFromJSON(context.Background(), b, iamLoginScope) + scoped, err := credentials.DetectDefault(&credentials.DetectOptions{ + Scopes: []string{iamLoginScope}, + CredentialsJSON: b, + }) if err != nil { d.err = errtype.NewConfigError(err.Error(), "n/a") return } - d.iamLoginTokenSource = scoped.TokenSource + d.iamLoginTokenProvider = scoped.TokenProvider d.setCredentials = true } } @@ -158,7 +159,7 @@ func WithIAMAuthNTokenSources(apiTS, iamLoginTS oauth2.TokenSource) Option { return func(d *dialerConfig) { d.setIAMAuthNTokenSource = true d.setCredentials = true - d.iamLoginTokenSource = iamLoginTS + d.iamLoginTokenProvider = oauth2adapt.TokenProviderFromTokenSource(iamLoginTS) d.sqladminOpts = append(d.sqladminOpts, apiopt.WithTokenSource(apiTS)) } } @@ -184,6 +185,7 @@ func WithRefreshTimeout(t time.Duration) Option { func WithHTTPClient(client *http.Client) Option { return func(d *dialerConfig) { d.sqladminOpts = append(d.sqladminOpts, apiopt.WithHTTPClient(client)) + d.setHTTPClient = true } } @@ -193,7 +195,6 @@ func WithAdminAPIEndpoint(url string) Option { return func(d *dialerConfig) { d.sqladminOpts = append(d.sqladminOpts, apiopt.WithEndpoint(url)) d.setAdminAPIEndpoint = true - d.serviceUniverse = "" } } @@ -202,8 +203,7 @@ func WithAdminAPIEndpoint(url string) Option { func WithUniverseDomain(ud string) Option { return func(d *dialerConfig) { d.sqladminOpts = append(d.sqladminOpts, apiopt.WithUniverseDomain(ud)) - d.serviceUniverse = ud - d.setUniverseDomain = true + d.clientUniverseDomain = ud } }