Skip to content
This repository was archived by the owner on Mar 27, 2023. It is now read-only.

fix potentially deadlock issue when stop timer/ticker #3

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions clock.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ type (
mock *mockClock
real *time.Timer

outC chan<- time.Time
when time.Time
stopped bool
outC chan<- time.Time
when time.Time
stopC chan struct{}
}

// Ticker is created from NewTicker() and sends to C every duration interval
Expand All @@ -35,9 +35,9 @@ type (
mock *mockClock
real *time.Ticker

outC chan<- time.Time
when time.Time
period time.Duration
stopped bool
outC chan<- time.Time
when time.Time
period time.Duration
stopC chan struct{}
}
)
48 changes: 34 additions & 14 deletions mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,15 @@ func sleeperWhen(sleeper interface{}) time.Time {
func notifySleeper(sleeper interface{}, now time.Time, yield bool) {
switch s := sleeper.(type) {
case *Timer:
s.outC <- now
select {
case s.outC <- now:
case <-s.stopC:
}
case *Ticker:
s.outC <- now
select {
case s.outC <- now:
case <-s.stopC:
}
}

// Give timers an opportunity to run - helpful if we're in the middle of a
Expand Down Expand Up @@ -122,10 +128,11 @@ func (c *mockClock) NewTimer(duration time.Duration) *Timer {

outC := make(chan time.Time)
t := &Timer{
C: outC, // user sees receive-channel
outC: outC, // we use it as a send-channel
mock: c,
when: c.now.Add(duration),
C: outC, // user sees receive-channel
outC: outC, // we use it as a send-channel
mock: c,
when: c.now.Add(duration),
stopC: make(chan struct{}),
}
c.insertSleeper(t)

Expand All @@ -144,6 +151,7 @@ func (c *mockClock) NewTicker(duration time.Duration) *Ticker {
mock: c,
when: c.now.Add(duration),
period: duration,
stopC: make(chan struct{}),
}
c.insertSleeper(t)

Expand Down Expand Up @@ -187,13 +195,19 @@ func (c *mockClock) Advance(duration time.Duration) {
switch s := head.(type) {
case *Ticker:
// Requeue ticker
if !s.stopped {
select {
case <-s.stopC:
default:
s.when = c.now.Add(s.period)
c.insertSleeper(s)
}
case *Timer:
// Discard timer, and notify that sleepers changed
s.stopped = true
select {
case <-s.stopC:
default:
close(s.stopC)
}
for _, sc := range c.sleepersChanged {
sc <- len(c.sleepers)
}
Expand All @@ -206,10 +220,13 @@ func stopMockTimer(t *Timer) bool {

c.mutex.Lock()
defer c.mutex.Unlock()
if t.stopped {
return false

select {
case <-t.stopC:
default:
close(t.stopC)
}
t.stopped = true

return c.removeSleeper(t)
}

Expand All @@ -218,10 +235,13 @@ func stopMockTicker(t *Ticker) {

c.mutex.Lock()
defer c.mutex.Unlock()
if t.stopped {
return

select {
case <-t.stopC:
default:
close(t.stopC)
}
t.stopped = true

c.removeSleeper(t)
}

Expand Down
40 changes: 40 additions & 0 deletions mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,43 @@ func TestUntil(t *testing.T) {
t.Errorf("Expected -24 hours until from 24 hour time, got %v", u)
}
}

func TestStopTimerNoDeadlock(t *testing.T) {
// Hopefully, the execution order will be:
// t1 := c.NewTimer(time.Second)
// c.Advance(2 * time.Second)
// t1.Stop()
c := NewMock(defaultOpts)
barrier := make(chan struct{})
go func() {
t1 := c.NewTimer(time.Second)
barrier <- struct{}{}
<-barrier
t1.Stop()
}()
<-barrier
go func() {
barrier <- struct{}{}
}()
c.Advance(2 * time.Second)
}

func TestStopTickerNoDeadlock(t *testing.T) {
// Hopefully, the execution order will be:
// t1 := c.NewTicker(time.Second)
// c.Advance(2 * time.Second)
// t1.Stop()
c := NewMock(defaultOpts)
barrier := make(chan struct{})
go func() {
t1 := c.NewTicker(time.Second)
barrier <- struct{}{}
<-barrier
t1.Stop()
}()
<-barrier
go func() {
barrier <- struct{}{}
}()
c.Advance(2 * time.Second)
}