Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deadlock demo in acquireShards #5824

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions common/membership/hashring.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ func (r *ring) Stop() {
r.subscribers.keys = make(map[string]chan<- *ChangedEvent)
close(r.shutdownCh)

// TODO: this can deadlock for 1m
if success := common.AwaitWaitGroup(&r.shutdownWG, time.Minute); !success {
r.logger.Warn("service resolver timed out on shutdown.")
}
Expand Down
12 changes: 5 additions & 7 deletions service/history/shard/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ func (c *controller) acquireShards() {
defer sw.Stop()

concurrency := common.MaxInt(c.config.AcquireShardConcurrency(), 1)
shardActionCh := make(chan int, concurrency)
numShards := c.config.NumberOfShards
shardActionCh := make(chan int, numShards)
var wg sync.WaitGroup
wg.Add(concurrency)
// Spawn workers that would lookup and add/remove shards concurrently.
Expand All @@ -411,14 +412,11 @@ func (c *controller) acquireShards() {
}()
}
// Submit tasks to the channel.
for shardID := 0; shardID < c.config.NumberOfShards; shardID++ {
shardActionCh <- shardID
if c.isShuttingDown() {
return
}
for shardID := 0; shardID < numShards; shardID++ {
shardActionCh <- shardID // must be non-blocking
}
close(shardActionCh)
// Wait until all shards are processed.
// Wait until all shards are processed or have shut down
wg.Wait()

c.metricsScope.UpdateGauge(metrics.NumShardsGauge, float64(c.NumShards()))
Expand Down
146 changes: 146 additions & 0 deletions service/history/shard/controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ package shard
import (
"errors"
"fmt"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"

Expand Down Expand Up @@ -67,6 +69,150 @@ type (
}
)

func TestDeadlock(t *testing.T) {
useFix := false
/*
`try` is essentially what acquireShards() does:

c := make(chan, buffer) // "small" without fix, == len(shards) with fix
for range some {
go func() {
for range c {
if stop { return }
process()
}
}()
}
for range shards {
c <- ...
if stop { return } // removed with fix
}

plus some logging, and triggering a "stop" in the middle.
*/
try := func(logfn func(...interface{})) {
var wg sync.WaitGroup
wg.Add(10)
var c chan struct{}
numShards := 16000
if useFix {
c = make(chan struct{}, numShards)
} else {
// original flawed code
c = make(chan struct{}, 10)
}
stop := &atomic.Bool{}

for i := 0; i < 10; i++ {
go func() {
defer wg.Done()
for range c {
if stop.Load() {
logfn("worker returning")
return
}
logfn("processed")
runtime.Gosched()
}
}()
}

// using realistic values. occurs at lower values too, but perf is fine.
for i := 0; i < numShards; i++ {
logfn("pushing", i)
c <- struct{}{}
if i == numShards/2 {
logfn("stopping")
go func() {
time.Sleep(time.Microsecond)
stop.Store(true)
}()
}
// unnecessary with fixed job-publishing
if !useFix {
if stop.Load() {
logfn("breaking")
break
}
}
}
logfn("closing")
close(c)

wg.Wait()
}

// this runs `try` and watches for deadlocks, up to some duration.
run := func(logfn func(...interface{}), timeout time.Duration) time.Duration {
done := make(chan struct{})
start := time.Now()
go func() {
defer close(done)
try(logfn)
}()

select {
case <-done:
logfn("done")
return time.Since(start)
case <-time.After(timeout):
t.Errorf("took too long")
return 0
}
}

logfn := func(...interface{}) {} // noop by default
// logfn = t.Log // if desired, but probably only run one attempt to avoid extreme logspam

if _, ok := t.Deadline(); !ok {
t.Fatal("no deadline, skipping")
}

// keep trying batches until fail or timeout
for {
attempts := 100
timeout := 10 * time.Second // needs to be higher for more concurrent attempts
var wg sync.WaitGroup
wg.Add(attempts)
var avg, count float64
var successes, fails int
var longest time.Duration
var mut sync.Mutex

// run some tests in parallel to increase the odds of triggering the issue,
// and keep track of avg/max runtime to make sure the timeout is reasonably
// larger than random latency could exceed.
for i := 0; i < attempts; i++ {
go func() {
defer wg.Done()
result := run(logfn, timeout)
mut.Lock()
if result > 0 {
if result > longest {
longest = result
}
avg = (avg*count + float64(result)) / (count + 1)
count++
successes++
} else {
fails++
}
mut.Unlock()
}()
}
wg.Wait()
dead, _ := t.Deadline()
t.Logf("time remaining: %v, avg runtime: %v, longest: %v, successes: %v, fails: %v", time.Until(dead), time.Duration(avg), longest, successes, fails)
if t.Failed() {
break
}
if time.Until(dead) < 10*time.Second {
t.Log("giving up early")
break
}
}
}

func TestControllerSuite(t *testing.T) {
s := new(controllerSuite)
suite.Run(t, s)
Expand Down
Loading