Skip to content

Commit

Permalink
feat: add DialOption for IAM DB Authentication (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtisvg authored Jun 1, 2022
1 parent cb17568 commit c103acc
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 63 deletions.
37 changes: 25 additions & 12 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,16 +117,13 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc, then use the
// default token source.
if cfg.useIAMAuthN && cfg.tokenSource == nil {
if cfg.tokenSource == nil {
ts, err := google.DefaultTokenSource(ctx, sqladmin.SqlserviceAdminScope)
if err != nil {
return nil, fmt.Errorf("failed to create token source: %v", err)
}
cfg.tokenSource = ts
}
// If IAM Authn is not explicitly enabled, remove the token source.
if !cfg.useIAMAuthN {
cfg.tokenSource = nil
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(ts))
}

if cfg.rsaKey == nil {
Expand All @@ -145,6 +142,9 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
dialCfg := dialCfg{
ipType: cloudsql.PublicIP,
tcpKeepAlive: defaultTCPKeepAlive,
refreshCfg: cloudsql.RefreshCfg{
UseIAMAuthN: cfg.useIAMAuthN,
},
}
for _, opt := range cfg.dialOpts {
opt(&dialCfg)
Expand Down Expand Up @@ -186,7 +186,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)

var endInfo trace.EndSpanFunc
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
i, err := d.instance(instance)
i, err := d.instance(instance, &cfg.refreshCfg)
if err != nil {
endInfo(err)
return nil, err
Expand Down Expand Up @@ -240,7 +240,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
// corespond to one of the following types for the instance:
// https://cloud.google.com/sql/docs/mysql/admin-api/rest/v1beta4/SqlDatabaseVersion
func (d *Dialer) EngineVersion(ctx context.Context, instance string) (string, error) {
i, err := d.instance(instance)
i, err := d.instance(instance, nil)
if err != nil {
return "", err
}
Expand All @@ -254,7 +254,11 @@ func (d *Dialer) EngineVersion(ctx context.Context, instance string) (string, er
// Warmup starts the background refresh neccesary to connect to the instance. Use Warmup
// to start the refresh process early if you don't know when you'll need to call "Dial".
func (d *Dialer) Warmup(ctx context.Context, instance string, opts ...DialOption) error {
_, err := d.instance(instance)
cfg := d.defaultDialCfg
for _, opt := range opts {
opt(&cfg)
}
_, err := d.instance(instance, &cfg.refreshCfg)
return err
}

Expand Down Expand Up @@ -297,24 +301,33 @@ func (d *Dialer) Close() error {
return nil
}

func (d *Dialer) instance(connName string) (*cloudsql.Instance, error) {
// instance is a helper function for returning the appropriate instance object in a threadsafe way.
// It will create a new instance object, modify the existing one, or leave it unchanged as needed.
func (d *Dialer) instance(connName string, r *cloudsql.RefreshCfg) (*cloudsql.Instance, error) {
// Check instance cache
d.lock.RLock()
i, ok := d.instances[connName]
d.lock.RUnlock()
if !ok {
// If the instance hasn't been creted yet or if the refreshCfg has changed
if !ok || (r != nil && *r != i.RefreshCfg) {
d.lock.Lock()
// Recheck to ensure instance wasn't created between locks
// Recheck to ensure instance wasn't created or changed between locks
i, ok = d.instances[connName]
if !ok {
// Create a new instance
if r == nil {
r = &d.defaultDialCfg.refreshCfg
}
var err error
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource, d.dialerID)
i, err = cloudsql.NewInstance(connName, d.sqladmin, d.key, d.refreshTimeout, d.iamTokenSource, d.dialerID, *r)
if err != nil {
d.lock.Unlock()
return nil, err
}
d.instances[connName] = i
} else if r != nil && *r != i.RefreshCfg {
// Update the instance with the new refresh cfg
i.UpdateRefresh(*r)
}
d.lock.Unlock()
}
Expand Down
195 changes: 171 additions & 24 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,22 @@ import (
"cloud.google.com/go/cloudsqlconn/internal/mock"
)

func testSuccessfulDial(t *testing.T, d *Dialer, ctx context.Context, i string, opts ...DialOption) {
conn, err := d.Dial(ctx, i, opts...)
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
}
defer conn.Close()

data, err := ioutil.ReadAll(conn)
if err != nil {
t.Fatalf("expected ReadAll to succeed, got error %v", err)
}
if string(data) != "my-instance" {
t.Fatalf("expected known response from the server, but got %v", string(data))
}
}

func TestDialerCanConnectToInstance(t *testing.T) {
inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
svc, cleanup, err := mock.NewSQLAdminService(
Expand Down Expand Up @@ -56,19 +72,7 @@ func TestDialerCanConnectToInstance(t *testing.T) {
}
d.sqladmin = svc

conn, err := d.Dial(context.Background(), "my-project:my-region:my-instance", WithPublicIP())
if err != nil {
t.Fatalf("expected Dial to succeed, but got error: %v", err)
}
defer conn.Close()

data, err := ioutil.ReadAll(conn)
if err != nil {
t.Fatalf("expected ReadAll to succeed, got error %v", err)
}
if string(data) != "my-instance" {
t.Fatalf("expected known response from the server, but got %v", string(data))
}
testSuccessfulDial(t, d, context.Background(), "my-project:my-region:my-instance", WithPublicIP())
}

func TestDialWithAdminAPIErrors(t *testing.T) {
Expand Down Expand Up @@ -184,19 +188,19 @@ var fakeServiceAccount = []byte(`{

func TestIAMAuthn(t *testing.T) {
tcs := []struct {
desc string
opts Option
wantTokenSource bool
desc string
opts Option
wantIAMAuthN bool
}{
{
desc: "When Credentials are provided with IAM Authn ENABLED",
opts: WithOptions(WithIAMAuthN(), WithCredentialsJSON(fakeServiceAccount)),
wantTokenSource: true,
desc: "When Credentials are provided with IAM Authn ENABLED",
opts: WithOptions(WithIAMAuthN(), WithCredentialsJSON(fakeServiceAccount)),
wantIAMAuthN: true,
},
{
desc: "When Credentials are provided with IAM Authn DISABLED",
opts: WithCredentialsJSON(fakeServiceAccount),
wantTokenSource: false,
desc: "When Credentials are provided with IAM Authn DISABLED",
opts: WithCredentialsJSON(fakeServiceAccount),
wantIAMAuthN: false,
},
}

Expand All @@ -206,8 +210,8 @@ func TestIAMAuthn(t *testing.T) {
if err != nil {
t.Fatalf("NewDialer failed with error = %v", err)
}
if gotTokenSource := d.iamTokenSource != nil; gotTokenSource != tc.wantTokenSource {
t.Fatalf("want = %v, got = %v", tc.wantTokenSource, gotTokenSource)
if gotIAMAuthN := d.defaultDialCfg.refreshCfg.UseIAMAuthN; gotIAMAuthN != tc.wantIAMAuthN {
t.Fatalf("want = %v, got = %v", tc.wantIAMAuthN, gotIAMAuthN)
}
})
}
Expand Down Expand Up @@ -295,3 +299,146 @@ func TestDialerUserAgent(t *testing.T) {
t.Errorf("embed version mismatched: want %q, got %q", want, userAgent)
}
}

func TestWarmup(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
stop := mock.StartServerProxy(t, inst)
defer stop()
tests := []struct {
desc string
warmupOpts []DialOption
dialOpts []DialOption
expectedCalls []*mock.Request
}{
{
desc: "warmup and dial are the same",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{WithDialIAMAuthN(true)},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
},
{
desc: "warmup and dial are different",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{WithDialIAMAuthN(false)},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
},
},
{
desc: "warmup and default dial are different",
warmupOpts: []DialOption{WithDialIAMAuthN(true)},
dialOpts: []DialOption{},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
},
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
svc, cleanup, err := mock.NewSQLAdminService(ctx, test.expectedCalls...)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(context.Background(), WithTokenSource(mock.EmptyTokenSource{}))
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}
d.sqladmin = svc
defer func() {
if err := cleanup(); err != nil {
t.Fatalf("%v", err)
}
}()

// Warmup once with the "default" options
err = d.Warmup(ctx, "my-project:my-region:my-instance", test.warmupOpts...)
if err != nil {
t.Fatalf("Warmup failed: %v", err)
}
// Call EngineVersion to make sure we block until both API calls are completed.
_, err = d.EngineVersion(ctx, "my-project:my-region:my-instance")
if err != nil {
t.Fatalf("Warmup failed: %v", err)
}
// Dial once with the "dial" options
testSuccessfulDial(t, d, ctx, "my-project:my-region:my-instance", test.dialOpts...)
})
}
}

func TestDialDialerOptsConflicts(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

inst := mock.NewFakeCSQLInstance("my-project", "my-region", "my-instance")
stop := mock.StartServerProxy(t, inst)
defer stop()
tests := []struct {
desc string
dialerOpts []Option
dialOpts []DialOption
expectedCalls []*mock.Request
}{
{
desc: "dialer opts set and dial uses default",
dialerOpts: []Option{WithIAMAuthN()},
dialOpts: []DialOption{},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
},
{
desc: "dialer and dial opts are the same",
dialerOpts: []Option{WithIAMAuthN()},
dialOpts: []DialOption{WithDialIAMAuthN(true)},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 1),
mock.CreateEphemeralSuccess(inst, 1),
},
},
{
desc: "dialer and dial opts are different",
dialerOpts: []Option{WithIAMAuthN()},
dialOpts: []DialOption{WithDialIAMAuthN(false)},
expectedCalls: []*mock.Request{
mock.InstanceGetSuccess(inst, 2),
mock.CreateEphemeralSuccess(inst, 2),
},
},
}

for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
svc, cleanup, err := mock.NewSQLAdminService(ctx, test.expectedCalls...)
if err != nil {
t.Fatalf("failed to init SQLAdminService: %v", err)
}
d, err := NewDialer(context.Background(), WithTokenSource(mock.EmptyTokenSource{}), WithOptions(test.dialerOpts...))
if err != nil {
t.Fatalf("failed to init Dialer: %v", err)
}
d.sqladmin = svc
defer func() {
if err := cleanup(); err != nil {
t.Fatalf("%v", err)
}
}()

// Dial once with the "default" options
testSuccessfulDial(t, d, ctx, "my-project:my-region:my-instance")

// Dial once with the "dial" options
testSuccessfulDial(t, d, ctx, "my-project:my-region:my-instance", test.dialOpts...)
})
}
}
Loading

0 comments on commit c103acc

Please sign in to comment.