Skip to content

Commit

Permalink
Refactor weighted round robin scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
Shaddoll committed Jan 23, 2025
1 parent 70edd8d commit a0aed21
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 136 deletions.
18 changes: 11 additions & 7 deletions common/dynamicconfig/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -5084,13 +5084,9 @@ var MapKeys = map[MapKey]DynamicMap{
DefaultValue: definition.GetDefaultIndexedKeys(),
},
TaskSchedulerRoundRobinWeights: {
KeyName: "history.taskSchedulerRoundRobinWeight",
Description: "TaskSchedulerRoundRobinWeights is the priority weight for weighted round robin task scheduler",
DefaultValue: common.ConvertIntMapToDynamicConfigMapProperty(map[int]int{
common.GetTaskPriority(common.HighPriorityClass, common.DefaultPrioritySubclass): 500,
common.GetTaskPriority(common.DefaultPriorityClass, common.DefaultPrioritySubclass): 20,
common.GetTaskPriority(common.LowPriorityClass, common.DefaultPrioritySubclass): 5,
}),
KeyName: "history.taskSchedulerRoundRobinWeight",
Description: "TaskSchedulerRoundRobinWeights is the priority weight for weighted round robin task scheduler",
DefaultValue: common.ConvertIntMapToDynamicConfigMapProperty(DefaultTaskSchedulerRoundRobinWeights),
},
QueueProcessorPendingTaskSplitThreshold: {
KeyName: "history.queueProcessorPendingTaskSplitThreshold",
Expand Down Expand Up @@ -5171,3 +5167,11 @@ func init() {
_keyNames[v.KeyName] = k
}
}

var (
DefaultTaskSchedulerRoundRobinWeights = map[int]int{
common.GetTaskPriority(common.HighPriorityClass, common.DefaultPrioritySubclass): 500,
common.GetTaskPriority(common.DefaultPriorityClass, common.DefaultPrioritySubclass): 20,
common.GetTaskPriority(common.LowPriorityClass, common.DefaultPrioritySubclass): 5,
}
)
28 changes: 15 additions & 13 deletions common/task/scheduler_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@ import (
"github.com/uber/cadence/common/dynamicconfig"
)

type SchedulerOptions struct {
type SchedulerOptions[K comparable] struct {
SchedulerType SchedulerType
FIFOSchedulerOptions *FIFOTaskSchedulerOptions
WRRSchedulerOptions *WeightedRoundRobinTaskSchedulerOptions
WRRSchedulerOptions *WeightedRoundRobinTaskSchedulerOptions[K]
}

func NewSchedulerOptions(
func NewSchedulerOptions[K comparable](
schedulerType int,
queueSize int,
workerCount dynamicconfig.IntPropertyFn,
dispatcherCount int,
weights dynamicconfig.MapPropertyFn,
) (*SchedulerOptions, error) {
options := &SchedulerOptions{
taskToChannelKeyFn func(PriorityTask) K,
channelKeyToWeightFn func(K) int,
) (*SchedulerOptions[K], error) {
options := &SchedulerOptions[K]{
SchedulerType: SchedulerType(schedulerType),
}
switch options.SchedulerType {
Expand All @@ -52,20 +53,21 @@ func NewSchedulerOptions(
RetryPolicy: common.CreateTaskProcessingRetryPolicy(),
}
case SchedulerTypeWRR:
options.WRRSchedulerOptions = &WeightedRoundRobinTaskSchedulerOptions{
Weights: weights,
QueueSize: queueSize,
WorkerCount: workerCount,
DispatcherCount: dispatcherCount,
RetryPolicy: common.CreateTaskProcessingRetryPolicy(),
options.WRRSchedulerOptions = &WeightedRoundRobinTaskSchedulerOptions[K]{
QueueSize: queueSize,
WorkerCount: workerCount,
DispatcherCount: dispatcherCount,
RetryPolicy: common.CreateTaskProcessingRetryPolicy(),
TaskToChannelKeyFn: taskToChannelKeyFn,
ChannelKeyToWeightFn: channelKeyToWeightFn,
}
default:
return nil, fmt.Errorf("unknown task scheduler type: %v", schedulerType)
}
return options, nil
}

func (o *SchedulerOptions) String() string {
func (o *SchedulerOptions[K]) String() string {
return fmt.Sprintf("{schedulerType:%v, fifoSchedulerOptions:%s, wrrSchedulerOptions:%s}",
o.SchedulerType, o.FIFOSchedulerOptions, o.WRRSchedulerOptions)
}
9 changes: 2 additions & 7 deletions common/task/scheduler_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ func TestSchedulerOptionsString(t *testing.T) {
queueSize int
workerCount dynamicconfig.IntPropertyFn
dispatcherCount int
weights dynamicconfig.MapPropertyFn
wantErr bool
want string
}{
Expand All @@ -51,11 +50,7 @@ func TestSchedulerOptionsString(t *testing.T) {
queueSize: 3,
workerCount: dynamicconfig.GetIntPropertyFn(4),
dispatcherCount: 5,
weights: dynamicconfig.GetMapPropertyFn(map[string]interface{}{
"1": 500,
"9": 20,
}),
want: "{schedulerType:2, fifoSchedulerOptions:<nil>, wrrSchedulerOptions:{QueueSize: 3, WorkerCount: 4, DispatcherCount: 5, Weights: map[1:500 9:20]}}",
want: "{schedulerType:2, fifoSchedulerOptions:<nil>, wrrSchedulerOptions:{QueueSize: 3, WorkerCount: 4, DispatcherCount: 5}}",
},
{
desc: "InvalidSchedulerType",
Expand All @@ -66,7 +61,7 @@ func TestSchedulerOptionsString(t *testing.T) {

for _, tc := range tests {
t.Run(tc.desc, func(t *testing.T) {
o, err := NewSchedulerOptions(tc.schedulerType, tc.queueSize, tc.workerCount, tc.dispatcherCount, tc.weights)
o, err := NewSchedulerOptions[int](tc.schedulerType, tc.queueSize, tc.workerCount, tc.dispatcherCount, nil, nil)
if (err != nil) != tc.wantErr {
t.Errorf("Got error: %v, wantErr: %v", err, tc.wantErr)
}
Expand Down
115 changes: 32 additions & 83 deletions common/task/weighted_round_robin_task_scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ package task

import (
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
Expand All @@ -33,18 +32,17 @@ import (
"github.com/uber/cadence/common/metrics"
)

type weightedRoundRobinTaskSchedulerImpl struct {
type weightedRoundRobinTaskSchedulerImpl[K comparable] struct {
sync.RWMutex

status int32
weights atomic.Value // store the currently used weights
taskChs map[int]chan PriorityTask
taskChs map[K]chan PriorityTask
shutdownCh chan struct{}
notifyCh chan struct{}
dispatcherWG sync.WaitGroup
logger log.Logger
metricsScope metrics.Scope
options *WeightedRoundRobinTaskSchedulerOptions
options *WeightedRoundRobinTaskSchedulerOptions[K]

processor Processor
}
Expand All @@ -60,23 +58,14 @@ var (
)

// NewWeightedRoundRobinTaskScheduler creates a new WRR task scheduler
func NewWeightedRoundRobinTaskScheduler(
func NewWeightedRoundRobinTaskScheduler[K comparable](
logger log.Logger,
metricsClient metrics.Client,
options *WeightedRoundRobinTaskSchedulerOptions,
options *WeightedRoundRobinTaskSchedulerOptions[K],
) (Scheduler, error) {
weights, err := common.ConvertDynamicConfigMapPropertyToIntMap(options.Weights())
if err != nil {
return nil, err
}

if len(weights) == 0 {
return nil, errors.New("weight is not specified in the scheduler option")
}

scheduler := &weightedRoundRobinTaskSchedulerImpl{
scheduler := &weightedRoundRobinTaskSchedulerImpl[K]{
status: common.DaemonStatusInitialized,
taskChs: make(map[int]chan PriorityTask),
taskChs: make(map[K]chan PriorityTask),
shutdownCh: make(chan struct{}),
notifyCh: make(chan struct{}, 1),
logger: logger,
Expand All @@ -92,12 +81,11 @@ func NewWeightedRoundRobinTaskScheduler(
},
),
}
scheduler.weights.Store(weights)

return scheduler, nil
}

func (w *weightedRoundRobinTaskSchedulerImpl) Start() {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) Start() {
if !atomic.CompareAndSwapInt32(&w.status, common.DaemonStatusInitialized, common.DaemonStatusStarted) {
return
}
Expand All @@ -108,12 +96,10 @@ func (w *weightedRoundRobinTaskSchedulerImpl) Start() {
for i := 0; i != w.options.DispatcherCount; i++ {
go w.dispatcher()
}
go w.updateWeights()

w.logger.Info("Weighted round robin task scheduler started.")
}

func (w *weightedRoundRobinTaskSchedulerImpl) Stop() {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) Stop() {
if !atomic.CompareAndSwapInt32(&w.status, common.DaemonStatusStarted, common.DaemonStatusStopped) {
return
}
Expand All @@ -135,7 +121,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl) Stop() {
w.logger.Info("Weighted round robin task scheduler shutdown.")
}

func (w *weightedRoundRobinTaskSchedulerImpl) Submit(task PriorityTask) error {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) Submit(task PriorityTask) error {
w.metricsScope.IncCounter(metrics.PriorityTaskSubmitRequest)
sw := w.metricsScope.StartTimer(metrics.PriorityTaskSubmitLatency)
defer sw.Stop()
Expand All @@ -144,11 +130,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl) Submit(task PriorityTask) error {
return ErrTaskSchedulerClosed
}

taskCh, err := w.getOrCreateTaskChan(task.Priority())
if err != nil {
return err
}

taskCh := w.getOrCreateTaskChan(task)
select {
case taskCh <- task:
w.notifyDispatcher()
Expand All @@ -161,17 +143,14 @@ func (w *weightedRoundRobinTaskSchedulerImpl) Submit(task PriorityTask) error {
}
}

func (w *weightedRoundRobinTaskSchedulerImpl) TrySubmit(
func (w *weightedRoundRobinTaskSchedulerImpl[K]) TrySubmit(
task PriorityTask,
) (bool, error) {
if w.isStopped() {
return false, ErrTaskSchedulerClosed
}

taskCh, err := w.getOrCreateTaskChan(task.Priority())
if err != nil {
return false, err
}
taskCh := w.getOrCreateTaskChan(task)

select {
case taskCh <- task:
Expand All @@ -189,11 +168,11 @@ func (w *weightedRoundRobinTaskSchedulerImpl) TrySubmit(
}
}

func (w *weightedRoundRobinTaskSchedulerImpl) dispatcher() {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) dispatcher() {
defer w.dispatcherWG.Done()

outstandingTasks := false
taskChs := make(map[int]chan PriorityTask)
taskChs := make(map[K]chan PriorityTask)

for {
if !outstandingTasks {
Expand All @@ -211,15 +190,10 @@ func (w *weightedRoundRobinTaskSchedulerImpl) dispatcher() {

outstandingTasks = false
w.updateTaskChs(taskChs)
weights := w.getWeights()
for priority, taskCh := range taskChs {
count, ok := weights[priority]
if !ok {
w.logger.Error("weights not found for task priority", tag.Dynamic("priority", priority), tag.Dynamic("weights", weights))
continue
}
for key, taskCh := range taskChs {
weight := w.options.ChannelKeyToWeightFn(key)
Submit_Loop:
for i := 0; i < count; i++ {
for i := 0; i < weight; i++ {
select {
case task := <-taskCh:
// dispatched at least one task in this round
Expand All @@ -240,40 +214,37 @@ func (w *weightedRoundRobinTaskSchedulerImpl) dispatcher() {
}
}

func (w *weightedRoundRobinTaskSchedulerImpl) getOrCreateTaskChan(priority int) (chan PriorityTask, error) {
if _, ok := w.getWeights()[priority]; !ok {
return nil, fmt.Errorf("unknown task priority: %v", priority)
}

func (w *weightedRoundRobinTaskSchedulerImpl[K]) getOrCreateTaskChan(task PriorityTask) chan PriorityTask {
key := w.options.TaskToChannelKeyFn(task)
w.RLock()
if taskCh, ok := w.taskChs[priority]; ok {
if taskCh, ok := w.taskChs[key]; ok {
w.RUnlock()
return taskCh, nil
return taskCh
}
w.RUnlock()

w.Lock()
defer w.Unlock()
if taskCh, ok := w.taskChs[priority]; ok {
return taskCh, nil
if taskCh, ok := w.taskChs[key]; ok {
return taskCh
}
taskCh := make(chan PriorityTask, w.options.QueueSize)
w.taskChs[priority] = taskCh
return taskCh, nil
w.taskChs[key] = taskCh
return taskCh
}

func (w *weightedRoundRobinTaskSchedulerImpl) updateTaskChs(taskChs map[int]chan PriorityTask) {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) updateTaskChs(taskChs map[K]chan PriorityTask) {
w.RLock()
defer w.RUnlock()

for priority, taskCh := range w.taskChs {
if _, ok := taskChs[priority]; !ok {
taskChs[priority] = taskCh
for key, taskCh := range w.taskChs {
if _, ok := taskChs[key]; !ok {
taskChs[key] = taskCh
}
}
}

func (w *weightedRoundRobinTaskSchedulerImpl) notifyDispatcher() {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) notifyDispatcher() {
select {
case w.notifyCh <- struct{}{}:
// sent a notification to the dispatcher
Expand All @@ -282,29 +253,7 @@ func (w *weightedRoundRobinTaskSchedulerImpl) notifyDispatcher() {
}
}

func (w *weightedRoundRobinTaskSchedulerImpl) getWeights() map[int]int {
return w.weights.Load().(map[int]int)
}

func (w *weightedRoundRobinTaskSchedulerImpl) updateWeights() {
ticker := time.NewTicker(defaultUpdateWeightsInterval)
for {
select {
case <-ticker.C:
weights, err := common.ConvertDynamicConfigMapPropertyToIntMap(w.options.Weights())
if err != nil {
w.logger.Error("failed to update weight for round robin task scheduler", tag.Error(err))
} else {
w.weights.Store(weights)
}
case <-w.shutdownCh:
ticker.Stop()
return
}
}
}

func (w *weightedRoundRobinTaskSchedulerImpl) isStopped() bool {
func (w *weightedRoundRobinTaskSchedulerImpl[K]) isStopped() bool {
return atomic.LoadInt32(&w.status) == common.DaemonStatusStopped
}

Expand Down
17 changes: 9 additions & 8 deletions common/task/weighted_round_robin_task_scheduler_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,15 @@ import (
)

// WeightedRoundRobinTaskSchedulerOptions configs WRR task scheduler
type WeightedRoundRobinTaskSchedulerOptions struct {
Weights dynamicconfig.MapPropertyFn
QueueSize int
WorkerCount dynamicconfig.IntPropertyFn
DispatcherCount int
RetryPolicy backoff.RetryPolicy
type WeightedRoundRobinTaskSchedulerOptions[K comparable] struct {
QueueSize int
WorkerCount dynamicconfig.IntPropertyFn
DispatcherCount int
RetryPolicy backoff.RetryPolicy
TaskToChannelKeyFn func(PriorityTask) K
ChannelKeyToWeightFn func(K) int
}

func (o *WeightedRoundRobinTaskSchedulerOptions) String() string {
return fmt.Sprintf("{QueueSize: %v, WorkerCount: %v, DispatcherCount: %v, Weights: %v}", o.QueueSize, o.WorkerCount(), o.DispatcherCount, o.Weights())
func (o *WeightedRoundRobinTaskSchedulerOptions[K]) String() string {
return fmt.Sprintf("{QueueSize: %v, WorkerCount: %v, DispatcherCount: %v}", o.QueueSize, o.WorkerCount(), o.DispatcherCount)
}
Loading

0 comments on commit a0aed21

Please sign in to comment.