diff --git a/interceptors/retry/backoff.go b/interceptors/retry/backoff.go index 5b8aaa68..fdc374bf 100644 --- a/interceptors/retry/backoff.go +++ b/interceptors/retry/backoff.go @@ -5,6 +5,7 @@ package retry import ( "context" + "math" "math/rand" "time" ) @@ -24,8 +25,15 @@ func jitterUp(duration time.Duration, jitter float64) time.Duration { return time.Duration(float64(duration) * (1 + multiplier)) } -// exponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 0. +// exponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 1. +// if a is greater than 62, the result is 2^62 to avoid overflowing int64 func exponentBase2(a uint) uint { + if a == 0 { + return 1 + } + if a > 62 { + return 1 << 62 + } return (1 << a) >> 1 } @@ -50,6 +58,31 @@ func BackoffExponential(scalar time.Duration) BackoffFunc { // BackoffExponential does, but adds jitter. func BackoffExponentialWithJitter(scalar time.Duration, jitterFraction float64) BackoffFunc { return func(ctx context.Context, attempt uint) time.Duration { + exp := exponentBase2(attempt) + dur := scalar * time.Duration(exp) + // Check for overflow in duration multiplication + if exp != 0 && dur/scalar != time.Duration(exp) { + return time.Duration(math.MaxInt64) + } return jitterUp(scalar*time.Duration(exponentBase2(attempt)), jitterFraction) } } + +func BackoffExponentialWithJitterBounded(scalar time.Duration, jitterFrac float64, maxBound time.Duration) BackoffFunc { + return func(ctx context.Context, attempt uint) time.Duration { + exp := exponentBase2(attempt) + dur := scalar * time.Duration(exp) + // Check for overflow in duration multiplication + if exp != 0 && dur/scalar != time.Duration(exp) { + return maxBound + } + // Apply random jitter between -jitterFrac and +jitterFrac + jitter := 1 + jitterFrac*(rand.Float64()*2-1) + jitteredDuration := time.Duration(float64(dur) * jitter) + // Check for overflow in jitter multiplication + if float64(dur)*jitter > float64(math.MaxInt64) { + return maxBound + } + return min(jitteredDuration, maxBound) + } +} diff --git a/interceptors/retry/backoff_test.go b/interceptors/retry/backoff_test.go new file mode 100644 index 00000000..866f24c7 --- /dev/null +++ b/interceptors/retry/backoff_test.go @@ -0,0 +1,39 @@ +// Copyright (c) The go-grpc-middleware Authors. +// Licensed under the Apache License 2.0. +package retry + +import ( + "context" + "testing" + "time" +) + +func TestBackoffExponentialWithJitter(t *testing.T) { + scalar := 100 * time.Millisecond + jitterFrac := 0.10 + backoffFunc := BackoffExponentialWithJitter(scalar, jitterFrac) + // use 64 so we are past number of attempts where exponentBase2 would overflow + for i := 0; i < 64; i++ { + waitFor := backoffFunc(nil, uint(i)) + if waitFor < 0 { + t.Errorf("BackoffExponentialWithJitter(%d) = %d; want >= 0", i, waitFor) + } + } +} + +func TestBackoffExponentialWithJitterBounded(t *testing.T) { + scalar := 100 * time.Millisecond + jitterFrac := 0.10 + maxBound := 10 * time.Second + backoff := BackoffExponentialWithJitterBounded(scalar, jitterFrac, maxBound) + // use 64 so we are past number of attempts where exponentBase2 would overflow + for i := 0; i < 64; i++ { + waitFor := backoff(context.Background(), uint(i)) + if waitFor > maxBound { + t.Fatalf("expected dur to be less than %v, got %v for %d", maxBound, waitFor, i) + } + if waitFor < 0 { + t.Fatalf("expected dur to be greater than 0, got %v for %d", waitFor, i) + } + } +}