Skip to content

Commit

Permalink
grpc_retry backoff overflow (#747)
Browse files Browse the repository at this point in the history
* grpc_retry backoff overflow

* Add bounds to exponentBase2 and use it instead of math.Exp2

* Add a few tests

* Add copyright to backoff_test
  • Loading branch information
JacobSMoller authored Feb 11, 2025
1 parent ed865db commit d75a1b8
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 1 deletion.
35 changes: 34 additions & 1 deletion interceptors/retry/backoff.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package retry

import (
"context"
"math"
"math/rand"
"time"
)
Expand All @@ -24,8 +25,15 @@ func jitterUp(duration time.Duration, jitter float64) time.Duration {
return time.Duration(float64(duration) * (1 + multiplier))
}

// exponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 0.
// exponentBase2 computes 2^(a-1) where a >= 1. If a is 0, the result is 1.
// if a is greater than 62, the result is 2^62 to avoid overflowing int64
func exponentBase2(a uint) uint {
if a == 0 {
return 1
}
if a > 62 {
return 1 << 62
}
return (1 << a) >> 1
}

Expand All @@ -50,6 +58,31 @@ func BackoffExponential(scalar time.Duration) BackoffFunc {
// BackoffExponential does, but adds jitter.
func BackoffExponentialWithJitter(scalar time.Duration, jitterFraction float64) BackoffFunc {
return func(ctx context.Context, attempt uint) time.Duration {
exp := exponentBase2(attempt)
dur := scalar * time.Duration(exp)
// Check for overflow in duration multiplication
if exp != 0 && dur/scalar != time.Duration(exp) {
return time.Duration(math.MaxInt64)
}
return jitterUp(scalar*time.Duration(exponentBase2(attempt)), jitterFraction)
}
}

func BackoffExponentialWithJitterBounded(scalar time.Duration, jitterFrac float64, maxBound time.Duration) BackoffFunc {
return func(ctx context.Context, attempt uint) time.Duration {
exp := exponentBase2(attempt)
dur := scalar * time.Duration(exp)
// Check for overflow in duration multiplication
if exp != 0 && dur/scalar != time.Duration(exp) {
return maxBound
}
// Apply random jitter between -jitterFrac and +jitterFrac
jitter := 1 + jitterFrac*(rand.Float64()*2-1)
jitteredDuration := time.Duration(float64(dur) * jitter)
// Check for overflow in jitter multiplication
if float64(dur)*jitter > float64(math.MaxInt64) {
return maxBound
}
return min(jitteredDuration, maxBound)
}
}
39 changes: 39 additions & 0 deletions interceptors/retry/backoff_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) The go-grpc-middleware Authors.
// Licensed under the Apache License 2.0.
package retry

import (
"context"
"testing"
"time"
)

func TestBackoffExponentialWithJitter(t *testing.T) {
scalar := 100 * time.Millisecond
jitterFrac := 0.10
backoffFunc := BackoffExponentialWithJitter(scalar, jitterFrac)
// use 64 so we are past number of attempts where exponentBase2 would overflow
for i := 0; i < 64; i++ {
waitFor := backoffFunc(nil, uint(i))
if waitFor < 0 {
t.Errorf("BackoffExponentialWithJitter(%d) = %d; want >= 0", i, waitFor)
}
}
}

func TestBackoffExponentialWithJitterBounded(t *testing.T) {
scalar := 100 * time.Millisecond
jitterFrac := 0.10
maxBound := 10 * time.Second
backoff := BackoffExponentialWithJitterBounded(scalar, jitterFrac, maxBound)
// use 64 so we are past number of attempts where exponentBase2 would overflow
for i := 0; i < 64; i++ {
waitFor := backoff(context.Background(), uint(i))
if waitFor > maxBound {
t.Fatalf("expected dur to be less than %v, got %v for %d", maxBound, waitFor, i)
}
if waitFor < 0 {
t.Fatalf("expected dur to be greater than 0, got %v for %d", waitFor, i)
}
}
}

0 comments on commit d75a1b8

Please sign in to comment.