diff --git a/examples/deadline_worker/deadlineworker.go b/examples/deadline_worker/deadlineworker.go index 3f91a26..832670c 100644 --- a/examples/deadline_worker/deadlineworker.go +++ b/examples/deadline_worker/deadlineworker.go @@ -5,7 +5,7 @@ package main import ( "context" "fmt" - worker "github.com/catmullet/go-workers" + "github.com/catmullet/go-workers" "time" ) @@ -13,14 +13,14 @@ func main() { ctx := context.Background() t := time.Now() - deadlineWorker := worker.NewWorker(ctx, NewDeadlineWorker(), 100). - SetDeadline(t.Add(200 * time.Millisecond)).Work() + deadlineWorker := workers.NewRunner(ctx, NewDeadlineWorker(), 100). + SetDeadline(t.Add(200 * time.Millisecond)).Start() for i := 0; i < 1000000; i++ { deadlineWorker.Send("hello") } - err := deadlineWorker.Close() + err := deadlineWorker.Wait() if err != nil { fmt.Println(err) } @@ -29,12 +29,12 @@ func main() { type DeadlineWorker struct{} -func NewDeadlineWorker() *DeadlineWorker { +func NewDeadlineWorker() workers.Worker { return &DeadlineWorker{} } -func (dlw *DeadlineWorker) Work(w *worker.Worker, in interface{}) error { - w.Println(in) +func (dlw *DeadlineWorker) Work(in interface{}, out chan<- interface{}) error { + fmt.Println(in) time.Sleep(1 * time.Second) return nil } diff --git a/examples/multiple_workers/multipleworkers.go b/examples/multiple_workers/multipleworkers.go index 7ca79af..86f76d7 100644 --- a/examples/multiple_workers/multipleworkers.go +++ b/examples/multiple_workers/multipleworkers.go @@ -5,50 +5,80 @@ package main import ( "context" "fmt" - worker "github.com/catmullet/go-workers" + "github.com/catmullet/go-workers" "math/rand" + "sync" +) + +var ( + count = make(map[string]int) + mut = sync.RWMutex{} ) func main() { ctx := context.Background() - workerOne := worker.NewWorker(ctx, NewWorkerOne(), 1000).Work() - workerTwo := worker.NewWorker(ctx, NewWorkerTwo(), 1000).InFrom(workerOne).Work() - for i := 0; i < 1000000; i++ { - workerOne.Send(rand.Intn(100)) - } + workerOne := workers.NewRunner(ctx, NewWorkerOne(), 1000).Start() + workerTwo := workers.NewRunner(ctx, NewWorkerTwo(), 1000).InFrom(workerOne).Start() - if err := workerOne.Close(); err != nil { - fmt.Println(err) - } + go func() { + for i := 0; i < 100000; i++ { + workerOne.Send(rand.Intn(100)) + } + if err := workerOne.Wait(); err != nil { + fmt.Println(err) + } + }() - if err := workerTwo.Close(); err != nil { + if err := workerTwo.Wait(); err != nil { fmt.Println(err) } + fmt.Println("worker_one", count["worker_one"]) + fmt.Println("worker_two", count["worker_two"]) fmt.Println("finished") } -type WorkerOne struct{} -type WorkerTwo struct{} +type WorkerOne struct { +} +type WorkerTwo struct { +} -func NewWorkerOne() *WorkerOne { +func NewWorkerOne() workers.Worker { return &WorkerOne{} } -func NewWorkerTwo() *WorkerTwo { +func NewWorkerTwo() workers.Worker { return &WorkerTwo{} } -func (wo *WorkerOne) Work(w *worker.Worker, in interface{}) error { +func (wo *WorkerOne) Work(in interface{}, out chan<- interface{}) error { + var workerOne = "worker_one" + mut.Lock() + if val, ok := count[workerOne]; ok { + count[workerOne] = val + 1 + } else { + count[workerOne] = 1 + } + mut.Unlock() + total := in.(int) * 2 - w.Println(fmt.Sprintf("%d * 2 = %d", in.(int), total)) - w.Out(total) + fmt.Println("worker1", fmt.Sprintf("%d * 2 = %d", in.(int), total)) + out <- total return nil } -func (wt *WorkerTwo) Work(w *worker.Worker, in interface{}) error { +func (wt *WorkerTwo) Work(in interface{}, out chan<- interface{}) error { + var workerTwo = "worker_two" + mut.Lock() + if val, ok := count[workerTwo]; ok { + count[workerTwo] = val + 1 + } else { + count[workerTwo] = 1 + } + mut.Unlock() + totalFromWorkerOne := in.(int) - w.Println(fmt.Sprintf("%d * 4 = %d", totalFromWorkerOne, totalFromWorkerOne*4)) + fmt.Println("worker2", fmt.Sprintf("%d * 4 = %d", totalFromWorkerOne, totalFromWorkerOne*4)) return nil } diff --git a/examples/passing_fields/passingfields.go b/examples/passing_fields/passingfields.go index 6fab9b7..b4cc768 100644 --- a/examples/passing_fields/passingfields.go +++ b/examples/passing_fields/passingfields.go @@ -5,27 +5,24 @@ package main import ( "context" "fmt" - worker "github.com/catmullet/go-workers" + "github.com/catmullet/go-workers" "math/rand" ) func main() { ctx := context.Background() - workerOne := worker.NewWorker(ctx, NewWorkerOne(2), 10). - Work() - workerTwo := worker.NewWorker(ctx, NewWorkerTwo(4), 10). - InFrom(workerOne). - Work() + workerOne := workers.NewRunner(ctx, NewWorkerOne(2), 100).Start() + workerTwo := workers.NewRunner(ctx, NewWorkerTwo(4), 100).InFrom(workerOne).Start() - for i := 0; i < 10; i++ { + for i := 0; i < 15; i++ { workerOne.Send(rand.Intn(100)) } - if err := workerOne.Close(); err != nil { + if err := workerOne.Wait(); err != nil { fmt.Println(err) } - if err := workerTwo.Close(); err != nil { + if err := workerTwo.Wait(); err != nil { fmt.Println(err) } } @@ -37,27 +34,27 @@ type WorkerTwo struct { amountToMultiply int } -func NewWorkerOne(amountToMultiply int) *WorkerOne { +func NewWorkerOne(amountToMultiply int) workers.Worker { return &WorkerOne{ amountToMultiply: amountToMultiply, } } -func NewWorkerTwo(amountToMultiply int) *WorkerTwo { +func NewWorkerTwo(amountToMultiply int) workers.Worker { return &WorkerTwo{ amountToMultiply, } } -func (wo *WorkerOne) Work(w *worker.Worker, in interface{}) error { +func (wo *WorkerOne) Work(in interface{}, out chan<- interface{}) error { total := in.(int) * wo.amountToMultiply - fmt.Println(fmt.Sprintf("%d * %d = %d", in.(int), wo.amountToMultiply, total)) - w.Out(total) + fmt.Println("worker1", fmt.Sprintf("%d * %d = %d", in.(int), wo.amountToMultiply, total)) + out <- total return nil } -func (wt *WorkerTwo) Work(w *worker.Worker, in interface{}) error { +func (wt *WorkerTwo) Work(in interface{}, out chan<- interface{}) error { totalFromWorkerOne := in.(int) - fmt.Println(fmt.Sprintf("%d * %d = %d", totalFromWorkerOne, wt.amountToMultiply, totalFromWorkerOne*wt.amountToMultiply)) + fmt.Println("worker2", fmt.Sprintf("%d * %d = %d", totalFromWorkerOne, wt.amountToMultiply, totalFromWorkerOne*wt.amountToMultiply)) return nil } diff --git a/examples/quickstart/quickstart.go b/examples/quickstart/quickstart.go index 70fc8dc..2829039 100644 --- a/examples/quickstart/quickstart.go +++ b/examples/quickstart/quickstart.go @@ -5,7 +5,7 @@ package main import ( "context" "fmt" - worker "github.com/catmullet/go-workers" + "github.com/catmullet/go-workers" "math/rand" "time" ) @@ -13,13 +13,13 @@ import ( func main() { ctx := context.Background() t := time.Now() - w := worker.NewWorker(ctx, NewWorker(), 1000).Work() + rnr := workers.NewRunner(ctx, NewWorker(), 100).Start() for i := 0; i < 1000000; i++ { - w.Send(rand.Intn(100)) + rnr.Send(rand.Intn(100)) } - if err := w.Close(); err != nil { + if err := rnr.Wait(); err != nil { fmt.Println(err) } @@ -27,15 +27,15 @@ func main() { fmt.Printf("total time %dms\n", totalTime) } -type Worker struct { +type WorkerOne struct { } -func NewWorker() *Worker { - return &Worker{} +func NewWorker() workers.Worker { + return &WorkerOne{} } -func (wo *Worker) Work(w *worker.Worker, in interface{}) error { +func (wo *WorkerOne) Work(in interface{}, out chan<- interface{}) error { total := in.(int) * 2 - defer w.Println(fmt.Sprintf("%d * 2 = %d", in.(int), total)) + fmt.Println(fmt.Sprintf("%d * 2 = %d", in.(int), total)) return nil } diff --git a/examples/timeout_worker/timeoutworker.go b/examples/timeout_worker/timeoutworker.go index 1670406..f5226a1 100644 --- a/examples/timeout_worker/timeoutworker.go +++ b/examples/timeout_worker/timeoutworker.go @@ -5,21 +5,20 @@ package main import ( "context" "fmt" - worker "github.com/catmullet/go-workers" + "github.com/catmullet/go-workers" "time" ) func main() { ctx := context.Background() - timeoutWorker := worker.NewWorker(ctx, NewTimeoutWorker(), 10).Work() - timeoutWorker.SetTimeout(100 * time.Millisecond) + timeoutWorker := workers.NewRunner(ctx, NewTimeoutWorker(), 10).SetTimeout(100 * time.Millisecond).Start() for i := 0; i < 1000000; i++ { timeoutWorker.Send("hello") } - err := timeoutWorker.Close() + err := timeoutWorker.Wait() if err != nil { fmt.Println(err) } @@ -27,11 +26,11 @@ func main() { type TimeoutWorker struct{} -func NewTimeoutWorker() *TimeoutWorker { +func NewTimeoutWorker() workers.Worker { return &TimeoutWorker{} } -func (tw *TimeoutWorker) Work(w *worker.Worker, in interface{}) error { +func (tw *TimeoutWorker) Work(in interface{}, out chan<- interface{}) error { fmt.Println(in) time.Sleep(1 * time.Second) return nil diff --git a/go.mod b/go.mod index 75dba99..684b7a6 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,3 @@ module github.com/catmullet/go-workers go 1.15 - -require golang.org/x/sync v0.0.0-20210220032951-036812b2e83c diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/workers.go b/workers.go index e302297..230b388 100644 --- a/workers.go +++ b/workers.go @@ -1,266 +1,253 @@ package workers import ( - "bufio" "context" "errors" - "fmt" - "golang.org/x/sync/semaphore" - "io" "os" "os/signal" "sync" + "syscall" "time" ) -const ( - internalBufferFlushLimit = 512 - signalChannelBufferSize = 1 -) - -var ( - // Exported variables - // ErrOutChannelUpdate When the output channel is not able to be updated. - ErrOutChannelUpdate = errors.New("out channel already set") -) - -// WorkerObject interface to be implemented -type WorkerObject interface { - Work(w *Worker, in interface{}) error -} - -// Worker The object to hold all necessary configuration and channels for the worker -// only accessible by it's methods. -type Worker struct { - Ctx context.Context - workerFunction WorkerObject - err error - numberOfWorkers int64 - inChan chan interface{} - outChan chan interface{} - sema *semaphore.Weighted - sigChan chan os.Signal - timeout time.Duration - cancel context.CancelFunc - writer *bufio.Writer - mu *sync.RWMutex - wg *sync.WaitGroup - once *sync.Once -} - -// NewWorker factory method to return new Worker -func NewWorker(ctx context.Context, workerFunction WorkerObject, numberOfWorkers int64) (worker *Worker) { - c, cancel := context.WithCancel(ctx) - return &Worker{ - numberOfWorkers: numberOfWorkers, - Ctx: c, - workerFunction: workerFunction, - inChan: make(chan interface{}, numberOfWorkers), - sema: semaphore.NewWeighted(numberOfWorkers), - sigChan: make(chan os.Signal, signalChannelBufferSize), - timeout: time.Duration(0), - cancel: cancel, - writer: bufio.NewWriter(os.Stdout), - wg: new(sync.WaitGroup), - mu: new(sync.RWMutex), - once: new(sync.Once), +var defaultWatchSignals = []os.Signal{syscall.SIGINT, syscall.SIGTERM, syscall.SIGKILL} + +// Worker Contains the work function. Allows an input and output to a channel or another worker for pipeline work. +// Return nil if you want the Runner to continue otherwise any error will cause the Runner to shutdown and return the +// error. +type Worker interface { + Work(in interface{}, out chan<- interface{}) error +} + +// Runner Handles the running the Worker logic. +type Runner interface { + BeforeFunc(func(ctx context.Context) error) Runner + AfterFunc(func(ctx context.Context, err error) error) Runner + SetDeadline(t time.Time) Runner + SetTimeout(duration time.Duration) Runner + SetFollower() + Send(in interface{}) + InFrom(w ...Runner) Runner + SetOut(chan interface{}) + Start() Runner + Stop() chan error + Wait() error +} + +type runner struct { + ctx context.Context + cancel context.CancelFunc + inChan chan interface{} + outChan chan interface{} + errChan chan error + signalChan chan os.Signal + limiter chan struct{} + + afterFunc func(ctx context.Context, err error) error + workFunc func(in interface{}, out chan<- interface{}) error + beforeFunc func(ctx context.Context) error + + timeout time.Duration + deadline time.Duration + + isLeader bool + stopCalled bool + + numWorkers int64 + lock *sync.RWMutex + wg *sync.WaitGroup + done *sync.Once + once *sync.Once +} + +// NewRunner Factory function for a new Runner. The Runner will handle running the workers logic. +func NewRunner(ctx context.Context, w Worker, numWorkers int64) Runner { + var runnerCtx, runnerCancel = context.WithCancel(ctx) + var runner = &runner{ + ctx: runnerCtx, + cancel: runnerCancel, + inChan: make(chan interface{}, numWorkers), + outChan: nil, + errChan: make(chan error, 1), + signalChan: make(chan os.Signal, 1), + limiter: make(chan struct{}, numWorkers), + afterFunc: func(ctx context.Context, err error) error { return err }, + workFunc: w.Work, + beforeFunc: func(ctx context.Context) error { return nil }, + numWorkers: numWorkers, + isLeader: true, + lock: new(sync.RWMutex), + wg: new(sync.WaitGroup), + once: new(sync.Once), + done: new(sync.Once), } + runner.waitForSignal(defaultWatchSignals...) + return runner } -// Send wrapper to send interface through workers "in" channel -func (iw *Worker) Send(in interface{}) { +// Send Send an object to the worker for processing. +func (r *runner) Send(in interface{}) { select { - case <-iw.IsDone(): + case <-r.ctx.Done(): return - case iw.inChan <- in: - return - } -} - -// InFrom assigns workers out channel to this workers in channel -func (iw *Worker) InFrom(inWorker ...*Worker) *Worker { - for _, worker := range inWorker { - worker.outChan = iw.inChan + case r.inChan <- in: } - return iw } -// Work start up the number of workers specified by the numberOfWorkers variable -func (iw *Worker) Work() *Worker { - if iw.timeout > 0 { - iw.Ctx, iw.cancel = context.WithTimeout(iw.Ctx, iw.timeout) +// InFrom Set a worker to accept output from another worker(s). +func (r *runner) InFrom(w ...Runner) Runner { + r.SetFollower() + for _, wr := range w { + wr.SetOut(r.inChan) } - iw.wg.Add(1) - go func() { - defer iw.wg.Done() - var wg = new(sync.WaitGroup) - for { - select { - case <-iw.IsDone(): - wg.Wait() - if len(iw.inChan) > 0 { - continue - } - if iw.err == nil { - iw.err = context.Canceled - } - return - case in := <-iw.inChan: - err := iw.sema.Acquire(iw.Ctx, 1) - if err != nil { - return - } - wg.Add(1) - go func(in interface{}) { - defer wg.Done() - defer iw.sema.Release(1) - if err := iw.workerFunction.Work(iw, in); err != nil { - iw.once.Do(func() { - iw.err = err - if iw.cancel != nil { - iw.cancel() - } - }) - return - } - }(in) - } - } - }() - return iw + return r } -// OutChannel Sets the workers output channel to one provided. -// If the worker already has a child worker attached this function will return an error (workers.ErrOutChannelUpdate). -func (iw *Worker) OutChannel(out chan interface{}) error { - if iw.outChan != nil { - return ErrOutChannelUpdate - } - iw.outChan = out - return nil +// SetFollower Sets the worker as a follower and does not need to close it's in channel. +func (r *runner) SetFollower() { + r.lock.Lock() + r.isLeader = false + r.lock.Unlock() } -// InChannel Returns the workers intake channel for use with legacy systems, otherwise use workers Send() method. -func (iw *Worker) InChannel() chan interface{} { - return iw.inChan +// Start Starts the worker on processing. +func (r *runner) Start() Runner { + r.startWork() + return r } -// Out pushes value to workers out channel -func (iw *Worker) Out(out interface{}) { - select { - case <-iw.Ctx.Done(): - return - case iw.outChan <- out: - return - } +// BeforeFunc Function to be run before worker starts processing. +func (r *runner) BeforeFunc(f func(ctx context.Context) error) Runner { + r.beforeFunc = f + return r } -// CancelOnSignal will send cancel to workers when signals specified are received -func (iw *Worker) CancelOnSignal(signals ...os.Signal) *Worker { - if len(signals) > 0 { - iw.waitForSignal(signals...) - } - return iw +// AfterFunc Function to be run after worker has stopped. +func (r *runner) AfterFunc(f func(ctx context.Context, err error) error) Runner { + r.afterFunc = f + return r } -// waitForSignal make sure we wait for a term signal and shutdown correctly -func (iw *Worker) waitForSignal(signals ...os.Signal) { - go func() { - signal.Notify(iw.sigChan, signals...) - <-iw.sigChan - if iw.cancel != nil { - iw.cancel() - } - }() -} - -// Wait waits for all the workers to finish up -func (iw *Worker) Wait() (err error) { - iw.wg.Wait() - if iw.cancel != nil { - iw.cancel() +// SetOut Allows the setting of a workers out channel, if not already set. +func (r *runner) SetOut(c chan interface{}) { + if r.outChan != nil { + return } - return iw.err + r.outChan = c } // SetDeadline allows a time to be set when the workers should stop. // Deadline needs to be handled by the IsDone method. -func (iw *Worker) SetDeadline(t time.Time) *Worker { - iw.mu.Lock() - defer iw.mu.Unlock() - iw.Ctx, iw.cancel = context.WithDeadline(iw.Ctx, t) - return iw +func (r *runner) SetDeadline(t time.Time) Runner { + r.lock.Lock() + defer r.lock.Unlock() + r.ctx, r.cancel = context.WithDeadline(r.ctx, t) + return r } // SetTimeout allows a time duration to be set when the workers should stop. // Timeout needs to be handled by the IsDone method. -func (iw *Worker) SetTimeout(duration time.Duration) *Worker { - iw.mu.Lock() - defer iw.mu.Unlock() - iw.timeout = duration - return iw -} - -// IsDone returns a context's cancellation or error -func (iw *Worker) IsDone() <-chan struct{} { - return iw.Ctx.Done() -} - -// SetWriterOut sets the writer for the Print* functions -// (ex. -// f, err := os.Create("output.txt")) -// defer f.Close() -// worker.SetWriteOut(f) -// ) -// If you have to print anything to stdout using the provided -// Print functions can significantly improve performance by using -// buffered output. -func (iw *Worker) SetWriterOut(writer io.Writer) *Worker { - iw.writer.Reset(writer) - return iw +func (r *runner) SetTimeout(duration time.Duration) Runner { + r.lock.Lock() + defer r.lock.Unlock() + r.timeout = duration + return r +} + +// Wait calls stop on workers and waits for the channel to drain. +// !!Should only be called when certain nothing will send to worker. +func (r *runner) Wait() (err error) { + r.waitForDrain() + if err = <-r.Stop(); err != nil || !errors.Is(err, context.Canceled) { + return + } + return nil } -// Println prints line output of a -func (iw *Worker) Println(a ...interface{}) { - iw.mu.Lock() - defer iw.mu.Unlock() - iw.internalBufferFlush() - _, _ = iw.writer.WriteString(fmt.Sprintln(a...)) +// Stop Stops the processing of a worker and closes it's channel in. +// Returns a blocking channel with type error. +// !!Should only be called when certain nothing will send to worker. +func (r *runner) Stop() chan error { + r.done.Do(func() { + if r.inChan != nil && r.isLeader { + close(r.inChan) + } + }) + return r.errChan } -// Printf prints based on format provided and a -func (iw *Worker) Printf(format string, a ...interface{}) { - iw.mu.Lock() - defer iw.mu.Unlock() - iw.internalBufferFlush() - _, _ = iw.writer.WriteString(fmt.Sprintf(format, a...)) +// IsDone returns a channel signaling the workers context has been canceled. +func (r *runner) IsDone() <-chan struct{} { + return r.ctx.Done() } -// Print prints output of a -func (iw *Worker) Print(a ...interface{}) { - iw.mu.Lock() - defer iw.mu.Unlock() - iw.internalBufferFlush() - _, _ = iw.writer.WriteString(fmt.Sprint(a...)) +// waitForSignal make sure we wait for a term signal and shutdown correctly +func (r *runner) waitForSignal(signals ...os.Signal) { + go func() { + signal.Notify(r.signalChan, signals...) + <-r.signalChan + if r.cancel != nil { + r.cancel() + } + }() } -// internalBufferFlush makes sure we haven't used up the available buffer -// by flushing the buffer when we get below the danger zone. -func (iw *Worker) internalBufferFlush() { - if iw.writer.Available() < internalBufferFlushLimit { - _ = iw.writer.Flush() +// waitForDrain Waits for the limiter to be zeroed out and the in channel to be empty. +func (r *runner) waitForDrain() { + for len(r.limiter) > 0 || len(r.inChan) > 0 { + // Wait for the drain. } } -// Close Note that it is only necessary to close a channel if the receiver is -// looking for a close. Closing the channel is a control signal on the -// channel indicating that no more data follows. Thus it makes sense to only close the -// in channel on the worker. For now we will just send the cancel signal -func (iw *Worker) Close() error { - iw.cancel() - defer func() { _ = iw.writer.Flush() }() - if err := iw.Wait(); err != nil && !errors.Is(err, context.Canceled) { - return err +// startWork Runs the before function and starts processing until one of three things happen. +// 1. A term signal is received or cancellation of context. +// 2. Stop function is called. +// 3. Worker returns an error. +func (r *runner) startWork() { + var err error + if err = r.beforeFunc(r.ctx); err != nil { + r.errChan <- err + return } - return nil + if r.timeout > 0 { + r.ctx, r.cancel = context.WithTimeout(r.ctx, r.timeout) + } + r.wg.Add(1) + go func() { + var workerWG = new(sync.WaitGroup) + var closeOnce = new(sync.Once) + defer func() { + workerWG.Wait() + r.errChan <- err + closeOnce.Do(func() { + if r.outChan != nil { + close(r.outChan) + } + }) + r.wg.Done() + }() + for in := range r.inChan { + select { + case <-r.ctx.Done(): + err = context.Canceled + continue + default: + r.limiter <- struct{}{} + workerWG.Add(1) + go func() { + defer func() { + <-r.limiter + workerWG.Done() + }() + if workErr := r.workFunc(in, r.outChan); workErr != nil { + r.once.Do(func() { + errors.As(err, &workErr) + r.cancel() + return + }) + } + }() + } + } + }() } diff --git a/workers_test.go b/workers_test.go index 8447378..a1e6b11 100644 --- a/workers_test.go +++ b/workers_test.go @@ -6,79 +6,106 @@ import ( "fmt" "math/rand" "os" - "path/filepath" "runtime" "runtime/debug" + "sync" "testing" "time" ) const ( - workerCount = 10000 + workerCount = 1000 workerTimeout = time.Millisecond * 300 - RunTimes = 1000000 + runTimes = 100000 ) +type WorkerOne struct { +} +type WorkerTwo struct { +} + +func NewWorkerOne() Worker { + return &WorkerOne{} +} + +func NewWorkerTwo() Worker { + return &WorkerTwo{} +} + +func (wo *WorkerOne) Work(in interface{}, out chan<- interface{}) error { + var workerOne = "worker_one" + mut.Lock() + if val, ok := count[workerOne]; ok { + count[workerOne] = val + 1 + } else { + count[workerOne] = 1 + } + mut.Unlock() + + total := in.(int) * 2 + out <- total + return nil +} + +func (wt *WorkerTwo) Work(in interface{}, out chan<- interface{}) error { + var workerTwo = "worker_two" + mut.Lock() + if val, ok := count[workerTwo]; ok { + count[workerTwo] = val + 1 + } else { + count[workerTwo] = 1 + } + mut.Unlock() + return nil +} + var ( + count = make(map[string]int) + mut = sync.RWMutex{} err = errors.New("test error") deadline = func() time.Time { return time.Now().Add(workerTimeout) } workerTestScenarios = []workerTest{ { - name: "work basic", - workerObject: NewTestWorkerObject(workBasic()), - numWorkers: workerCount, + name: "work basic", + worker: NewTestWorkerObject(workBasic()), + numWorkers: workerCount, }, { - name: "work basic with Println", - workerObject: NewTestWorkerObject(workBasicPrintln()), - numWorkers: workerCount, + name: "work basic with timeout", + timeout: workerTimeout, + worker: NewTestWorkerObject(workBasic()), + numWorkers: workerCount, }, { - name: "work basic with Printf", - workerObject: NewTestWorkerObject(workBasicPrintf()), - numWorkers: workerCount, + name: "work basic with deadline", + deadline: deadline, + worker: NewTestWorkerObject(workBasic()), + numWorkers: workerCount, }, { - name: "work basic with Print", - workerObject: NewTestWorkerObject(workBasicPrint()), - numWorkers: workerCount, + name: "work with return of error", + worker: NewTestWorkerObject(workWithError(err)), + errExpected: true, + numWorkers: workerCount, }, { - name: "work basic with timeout", - timeout: workerTimeout, - workerObject: NewTestWorkerObject(workBasic()), - numWorkers: workerCount, + name: "work with return of error with timeout", + timeout: workerTimeout, + worker: NewTestWorkerObject(workWithError(err)), + errExpected: true, + numWorkers: workerCount, }, { - name: "work basic with deadline", - deadline: deadline, - workerObject: NewTestWorkerObject(workBasic()), - numWorkers: workerCount, - }, - { - name: "work with return of error", - workerObject: NewTestWorkerObject(workWithError(err)), - errExpected: true, - numWorkers: workerCount, - }, - { - name: "work with return of error with timeout", - timeout: workerTimeout, - workerObject: NewTestWorkerObject(workWithError(err)), - errExpected: true, - numWorkers: workerCount, - }, - { - name: "work with return of error with deadline", - deadline: deadline, - workerObject: NewTestWorkerObject(workWithError(err)), - errExpected: true, - numWorkers: workerCount, + name: "work with return of error with deadline", + deadline: deadline, + worker: NewTestWorkerObject(workWithError(err)), + errExpected: true, + numWorkers: workerCount, }, } - getWorker = func(ctx context.Context, wt workerTest) *Worker { - worker := NewWorker(ctx, wt.workerObject, wt.numWorkers) + getWorker = func(ctx context.Context, wt workerTest) Runner { + worker := NewRunner(ctx, wt.worker, wt.numWorkers) if wt.timeout > 0 { worker.SetTimeout(wt.timeout) } @@ -90,74 +117,50 @@ var ( ) type workerTest struct { - name string - timeout time.Duration - deadline func() time.Time - workerObject WorkerObject - numWorkers int64 - testSignal bool - errExpected bool + name string + timeout time.Duration + deadline func() time.Time + worker Worker + numWorkers int64 + testSignal bool + errExpected bool } type TestWorkerObject struct { - workFunc func(w *Worker, in interface{}) error + workFunc func(in interface{}, out chan<- interface{}) error } -func NewTestWorkerObject(wf func(w *Worker, in interface{}) error) *TestWorkerObject { +func NewTestWorkerObject(wf func(in interface{}, out chan<- interface{}) error) Worker { return &TestWorkerObject{wf} } -func (tw *TestWorkerObject) Work(w *Worker, in interface{}) error { - return tw.workFunc(w, in) +func (tw *TestWorkerObject) Work(in interface{}, out chan<- interface{}) error { + return tw.workFunc(in, out) } -func workBasicNoOut() func(w *Worker, in interface{}) error { - return func(w *Worker, in interface{}) error { +func workBasicNoOut() func(in interface{}, out chan<- interface{}) error { + return func(in interface{}, out chan<- interface{}) error { _ = in.(int) return nil } } -func workBasicPrintln() func(w *Worker, in interface{}) error { - return func(w *Worker, in interface{}) error { +func workBasic() func(in interface{}, out chan<- interface{}) error { + return func(in interface{}, out chan<- interface{}) error { i := in.(int) - w.Println(i) + out <- i return nil } } -func workBasicPrintf() func(w *Worker, in interface{}) error { - return func(w *Worker, in interface{}) error { - i := in.(int) - w.Printf("test_number:%d", i) - return nil - } -} - -func workBasicPrint() func(w *Worker, in interface{}) error { - return func(w *Worker, in interface{}) error { - i := in.(int) - w.Print(i) - return nil - } -} - -func workBasic() func(w *Worker, in interface{}) error { - return func(w *Worker, in interface{}) error { - i := in.(int) - w.Out(i) - return nil - } -} - -func workWithError(err error) func(w *Worker, in interface{}) error { - return func(w *Worker, in interface{}) error { +func workWithError(err error) func(in interface{}, out chan<- interface{}) error { + return func(in interface{}, out chan<- interface{}) error { i := in.(int) total := i * rand.Intn(1000) if i == 100 { return err } - w.Out(total) + out <- total return nil } } @@ -170,27 +173,22 @@ func TestMain(m *testing.M) { } func TestWorkers(t *testing.T) { - f, err := os.Create(filepath.Join(os.TempDir(), "testfile.txt")) - if err != nil { - t.Fail() - } for _, tt := range workerTestScenarios { t.Run(tt.name, func(t *testing.T) { ctx := context.Background() - workerOne := getWorker(ctx, tt).SetWriterOut(f).Work() + workerOne := getWorker(ctx, tt).Start() // always need a consumer for the out tests so using basic here. - workerTwo := NewWorker(ctx, NewTestWorkerObject(workBasicNoOut()), workerCount) - workerTwo.InFrom(workerOne).Work() + workerTwo := NewRunner(ctx, NewTestWorkerObject(workBasicNoOut()), workerCount).InFrom(workerOne).Start() - for i := 0; i < RunTimes; i++ { + for i := 0; i < runTimes; i++ { workerOne.Send(i) } - if err := workerOne.Close(); err != nil && !tt.errExpected { + if err := workerOne.Wait(); err != nil && !tt.errExpected { fmt.Println(err) t.Fail() } - if err := workerTwo.Close(); err != nil && !tt.errExpected { + if err := workerTwo.Wait(); err != nil && !tt.errExpected { fmt.Println(err) t.Fail() } @@ -198,16 +196,46 @@ func TestWorkers(t *testing.T) { } } +func TestWorkersFinish(t *testing.T) { + ctx := context.Background() + workerOne := NewRunner(ctx, NewWorkerOne(), 1000).Start() + workerTwo := NewRunner(ctx, NewWorkerTwo(), 1000).InFrom(workerOne).Start() + + for i := 0; i < 100000; i++ { + workerOne.Send(rand.Intn(100)) + } + + if err := workerOne.Wait(); err != nil { + fmt.Println(err) + } + + if err := workerTwo.Wait(); err != nil { + fmt.Println(err) + } + + if count["worker_one"] != 100000 { + fmt.Println("worker one failed to finish,", "worker_one count", count["worker_one"], "/ 100000") + t.Fail() + } + if count["worker_two"] != 100000 { + fmt.Println("worker two failed to finish,", "worker_two count", count["worker_two"], "/ 100000") + t.Fail() + } +} + func BenchmarkGoWorkers(b *testing.B) { ctx := context.Background() - worker := NewWorker(ctx, NewTestWorkerObject(workBasicNoOut()), RunTimes).Work() - defer worker.Close() + worker := NewRunner(ctx, NewTestWorkerObject(workBasicNoOut()), workerCount).Start() b.StartTimer() for i := 0; i < b.N; i++ { - for j := 0; j < RunTimes; j++ { + for j := 0; j < runTimes; j++ { worker.Send(j) } } + b.StopTimer() + if err := worker.Wait(); err != nil { + b.Error(err) + } }