diff --git a/pkg/sql/stmtdiagnostics/statement_diagnostics.go b/pkg/sql/stmtdiagnostics/statement_diagnostics.go index 58c978611e30..9741538b53aa 100644 --- a/pkg/sql/stmtdiagnostics/statement_diagnostics.go +++ b/pkg/sql/stmtdiagnostics/statement_diagnostics.go @@ -157,18 +157,21 @@ func (r *Registry) Start(ctx context.Context, stopper *stop.Stopper) { func (r *Registry) poll(ctx context.Context) { var ( - timer timeutil.Timer + timer *timeutil.Timer // We need to store timer.C reference separately because timer.Stop() // (called when polling is disabled) puts timer into the pool and // prohibits further usage of stored timer.C. - timerC = timer.C + timerC <-chan time.Time lastPoll time.Time deadline time.Time pollIntervalChanged = make(chan struct{}, 1) maybeResetTimer = func() { if interval := pollingInterval.Get(&r.st.SV); interval == 0 { // Setting the interval to zero stops the polling. - timer.Stop() + if timer != nil { + timer.Stop() + timer = nil + } // nil out the channel so that it'd block forever in the loop // below (until the polling interval is changed). timerC = nil @@ -176,6 +179,9 @@ func (r *Registry) poll(ctx context.Context) { newDeadline := lastPoll.Add(interval) if deadline.IsZero() || !deadline.Equal(newDeadline) { deadline = newDeadline + if timer == nil { + timer = timeutil.NewTimer() + } timer.Reset(timeutil.Until(deadline)) timerC = timer.C } diff --git a/pkg/util/timeutil/timer.go b/pkg/util/timeutil/timer.go index a090cd5d2d89..d21c3af733b3 100644 --- a/pkg/util/timeutil/timer.go +++ b/pkg/util/timeutil/timer.go @@ -55,11 +55,16 @@ type Timer struct { // the timer has been initialized (via Reset). C <-chan time.Time Read bool + // fromPool indicates whether this Timer came from timerPool. If false, then + // it won't be put into the timerPool on Stop. + fromPool bool } // NewTimer allocates a new timer. func NewTimer() *Timer { - return timerPool.Get().(*Timer) + t := timerPool.Get().(*Timer) + t.fromPool = true + return t } // Reset changes the timer to expire after duration d and returns @@ -102,6 +107,9 @@ func (t *Timer) Stop() bool { timeTimerPool.Put(t.timer) } } + if !t.fromPool { + return res + } *t = Timer{} timerPool.Put(t) return res diff --git a/pkg/util/timeutil/timer_test.go b/pkg/util/timeutil/timer_test.go index 9ef96c9a9ab6..63effbb42ef8 100644 --- a/pkg/util/timeutil/timer_test.go +++ b/pkg/util/timeutil/timer_test.go @@ -11,9 +11,14 @@ package timeutil import ( + "context" "fmt" + "math/rand" + "sync" "testing" "time" + + "github.com/cockroachdb/cockroach/pkg/util/randutil" ) const timeStep = 10 * time.Millisecond @@ -139,3 +144,61 @@ func TestTimerMakesProgressInLoop(t *testing.T) { timer.Read = true } } + +// TestIllegalTimerShare is a regression test for sharing the same Timer between +// multiple users when it was originally allocated on the stack of one of them +// but then later was put into timerPool on Stop() (see #119593). +func TestIllegalTimerShare(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + resetTimer := func(t *Timer, rng *rand.Rand) { + t.Reset(time.Duration(rng.Intn(100)+1) * time.Nanosecond) + } + + var wg sync.WaitGroup + // Simulate a pattern of usage of the stack-allocated Timer that is being + // stopped each time when the Timer fires. + fromStack := func() { + defer wg.Done() + rng, _ := randutil.NewTestRand() + var t Timer + defer t.Stop() + resetTimer(&t, rng) + for { + select { + case <-ctx.Done(): + return + case <-t.C: + t.Read = true + t.Stop() + resetTimer(&t, rng) + } + } + } + // Simulate the most common pattern where the Timer is taken from the + // timerPool, fires repeatedly, and then is stopped in a defer. + fromPool := func() { + defer wg.Done() + rng, _ := randutil.NewTestRand() + t := NewTimer() + defer t.Stop() + resetTimer(t, rng) + for { + select { + case <-ctx.Done(): + return + case <-t.C: + t.Read = true + resetTimer(t, rng) + } + } + } + // Spin up a few goroutines per each access pattern. + wg.Add(6) + for i := 0; i < 3; i++ { + go fromStack() + go fromPool() + } + wg.Wait() +}