diff --git a/retrier.go b/retrier.go index c093a9e..17213ab 100644 --- a/retrier.go +++ b/retrier.go @@ -25,7 +25,7 @@ type Retrier struct { intervalCalculator Strategy strategyType string - manualInterval *time.Duration + nextInterval time.Duration } type jitterRange struct{ min, max time.Duration } @@ -184,7 +184,7 @@ func NewRetrier(opts ...retrierOpt) *Retrier { oldJitter := r.jitter r.jitter = false // Temporarily turn off jitter while we check if the interval is 0 - if r.forever && r.strategyType == constantStrategy && r.NextInterval() == 0 { + if r.forever && r.strategyType == constantStrategy && r.intervalCalculator(r) == 0 { panic("retriers using the constant strategy that run forever must have an interval") } r.jitter = oldJitter // and now set it back to what it was previously @@ -192,7 +192,9 @@ func NewRetrier(opts ...retrierOpt) *Retrier { return r } -// Jitter returns a duration in the interval (0, 1] s if jitter is enabled, or 0 s if it's not +// Jitter returns a duration in the interval in the range [0, r.jitterRange.max - r.jitterRange.min). When no jitter range +// is defined, the default range is [0, 1 second). The jitter is recalculated for each retry. +// If jitter is disabled, this method will always return 0. func (r *Retrier) Jitter() time.Duration { if !r.jitter { return 0 @@ -215,7 +217,7 @@ func (r *Retrier) Break() { // SetNextInterval overrides the strategy for the interval before the next try func (r *Retrier) SetNextInterval(d time.Duration) { - r.manualInterval = &d + r.nextInterval = d } // ShouldGiveUp returns whether the retrier should stop trying do do the thing it's been asked to do @@ -233,14 +235,9 @@ func (r *Retrier) ShouldGiveUp() bool { return r.attemptCount >= r.maxAttempts } -// NextInterval returns the next interval that the retrier will use. Behind the scenes, it calls the function generated -// by either retrier's strategy +// NextInterval returns the length of time that the retrier will wait before the next retry func (r *Retrier) NextInterval() time.Duration { - if r.manualInterval != nil { - return *r.manualInterval - } - - return r.intervalCalculator(r) + return r.nextInterval } func (r *Retrier) String() string { @@ -256,9 +253,8 @@ func (r *Retrier) String() string { return str } - nextInterval := r.NextInterval() - if nextInterval > 0 { - str = str + fmt.Sprintf(" Retrying in %s", nextInterval) + if r.nextInterval > 0 { + str = str + fmt.Sprintf(" Retrying in %s", r.nextInterval) } else { str = str + " Retrying immediately" } @@ -280,21 +276,16 @@ func (r *Retrier) Do(callback func(*Retrier) error) error { // DoWithContext is a context-aware variant of Do. func (r *Retrier) DoWithContext(ctx context.Context, callback func(*Retrier) error) error { for { + // Calculate the next interval before we do work - this way, the calls to r.NextInterval() in the callback will be + // accurate and include the calculated jitter, if present + r.nextInterval = r.intervalCalculator(r) + // Perform the action the user has requested we retry err := callback(r) if err == nil { return nil } - // Calculate the next interval before we increment the attempt count - // In the exponential case, if we didn't do this, we'd skip the first interval - // ie, we would wait 2^1, 2^2, 2^3, ..., 2^n+1 seconds (bad) - // instead of 2^0, 2^1, 2^2, ..., 2^n seconds (good) - nextInterval := r.NextInterval() - - // Reset the manualInterval now that the nextInterval has been acquired. - r.manualInterval = nil - r.MarkAttempt() // If the last callback called r.Break(), or if we've hit our call limit, bail out and return the last error we got @@ -302,7 +293,7 @@ func (r *Retrier) DoWithContext(ctx context.Context, callback func(*Retrier) err return err } - if err := r.sleepOrDone(ctx, nextInterval); err != nil { + if err := r.sleepOrDone(ctx, r.nextInterval); err != nil { return err } } diff --git a/retrier_test.go b/retrier_test.go index 1577f5e..21d4e47 100644 --- a/retrier_test.go +++ b/retrier_test.go @@ -3,6 +3,7 @@ package roko import ( "context" "errors" + "regexp" "testing" "time" @@ -518,6 +519,41 @@ func TestString_WithTryForever(t *testing.T) { }, retryingIns) } +func TestString_WithJitter(t *testing.T) { + t.Parallel() + + insomniac := newInsomniac() + r := NewRetrier( + WithStrategy(Constant(10*time.Second)), + WithJitterRange(-1*time.Second, 10*time.Second), + WithMaxAttempts(5), + WithSleepFunc(insomniac.sleep), + ) + + retryingIns := make([]time.Duration, 0, 5) + durationRE := regexp.MustCompile(`[\d\.]+s`) + err := r.Do(func(_ *Retrier) error { + retryingIn := r.String() + dur := durationRE.FindString(retryingIn) + if dur != "" { + d, err := time.ParseDuration(dur) + if err != nil { + t.Fatalf("failed to parse duration: %s", dur) + } + retryingIns = append(retryingIns, d) + } + + return errDummy + }) + + assert.ErrorIs(t, err, errDummy) + + // Assert that the durations returned by the String() method are the actual lengths of time that the retrier slept + for i, interval := range insomniac.sleepIntervals { + assert.Check(t, cmp.DeepEqual(interval, retryingIns[i], DurationExact())) + } +} + func TestString_WithNoDelay(t *testing.T) { t.Parallel()