From cb7eb2972cc828dd30a2b046fd5636e96bdc4358 Mon Sep 17 00:00:00 2001 From: Robert Catmull Date: Thu, 4 Mar 2021 16:14:53 -0700 Subject: [PATCH] Cleanup worker function and add semaphore for controlling worker counts. --- go.mod | 1 + workers.go | 32 ++++++++++++++++---------------- workers_test.go | 12 +----------- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/go.mod b/go.mod index a4775a3..5fa12c8 100644 --- a/go.mod +++ b/go.mod @@ -5,4 +5,5 @@ go 1.15 require ( github.com/panjf2000/ants v1.2.0 github.com/stretchr/testify v1.7.0 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c ) diff --git a/workers.go b/workers.go index 32c8257..c169443 100644 --- a/workers.go +++ b/workers.go @@ -5,6 +5,7 @@ import ( "context" "errors" "fmt" + "golang.org/x/sync/semaphore" "io" "os" "os/signal" @@ -14,7 +15,6 @@ import ( const ( internalBufferFlushLimit = 512 - minNumberOfWorkersLimit = 1 signalChannelBufferSize = 1 ) @@ -29,9 +29,10 @@ type Worker struct { Ctx context.Context workerFunction WorkerObject err error - numberOfWorkers int + numberOfWorkers int64 inChan chan interface{} outChan chan interface{} + sema *semaphore.Weighted sigChan chan os.Signal timeout time.Duration cancel context.CancelFunc @@ -42,14 +43,14 @@ type Worker struct { } // NewWorker factory method to return new Worker -func NewWorker(ctx context.Context, workerFunction WorkerObject, numberOfWorkers int) (worker *Worker) { +func NewWorker(ctx context.Context, workerFunction WorkerObject, numberOfWorkers int64) (worker *Worker) { c, cancel := context.WithCancel(ctx) - numberOfWorkers = getCorrectWorkers(numberOfWorkers) return &Worker{ numberOfWorkers: numberOfWorkers, Ctx: c, workerFunction: workerFunction, inChan: make(chan interface{}, numberOfWorkers), + sema: semaphore.NewWeighted(numberOfWorkers), sigChan: make(chan os.Signal, signalChannelBufferSize), timeout: time.Duration(0), cancel: cancel, @@ -83,12 +84,14 @@ func (iw *Worker) Work() *Worker { if iw.timeout > 0 { iw.Ctx, iw.cancel = context.WithTimeout(iw.Ctx, iw.timeout) } + iw.wg.Add(1) go func() { - iw.wg.Add(1) defer iw.wg.Done() + var wg = new(sync.WaitGroup) for { select { case <-iw.IsDone(): + wg.Wait() if len(iw.inChan) > 0 { continue } @@ -97,9 +100,14 @@ func (iw *Worker) Work() *Worker { } return case in := <-iw.inChan: - iw.wg.Add(1) + err := iw.sema.Acquire(iw.Ctx, 1) + if err != nil { + return + } + wg.Add(1) go func(in interface{}) { - defer iw.wg.Done() + defer wg.Done() + defer iw.sema.Release(1) if err := iw.workerFunction.Work(iw, in); err != nil { iw.once.Do(func() { iw.err = err @@ -107,6 +115,7 @@ func (iw *Worker) Work() *Worker { iw.cancel() } }) + return } }(in) } @@ -144,15 +153,6 @@ func (iw *Worker) waitForSignal(signals ...os.Signal) { }() } -// getCorrectWorkers don't let oversizing occur on workers -func getCorrectWorkers(numberOfWorkers int) int { - if numberOfWorkers < minNumberOfWorkersLimit { - numberOfWorkers = minNumberOfWorkersLimit - } - - return numberOfWorkers -} - // Wait waits for all the workers to finish up func (iw *Worker) Wait() (err error) { iw.wg.Wait() diff --git a/workers_test.go b/workers_test.go index ac67a84..4bc06bd 100644 --- a/workers_test.go +++ b/workers_test.go @@ -49,16 +49,6 @@ var ( workerObject: NewTestWorkerObject(workBasicPrint()), numWorkers: workerCount, }, - { - name: "work basic less than minimum worker count", - workerObject: NewTestWorkerObject(workBasic()), - numWorkers: 0, - }, - { - name: "work basic more than maximum worker count", - workerObject: NewTestWorkerObject(workBasic()), - numWorkers: 20000, - }, { name: "work basic with timeout", timeout: workerTimeout, @@ -115,7 +105,7 @@ type workerTest struct { timeout time.Duration deadline func() time.Time workerObject WorkerObject - numWorkers int + numWorkers int64 testSignal bool errExpected bool }