Skip to content

Commit

Permalink
fix: prevent repeated context expired errors
Browse files Browse the repository at this point in the history
  • Loading branch information
enocom committed Feb 15, 2023
1 parent f8c584a commit 54616f1
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 71 deletions.
2 changes: 1 addition & 1 deletion dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type Dialer struct {
// RSA keypair is generated will be faster.
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: 30 * time.Second,
refreshTimeout: alloydb.RefreshTimeout,
dialFunc: proxy.Dial,
useragents: []string{userAgent},
}
Expand Down
79 changes: 54 additions & 25 deletions internal/alloydb/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,26 @@ import (

"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
"golang.org/x/time/rate"
)

// the refresh buffer is the amount of time before a refresh's result expires
// that a new refresh operation begins.
const refreshBuffer = 4 * time.Minute
const (
// the refresh buffer is the amount of time before a refresh's result
// expires that a new refresh operation begins.
refreshBuffer = 4 * time.Minute

// refreshInterval is the amount of time between refresh attempts as
// enforced by the rate limiter.
refreshInterval = 30 * time.Second

// RefreshTimeout is the maximum amount of time to wait for a refresh
// cycle to complete. This value should be greater than the
// refreshInterval.
RefreshTimeout = 60 * time.Second

// refreshBurst is the initial burst allowed by the rate limiter.
refreshBurst = 2
)

var (
// Instance URI is in the format:
Expand Down Expand Up @@ -117,7 +132,12 @@ type Instance struct {

instanceURI
key *rsa.PrivateKey
r refresher
// refreshTimeout sets the maximum duration a refresh cycle can run
// for.
refreshTimeout time.Duration
// l controls the rate at which refresh cycles are run.
l *rate.Limiter
r refresher

resultGuard sync.RWMutex
// cur represents the current refreshOperation that will be used to
Expand Down Expand Up @@ -148,17 +168,13 @@ func NewInstance(
}
ctx, cancel := context.WithCancel(context.Background())
i := &Instance{
instanceURI: cn,
key: key,
r: newRefresher(
client,
refreshTimeout,
30*time.Second,
2,
dialerID,
),
ctx: ctx,
cancel: cancel,
instanceURI: cn,
key: key,
l: rate.NewLimiter(rate.Every(refreshInterval), refreshBurst),
r: newRefresher(client, dialerID),
refreshTimeout: refreshTimeout,
ctx: ctx,
cancel: cancel,
}
// For the initial refresh operation, set cur = next so that connection
// requests block until the first refresh is complete.
Expand Down Expand Up @@ -234,20 +250,33 @@ func refreshDuration(now, certExpiry time.Time) time.Duration {

// scheduleRefresh schedules a refresh operation to be triggered after a given
// duration. The returned refreshOperation can be used to either Cancel or Wait
// for the operations result.
// for the operation's result.
func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
res := &refreshOperation{}
res.ready = make(chan struct{})
res.timer = time.AfterFunc(d, func() {
res.result, res.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
close(res.ready)
r := &refreshOperation{}
r.ready = make(chan struct{})
r.timer = time.AfterFunc(d, func() {
ctx, cancel := context.WithTimeout(i.ctx, i.refreshTimeout)
defer cancel()

err := i.l.Wait(ctx)
if err != nil {
r.err = errtype.NewDialError(
"context was canceled or expired before refresh completed",
i.instanceURI.String(),
nil,
)
} else {
r.result, r.err = i.r.performRefresh(i.ctx, i.instanceURI, i.key)
}

close(r.ready)

// Once the refresh is complete, update "current" with working
// result and schedule a new refresh
i.resultGuard.Lock()
defer i.resultGuard.Unlock()
// if failed, scheduled the next refresh immediately
if res.err != nil {
if r.err != nil {
i.next = i.scheduleRefresh(0)
// If the latest result is bad, avoid replacing the
// used result while it's still valid and potentially
Expand All @@ -256,13 +285,13 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
// valid are surpressed. We should try to surface
// errors in a more meaningful way.
if !i.cur.isValid() {
i.cur = res
i.cur = r
}
return
}
// Update the current results, and schedule the next refresh in
// the future
i.cur = res
i.cur = r
select {
case <-i.ctx.Done():
// instance has been closed, don't schedule anything
Expand All @@ -272,7 +301,7 @@ func (i *Instance) scheduleRefresh(d time.Duration) *refreshOperation {
t := refreshDuration(time.Now(), i.cur.result.expiry)
i.next = i.scheduleRefresh(t)
})
return res
return r
}

// String returns the instance's URI.
Expand Down
3 changes: 2 additions & 1 deletion internal/alloydb/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"crypto/rand"
"crypto/rsa"
"errors"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -210,7 +211,7 @@ func TestClose(t *testing.T) {
im.Close()

_, _, err = im.ConnectInfo(ctx)
if !errors.Is(err, context.Canceled) {
if !strings.Contains(err.Error(), "context was canceled or expired") {
t.Fatalf("failed to retrieve connect info: %v", err)
}
}
Expand Down
32 changes: 2 additions & 30 deletions internal/alloydb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
"cloud.google.com/go/alloydbconn/internal/trace"
"golang.org/x/time/rate"
)

type connectInfo struct {
Expand Down Expand Up @@ -196,16 +195,11 @@ func createTLSConfig(inst instanceURI, cc certChain, info connectInfo, k *rsa.Pr
// newRefresher creates a Refresher.
func newRefresher(
client *alloydbapi.Client,
timeout time.Duration,
interval time.Duration,
burst int,
dialerID string,
) refresher {
return refresher{
client: client,
timeout: timeout,
clientLimiter: rate.NewLimiter(rate.Every(interval), burst),
dialerID: dialerID,
client: client,
dialerID: dialerID,
}
}

Expand All @@ -215,14 +209,8 @@ type refresher struct {
// client provides access to the AlloyDB Admin API
client *alloydbapi.Client

// timeout is the maximum amount of time a refresh operation should be allowed to take.
timeout time.Duration

// dialerID is the unique ID of the associated dialer.
dialerID string

// clientLimiter limits the number of refreshes.
clientLimiter *rate.Limiter
}

type refreshResult struct {
Expand All @@ -247,22 +235,6 @@ func (r refresher) performRefresh(ctx context.Context, cn instanceURI, k *rsa.Pr
refreshEnd(err)
}()

ctx, cancel := context.WithTimeout(ctx, r.timeout)
defer cancel()
if ctx.Err() == context.Canceled {
return refreshResult{}, ctx.Err()
}

// avoid refreshing too often to try not to tax the AlloyDB Admin API quotas
err = r.clientLimiter.Wait(ctx)
if err != nil {
return refreshResult{}, errtype.NewDialError(
"refresh was throttled until context expired",
cn.String(),
nil,
)
}

type mdRes struct {
info connectInfo
err error
Expand Down
17 changes: 4 additions & 13 deletions internal/alloydb/refresh_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@ import (
"testing"
"time"

"cloud.google.com/go/alloydbconn/errtype"
"cloud.google.com/go/alloydbconn/internal/alloydbapi"
"cloud.google.com/go/alloydbconn/internal/mock"
"google.golang.org/api/option"
)

const testDialerID = "some-dialer-id"

func TestRefresh(t *testing.T) {
wantIP := "10.0.0.1"
wantExpiry := time.Now().Add(time.Hour).UTC().Round(time.Second)
Expand Down Expand Up @@ -57,7 +58,7 @@ func TestRefresh(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newRefresher(cl, time.Hour, 30*time.Second, 2, "some-id")
r := newRefresher(cl, testDialerID)
res, err := r.performRefresh(context.Background(), cn, RSAKey)
if err != nil {
t.Fatalf("performRefresh unexpectedly failed with error: %v", err)
Expand Down Expand Up @@ -98,7 +99,7 @@ func TestRefreshFailsFast(t *testing.T) {
if err != nil {
t.Fatalf("admin API client error: %v", err)
}
r := newRefresher(cl, time.Hour, 30*time.Second, 1, "some-id")
r := newRefresher(cl, testDialerID)

_, err = r.performRefresh(context.Background(), cn, RSAKey)
if err != nil {
Expand All @@ -112,14 +113,4 @@ func TestRefreshFailsFast(t *testing.T) {
if !errors.Is(err, context.Canceled) {
t.Fatalf("expected context.Canceled error, got = %v", err)
}

// force the rate limiter to throttle with a timed out context
ctx, cancel = context.WithTimeout(context.Background(), time.Millisecond)
defer cancel()
_, err = r.performRefresh(ctx, cn, RSAKey)

var wantErr *errtype.DialError
if !errors.As(err, &wantErr) {
t.Fatalf("when refresh is throttled, want = %T, got = %v", wantErr, err)
}
}
3 changes: 2 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ func WithRSAKey(k *rsa.PrivateKey) Option {
}
}

// WithRefreshTimeout returns an Option that sets a timeout on refresh operations. Defaults to 30s.
// WithRefreshTimeout returns an Option that sets a timeout on refresh
// operations. Defaults to 60s.
func WithRefreshTimeout(t time.Duration) Option {
return func(d *dialerConfig) {
d.refreshTimeout = t
Expand Down

0 comments on commit 54616f1

Please sign in to comment.