Skip to content

Commit

Permalink
disttask: fix run one step SubtaskExecutor twice (#46106)
Browse files Browse the repository at this point in the history
close #46098
  • Loading branch information
ywqzzy authored Aug 17, 2023
1 parent 4fc7970 commit 24122b5
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 103 deletions.
2 changes: 1 addition & 1 deletion disttask/framework/framework_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ type testMiniTask struct{}
func (testMiniTask) IsMinimalTask() {}

func (testMiniTask) String() string {
return ""
return "testMiniTask"
}

type testScheduler struct{}
Expand Down
4 changes: 2 additions & 2 deletions disttask/framework/scheduler/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ import (
type TaskTable interface {
GetGlobalTasksInStates(states ...interface{}) (task []*proto.Task, err error)
GetGlobalTaskByID(taskID int64) (task *proto.Task, err error)
GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error)
GetSubtaskInStates(instanceID string, taskID int64, step int64, states ...interface{}) (*proto.Subtask, error)
UpdateSubtaskStateAndError(id int64, state string, err error) error
FinishSubtask(id int64, meta []byte) error
HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error)
HasSubtasksInStates(instanceID string, taskID int64, step int64, states ...interface{}) (bool, error)
UpdateErrorToSubtask(tidbID string, err error) error
}

Expand Down
8 changes: 4 additions & 4 deletions disttask/framework/scheduler/interface_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func (t *MockTaskTable) GetGlobalTaskByID(id int64) (*proto.Task, error) {
}

// GetSubtaskInStates implements SubtaskTable.GetSubtaskInStates.
func (t *MockTaskTable) GetSubtaskInStates(instanceID string, taskID int64, states ...interface{}) (*proto.Subtask, error) {
args := t.Called(instanceID, taskID, states)
func (t *MockTaskTable) GetSubtaskInStates(instanceID string, taskID int64, step int64, states ...interface{}) (*proto.Subtask, error) {
args := t.Called(instanceID, taskID, step, states)
if args.Error(1) != nil {
return nil, args.Error(1)
} else if args.Get(0) == nil {
Expand All @@ -76,8 +76,8 @@ func (t *MockTaskTable) FinishSubtask(id int64, meta []byte) error {
}

// HasSubtasksInStates implements SubtaskTable.HasSubtasksInStates.
func (t *MockTaskTable) HasSubtasksInStates(instanceID string, taskID int64, states ...interface{}) (bool, error) {
args := t.Called(instanceID, taskID, states)
func (t *MockTaskTable) HasSubtasksInStates(instanceID string, taskID int64, step int64, states ...interface{}) (bool, error) {
args := t.Called(instanceID, taskID, step, states)
return args.Bool(0), args.Error(1)
}

Expand Down
25 changes: 12 additions & 13 deletions disttask/framework/scheduler/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ type Manager struct {
subtaskExecutorPools map[string]Pool
mu struct {
sync.RWMutex
// taskID -> cancelFunc
// cancelFunc is used to fast cancel the scheduler.Run
// taskID -> cancelFunc.
// cancelFunc is used to fast cancel the scheduler.Run.
handlingTasks map[int64]context.CancelFunc
}
id string
Expand Down Expand Up @@ -188,7 +188,7 @@ func (m *Manager) onRunnableTasks(ctx context.Context, tasks []*proto.Task) {
logutil.Logger(m.logCtx).Error("unknown task type", zap.String("type", task.Type))
continue
}
exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, proto.TaskStatePending, proto.TaskStateRevertPending)
exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, task.Step, proto.TaskStatePending, proto.TaskStateRevertPending)
if err != nil {
logutil.Logger(m.logCtx).Error("check subtask exist failed", zap.Error(err))
m.onError(err)
Expand All @@ -197,14 +197,14 @@ func (m *Manager) onRunnableTasks(ctx context.Context, tasks []*proto.Task) {
if !exist {
continue
}
logutil.Logger(m.logCtx).Info("detect new subtask", zap.Any("id", task.ID))
logutil.Logger(m.logCtx).Info("detect new subtask", zap.Any("task_id", task.ID))
m.addHandlingTask(task.ID)
t := task
err = m.schedulerPool.Run(func() {
m.onRunnableTask(ctx, t.ID, t.Type)
m.removeHandlingTask(t.ID)
m.removeHandlingTask(task.ID)
})
// pool closed
// pool closed.
if err != nil {
m.removeHandlingTask(task.ID)
m.onError(err)
Expand All @@ -218,7 +218,7 @@ func (m *Manager) onCanceledTasks(_ context.Context, tasks []*proto.Task) {
m.mu.RLock()
defer m.mu.RUnlock()
for _, task := range tasks {
logutil.Logger(m.logCtx).Info("onCanceledTasks", zap.Any("id", task.ID))
logutil.Logger(m.logCtx).Info("onCanceledTasks", zap.Any("task_id", task.ID))
if cancel, ok := m.mu.handlingTasks[task.ID]; ok && cancel != nil {
cancel()
}
Expand All @@ -230,7 +230,7 @@ func (m *Manager) cancelAllRunningTasks() {
m.mu.RLock()
defer m.mu.RUnlock()
for id, cancel := range m.mu.handlingTasks {
logutil.Logger(m.logCtx).Info("cancelAllRunningTasks", zap.Any("id", id))
logutil.Logger(m.logCtx).Info("cancelAllRunningTasks", zap.Any("task_id", id))
if cancel != nil {
cancel()
}
Expand All @@ -254,12 +254,12 @@ func (m *Manager) filterAlreadyHandlingTasks(tasks []*proto.Task) []*proto.Task

// onRunnableTask handles a runnable task.
func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType string) {
logutil.Logger(m.logCtx).Info("onRunnableTask", zap.Any("id", taskID), zap.Any("type", taskType))
logutil.Logger(m.logCtx).Info("onRunnableTask", zap.Any("task_id", taskID), zap.Any("type", taskType))
if _, ok := m.subtaskExecutorPools[taskType]; !ok {
m.onError(errors.Errorf("task type %s not found", taskType))
return
}
// runCtx only used in scheduler.Run, cancel in m.fetchAndFastCancelTasks
// runCtx only used in scheduler.Run, cancel in m.fetchAndFastCancelTasks.
scheduler := m.newScheduler(ctx, m.id, taskID, m.taskTable, m.subtaskExecutorPools[taskType])
scheduler.Start()
defer scheduler.Stop()
Expand All @@ -275,11 +275,10 @@ func (m *Manager) onRunnableTask(ctx context.Context, taskID int64, taskType str
return
}
if task.State != proto.TaskStateRunning && task.State != proto.TaskStateReverting {
logutil.Logger(m.logCtx).Info("onRunnableTask exit", zap.Any("id", taskID), zap.Any("state", task.State))
logutil.Logger(m.logCtx).Info("onRunnableTask exit", zap.Any("task_id", taskID), zap.Int64("step", task.Step), zap.Any("state", task.State))
return
}
// TODO: intergrate with heartbeat mechanism
if exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, proto.TaskStatePending, proto.TaskStateRevertPending); err != nil {
if exist, err := m.taskTable.HasSubtasksInStates(m.id, task.ID, task.Step, proto.TaskStatePending, proto.TaskStateRevertPending); err != nil {
m.onError(err)
return
} else if !exist {
Expand Down
28 changes: 14 additions & 14 deletions disttask/framework/scheduler/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,36 +93,36 @@ func TestOnRunnableTasks(t *testing.T) {
m.subtaskExecutorPools["type"] = mockPool

// get subtask failed
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, errors.New("get subtask failed")).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, errors.New("get subtask failed")).Once()
m.onRunnableTasks(context.Background(), []*proto.Task{task})

// no subtask
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil).Once()
m.onRunnableTasks(context.Background(), []*proto.Task{task})

// pool error
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockPool.On("Run", mock.Anything).Return(errors.New("pool error")).Once()
m.onRunnableTasks(context.Background(), []*proto.Task{task})

// step 0 succeed
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockPool.On("Run", mock.Anything).Return(nil).Once()
mockInternalScheduler.On("Start").Once()
mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockInternalScheduler.On("Run", mock.Anything, task).Return(nil).Once()
m.onRunnableTasks(context.Background(), []*proto.Task{task})

// step 1 canceled
task1 := &proto.Task{ID: taskID, State: proto.TaskStateRunning, Step: proto.StepTwo}
mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task1, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockInternalScheduler.On("Run", mock.Anything, task1).Return(errors.New("run errr")).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepTwo, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockInternalScheduler.On("Run", mock.Anything, task1).Return(errors.New("run err")).Once()

task2 := &proto.Task{ID: taskID, State: proto.TaskStateReverting, Step: proto.StepTwo}
mockTaskTable.On("GetGlobalTaskByID", taskID).Return(task2, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID, proto.StepTwo, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockInternalScheduler.On("Rollback", mock.Anything, task2).Return(nil).Once()

task3 := &proto.Task{ID: taskID, State: proto.TaskStateReverted, Step: proto.StepTwo}
Expand Down Expand Up @@ -161,22 +161,22 @@ func TestManager(t *testing.T) {
mockTaskTable.On("GetGlobalTasksInStates", proto.TaskStateRunning, proto.TaskStateReverting).Return([]*proto.Task{task1, task2}, nil)
mockTaskTable.On("GetGlobalTasksInStates", proto.TaskStateReverting).Return([]*proto.Task{task2}, nil)
// task1
mockTaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockPool.On("Run", mock.Anything).Return(nil).Once()
mockInternalScheduler.On("Start").Once()
mockTaskTable.On("GetGlobalTaskByID", taskID1).Return(task1, nil)
mockTaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockInternalScheduler.On("Run", mock.Anything, task1).Return(nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID1, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil)
mockTaskTable.On("HasSubtasksInStates", id, taskID1, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil)
mockInternalScheduler.On("Stop").Once()
// task2
mockTaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockPool.On("Run", mock.Anything).Return(nil).Once()
mockInternalScheduler.On("Start").Once()
mockTaskTable.On("GetGlobalTaskByID", taskID2).Return(task2, nil)
mockTaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(true, nil).Once()
mockInternalScheduler.On("Rollback", mock.Anything, task2).Return(nil).Once()
mockTaskTable.On("HasSubtasksInStates", id, taskID2, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil)
mockTaskTable.On("HasSubtasksInStates", id, taskID2, proto.StepOne, []interface{}{proto.TaskStatePending, proto.TaskStateRevertPending}).Return(false, nil)
mockInternalScheduler.On("Stop").Once()
// once for scheduler pool, once for subtask pool
mockPool.On("ReleaseAndWait").Twice()
Expand Down
44 changes: 20 additions & 24 deletions disttask/framework/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,6 @@ func (s *InternalSchedulerImpl) Stop() {
s.wg.Wait()
}

// func (s *InternalSchedulerImpl) heartbeat() {
// ticker := time.NewTicker(proto.HeartbeatInterval)
// for {
// select {
// case <-s.ctx.Done():
// return
// case <-ticker.C:
// if err := s.subtaskTable.UpdateHeartbeat(s.id, s.taskID, time.Now()); err != nil {
// s.onError(err)
// return
// }
// }
// }
// }

// Run runs the scheduler task.
func (s *InternalSchedulerImpl) Run(ctx context.Context, task *proto.Task) error {
err := s.run(ctx, task)
Expand All @@ -107,6 +92,10 @@ func (s *InternalSchedulerImpl) Run(ctx context.Context, task *proto.Task) error
}

func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error {
if ctx.Err() != nil {
s.onError(ctx.Err())
return s.getError()
}
runCtx, runCancel := context.WithCancel(ctx)
defer runCancel()
s.registerCancelFunc(runCancel)
Expand Down Expand Up @@ -143,12 +132,12 @@ func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error
}

for {
// check if any error occurs
// check if any error occurs.
if err := s.getError(); err != nil {
break
}

subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStatePending)
subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, task.Step, proto.TaskStatePending)
if err != nil {
s.onError(err)
break
Expand All @@ -160,12 +149,12 @@ func (s *InternalSchedulerImpl) run(ctx context.Context, task *proto.Task) error
if err := s.getError(); err != nil {
break
}
s.runSubtask(runCtx, scheduler, subtask, task.Step, minimalTaskCh)
s.runSubtask(runCtx, scheduler, subtask, minimalTaskCh)
}
return s.getError()
}

func (s *InternalSchedulerImpl) runSubtask(ctx context.Context, scheduler Scheduler, subtask *proto.Subtask, step int64, minimalTaskCh chan func()) {
func (s *InternalSchedulerImpl) runSubtask(ctx context.Context, scheduler Scheduler, subtask *proto.Subtask, minimalTaskCh chan func()) {
minimalTasks, err := scheduler.SplitSubtask(ctx, subtask.Meta)
if err != nil {
s.onError(err)
Expand All @@ -177,14 +166,21 @@ func (s *InternalSchedulerImpl) runSubtask(ctx context.Context, scheduler Schedu
s.markErrorHandled()
return
}
logutil.Logger(s.logCtx).Info("split subTask", zap.Int("cnt", len(minimalTasks)), zap.Int64("subtask-id", subtask.ID))
if ctx.Err() != nil {
s.onError(ctx.Err())
return
}
logutil.Logger(s.logCtx).Info("split subTask",
zap.Int("cnt", len(minimalTasks)),
zap.Int64("subtask_id", subtask.ID),
zap.Int64("subtask_step", subtask.Step))

var minimalTaskWg sync.WaitGroup
for _, minimalTask := range minimalTasks {
minimalTaskWg.Add(1)
j := minimalTask
minimalTaskCh <- func() {
s.runMinimalTask(ctx, j, subtask.Type, step)
s.runMinimalTask(ctx, j, subtask.Type, subtask.Step)
minimalTaskWg.Done()
}
}
Expand Down Expand Up @@ -226,7 +222,7 @@ func (s *InternalSchedulerImpl) onSubtaskFinished(ctx context.Context, scheduler
}

func (s *InternalSchedulerImpl) runMinimalTask(minimalTaskCtx context.Context, minimalTask proto.MinimalTask, tp string, step int64) {
logutil.Logger(s.logCtx).Info("scheduler run a minimalTask", zap.Any("step", step), zap.Stringer("minimal-task", minimalTask))
logutil.Logger(s.logCtx).Info("scheduler run a minimalTask", zap.Any("step", step), zap.Stringer("minimal_task", minimalTask))
select {
case <-minimalTaskCtx.Done():
s.onError(minimalTaskCtx.Err())
Expand Down Expand Up @@ -276,7 +272,7 @@ func (s *InternalSchedulerImpl) Rollback(ctx context.Context, task *proto.Task)

// We should cancel all subtasks before rolling back
for {
subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStatePending, proto.TaskStateRunning)
subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, task.Step, proto.TaskStatePending, proto.TaskStateRunning)
if err != nil {
s.onError(err)
return s.getError()
Expand All @@ -297,7 +293,7 @@ func (s *InternalSchedulerImpl) Rollback(ctx context.Context, task *proto.Task)
s.onError(err)
return s.getError()
}
subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, proto.TaskStateRevertPending)
subtask, err := s.taskTable.GetSubtaskInStates(s.id, task.ID, task.Step, proto.TaskStateRevertPending)
if err != nil {
s.onError(err)
return s.getError()
Expand Down
Loading

0 comments on commit 24122b5

Please sign in to comment.