Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use auth DetectDefault over oauth2 FindDefaultCredentials #909

Merged
merged 12 commits into from
Jan 14, 2025
Merged
105 changes: 66 additions & 39 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@ import (
"fmt"
"io"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/httptransport"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
"cloud.google.com/go/cloudsqlconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand All @@ -50,6 +52,12 @@ const (
// iamLoginScope is the OAuth2 scope used for tokens embedded in the ephemeral
// certificate.
iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login"
// universeDomainEnvVar is the environment variable for setting the default
// service domain for a given Cloud universe.
universeDomainEnvVar = "GOOGLE_CLOUD_UNIVERSE_DOMAIN"
// defaultUniverseDomain is the default value for universe domain.
// Universe domain is the default service domain for a given Cloud universe.
defaultUniverseDomain = "googleapis.com"
)

var (
Expand Down Expand Up @@ -117,6 +125,25 @@ type cacheKey struct {
name string
}

// getClientUniverseDomain returns the default service domain for a given Cloud
// universe, with the following precedence:
//
// 1. A non-empty option.WithUniverseDomain or similar client option.
// 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
// 3. The default value "googleapis.com".
//
// This is the universe domain configured for the client, which will be compared
// to the universe domain that is separately configured for the credentials.
func (cfg *dialerConfig) getClientUniverseDomain() string {
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
if cfg.clientUniverseDomain != "" {
return cfg.clientUniverseDomain
}
if envUD := os.Getenv(universeDomainEnvVar); envUD != "" {
return envUD
}
return defaultUniverseDomain
}

// A Dialer is used to create connections to Cloud SQL instances.
//
// Use NewDialer to initialize a Dialer.
Expand Down Expand Up @@ -150,8 +177,8 @@ type Dialer struct {
// network. By default, it is golang.org/x/net/proxy#Dial.
dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)

// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
iamTokenSource oauth2.TokenSource
// iamTokenProvider supplies the OAuth2 token used for IAM DB Authn.
iamTokenProvider auth.TokenProvider

// resolver converts instance names into DNS names.
resolver instance.ConnectionNameResolver
Expand All @@ -174,12 +201,11 @@ func (nullLogger) Debugf(_ context.Context, _ string, _ ...interface{}) {}
// RSA keypair is generated will be faster.
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: cloudsql.RefreshTimeout,
dialFunc: proxy.Dial,
logger: nullLogger{},
useragents: []string{userAgent},
serviceUniverse: "googleapis.com",
failoverPeriod: cloudsql.FailoverPeriod,
refreshTimeout: cloudsql.RefreshTimeout,
dialFunc: proxy.Dial,
logger: nullLogger{},
useragents: []string{userAgent},
failoverPeriod: cloudsql.FailoverPeriod,
}
for _, opt := range opts {
opt(cfg)
Expand All @@ -197,40 +223,41 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))

// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc., then use the
// default token source.
// If callers have not provided a credential source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc., then use
// default credentials
if !cfg.setCredentials {
c, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
c, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: []string{sqladmin.SqlserviceAdminScope},
})
if err != nil {
return nil, fmt.Errorf("failed to create default credentials: %v", err)
}
ud, err := c.GetUniverseDomain()
if err != nil {
return nil, fmt.Errorf("failed to get universe domain: %v", err)
}
cfg.credentialsUniverse = ud
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource))
scoped, err := google.DefaultTokenSource(ctx, iamLoginScope)
cfg.authCredentials = c
// create second set of credentials, scoped for IAM AuthN login only
scoped, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: []string{iamLoginScope},
})
if err != nil {
return nil, fmt.Errorf("failed to create scoped token source: %v", err)
return nil, fmt.Errorf("failed to create scoped credentials: %v", err)
}
cfg.iamLoginTokenSource = scoped
}

if cfg.setUniverseDomain && cfg.setAdminAPIEndpoint {
return nil, errors.New(
"can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " +
"use WithAdminAPIEndpoint (it already contains the universe domain)",
)
cfg.iamLoginTokenProvider = scoped.TokenProvider
}

if cfg.credentialsUniverse != "" && cfg.serviceUniverse != "" {
if cfg.credentialsUniverse != cfg.serviceUniverse {
return nil, fmt.Errorf(
"the configured service universe domain (%s) does not match the credential universe domain (%s)",
cfg.serviceUniverse, cfg.credentialsUniverse,
)
// For all credential paths, use auth library's built-in
// httptransport.NewClient
if cfg.authCredentials != nil {
authClient, err := httptransport.NewClient(&httptransport.Options{
Credentials: cfg.authCredentials,
UniverseDomain: cfg.getClientUniverseDomain(),
})
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
// If callers have not provided an HTTPClient explicitly with
// WithHTTPClient, then use auth client
if !cfg.setHTTPClient {
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithHTTPClient(authClient))
}
}

Expand Down Expand Up @@ -273,7 +300,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
logger: cfg.logger,
defaultDialConfig: dc,
dialerID: uuid.New().String(),
iamTokenSource: cfg.iamLoginTokenSource,
iamTokenProvider: cfg.iamLoginTokenProvider,
dialFunc: cfg.dialFunc,
resolver: r,
failoverPeriod: cfg.failoverPeriod,
Expand Down Expand Up @@ -636,15 +663,15 @@ func (d *Dialer) connectionInfoCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.refreshTimeout, d.iamTokenProvider,
d.dialerID, useIAMAuthNDial,
)
} else {
cache = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.refreshTimeout, d.iamTokenProvider,
d.dialerID, useIAMAuthNDial,
)
}
Expand Down
66 changes: 0 additions & 66 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,72 +280,6 @@ func TestSQLServerFailsOnIAMAuthN(t *testing.T) {
}
}

func TestUniverseDomain(t *testing.T) {
hessjcg marked this conversation as resolved.
Show resolved Hide resolved
tcs := []struct {
desc string
opts Option
}{
{
desc: "When universe domain matches GDU",
opts: WithOptions(
WithUniverseDomain("googleapis.com"),
WithCredentialsJSON(fakeServiceAccount("")),
),
},
{
desc: "When TPC universe matches TPC credential domain",
opts: WithOptions(
WithUniverseDomain("test-universe.test"),
WithCredentialsJSON(fakeServiceAccount("test-universe.test")),
),
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := NewDialer(context.Background(), tc.opts)
if err != nil {
t.Fatalf("NewDialer failed with error = %v", err)
}
})
}
}

func TestUniverseDomainErrors(t *testing.T) {
tcs := []struct {
desc string
opts Option
}{
{
desc: "When universe domain does not match ADC credentials from GDU",
opts: WithOptions(WithUniverseDomain("test-universe.test")),
},
{
desc: "When GDU does not match credential domain",
opts: WithOptions(WithCredentialsJSON(
fakeServiceAccount("test-universe.test"),
)),
},
{
desc: "WithUniverseDomain used alongside WithAdminAPIEndpoint",
opts: WithOptions(
WithUniverseDomain("googleapis.com"),
WithAdminAPIEndpoint("https://sqladmin.googleapis.com"),
),
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := NewDialer(context.Background(), tc.opts)
t.Log(err)
if err == nil {
t.Fatalf("Wanted universe domain mismatch, want error, got nil")
}
})
}
}

func TestDialerWithCustomDialFunc(t *testing.T) {
inst := mock.NewFakeCSQLInstance("proj", "region", "inst",
mock.WithEngineVersion("SQLSERVER"),
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module cloud.google.com/go/cloudsqlconn
go 1.22

require (
cloud.google.com/go/auth v0.13.0
cloud.google.com/go/auth/oauth2adapt v0.2.6
github.com/go-sql-driver/mysql v1.8.1
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v4 v4.18.3
Expand All @@ -18,8 +20,6 @@ require (
)

require (
cloud.google.com/go/auth v0.13.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
Expand Down
6 changes: 3 additions & 3 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
"sync"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"golang.org/x/oauth2"
"golang.org/x/time/rate"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -129,7 +129,7 @@ func NewRefreshAheadCache(
client *sqladmin.Service,
key *rsa.PrivateKey,
refreshTimeout time.Duration,
ts oauth2.TokenSource,
tp auth.TokenProvider,
dialerID string,
useIAMAuthNDial bool,
) *RefreshAheadCache {
Expand All @@ -142,7 +142,7 @@ func NewRefreshAheadCache(
l,
client,
key,
ts,
tp,
dialerID,
),
refreshTimeout: refreshTimeout,
Expand Down
6 changes: 3 additions & 3 deletions internal/cloudsql/lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import (
"sync"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/instance"
"golang.org/x/oauth2"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

Expand All @@ -45,7 +45,7 @@ func NewLazyRefreshCache(
client *sqladmin.Service,
key *rsa.PrivateKey,
_ time.Duration,
ts oauth2.TokenSource,
tp auth.TokenProvider,
dialerID string,
useIAMAuthNDial bool,
) *LazyRefreshCache {
Expand All @@ -56,7 +56,7 @@ func NewLazyRefreshCache(
l,
client,
key,
ts,
tp,
dialerID,
),
useIAMAuthNDial: useIAMAuthNDial,
Expand Down
Loading
Loading