From c725a4032368f5e6d6a0d90cb2dbcd52645507ed Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 22 Jun 2023 12:43:03 -0400 Subject: [PATCH] fix(generator): can stuck on .Get --- cmd/gemini/generators.go | 7 +- cmd/gemini/root.go | 24 +-- pkg/generators/generator.go | 86 +++------ pkg/generators/generator_test.go | 9 +- pkg/generators/partition.go | 61 +++++- pkg/jobs/jobs.go | 22 +-- pkg/stop/flag.go | 174 +++++++++++++++-- pkg/stop/flag_test.go | 316 ++++++++++++++++++++++++++++++- 8 files changed, 568 insertions(+), 131 deletions(-) diff --git a/cmd/gemini/generators.go b/cmd/gemini/generators.go index d15c5d8b..fc4f6a08 100644 --- a/cmd/gemini/generators.go +++ b/cmd/gemini/generators.go @@ -14,8 +14,6 @@ package main import ( - "context" - "github.com/scylladb/gemini/pkg/generators" "github.com/scylladb/gemini/pkg/typedef" @@ -23,13 +21,12 @@ import ( ) func createGenerators( - ctx context.Context, schema *typedef.Schema, schemaConfig typedef.SchemaConfig, distributionFunc generators.DistributionFunc, _, distributionSize uint64, logger *zap.Logger, -) []*generators.Generator { +) generators.Generators { partitionRangeConfig := typedef.PartitionRangeConfig{ MaxBlobLength: schemaConfig.MaxBlobLength, MinBlobLength: schemaConfig.MinBlobLength, @@ -47,7 +44,7 @@ func createGenerators( Seed: seed, PkUsedBufferSize: pkBufferReuseSize, } - g := generators.NewGenerator(ctx, table, gCfg, logger.Named("generators")) + g := generators.NewGenerator(table, gCfg, logger.Named("generators")) gs = append(gs, g) } return gs diff --git a/cmd/gemini/root.go b/cmd/gemini/root.go index 768b2f78..ba4c76cc 100644 --- a/cmd/gemini/root.go +++ b/cmd/gemini/root.go @@ -254,12 +254,12 @@ func run(_ *cobra.Command, _ []string) error { } ctx, done := context.WithTimeout(context.Background(), duration+warmup+time.Second*2) - warmupStopFlag := stop.NewFlag() - workStopFlag := stop.NewFlag() - stop.StartOsSignalsTransmitter(logger, &warmupStopFlag, &workStopFlag) + stopFlag := stop.NewFlag("main") + stop.StartOsSignalsTransmitter(logger, stopFlag) pump := jobs.NewPump(ctx, logger) - generators := createGenerators(ctx, schema, schemaConfig, distFunc, concurrency, partitionCount, logger) + gens := createGenerators(schema, schemaConfig, distFunc, concurrency, partitionCount, logger) + gens.StartAll(stopFlag) if !nonInteractive { sp := createSpinner(interactive()) @@ -268,7 +268,7 @@ func run(_ *cobra.Command, _ []string) error { defer done() for { select { - case <-ctx.Done(): + case <-stopFlag.SignalChannel(): return case <-ticker.C: sp.Set(" Running Gemini... %v", globalStatus) @@ -277,24 +277,18 @@ func run(_ *cobra.Command, _ []string) error { }() } - if warmup > 0 && !warmupStopFlag.IsHardOrSoft() { + if warmup > 0 && !stopFlag.IsHardOrSoft() { jobsList := jobs.ListFromMode(jobs.WarmupMode, warmup, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, generators, globalStatus, logger, seed, &warmupStopFlag, failFast, verbose); err != nil { + if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, seed, stopFlag.CreateChild("warmup"), failFast, verbose); err != nil { logger.Error("warmup encountered an error", zap.Error(err)) } } - select { - case <-ctx.Done(): - default: - if workStopFlag.IsHardOrSoft() { - break - } + if !stopFlag.IsHardOrSoft() { jobsList := jobs.ListFromMode(mode, duration, concurrency) - if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, generators, globalStatus, logger, seed, &workStopFlag, failFast, verbose); err != nil { + if err = jobsList.Run(ctx, schema, schemaConfig, st, pump, gens, globalStatus, logger, seed, stopFlag.CreateChild("workload"), failFast, verbose); err != nil { logger.Debug("error detected", zap.Error(err)) } - } logger.Info("test finished") globalStatus.PrintResult(outFile, schema, version) diff --git a/pkg/generators/generator.go b/pkg/generators/generator.go index e062689d..4a6bf08b 100644 --- a/pkg/generators/generator.go +++ b/pkg/generators/generator.go @@ -15,17 +15,13 @@ package generators import ( - "context" - "github.com/pkg/errors" + "go.uber.org/zap" + "golang.org/x/exp/rand" - "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/routingkey" + "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" - - "go.uber.org/zap" - "golang.org/x/exp/rand" - "golang.org/x/sync/errgroup" ) // TokenIndex represents the position of a token in the token ring. @@ -49,7 +45,6 @@ type GeneratorInterface interface { } type Generator struct { - ctx context.Context logger *zap.Logger table *typedef.Table routingKeyCreator *routingkey.Creator @@ -65,12 +60,18 @@ type Generator struct { cntEmitted uint64 } -type Partitions []*Partition - func (g *Generator) PartitionCount() uint64 { return g.partitionCount } +type Generators []*Generator + +func (g Generators) StartAll(stopFlag *stop.Flag) { + for _, gen := range g { + gen.Start(stopFlag) + } +} + type Config struct { PartitionsDistributionFunc DistributionFunc PartitionsRangeConfig typedef.PartitionRangeConfig @@ -79,21 +80,10 @@ type Config struct { PkUsedBufferSize uint64 } -func NewGenerator(ctx context.Context, table *typedef.Table, config *Config, logger *zap.Logger) *Generator { +func NewGenerator(table *typedef.Table, config *Config, logger *zap.Logger) *Generator { wakeUpSignal := make(chan struct{}) - partitions := make([]*Partition, config.PartitionsCount) - for i := 0; i < len(partitions); i++ { - partitions[i] = &Partition{ - ctx: ctx, - values: make(chan *typedef.ValueWithToken, config.PkUsedBufferSize), - oldValues: make(chan *typedef.ValueWithToken, config.PkUsedBufferSize), - inFlight: inflight.New(), - wakeUpSignal: wakeUpSignal, - } - } - gs := &Generator{ - ctx: ctx, - partitions: partitions, + return &Generator{ + partitions: NewPartitions(int(config.PartitionsCount), int(config.PkUsedBufferSize), wakeUpSignal), partitionCount: config.PartitionsCount, table: table, partitionsConfig: config.PartitionsRangeConfig, @@ -102,74 +92,46 @@ func NewGenerator(ctx context.Context, table *typedef.Table, config *Config, log logger: logger, wakeUpSignal: wakeUpSignal, } - gs.start() - return gs -} - -func (g *Generator) isContextCanceled() bool { - select { - case <-g.ctx.Done(): - return true - default: - return false - } } func (g *Generator) Get() *typedef.ValueWithToken { - if g.isContextCanceled() { - return nil - } - partition := g.partitions[uint64(g.idxFunc())%g.partitionCount] - return partition.get() + return g.partitions.GetPartitionForToken(g.idxFunc()).get() } // GetOld returns a previously used value and token or a new if // the old queue is empty. func (g *Generator) GetOld() *typedef.ValueWithToken { - if g.isContextCanceled() { - return nil - } - return g.partitions[uint64(g.idxFunc())%g.partitionCount].getOld() + return g.partitions.GetPartitionForToken(g.idxFunc()).getOld() } // GiveOld returns the supplied value for later reuse unless func (g *Generator) GiveOld(v *typedef.ValueWithToken) { - if g.isContextCanceled() { - return - } - g.partitions[v.Token%g.partitionCount].giveOld(v) + g.partitions.GetPartitionForToken(TokenIndex(v.Token)).giveOld(v) } // ReleaseToken removes the corresponding token from the in-flight tracking. func (g *Generator) ReleaseToken(token uint64) { - if g.isContextCanceled() { - return - } - g.partitions[token%g.partitionCount].releaseToken(token) + g.partitions.GetPartitionForToken(TokenIndex(token)).releaseToken(token) } -func (g *Generator) start() { - grp, gCtx := errgroup.WithContext(g.ctx) - g.ctx = gCtx - for _, partition := range g.partitions { - partition.ctx = gCtx - } - grp.Go(func() error { +func (g *Generator) Start(stopFlag *stop.Flag) { + go func() { g.logger.Info("starting partition key generation loop") g.routingKeyCreator = &routingkey.Creator{} g.r = rand.New(rand.NewSource(g.seed)) + defer g.partitions.CloseAll() for { g.fillAllPartitions() select { - case <-gCtx.Done(): + case <-stopFlag.SignalChannel(): g.logger.Debug("stopping partition key generation loop", zap.Uint64("keys_created", g.cntCreated), zap.Uint64("keys_emitted", g.cntEmitted)) - return gCtx.Err() + return case <-g.wakeUpSignal: } } - }) + }() } // fillAllPartitions guarantees that each partition was tested to be full diff --git a/pkg/generators/generator_test.go b/pkg/generators/generator_test.go index df10bbe6..a0ccc7f7 100644 --- a/pkg/generators/generator_test.go +++ b/pkg/generators/generator_test.go @@ -15,14 +15,14 @@ package generators_test import ( - "context" "sync/atomic" "testing" + "go.uber.org/zap" + "github.com/scylladb/gemini/pkg/generators" + "github.com/scylladb/gemini/pkg/stop" "github.com/scylladb/gemini/pkg/typedef" - - "go.uber.org/zap" ) func TestGenerator(t *testing.T) { @@ -45,7 +45,8 @@ func TestGenerator(t *testing.T) { }, } logger, _ := zap.NewDevelopment() - generator := generators.NewGenerator(context.Background(), table, cfg, logger) + generator := generators.NewGenerator(table, cfg, logger) + generator.Start(stop.NewFlag("main_test")) for i := uint64(0); i < cfg.PartitionsCount; i++ { atomic.StoreUint64(¤t, i) v := generator.Get() diff --git a/pkg/generators/partition.go b/pkg/generators/partition.go index fc669c27..8b3d1cb8 100644 --- a/pkg/generators/partition.go +++ b/pkg/generators/partition.go @@ -15,18 +15,19 @@ package generators import ( - "context" + "sync" "github.com/scylladb/gemini/pkg/inflight" "github.com/scylladb/gemini/pkg/typedef" ) type Partition struct { - ctx context.Context values chan *typedef.ValueWithToken oldValues chan *typedef.ValueWithToken inFlight inflight.InFlight wakeUpSignal chan<- struct{} // wakes up generator + closed bool + lock sync.RWMutex } // get returns a new value and ensures that it's corresponding token @@ -44,8 +45,6 @@ func (s *Partition) get() *typedef.ValueWithToken { // the old queue is empty. func (s *Partition) getOld() *typedef.ValueWithToken { select { - case <-s.ctx.Done(): - return nil case v := <-s.oldValues: return v default: @@ -57,8 +56,12 @@ func (s *Partition) getOld() *typedef.ValueWithToken { // is empty in which case it removes the corresponding token from the // in-flight tracking. func (s *Partition) giveOld(v *typedef.ValueWithToken) { + ch := s.safelyGetOldValuesChannel() + if ch == nil { + return + } select { - case s.oldValues <- v: + case ch <- v: default: // Old partition buffer is full, just drop the value } @@ -88,3 +91,51 @@ func (s *Partition) pick() *typedef.ValueWithToken { return <-s.values } } + +func (s *Partition) safelyGetOldValuesChannel() chan *typedef.ValueWithToken { + s.lock.RLock() + if s.closed { + // Since only giveOld could have been potentially called after partition is closed + // we need to protect it against writing to closed channel + return nil + } + defer s.lock.RUnlock() + return s.oldValues +} + +func (s *Partition) safelyCloseOldValuesChannel() { + s.lock.Lock() + s.closed = true + close(s.oldValues) + s.lock.Unlock() +} + +func (s *Partition) Close() { + close(s.values) + s.safelyCloseOldValuesChannel() +} + +type Partitions []*Partition + +func (p Partitions) CloseAll() { + for _, part := range p { + part.Close() + } +} + +func (p Partitions) GetPartitionForToken(token TokenIndex) *Partition { + return p[uint64(token)%uint64(len(p))] +} + +func NewPartitions(count, pkBufferSize int, wakeUpSignal chan struct{}) Partitions { + partitions := make(Partitions, count) + for i := 0; i < len(partitions); i++ { + partitions[i] = &Partition{ + values: make(chan *typedef.ValueWithToken, pkBufferSize), + oldValues: make(chan *typedef.ValueWithToken, pkBufferSize), + inFlight: inflight.New(), + wakeUpSignal: wakeUpSignal, + } + } + return partitions +} diff --git a/pkg/jobs/jobs.go b/pkg/jobs/jobs.go index c065b5cf..a7b90fac 100644 --- a/pkg/jobs/jobs.go +++ b/pkg/jobs/jobs.go @@ -116,12 +116,11 @@ func (l List) Run( failFast, verbose bool, ) error { logger = logger.Named(l.name) - jCtx, jobCancel := context.WithCancel(ctx) - g, gCtx := errgroup.WithContext(jCtx) - stopFlag.SetOnHardStopHandler(jobCancel) + ctx = stopFlag.CancelContextOnSignal(ctx, stop.SignalHardStop) + g, gCtx := errgroup.WithContext(ctx) time.AfterFunc(l.duration, func() { logger.Info("jobs time is up, begins jobs completion") - stopFlag.SetSoft() + stopFlag.SetSoft(true) }) partitionRangeConfig := typedef.PartitionRangeConfig{ @@ -176,7 +175,7 @@ func mutationJob( return nil } select { - case <-ctx.Done(): + case <-stopFlag.SignalChannel(): logger.Debug("mutation job terminated") return nil case hb := <-pump: @@ -189,7 +188,7 @@ func mutationJob( _ = mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, true, logger) } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft() + stopFlag.SetSoft(true) return nil } } @@ -224,7 +223,7 @@ func validationJob( return nil } select { - case <-ctx.Done(): + case <-stopFlag.SignalChannel(): return nil case hb := <-pump: time.Sleep(hb) @@ -251,7 +250,7 @@ func validationJob( } if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft() + stopFlag.SetSoft(true) return nil } } @@ -282,18 +281,13 @@ func warmupJob( }() for { if stopFlag.IsHardOrSoft() { - return nil - } - select { - case <-ctx.Done(): logger.Debug("warmup job terminated") return nil - default: } // Do we care about errors during warmup? _ = mutation(ctx, schema, schemaConfig, table, s, r, p, g, globalStatus, false, logger) if failFast && globalStatus.HasErrors() { - stopFlag.SetSoft() + stopFlag.SetSoft(true) return nil } } diff --git a/pkg/stop/flag.go b/pkg/stop/flag.go index f135ad51..54c6e1f2 100644 --- a/pkg/stop/flag.go +++ b/pkg/stop/flag.go @@ -15,8 +15,11 @@ package stop import ( + "context" + "fmt" "os" "os/signal" + "sync" "sync/atomic" "syscall" @@ -24,46 +27,164 @@ import ( ) const ( - signalNoSignal uint32 = iota - signalSoftStop - signalHardStop + SignalNoop uint32 = iota + SignalSoftStop + SignalHardStop ) +type SignalChannel chan uint32 + +var closedChan = createClosedChan() + +func createClosedChan() SignalChannel { + ch := make(SignalChannel) + close(ch) + return ch +} + +type SyncList[T any] struct { + children []T + childrenLock sync.RWMutex +} + +func (f *SyncList[T]) Append(el T) { + f.childrenLock.Lock() + defer f.childrenLock.Unlock() + f.children = append(f.children, el) +} + +func (f *SyncList[T]) Get() []T { + f.childrenLock.RLock() + defer f.childrenLock.RUnlock() + return f.children +} + +type logger interface { + Debug(msg string, fields ...zap.Field) +} + type Flag struct { - hardStopHandler func() - val atomic.Uint32 // 1 - "soft stop";2 - "hard stop" + name string + log logger + ch atomic.Pointer[SignalChannel] + parent *Flag + children SyncList[*Flag] + stopHandlers SyncList[func(signal uint32)] + val atomic.Uint32 } -func (s *Flag) SetSoft() bool { - return s.val.CompareAndSwap(signalNoSignal, signalSoftStop) +func (s *Flag) Name() string { + return s.name } -func (s *Flag) SetHard() bool { - out := s.val.CompareAndSwap(signalNoSignal, signalHardStop) - if out && s.hardStopHandler != nil { - s.hardStopHandler() +func (s *Flag) closeChannel() { + ch := s.ch.Swap(&closedChan) + if ch != &closedChan { + close(*ch) + } +} + +func (s *Flag) sendSignal(signal uint32, sendToParent bool) bool { + s.log.Debug(fmt.Sprintf("flag %s received signal %s", s.name, GetStateName(signal))) + s.closeChannel() + out := s.val.CompareAndSwap(SignalNoop, signal) + if !out { + return false + } + + for _, handler := range s.stopHandlers.Get() { + handler(signal) + } + + for _, child := range s.children.Get() { + child.sendSignal(signal, sendToParent) + } + if sendToParent && s.parent != nil { + s.parent.sendSignal(signal, sendToParent) } return out } +func (s *Flag) SetHard(sendToParent bool) bool { + return s.sendSignal(SignalHardStop, sendToParent) +} + +func (s *Flag) SetSoft(sendToParent bool) bool { + return s.sendSignal(SignalSoftStop, sendToParent) +} + +func (s *Flag) CreateChild(name string) *Flag { + child := newFlag(name, s) + s.children.Append(child) + val := s.val.Load() + switch val { + case SignalSoftStop, SignalHardStop: + child.sendSignal(val, false) + } + return child +} + +func (s *Flag) SignalChannel() SignalChannel { + return *s.ch.Load() +} + func (s *Flag) IsSoft() bool { - return s.val.Load() == signalSoftStop + return s.val.Load() == SignalSoftStop } func (s *Flag) IsHard() bool { - return s.val.Load() == signalHardStop + return s.val.Load() == SignalHardStop } func (s *Flag) IsHardOrSoft() bool { - return s.val.Load() != signalNoSignal + return s.val.Load() != SignalNoop +} + +func (s *Flag) AddHandler(handler func(signal uint32)) { + s.stopHandlers.Append(handler) + val := s.val.Load() + switch val { + case SignalSoftStop, SignalHardStop: + handler(val) + } +} + +func (s *Flag) AddHandler2(handler func(), expectedSignal uint32) { + s.AddHandler(func(signal uint32) { + switch expectedSignal { + case SignalNoop: + handler() + default: + if signal == expectedSignal { + handler() + } + } + }) } -func (s *Flag) SetOnHardStopHandler(function func()) { - s.hardStopHandler = function +func (s *Flag) CancelContextOnSignal(ctx context.Context, expectedSignal uint32) context.Context { + ctx, cancel := context.WithCancel(ctx) + s.AddHandler2(cancel, expectedSignal) + return ctx } -func NewFlag() Flag { - return Flag{} +func (s *Flag) SetLogger(log logger) { + s.log = log +} + +func NewFlag(name string) *Flag { + return newFlag(name, nil) +} + +func newFlag(name string, parent *Flag) *Flag { + out := Flag{ + name: name, + parent: parent, + log: zap.NewNop(), + } + ch := make(SignalChannel) + out.ch.Store(&ch) + return &out } func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { @@ -74,14 +195,27 @@ func StartOsSignalsTransmitter(logger *zap.Logger, flags ...*Flag) { switch sig { case syscall.SIGINT: for i := range flags { - flags[i].SetSoft() + flags[i].SetSoft(true) } logger.Info("Get SIGINT signal, begin soft stop.") default: for i := range flags { - flags[i].SetHard() + flags[i].SetHard(true) } logger.Info("Get SIGTERM signal, begin hard stop.") } }() } + +func GetStateName(state uint32) string { + switch state { + case SignalSoftStop: + return "soft" + case SignalHardStop: + return "hard" + case SignalNoop: + return "no-signal" + default: + panic(fmt.Sprintf("unexpected signal %d", state)) + } +} diff --git a/pkg/stop/flag_test.go b/pkg/stop/flag_test.go index 8d5dbd6f..76e2b1f5 100644 --- a/pkg/stop/flag_test.go +++ b/pkg/stop/flag_test.go @@ -16,6 +16,8 @@ package stop_test import ( "context" + "errors" + "fmt" "reflect" "runtime" "strings" @@ -71,11 +73,10 @@ func TestSoftOrHardStop(t *testing.T) { } func initVars() (testFlag *stop.Flag, ctx context.Context, workersDone *atomic.Uint32) { - testFlagOut := stop.NewFlag() - ctx, cancelFunc := context.WithCancel(context.Background()) - testFlagOut.SetOnHardStopHandler(cancelFunc) + testFlagOut := stop.NewFlag("main_test") + ctx = testFlagOut.CancelContextOnSignal(context.Background(), stop.SignalHardStop) workersDone = &atomic.Uint32{} - return &testFlagOut, ctx, workersDone + return testFlagOut, ctx, workersDone } func testSignals( @@ -83,7 +84,7 @@ func testSignals( workersDone *atomic.Uint32, workers int, checkFunc func() bool, - setFunc func() bool, + setFunc func(propagation bool) bool, ) { t.Helper() for i := 0; i != workers; i++ { @@ -98,7 +99,7 @@ func testSignals( }() } time.Sleep(200 * time.Millisecond) - setFunc() + setFunc(false) for i := 0; i != 10; i++ { time.Sleep(100 * time.Millisecond) @@ -120,3 +121,306 @@ func testSignals( t.Errorf("Error:%s or %s functions works not correctly %[2]s=%v", setFuncName, checkFuncName, checkFunc()) } } + +func TestSendToParent(t *testing.T) { + t.Parallel() + tcases := []tCase{ + { + testName: "parent-hard-true", + parentSignal: stop.SignalHardStop, + child1Signal: stop.SignalHardStop, + child11Signal: stop.SignalHardStop, + child12Signal: stop.SignalHardStop, + child2Signal: stop.SignalHardStop, + }, + { + testName: "parent-hard-false", + parentSignal: stop.SignalHardStop, + child1Signal: stop.SignalHardStop, + child11Signal: stop.SignalHardStop, + child12Signal: stop.SignalHardStop, + child2Signal: stop.SignalHardStop, + }, + { + testName: "parent-soft-true", + parentSignal: stop.SignalSoftStop, + child1Signal: stop.SignalSoftStop, + child11Signal: stop.SignalSoftStop, + child12Signal: stop.SignalSoftStop, + child2Signal: stop.SignalSoftStop, + }, + { + testName: "parent-soft-false", + parentSignal: stop.SignalSoftStop, + child1Signal: stop.SignalSoftStop, + child11Signal: stop.SignalSoftStop, + child12Signal: stop.SignalSoftStop, + child2Signal: stop.SignalSoftStop, + }, + { + testName: "child1-soft-true", + parentSignal: stop.SignalSoftStop, + child1Signal: stop.SignalSoftStop, + child11Signal: stop.SignalSoftStop, + child12Signal: stop.SignalSoftStop, + child2Signal: stop.SignalSoftStop, + }, + { + testName: "child1-soft-false", + parentSignal: stop.SignalNoop, + child1Signal: stop.SignalSoftStop, + child11Signal: stop.SignalSoftStop, + child12Signal: stop.SignalSoftStop, + child2Signal: stop.SignalNoop, + }, + { + testName: "child11-soft-true", + parentSignal: stop.SignalSoftStop, + child1Signal: stop.SignalSoftStop, + child11Signal: stop.SignalSoftStop, + child12Signal: stop.SignalSoftStop, + child2Signal: stop.SignalSoftStop, + }, + { + testName: "child11-soft-false", + parentSignal: stop.SignalNoop, + child1Signal: stop.SignalNoop, + child11Signal: stop.SignalSoftStop, + child12Signal: stop.SignalNoop, + child2Signal: stop.SignalNoop, + }, + } + for id := range tcases { + tcase := tcases[id] + t.Run(tcase.testName, func(t *testing.T) { + t.Parallel() + if err := tcase.runTest(); err != nil { + t.Error(err) + } + }) + } +} + +// nolint: govet +type parentChildInfo struct { + parent *stop.Flag + parentSignal uint32 + child1 *stop.Flag + child1Signal uint32 + child11 *stop.Flag + child11Signal uint32 + child12 *stop.Flag + child12Signal uint32 + child2 *stop.Flag + child2Signal uint32 +} + +func (t *parentChildInfo) getFlag(flagName string) *stop.Flag { + switch flagName { + case "parent": + return t.parent + case "child1": + return t.child1 + case "child2": + return t.child2 + case "child11": + return t.child11 + case "child12": + return t.child12 + default: + panic(fmt.Sprintf("no such flag %s", flagName)) + } +} + +func (t *parentChildInfo) getFlagHandlerState(flagName string) uint32 { + switch flagName { + case "parent": + return t.parentSignal + case "child1": + return t.child1Signal + case "child2": + return t.child2Signal + case "child11": + return t.child11Signal + case "child12": + return t.child12Signal + default: + panic(fmt.Sprintf("no such flag %s", flagName)) + } +} + +func (t *parentChildInfo) checkFlagState(flag *stop.Flag, expectedState uint32) error { + var err error + flagName := flag.Name() + state := t.getFlagHandlerState(flagName) + if state != expectedState { + err = errors.Join(err, fmt.Errorf("flag %s handler has state %s while it is expected to be %s", flagName, stop.GetStateName(state), stop.GetStateName(expectedState))) + } + flagState := getFlagState(flag) + if stop.GetStateName(expectedState) != flagState { + err = errors.Join(err, fmt.Errorf("flag %s has state %s while it is expected to be %s", flagName, flagState, stop.GetStateName(expectedState))) + } + return err +} + +type tCase struct { + testName string + parentSignal uint32 + child1Signal uint32 + child11Signal uint32 + child12Signal uint32 + child2Signal uint32 +} + +func (t *tCase) runTest() error { + chunk := strings.Split(t.testName, "-") + if len(chunk) != 3 { + panic(fmt.Sprintf("wrong test name %s", t.testName)) + } + flagName := chunk[0] + signalTypeName := chunk[1] + sendToParentName := chunk[2] + + var sendToParent bool + switch sendToParentName { + case "true": + sendToParent = true + case "false": + sendToParent = false + default: + panic(fmt.Sprintf("wrong test name %s", t.testName)) + } + runt := newParentChildInfo() + flag := runt.getFlag(flagName) + switch signalTypeName { + case "soft": + flag.SetSoft(sendToParent) + case "hard": + flag.SetHard(sendToParent) + default: + panic(fmt.Sprintf("wrong test name %s", t.testName)) + } + var err error + err = errors.Join(err, runt.checkFlagState(runt.parent, t.parentSignal)) + err = errors.Join(err, runt.checkFlagState(runt.child1, t.child1Signal)) + err = errors.Join(err, runt.checkFlagState(runt.child2, t.child2Signal)) + err = errors.Join(err, runt.checkFlagState(runt.child11, t.child11Signal)) + err = errors.Join(err, runt.checkFlagState(runt.child12, t.child12Signal)) + return err +} + +func newParentChildInfo() *parentChildInfo { + parent := stop.NewFlag("parent") + child1 := parent.CreateChild("child1") + out := parentChildInfo{ + parent: parent, + child1: child1, + child11: child1.CreateChild("child11"), + child12: child1.CreateChild("child12"), + child2: parent.CreateChild("child2"), + } + + out.parent.AddHandler(func(signal uint32) { + out.parentSignal = signal + }) + out.child1.AddHandler(func(signal uint32) { + out.child1Signal = signal + }) + out.child11.AddHandler(func(signal uint32) { + out.child11Signal = signal + }) + out.child12.AddHandler(func(signal uint32) { + out.child12Signal = signal + }) + out.child2.AddHandler(func(signal uint32) { + out.child2Signal = signal + }) + return &out +} + +func getFlagState(flag *stop.Flag) string { + switch { + case flag.IsSoft(): + return "soft" + case flag.IsHard(): + return "hard" + default: + return "no-signal" + } +} + +func TestSignalChannel(t *testing.T) { + t.Parallel() + t.Run("single-no-signal", func(t *testing.T) { + t.Parallel() + flag := stop.NewFlag("parent") + select { + case <-flag.SignalChannel(): + t.Error("should not get the signal") + case <-time.Tick(200 * time.Millisecond): + } + }) + + t.Run("single-beforehand", func(t *testing.T) { + t.Parallel() + flag := stop.NewFlag("parent") + flag.SetSoft(true) + <-flag.SignalChannel() + }) + + t.Run("single-normal", func(t *testing.T) { + t.Parallel() + flag := stop.NewFlag("parent") + go func() { + time.Sleep(200 * time.Millisecond) + flag.SetSoft(true) + }() + <-flag.SignalChannel() + }) + + t.Run("parent-beforehand", func(t *testing.T) { + t.Parallel() + parent := stop.NewFlag("parent") + child := parent.CreateChild("child") + parent.SetSoft(true) + <-child.SignalChannel() + }) + + t.Run("parent-beforehand", func(t *testing.T) { + t.Parallel() + parent := stop.NewFlag("parent") + parent.SetSoft(true) + child := parent.CreateChild("child") + <-child.SignalChannel() + }) + + t.Run("parent-normal", func(t *testing.T) { + t.Parallel() + parent := stop.NewFlag("parent") + child := parent.CreateChild("child") + go func() { + time.Sleep(200 * time.Millisecond) + parent.SetSoft(true) + }() + <-child.SignalChannel() + }) + + t.Run("child-beforehand", func(t *testing.T) { + t.Parallel() + parent := stop.NewFlag("parent") + child := parent.CreateChild("child") + child.SetSoft(true) + <-parent.SignalChannel() + }) + + t.Run("child-normal", func(t *testing.T) { + t.Parallel() + parent := stop.NewFlag("parent") + child := parent.CreateChild("child") + go func() { + time.Sleep(200 * time.Millisecond) + child.SetSoft(true) + }() + <-parent.SignalChannel() + }) +}