Skip to content

Commit

Permalink
chore: move all cloudsql.Instance references behind methods (#678)
Browse files Browse the repository at this point in the history
This is the first step to extracting an interface for cloudsql.Instance.

Also, prefer full words (cfg vs config) for types, field names, and
compound variable names.
  • Loading branch information
enocom authored Dec 1, 2023
1 parent e41783a commit 69c0728
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 66 deletions.
56 changes: 28 additions & 28 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ type Dialer struct {

sqladmin *sqladmin.Service

// defaultDialCfg holds the constructor level DialOptions, so that it can
// be copied and mutated by the Dial function.
defaultDialCfg dialCfg
// defaultDialConfig holds the constructor level DialOptions, so that it
// can be copied and mutated by the Dial function.
defaultDialConfig dialConfig

// dialerID uniquely identifies a Dialer. Used for monitoring purposes,
// *only* when a client has configured OpenCensus exporters.
Expand Down Expand Up @@ -157,10 +157,10 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, fmt.Errorf("failed to create sqladmin client: %v", err)
}

dc := dialCfg{
dc := dialConfig{
ipType: cloudsql.PublicIP,
tcpKeepAlive: defaultTCPKeepAlive,
refreshCfg: cloudsql.RefreshCfg{
refreshConfig: cloudsql.RefreshConfig{
UseIAMAuthN: cfg.useIAMAuthN,
},
}
Expand All @@ -172,14 +172,14 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
return nil, err
}
d := &Dialer{
instances: make(map[cloudsql.ConnName]*cloudsql.Instance),
key: cfg.rsaKey,
refreshTimeout: cfg.refreshTimeout,
sqladmin: client,
defaultDialCfg: dc,
dialerID: uuid.New().String(),
iamTokenSource: cfg.iamLoginTokenSource,
dialFunc: cfg.dialFunc,
instances: make(map[cloudsql.ConnName]*cloudsql.Instance),
key: cfg.rsaKey,
refreshTimeout: cfg.refreshTimeout,
sqladmin: client,
defaultDialConfig: dc,
dialerID: uuid.New().String(),
iamTokenSource: cfg.iamLoginTokenSource,
dialFunc: cfg.dialFunc,
}
return d, nil
}
Expand All @@ -202,15 +202,15 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
return nil, err
}

cfg := d.defaultDialCfg
cfg := d.defaultDialConfig
for _, opt := range opts {
opt(&cfg)
}

var endInfo trace.EndSpanFunc
ctx, endInfo = trace.StartSpan(ctx, "cloud.google.com/go/cloudsqlconn/internal.InstanceInfo")
i := d.instance(cn, &cfg.refreshCfg)
addr, tlsCfg, err := i.ConnectInfo(ctx, cfg.ipType)
i := d.instance(cn, &cfg.refreshConfig)
addr, tlsConfig, err := i.ConnectInfo(ctx, cfg.ipType)
if err != nil {
d.removeInstance(i)
endInfo(err)
Expand Down Expand Up @@ -241,7 +241,7 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)
}
}

tlsConn := tls.Client(conn, tlsCfg)
tlsConn := tls.Client(conn, tlsConfig)
if err := tlsConn.HandshakeContext(ctx); err != nil {
// refresh the instance info in case it caused the handshake failure
i.ForceRefresh()
Expand All @@ -251,13 +251,13 @@ func (d *Dialer) Dial(ctx context.Context, instance string, opts ...DialOption)

latency := time.Since(startTime).Milliseconds()
go func() {
n := atomic.AddUint64(&i.OpenConns, 1)
n := atomic.AddUint64(i.OpenConns(), 1)
trace.RecordOpenConnections(ctx, int64(n), d.dialerID, i.String())
trace.RecordDialLatency(ctx, instance, d.dialerID, latency)
}()

return newInstrumentedConn(tlsConn, func() {
n := atomic.AddUint64(&i.OpenConns, ^uint64(0))
n := atomic.AddUint64(i.OpenConns(), ^uint64(0))
trace.RecordOpenConnections(context.Background(), int64(n), d.dialerID, i.String())
}), nil
}
Expand Down Expand Up @@ -285,11 +285,11 @@ func (d *Dialer) Warmup(_ context.Context, instance string, opts ...DialOption)
if err != nil {
return err
}
cfg := d.defaultDialCfg
cfg := d.defaultDialConfig
for _, opt := range opts {
opt(&cfg)
}
_ = d.instance(cn, &cfg.refreshCfg)
_ = d.instance(cn, &cfg.refreshConfig)
return nil
}

Expand Down Expand Up @@ -334,26 +334,26 @@ func (d *Dialer) Close() error {

// instance is a helper function for returning the appropriate instance object in a threadsafe way.
// It will create a new instance object, modify the existing one, or leave it unchanged as needed.
func (d *Dialer) instance(cn cloudsql.ConnName, r *cloudsql.RefreshCfg) *cloudsql.Instance {
func (d *Dialer) instance(cn cloudsql.ConnName, r *cloudsql.RefreshConfig) *cloudsql.Instance {
// Check instance cache
d.lock.RLock()
i, ok := d.instances[cn]
d.lock.RUnlock()
// If the instance hasn't been created yet or if the refreshCfg has changed
if !ok || (r != nil && *r != i.RefreshCfg) {
// If the instance hasn't been created yet or if the refreshConfig has changed
if !ok || (r != nil && *r != i.RefreshConfig()) {
d.lock.Lock()
// Recheck to ensure instance wasn't created or changed between locks
i, ok = d.instances[cn]
if !ok {
// Create a new instance
if r == nil {
r = &d.defaultDialCfg.refreshCfg
r = &d.defaultDialConfig.refreshConfig
}
i = cloudsql.NewInstance(cn, d.sqladmin, d.key,
d.refreshTimeout, d.iamTokenSource, d.dialerID, *r)
d.instances[cn] = i
} else if r != nil && *r != i.RefreshCfg {
// Update the instance with the new refresh cfg
} else if r != nil && *r != i.RefreshConfig() {
// Update the instance with the new refresh config
i.UpdateRefresh(*r)
}
d.lock.Unlock()
Expand All @@ -366,5 +366,5 @@ func (d *Dialer) removeInstance(i *cloudsql.Instance) {
defer d.lock.Unlock()
// Stop all background refreshes
i.Close()
delete(d.instances, i.ConnName)
delete(d.instances, i.ConnName())
}
2 changes: 1 addition & 1 deletion dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ func TestIAMAuthn(t *testing.T) {
if err != nil {
t.Fatalf("NewDialer failed with error = %v", err)
}
if gotIAMAuthN := d.defaultDialCfg.refreshCfg.UseIAMAuthN; gotIAMAuthN != tc.wantIAMAuthN {
if gotIAMAuthN := d.defaultDialConfig.refreshConfig.UseIAMAuthN; gotIAMAuthN != tc.wantIAMAuthN {
t.Fatalf("want = %v, got = %v", tc.wantIAMAuthN, gotIAMAuthN)
}
})
Expand Down
49 changes: 33 additions & 16 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ func (r *refreshOperation) isValid() bool {
}
}

// RefreshCfg is a collection of attributes that trigger new refresh operations.
type RefreshCfg struct {
// RefreshConfig is a collection of attributes that trigger new refresh operations.
type RefreshConfig struct {
UseIAMAuthN bool
}

Expand All @@ -130,11 +130,11 @@ type RefreshCfg struct {
// the required information approximately 4 minutes before the previous
// certificate expires (every ~56 minutes).
type Instance struct {
// OpenConns is the number of open connections to the instance.
OpenConns uint64
// openConns is the number of open connections to the instance.
openConns uint64

ConnName
key *rsa.PrivateKey
connName ConnName
key *rsa.PrivateKey

// refreshTimeout sets the maximum duration a refresh cycle can run
// for.
Expand All @@ -143,8 +143,8 @@ type Instance struct {
l *rate.Limiter
r refresher

refreshLock sync.RWMutex
RefreshCfg RefreshCfg
refreshLock sync.RWMutex
refreshConfig RefreshConfig
// cur represents the current refreshOperation that will be used to
// create connections. If a valid complete refreshOperation isn't
// available it's possible for cur to be equal to next.
Expand All @@ -167,11 +167,11 @@ func NewInstance(
refreshTimeout time.Duration,
ts oauth2.TokenSource,
dialerID string,
r RefreshCfg,
r RefreshConfig,
) *Instance {
ctx, cancel := context.WithCancel(context.Background())
i := &Instance{
ConnName: cn,
connName: cn,
key: key,
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
r: newRefresher(
Expand All @@ -180,7 +180,7 @@ func NewInstance(
dialerID,
),
refreshTimeout: refreshTimeout,
RefreshCfg: r,
refreshConfig: r,
ctx: ctx,
cancel: cancel,
}
Expand All @@ -193,6 +193,23 @@ func NewInstance(
return i
}

// OpenConns returns a pointer to the number of open connections to
// faciliate changing the value using atomics.
func (i *Instance) OpenConns() *uint64 {
return &i.openConns
}

// ConnName returns the instance connection name associated with this Instance.
func (i *Instance) ConnName() ConnName {
return i.connName
}

// RefreshConfig returns the refresh configuration associated with this
// instance.
func (i *Instance) RefreshConfig() RefreshConfig {
return i.refreshConfig
}

// Close closes the instance; it stops the refresh cycle and prevents it from
// making additional calls to the Cloud SQL Admin API.
func (i *Instance) Close() {
Expand Down Expand Up @@ -245,14 +262,14 @@ func (i *Instance) InstanceEngineVersion(ctx context.Context) (string, error) {

// UpdateRefresh cancels all existing refresh attempts and schedules new
// attempts with the provided config.
func (i *Instance) UpdateRefresh(cfg RefreshCfg) {
func (i *Instance) UpdateRefresh(c RefreshConfig) {
i.refreshLock.Lock()
defer i.refreshLock.Unlock()
// Cancel any pending refreshes
i.cur.cancel()
i.next.cancel()
// update the refresh config as needed
i.RefreshCfg = cfg
i.refreshConfig = c
// reschedule a new refresh immediately
i.cur = i.scheduleRefresh(0)
i.next = i.cur
Expand Down Expand Up @@ -328,12 +345,12 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
if err != nil {
r.err = errtype.NewDialError(
"context was canceled or expired before refresh completed",
i.ConnName.String(),
i.connName.String(),
nil,
)
} else {
r.result, r.err = i.r.performRefresh(
ctx, i.ConnName, i.key, i.RefreshCfg.UseIAMAuthN)
ctx, i.connName, i.key, i.refreshConfig.UseIAMAuthN)
}

close(r.ready)
Expand Down Expand Up @@ -376,5 +393,5 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {

// String returns the instance's connection name.
func (i *Instance) String() string {
return i.ConnName.String()
return i.connName.String()
}
12 changes: 6 additions & 6 deletions internal/cloudsql/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func TestInstanceEngineVersion(t *testing.T) {
t.Fatalf("%v", err)
}
}()
i := NewInstance(testInstanceConnName(), client, RSAKey, 30*time.Second, nil, "", RefreshCfg{})
i := NewInstance(testInstanceConnName(), client, RSAKey, 30*time.Second, nil, "", RefreshConfig{})
if err != nil {
t.Fatalf("failed to init instance: %v", err)
}
Expand Down Expand Up @@ -129,7 +129,7 @@ func TestConnectInfo(t *testing.T) {
}
}()

i := NewInstance(testInstanceConnName(), client, RSAKey, 30*time.Second, nil, "", RefreshCfg{})
i := NewInstance(testInstanceConnName(), client, RSAKey, 30*time.Second, nil, "", RefreshConfig{})

gotAddr, gotTLSCfg, err := i.ConnectInfo(ctx, PublicIP)
if err != nil {
Expand Down Expand Up @@ -194,7 +194,7 @@ func TestConnectInfoAutoIP(t *testing.T) {
}
}()

i := NewInstance(testInstanceConnName(), client, RSAKey, 30*time.Second, nil, "", RefreshCfg{})
i := NewInstance(testInstanceConnName(), client, RSAKey, 30*time.Second, nil, "", RefreshConfig{})
if err != nil {
t.Fatalf("failed to create mock instance: %v", err)
}
Expand Down Expand Up @@ -223,7 +223,7 @@ func TestConnectInfoErrors(t *testing.T) {
defer cleanup()

// Use a timeout that should fail instantly
i := NewInstance(testInstanceConnName(), client, RSAKey, 0, nil, "", RefreshCfg{})
i := NewInstance(testInstanceConnName(), client, RSAKey, 0, nil, "", RefreshConfig{})

_, _, err = i.ConnectInfo(ctx, PublicIP)
var wantErr *errtype.DialError
Expand All @@ -248,7 +248,7 @@ func TestClose(t *testing.T) {
defer cleanup()

// Set up an instance and then close it immediately
i := NewInstance(testInstanceConnName(), client, RSAKey, 30, nil, "", RefreshCfg{})
i := NewInstance(testInstanceConnName(), client, RSAKey, 30, nil, "", RefreshConfig{})
i.Close()

_, _, err = i.ConnectInfo(ctx, PublicIP)
Expand Down Expand Up @@ -311,7 +311,7 @@ func TestContextCancelled(t *testing.T) {
defer cleanup()

// Set up an instance and then close it immediately
i := NewInstance(testInstanceConnName(), client, RSAKey, 30, nil, "", RefreshCfg{})
i := NewInstance(testInstanceConnName(), client, RSAKey, 30, nil, "", RefreshConfig{})
if err != nil {
t.Fatalf("failed to initialize Instance: %v", err)
}
Expand Down
Loading

0 comments on commit 69c0728

Please sign in to comment.