diff --git a/br/pkg/restore/snap_client/systable_restore_test.go b/br/pkg/restore/snap_client/systable_restore_test.go index e0181815852e8..9917e3b1ad06c 100644 --- a/br/pkg/restore/snap_client/systable_restore_test.go +++ b/br/pkg/restore/snap_client/systable_restore_test.go @@ -116,5 +116,5 @@ func TestCheckSysTableCompatibility(t *testing.T) { // // The above variables are in the file br/pkg/restore/systable_restore.go func TestMonitorTheSystemTableIncremental(t *testing.T) { - require.Equal(t, int64(242), session.CurrentBootstrapVersion) + require.Equal(t, int64(243), session.CurrentBootstrapVersion) } diff --git a/pkg/ddl/backfilling_dist_scheduler_test.go b/pkg/ddl/backfilling_dist_scheduler_test.go index a76f8c97ab349..4f471bb33d3e2 100644 --- a/pkg/ddl/backfilling_dist_scheduler_test.go +++ b/pkg/ddl/backfilling_dist_scheduler_test.go @@ -158,7 +158,7 @@ func TestBackfillingSchedulerGlobalSortMode(t *testing.T) { ext.(*ddl.LitBackfillScheduler).GlobalSort = true sch.Extension = ext - taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, "", task.Meta) + taskID, err := mgr.CreateTask(ctx, task.Key, proto.Backfill, 1, "", 0, task.Meta) require.NoError(t, err) task.ID = taskID execIDs := []string{":4000"} diff --git a/pkg/ddl/executor.go b/pkg/ddl/executor.go index 936fcb325cd80..3376ff292d27f 100644 --- a/pkg/ddl/executor.go +++ b/pkg/ddl/executor.go @@ -72,6 +72,7 @@ import ( "github.com/pingcap/tidb/pkg/util/stringutil" "github.com/pingcap/tidb/pkg/util/tracing" "github.com/tikv/client-go/v2/oracle" + "github.com/tikv/client-go/v2/tikv" pdhttp "github.com/tikv/pd/client/http" "go.uber.org/zap" ) @@ -4931,8 +4932,29 @@ func (e *executor) createIndex(ctx sessionctx.Context, ti ast.Ident, keyType ast return errors.Trace(err) } +// GetDXFDefaultMaxNodeCntAuto calcuates a default max node count for distributed task execution. +func GetDXFDefaultMaxNodeCntAuto(store kv.Storage) int { + tikvStore, ok := store.(tikv.Storage) + if !ok { + logutil.DDLLogger().Warn("not an TiKV or TiFlash store instance", zap.String("type", fmt.Sprintf("%T", store))) + return 0 + } + pdClient := tikvStore.GetRegionCache().PDClient() + if pdClient == nil { + logutil.DDLLogger().Warn("pd unavailable when get default max node count") + return 0 + } + stores, err := pdClient.GetAllStores(context.Background()) + if err != nil { + logutil.DDLLogger().Warn("get all stores failed when get default max node count", zap.Error(err)) + return 0 + } + return max(3, len(stores)/3) +} + func initJobReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) error { m := NewDDLReorgMeta(sctx) + setReorgParam := func() { if sv, ok := sctx.GetSessionVars().GetSystemVar(vardef.TiDBDDLReorgWorkerCount); ok { m.SetConcurrency(variable.TidbOptInt(sv, 0)) @@ -4946,6 +4968,12 @@ func initJobReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) erro m.IsDistReorg = vardef.EnableDistTask.Load() m.IsFastReorg = vardef.EnableFastReorg.Load() m.TargetScope = vardef.ServiceScope.Load() + if sv, ok := sctx.GetSessionVars().GetSystemVar(vardef.TiDBMaxDistTaskNodes); ok { + m.MaxNodeCount = variable.TidbOptInt(sv, 0) + if m.MaxNodeCount == -1 { // -1 means calculate automatically + m.MaxNodeCount = GetDXFDefaultMaxNodeCntAuto(sctx.GetStore()) + } + } if hasSysDB(job) { if m.IsDistReorg { logutil.DDLLogger().Info("cannot use distributed task execution on system DB", @@ -5002,6 +5030,7 @@ func initJobReorgMetaFromVariables(job *model.Job, sctx sessionctx.Context) erro zap.Bool("enableDistTask", m.IsDistReorg), zap.Bool("enableFastReorg", m.IsFastReorg), zap.String("targetScope", m.TargetScope), + zap.Int("maxNodeCount", m.MaxNodeCount), zap.Int("concurrency", m.GetConcurrency()), zap.Int("batchSize", m.GetBatchSize()), ) diff --git a/pkg/ddl/index.go b/pkg/ddl/index.go index 5845bfba91c45..8c2f9473f4be6 100644 --- a/pkg/ddl/index.go +++ b/pkg/ddl/index.go @@ -2571,7 +2571,9 @@ func (w *worker) executeDistTask(stepCtx context.Context, t table.Table, reorgIn g.Go(func() error { defer close(done) - err := submitAndWaitTask(ctx, taskKey, taskType, concurrency, reorgInfo.ReorgMeta.TargetScope, metaData) + targetScope := reorgInfo.ReorgMeta.TargetScope + maxNodeCnt := reorgInfo.ReorgMeta.MaxNodeCount + err := submitAndWaitTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, metaData) failpoint.InjectCall("pauseAfterDistTaskFinished") if err := w.isReorgRunnable(stepCtx, true); err != nil { if dbterror.ErrPausedDDLJob.Equal(err) { @@ -2728,8 +2730,8 @@ func (w *worker) updateDistTaskRowCount(taskKey string, jobID int64) { } // submitAndWaitTask submits a task and wait for it to finish. -func submitAndWaitTask(ctx context.Context, taskKey string, taskType proto.TaskType, concurrency int, targetScope string, taskMeta []byte) error { - task, err := handle.SubmitTask(ctx, taskKey, taskType, concurrency, targetScope, taskMeta) +func submitAndWaitTask(ctx context.Context, taskKey string, taskType proto.TaskType, concurrency int, targetScope string, maxNodeCnt int, taskMeta []byte) error { + task, err := handle.SubmitTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, taskMeta) if err != nil { return err } diff --git a/pkg/disttask/example/app_test.go b/pkg/disttask/example/app_test.go index c8731fb72101f..3ea73a2da9205 100644 --- a/pkg/disttask/example/app_test.go +++ b/pkg/disttask/example/app_test.go @@ -52,7 +52,7 @@ func TestExampleApplication(t *testing.T) { } bytes, err := json.Marshal(meta) require.NoError(t, err) - task, err := handle.SubmitTask(ctx, "test", proto.TaskTypeExample, 1, "", bytes) + task, err := handle.SubmitTask(ctx, "test", proto.TaskTypeExample, 1, "", 0, bytes) require.NoError(t, err) require.NoError(t, handle.WaitTaskDoneByKey(ctx, task.Key)) } diff --git a/pkg/disttask/framework/handle/handle.go b/pkg/disttask/framework/handle/handle.go index bb42c10b27003..2339d38a50ccd 100644 --- a/pkg/disttask/framework/handle/handle.go +++ b/pkg/disttask/framework/handle/handle.go @@ -55,7 +55,7 @@ func GetCPUCountOfNode(ctx context.Context) (int, error) { } // SubmitTask submits a task. -func SubmitTask(ctx context.Context, taskKey string, taskType proto.TaskType, concurrency int, targetScope string, taskMeta []byte) (*proto.Task, error) { +func SubmitTask(ctx context.Context, taskKey string, taskType proto.TaskType, concurrency int, targetScope string, maxNodeCnt int, taskMeta []byte) (*proto.Task, error) { taskManager, err := storage.GetTaskManager() if err != nil { return nil, err @@ -68,7 +68,7 @@ func SubmitTask(ctx context.Context, taskKey string, taskType proto.TaskType, co return nil, storage.ErrTaskAlreadyExists } - taskID, err := taskManager.CreateTask(ctx, taskKey, taskType, concurrency, targetScope, taskMeta) + taskID, err := taskManager.CreateTask(ctx, taskKey, taskType, concurrency, targetScope, maxNodeCnt, taskMeta) if err != nil { return nil, err } diff --git a/pkg/disttask/framework/handle/handle_test.go b/pkg/disttask/framework/handle/handle_test.go index dc73bba383bb5..8285ebdc46b2a 100644 --- a/pkg/disttask/framework/handle/handle_test.go +++ b/pkg/disttask/framework/handle/handle_test.go @@ -49,7 +49,7 @@ func TestHandle(t *testing.T) { storage.SetTaskManager(mgr) // no scheduler registered - task, err := handle.SubmitTask(ctx, "1", proto.TaskTypeExample, 2, "", proto.EmptyMeta) + task, err := handle.SubmitTask(ctx, "1", proto.TaskTypeExample, 2, "", 0, proto.EmptyMeta) require.NoError(t, err) waitedTaskBase, err := handle.WaitTask(ctx, task.ID, func(task *proto.TaskBase) bool { return task.IsDone() @@ -72,12 +72,12 @@ func TestHandle(t *testing.T) { require.NoError(t, handle.CancelTask(ctx, "1")) - task, err = handle.SubmitTask(ctx, "2", proto.TaskTypeExample, 2, "", proto.EmptyMeta) + task, err = handle.SubmitTask(ctx, "2", proto.TaskTypeExample, 2, "", 0, proto.EmptyMeta) require.NoError(t, err) require.Equal(t, "2", task.Key) // submit same task. - task, err = handle.SubmitTask(ctx, "2", proto.TaskTypeExample, 2, "", proto.EmptyMeta) + task, err = handle.SubmitTask(ctx, "2", proto.TaskTypeExample, 2, "", 0, proto.EmptyMeta) require.Nil(t, task) require.Error(t, storage.ErrTaskAlreadyExists, err) // pause and resume task. @@ -85,10 +85,10 @@ func TestHandle(t *testing.T) { require.NoError(t, handle.ResumeTask(ctx, "2")) // submit task with same key - task, err = handle.SubmitTask(ctx, "3", proto.TaskTypeExample, 2, "", proto.EmptyMeta) + task, err = handle.SubmitTask(ctx, "3", proto.TaskTypeExample, 2, "", 0, proto.EmptyMeta) require.NoError(t, err) require.NoError(t, mgr.TransferTasks2History(ctx, []*proto.Task{task})) - task, err = handle.SubmitTask(ctx, "3", proto.TaskTypeExample, 2, "", proto.EmptyMeta) + task, err = handle.SubmitTask(ctx, "3", proto.TaskTypeExample, 2, "", 0, proto.EmptyMeta) require.Nil(t, task) require.Error(t, storage.ErrTaskAlreadyExists, err) } diff --git a/pkg/disttask/framework/integrationtests/bench_test.go b/pkg/disttask/framework/integrationtests/bench_test.go index 59f07f648642f..3efc752f2b2fe 100644 --- a/pkg/disttask/framework/integrationtests/bench_test.go +++ b/pkg/disttask/framework/integrationtests/bench_test.go @@ -95,7 +95,7 @@ func BenchmarkSchedulerOverhead(b *testing.B) { for i := 0; i < 4*proto.MaxConcurrentTask; i++ { taskKey := fmt.Sprintf("task-%03d", i) taskMeta := make([]byte, *taskMetaSize) - _, err := handle.SubmitTask(c.Ctx, taskKey, proto.TaskTypeExample, 1, "", taskMeta) + _, err := handle.SubmitTask(c.Ctx, taskKey, proto.TaskTypeExample, 1, "", 0, taskMeta) require.NoError(c.T, err) } // task has 2 steps, each step has 1 subtask,wait in serial to reduce WaitTask check overhead. diff --git a/pkg/disttask/framework/integrationtests/modify_test.go b/pkg/disttask/framework/integrationtests/modify_test.go index 235c4eb0d4efc..13e43525b0b36 100644 --- a/pkg/disttask/framework/integrationtests/modify_test.go +++ b/pkg/disttask/framework/integrationtests/modify_test.go @@ -93,7 +93,7 @@ func TestModifyTaskConcurrency(t *testing.T) { var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { - task, err := handle.SubmitTask(c.Ctx, "k1", proto.TaskTypeExample, 3, "", []byte("init")) + task, err := handle.SubmitTask(c.Ctx, "k1", proto.TaskTypeExample, 3, "", 0, []byte("init")) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ @@ -144,7 +144,7 @@ func TestModifyTaskConcurrency(t *testing.T) { <-modifySyncCh }) }) - task, err := handle.SubmitTask(c.Ctx, "k2", proto.TaskTypeExample, 3, "", nil) + task, err := handle.SubmitTask(c.Ctx, "k2", proto.TaskTypeExample, 3, "", 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) // finish StepOne @@ -193,7 +193,7 @@ func TestModifyTaskConcurrency(t *testing.T) { } }, ) - task, err := handle.SubmitTask(c.Ctx, "k2-2", proto.TaskTypeExample, 3, "", nil) + task, err := handle.SubmitTask(c.Ctx, "k2-2", proto.TaskTypeExample, 3, "", 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) for i := 0; i < 5; i++ { @@ -217,7 +217,7 @@ func TestModifyTaskConcurrency(t *testing.T) { var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { - task, err := handle.SubmitTask(c.Ctx, "k3", proto.TaskTypeExample, 3, "", nil) + task, err := handle.SubmitTask(c.Ctx, "k3", proto.TaskTypeExample, 3, "", 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) found, err := c.TaskMgr.PauseTask(c.Ctx, task.Key) @@ -263,7 +263,7 @@ func TestModifyTaskConcurrency(t *testing.T) { var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { - task, err := handle.SubmitTask(c.Ctx, "k4", proto.TaskTypeExample, 3, "", nil) + task, err := handle.SubmitTask(c.Ctx, "k4", proto.TaskTypeExample, 3, "", 0, nil) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.NoError(t, c.TaskMgr.ModifyTaskByID(c.Ctx, task.ID, &proto.ModifyParam{ @@ -317,7 +317,7 @@ func TestModifyTaskConcurrency(t *testing.T) { var theTask *proto.Task testfailpoint.EnableCall(t, "github.com/pingcap/tidb/pkg/disttask/framework/scheduler/beforeGetSchedulableTasks", func() { once.Do(func() { - task, err := handle.SubmitTask(c.Ctx, "k5", proto.TaskTypeExample, 3, "", []byte("init")) + task, err := handle.SubmitTask(c.Ctx, "k5", proto.TaskTypeExample, 3, "", 0, []byte("init")) require.NoError(t, err) require.Equal(t, 3, task.Concurrency) require.EqualValues(t, []byte("init"), task.Meta) diff --git a/pkg/disttask/framework/integrationtests/resource_control_test.go b/pkg/disttask/framework/integrationtests/resource_control_test.go index d14130cf6d765..42ecdef1b18ee 100644 --- a/pkg/disttask/framework/integrationtests/resource_control_test.go +++ b/pkg/disttask/framework/integrationtests/resource_control_test.go @@ -96,7 +96,7 @@ func (c *resourceCtrlCaseContext) init(subtaskCntMap map[int64]map[proto.Step]in func (c *resourceCtrlCaseContext) runTaskAsync(prefix string, concurrencies []int) { for i, concurrency := range concurrencies { taskKey := fmt.Sprintf("%s-%d", prefix, i) - _, err := handle.SubmitTask(c.Ctx, taskKey, proto.TaskTypeExample, concurrency, "", nil) + _, err := handle.SubmitTask(c.Ctx, taskKey, proto.TaskTypeExample, concurrency, "", 0, nil) require.NoError(c.T, err) c.taskWG.RunWithLog(func() { task := testutil.WaitTaskDoneOrPaused(c.Ctx, c.T, taskKey) diff --git a/pkg/disttask/framework/planner/plan.go b/pkg/disttask/framework/planner/plan.go index d0d2cee557973..aeeb4636a3a32 100644 --- a/pkg/disttask/framework/planner/plan.go +++ b/pkg/disttask/framework/planner/plan.go @@ -32,6 +32,7 @@ type PlanCtx struct { TaskKey string TaskType proto.TaskType ThreadCnt int + MaxNodeCnt int // PreviousSubtaskMetas is subtask metas of previous steps. // We can remove this field if we find a better way to pass the result between steps. diff --git a/pkg/disttask/framework/planner/planner.go b/pkg/disttask/framework/planner/planner.go index 6b5c6f4a35684..850d3d4a58dd3 100644 --- a/pkg/disttask/framework/planner/planner.go +++ b/pkg/disttask/framework/planner/planner.go @@ -46,6 +46,7 @@ func (*Planner) Run(planCtx PlanCtx, plan LogicalPlan) (int64, error) { planCtx.TaskType, planCtx.ThreadCnt, config.GetGlobalConfig().Instance.TiDBServiceScope, + planCtx.MaxNodeCnt, taskMeta, ) } diff --git a/pkg/disttask/framework/proto/task.go b/pkg/disttask/framework/proto/task.go index 261b29050cc8e..14264d7082c9d 100644 --- a/pkg/disttask/framework/proto/task.go +++ b/pkg/disttask/framework/proto/task.go @@ -84,8 +84,9 @@ type TaskBase struct { // contain the tidb_service_scope=TargetScope label. // To be compatible with previous version, if it's "" or "background", the task try run on nodes of "background" scope, // if there is no such nodes, will try nodes of "" scope. - TargetScope string - CreateTime time.Time + TargetScope string + CreateTime time.Time + MaxNodeCount int } // IsDone checks if the task is done. diff --git a/pkg/disttask/framework/scheduler/balancer.go b/pkg/disttask/framework/scheduler/balancer.go index ee6d20b24ac34..4b66fbe7a6686 100644 --- a/pkg/disttask/framework/scheduler/balancer.go +++ b/pkg/disttask/framework/scheduler/balancer.go @@ -16,6 +16,7 @@ package scheduler import ( "context" + "sort" "time" "github.com/pingcap/errors" @@ -101,11 +102,26 @@ func (b *balancer) balanceSubtasks(ctx context.Context, sch Scheduler, managedNo if len(eligibleNodes) == 0 { return errors.New("no eligible nodes to balance subtasks") } - return b.doBalanceSubtasks(ctx, task.ID, eligibleNodes) + return b.doBalanceSubtasks(ctx, task, eligibleNodes) } -func (b *balancer) doBalanceSubtasks(ctx context.Context, taskID int64, eligibleNodes []string) (err error) { - subtasks, err := b.taskMgr.GetActiveSubtasks(ctx, taskID) +func filterNodesByMaxNodeCnt(nodes []string, subtasks []*proto.SubtaskBase, maxNodeCnt int) []string { + if maxNodeCnt == 0 || len(nodes) <= maxNodeCnt { + return nodes + } + // Order nodes by subtask count. + nodeSubtaskCnt := make(map[string]int, len(nodes)) + for _, st := range subtasks { + nodeSubtaskCnt[st.ExecID]++ + } + sort.SliceStable(nodes, func(i, j int) bool { + return nodeSubtaskCnt[nodes[i]] > nodeSubtaskCnt[nodes[j]] + }) + return nodes[:maxNodeCnt] +} + +func (b *balancer) doBalanceSubtasks(ctx context.Context, task *proto.Task, eligibleNodes []string) (err error) { + subtasks, err := b.taskMgr.GetActiveSubtasks(ctx, task.ID) if err != nil { return err } @@ -120,6 +136,7 @@ func (b *balancer) doBalanceSubtasks(ctx context.Context, taskID int64, eligible failpoint.Inject("mockNoEnoughSlots", func(_ failpoint.Value) { adjustedNodes = []string{} }) + adjustedNodes = filterNodesByMaxNodeCnt(adjustedNodes, subtasks, task.MaxNodeCount) if len(adjustedNodes) == 0 { // no node has enough slots to run the subtasks, skip balance and skip // update used slots. @@ -161,7 +178,7 @@ func (b *balancer) doBalanceSubtasks(ctx context.Context, taskID int64, eligible for node, sts := range executorSubtasks { if _, ok := adjustedNodeMap[node]; !ok { b.logger.Info("dead node or not have enough slots, schedule subtasks away", - zap.Int64("task-id", taskID), + zap.Int64("task-id", task.ID), zap.String("node", node), zap.Int("slot-capacity", b.slotMgr.getCapacity()), zap.Int("used-slots", b.currUsedSlots[node])) diff --git a/pkg/disttask/framework/scheduler/balancer_test.go b/pkg/disttask/framework/scheduler/balancer_test.go index 984b8cab53f54..31a164bfe116b 100644 --- a/pkg/disttask/framework/scheduler/balancer_test.go +++ b/pkg/disttask/framework/scheduler/balancer_test.go @@ -30,6 +30,7 @@ import ( type balanceTestCase struct { subtasks []*proto.SubtaskBase eligibleNodes []string + maxNodeCount int initUsedSlots map[string]int expectedSubtasks []*proto.SubtaskBase expectedUsedSlots map[string]int @@ -224,6 +225,44 @@ func TestBalanceOneTask(t *testing.T) { }, expectedUsedSlots: map[string]int{"tidb2": 16, "tidb3": 16}, }, + // balanced, but max node count is limited. + { + subtasks: []*proto.SubtaskBase{ + {ID: 1, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + maxNodeCount: 1, + eligibleNodes: []string{"tidb1", "tidb2"}, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0}, + expectedSubtasks: []*proto.SubtaskBase{ + {ID: 1, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 3, ExecID: "tidb2", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 0, "tidb2": 16}, + }, + // scale out, but max node count is limited. + { + subtasks: []*proto.SubtaskBase{ + {ID: 1, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + eligibleNodes: []string{"tidb1", "tidb2", "tidb3"}, + maxNodeCount: 2, + initUsedSlots: map[string]int{"tidb1": 0, "tidb2": 0, "tidb3": 0}, + expectedSubtasks: []*proto.SubtaskBase{ + {ID: 1, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStateRunning}, + {ID: 2, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 3, ExecID: "tidb3", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 4, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + {ID: 5, ExecID: "tidb1", Concurrency: 16, State: proto.SubtaskStatePending}, + }, + expectedUsedSlots: map[string]int{"tidb1": 16, "tidb2": 0, "tidb3": 16}, + }, } ctx := context.Background() @@ -235,7 +274,7 @@ func TestBalanceOneTask(t *testing.T) { mockTaskMgr.EXPECT().UpdateSubtasksExecIDs(gomock.Any(), gomock.Any()).Return(nil) } mockScheduler := mock.NewMockScheduler(ctrl) - mockScheduler.EXPECT().GetTask().Return(&proto.Task{TaskBase: proto.TaskBase{ID: 1}}).Times(2) + mockScheduler.EXPECT().GetTask().Return(&proto.Task{TaskBase: proto.TaskBase{ID: 1, MaxNodeCount: c.maxNodeCount}}).Times(2) mockScheduler.EXPECT().GetEligibleInstances(gomock.Any(), gomock.Any()).Return(nil, nil) slotMgr := newSlotManager() diff --git a/pkg/disttask/framework/scheduler/scheduler.go b/pkg/disttask/framework/scheduler/scheduler.go index a7ae30f9e28c6..9f64368653a10 100644 --- a/pkg/disttask/framework/scheduler/scheduler.go +++ b/pkg/disttask/framework/scheduler/scheduler.go @@ -468,6 +468,12 @@ func (s *BaseScheduler) switch2NextStep() error { if err != nil { return err } + if task.MaxNodeCount > 0 && len(eligibleNodes) > task.MaxNodeCount { + // OnNextSubtasksBatch may use len(eligibleNodes) as a hint to + // calculate the number of subtasks, so we need to do this before + // filtering nodes by available slots in scheduleSubtask. + eligibleNodes = eligibleNodes[:task.MaxNodeCount] + } s.logger.Info("eligible instances", zap.Int("num", len(eligibleNodes))) if len(eligibleNodes) == 0 { diff --git a/pkg/disttask/framework/scheduler/scheduler_manager_test.go b/pkg/disttask/framework/scheduler/scheduler_manager_test.go index daa1029741f4f..f9d4d0fcdc8df 100644 --- a/pkg/disttask/framework/scheduler/scheduler_manager_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_manager_test.go @@ -46,7 +46,7 @@ func TestCleanUpRoutine(t *testing.T) { mockCleanupRoutine.EXPECT().CleanUp(gomock.Any(), gomock.Any()).Return(nil).AnyTimes() sch.Start() defer sch.Stop() - taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, "", nil) + taskID, err := mgr.CreateTask(ctx, "test", proto.TaskTypeExample, 1, "", 0, nil) require.NoError(t, err) checkTaskRunningCnt := func() []*proto.Task { diff --git a/pkg/disttask/framework/scheduler/scheduler_test.go b/pkg/disttask/framework/scheduler/scheduler_test.go index e6e452b9ee0a8..e7ba6784490ca 100644 --- a/pkg/disttask/framework/scheduler/scheduler_test.go +++ b/pkg/disttask/framework/scheduler/scheduler_test.go @@ -130,7 +130,7 @@ func TestTaskFailInManager(t *testing.T) { defer schManager.Stop() // unknown task type - taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, "", nil) + taskID, err := mgr.CreateTask(ctx, "test", "test-type", 1, "", 0, nil) require.NoError(t, err) require.Eventually(t, func() bool { task, err := mgr.GetTaskByID(ctx, taskID) @@ -140,7 +140,7 @@ func TestTaskFailInManager(t *testing.T) { }, time.Second*10, time.Millisecond*300) // scheduler init error - taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, "", nil) + taskID, err = mgr.CreateTask(ctx, "test2", proto.TaskTypeExample, 1, "", 0, nil) require.NoError(t, err) require.Eventually(t, func() bool { task, err := mgr.GetTaskByID(ctx, taskID) @@ -215,7 +215,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, // Mock add tasks. taskIDs := make([]int64, 0, taskCnt) for i := 0; i < taskCnt; i++ { - taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, "background", nil) + taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", i), proto.TaskTypeExample, 0, "background", 0, nil) require.NoError(t, err) taskIDs = append(taskIDs, taskID) } @@ -225,7 +225,7 @@ func checkSchedule(t *testing.T, taskCnt int, isSucc, isCancel, isSubtaskCancel, checkSubtaskCnt(tasks, taskIDs) // test parallelism control if taskCnt == 1 { - taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, "background", nil) + taskID, err := mgr.CreateTask(ctx, fmt.Sprintf("%d", taskCnt), proto.TaskTypeExample, 0, "background", 0, nil) require.NoError(t, err) checkGetRunningTaskCnt(taskCnt) // Clean the task. @@ -460,7 +460,7 @@ func TestManagerScheduleLoop(t *testing.T) { }, ) for i := 0; i < len(concurrencies); i++ { - _, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], "", []byte("{}")) + _, err := taskMgr.CreateTask(ctx, fmt.Sprintf("key/%d", i), proto.TaskTypeExample, concurrencies[i], "", 0, []byte("{}")) require.NoError(t, err) } getRunningTaskKeys := func() []string { diff --git a/pkg/disttask/framework/storage/converter.go b/pkg/disttask/framework/storage/converter.go index e4c140bbecbc7..dbc55aa64110c 100644 --- a/pkg/disttask/framework/storage/converter.go +++ b/pkg/disttask/framework/storage/converter.go @@ -73,6 +73,8 @@ func Row2Task(r chunk.Row) *proto.Task { logutil.BgLogger().Error("unmarshal task modify param", zap.Error(err)) } } + maxNodeCnt := r.GetInt64(15) + task.MaxNodeCount = int(maxNodeCnt) return task } diff --git a/pkg/disttask/framework/storage/table_test.go b/pkg/disttask/framework/storage/table_test.go index 3868f26ef9d96..2af962ee43dd8 100644 --- a/pkg/disttask/framework/storage/table_test.go +++ b/pkg/disttask/framework/storage/table_test.go @@ -46,11 +46,11 @@ func TestTaskTable(t *testing.T) { require.NoError(t, gm.InitMeta(ctx, ":4000", "")) - _, err := gm.CreateTask(ctx, "key1", "test", 999, "", []byte("test")) + _, err := gm.CreateTask(ctx, "key1", "test", 999, "", 0, []byte("test")) require.ErrorContains(t, err, "task concurrency(999) larger than cpu count") timeBeforeCreate := time.Unix(time.Now().Unix(), 0) - id, err := gm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) @@ -99,11 +99,11 @@ func TestTaskTable(t *testing.T) { require.Equal(t, task.State, task6.State) // test cannot insert task with dup key - _, err = gm.CreateTask(ctx, "key1", "test2", 4, "", []byte("test2")) + _, err = gm.CreateTask(ctx, "key1", "test2", 4, "", 0, []byte("test2")) require.EqualError(t, err, "[kv:1062]Duplicate entry 'key1' for key 'tidb_global_task.task_key'") // test cancel task - id, err = gm.CreateTask(ctx, "key2", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key2", "test", 4, "", 0, []byte("test")) require.NoError(t, err) cancelling, err := testutil.IsTaskCancelling(ctx, gm, id) @@ -115,7 +115,7 @@ func TestTaskTable(t *testing.T) { require.NoError(t, err) require.True(t, cancelling) - id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, "", []byte("test2")) + id, err = gm.CreateTask(ctx, "key-fail", "test2", 4, "", 0, []byte("test2")) require.NoError(t, err) // state not right, update nothing require.NoError(t, gm.FailTask(ctx, id, proto.TaskStateRunning, errors.New("test error"))) @@ -135,7 +135,7 @@ func TestTaskTable(t *testing.T) { require.GreaterOrEqual(t, endTime, curTime) // succeed a pending task, no effect - id, err = gm.CreateTask(ctx, "key-success", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key-success", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.NoError(t, gm.SucceedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) @@ -154,7 +154,7 @@ func TestTaskTable(t *testing.T) { require.GreaterOrEqual(t, task.StateUpdateTime, startTime) // reverted a pending task, no effect - id, err = gm.CreateTask(ctx, "key-reverted", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key-reverted", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.NoError(t, gm.RevertedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) @@ -178,7 +178,7 @@ func TestTaskTable(t *testing.T) { require.Equal(t, proto.TaskStateReverted, task.State) // paused - id, err = gm.CreateTask(ctx, "key-paused", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key-paused", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.NoError(t, gm.PausedTask(ctx, id)) task, err = gm.GetTaskByID(ctx, id) @@ -229,7 +229,7 @@ func TestSwitchTaskStep(t *testing.T) { tk := testkit.NewTestKit(t, store) require.NoError(t, tm.InitMeta(ctx, ":4000", "")) - taskID, err := tm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + taskID, err := tm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) require.NoError(t, err) @@ -281,7 +281,7 @@ func TestSwitchTaskStepInBatch(t *testing.T) { require.NoError(t, tm.InitMeta(ctx, ":4000", "")) // normal flow prepare := func(taskKey string) (*proto.Task, []*proto.Subtask) { - taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, "", []byte("test")) + taskID, err := tm.CreateTask(ctx, taskKey, "test", 4, "", 0, []byte("test")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) require.NoError(t, err) @@ -373,7 +373,7 @@ func TestGetTopUnfinishedTasks(t *testing.T) { } for i, state := range taskStates { taskKey := fmt.Sprintf("key/%d", i) - _, err := gm.CreateTask(ctx, taskKey, "test", 4, "", []byte("test")) + _, err := gm.CreateTask(ctx, taskKey, "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.NoError(t, gm.WithNewSession(func(se sessionctx.Context) error { _, err := se.GetSQLExecutor().ExecuteInternal(ctx, ` @@ -446,7 +446,7 @@ func TestGetUsedSlotsOnNodes(t *testing.T) { func TestGetActiveSubtasks(t *testing.T) { _, tm, ctx := testutil.InitTableTest(t) require.NoError(t, tm.InitMeta(ctx, ":4000", "")) - id, err := tm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + id, err := tm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) task, err := tm.GetTaskByID(ctx, id) @@ -478,7 +478,7 @@ func TestSubTaskTable(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) timeBeforeCreate := time.Unix(time.Now().Unix(), 0) require.NoError(t, sm.InitMeta(ctx, ":4000", "")) - id, err := sm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + id, err := sm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) err = sm.SwitchTaskStep( @@ -634,7 +634,7 @@ func TestSubTaskTable(t *testing.T) { func TestBothTaskAndSubTaskTable(t *testing.T) { _, sm, ctx := testutil.InitTableTest(t) require.NoError(t, sm.InitMeta(ctx, ":4000", "")) - id, err := sm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + id, err := sm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.Equal(t, int64(1), id) @@ -843,9 +843,9 @@ func TestTaskHistoryTable(t *testing.T) { _, gm, ctx := testutil.InitTableTest(t) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) - _, err := gm.CreateTask(ctx, "1", proto.TaskTypeExample, 1, "", nil) + _, err := gm.CreateTask(ctx, "1", proto.TaskTypeExample, 1, "", 0, nil) require.NoError(t, err) - taskID, err := gm.CreateTask(ctx, "2", proto.TaskTypeExample, 1, "", nil) + taskID, err := gm.CreateTask(ctx, "2", proto.TaskTypeExample, 1, "", 0, nil) require.NoError(t, err) tasks, err := gm.GetTasksInStates(ctx, proto.TaskStatePending) @@ -878,7 +878,7 @@ func TestTaskHistoryTable(t *testing.T) { require.NotNil(t, task) // task with fail transfer - _, err = gm.CreateTask(ctx, "3", proto.TaskTypeExample, 1, "", nil) + _, err = gm.CreateTask(ctx, "3", proto.TaskTypeExample, 1, "", 0, nil) require.NoError(t, err) tasks, err = gm.GetTasksInStates(ctx, proto.TaskStatePending) require.NoError(t, err) @@ -1135,7 +1135,7 @@ func TestGetActiveTaskExecInfo(t *testing.T) { taskStates := []proto.TaskState{proto.TaskStateRunning, proto.TaskStateReverting, proto.TaskStateReverting, proto.TaskStatePausing} tasks := make([]*proto.Task, 0, len(taskStates)) for i, expectedState := range taskStates { - taskID, err := tm.CreateTask(ctx, fmt.Sprintf("key-%d", i), proto.TaskTypeExample, 8, "", []byte("")) + taskID, err := tm.CreateTask(ctx, fmt.Sprintf("key-%d", i), proto.TaskTypeExample, 8, "", 0, []byte("")) require.NoError(t, err) task, err := tm.GetTaskByID(ctx, taskID) require.NoError(t, err) diff --git a/pkg/disttask/framework/storage/task_state_test.go b/pkg/disttask/framework/storage/task_state_test.go index 0a515675d95ff..3afa557d8c0a5 100644 --- a/pkg/disttask/framework/storage/task_state_test.go +++ b/pkg/disttask/framework/storage/task_state_test.go @@ -40,7 +40,7 @@ func TestTaskState(t *testing.T) { require.NoError(t, gm.InitMeta(ctx, ":4000", "")) // 1. cancel task - id, err := gm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(1), id) TODO: unstable for infoschema v2 require.NoError(t, gm.CancelTask(ctx, id)) @@ -49,7 +49,7 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateCancelling, proto.StepInit) // 2. cancel task by key session - id, err = gm.CreateTask(ctx, "key2", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key2", "test", 4, "", 0, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(2), id) TODO: unstable for infoschema v2 require.NoError(t, gm.WithNewTxn(ctx, func(se sessionctx.Context) error { @@ -61,7 +61,7 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateCancelling, proto.StepInit) // 3. fail task - id, err = gm.CreateTask(ctx, "key3", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key3", "test", 4, "", 0, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(3), id) TODO: unstable for infoschema v2 failedErr := errors.New("test err") @@ -72,7 +72,7 @@ func TestTaskState(t *testing.T) { require.ErrorContains(t, task.Error, "test err") // 4. Reverted task - id, err = gm.CreateTask(ctx, "key4", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key4", "test", 4, "", 0, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(4), id) TODO: unstable for infoschema v2 task, err = gm.GetTaskByID(ctx, id) @@ -90,7 +90,7 @@ func TestTaskState(t *testing.T) { checkTaskStateStep(t, task, proto.TaskStateReverted, proto.StepInit) // 5. pause task - id, err = gm.CreateTask(ctx, "key5", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key5", "test", 4, "", 0, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(5), id) TODO: unstable for infoschema v2 found, err := gm.PauseTask(ctx, "key5") @@ -119,7 +119,7 @@ func TestTaskState(t *testing.T) { require.Equal(t, proto.TaskStateRunning, task.State) // 8. succeed task - id, err = gm.CreateTask(ctx, "key6", "test", 4, "", []byte("test")) + id, err = gm.CreateTask(ctx, "key6", "test", 4, "", 0, []byte("test")) require.NoError(t, err) // require.Equal(t, int64(6), id) TODO: unstable for infoschema v2 task, err = gm.GetTaskByID(ctx, id) @@ -139,7 +139,7 @@ func TestModifyTask(t *testing.T) { _, gm, ctx := testutil.InitTableTest(t) require.NoError(t, gm.InitMeta(ctx, ":4000", "")) - id, err := gm.CreateTask(ctx, "key1", "test", 4, "", []byte("test")) + id, err := gm.CreateTask(ctx, "key1", "test", 4, "", 0, []byte("test")) require.NoError(t, err) require.ErrorIs(t, gm.ModifyTaskByID(ctx, id, &proto.ModifyParam{ diff --git a/pkg/disttask/framework/storage/task_table.go b/pkg/disttask/framework/storage/task_table.go index 353ea90b8fc20..eb5e1ad271483 100644 --- a/pkg/disttask/framework/storage/task_table.go +++ b/pkg/disttask/framework/storage/task_table.go @@ -39,9 +39,9 @@ const ( basicTaskColumns = `t.id, t.task_key, t.type, t.state, t.step, t.priority, t.concurrency, t.create_time, t.target_scope` // TaskColumns is the columns for task. // TODO: dispatcher_id will update to scheduler_id later - TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error, t.modify_params` + TaskColumns = basicTaskColumns + `, t.start_time, t.state_update_time, t.meta, t.dispatcher_id, t.error, t.modify_params, t.max_node_count` // InsertTaskColumns is the columns used in insert task. - InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope` + InsertTaskColumns = `task_key, type, state, priority, concurrency, step, meta, create_time, target_scope, max_node_count` basicSubtaskColumns = `id, step, task_key, type, exec_id, state, concurrency, create_time, ordinal, start_time` // SubtaskColumns is the columns for subtask. SubtaskColumns = basicSubtaskColumns + `, state_update_time, meta, summary` @@ -202,10 +202,18 @@ func (mgr *TaskManager) ExecuteSQLWithNewSession(ctx context.Context, sql string } // CreateTask adds a new task to task table. -func (mgr *TaskManager) CreateTask(ctx context.Context, key string, tp proto.TaskType, concurrency int, targetScope string, meta []byte) (taskID int64, err error) { +func (mgr *TaskManager) CreateTask( + ctx context.Context, + key string, + tp proto.TaskType, + concurrency int, + targetScope string, + maxNodeCnt int, + meta []byte, +) (taskID int64, err error) { err = mgr.WithNewSession(func(se sessionctx.Context) error { var err2 error - taskID, err2 = mgr.CreateTaskWithSession(ctx, se, key, tp, concurrency, targetScope, meta) + taskID, err2 = mgr.CreateTaskWithSession(ctx, se, key, tp, concurrency, targetScope, maxNodeCnt, meta) return err2 }) return @@ -219,6 +227,7 @@ func (mgr *TaskManager) CreateTaskWithSession( tp proto.TaskType, concurrency int, targetScope string, + maxNodeCount int, meta []byte, ) (taskID int64, err error) { cpuCount, err := mgr.getCPUCountOfNode(ctx, se) @@ -230,8 +239,8 @@ func (mgr *TaskManager) CreateTaskWithSession( } _, err = sqlexec.ExecSQL(ctx, se.GetSQLExecutor(), ` insert into mysql.tidb_global_task(`+InsertTaskColumns+`) - values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), %?)`, - key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, proto.StepInit, meta, targetScope) + values (%?, %?, %?, %?, %?, %?, %?, CURRENT_TIMESTAMP(), %?, %?)`, + key, tp, proto.TaskStatePending, proto.NormalPriority, concurrency, proto.StepInit, meta, targetScope, maxNodeCount) if err != nil { return 0, err } diff --git a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go index a6e81a0fedd6a..055932e87a894 100644 --- a/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go +++ b/pkg/disttask/framework/taskexecutor/task_executor_testkit_test.go @@ -36,7 +36,7 @@ import ( ) func runOneTask(ctx context.Context, t *testing.T, mgr *storage.TaskManager, taskKey string, subtaskCnt int) { - taskID, err := mgr.CreateTask(ctx, taskKey, proto.TaskTypeExample, 1, "", nil) + taskID, err := mgr.CreateTask(ctx, taskKey, proto.TaskTypeExample, 1, "", 0, nil) require.NoError(t, err) task, err := mgr.GetTaskByID(ctx, taskID) require.NoError(t, err) diff --git a/pkg/disttask/framework/testutil/disttest_util.go b/pkg/disttask/framework/testutil/disttest_util.go index 75b2319cac5fb..7f7db281101f2 100644 --- a/pkg/disttask/framework/testutil/disttest_util.go +++ b/pkg/disttask/framework/testutil/disttest_util.go @@ -108,7 +108,7 @@ func RegisterTaskTypeForRollback(t testing.TB, ctrl *gomock.Controller, schedule // SubmitAndWaitTask schedule one task. func SubmitAndWaitTask(ctx context.Context, t testing.TB, taskKey string, targetScope string, concurrency int) *proto.TaskBase { - _, err := handle.SubmitTask(ctx, taskKey, proto.TaskTypeExample, concurrency, targetScope, nil) + _, err := handle.SubmitTask(ctx, taskKey, proto.TaskTypeExample, concurrency, targetScope, 0, nil) require.NoError(t, err) return WaitTaskDoneOrPaused(ctx, t, taskKey) } diff --git a/pkg/disttask/importinto/job.go b/pkg/disttask/importinto/job.go index 174b3e04b2624..6cb71b25776b7 100644 --- a/pkg/disttask/importinto/job.go +++ b/pkg/disttask/importinto/job.go @@ -89,6 +89,7 @@ func doSubmitTask(ctx context.Context, plan *importer.Plan, stmt string, instanc TaskKey: TaskKey(jobID), TaskType: proto.ImportInto, ThreadCnt: plan.ThreadCnt, + MaxNodeCnt: plan.MaxNodeCnt, } p := planner.NewPlanner() taskID, err2 = p.Run(planCtx, logicalPlan) diff --git a/pkg/disttask/importinto/job_testkit_test.go b/pkg/disttask/importinto/job_testkit_test.go index 14dd732d9052c..888b00d71a63e 100644 --- a/pkg/disttask/importinto/job_testkit_test.go +++ b/pkg/disttask/importinto/job_testkit_test.go @@ -53,7 +53,7 @@ func TestGetTaskImportedRows(t *testing.T) { } bytes, err := json.Marshal(taskMeta) require.NoError(t, err) - taskID, err := manager.CreateTask(ctx, importinto.TaskKey(111), proto.ImportInto, 1, "", bytes) + taskID, err := manager.CreateTask(ctx, importinto.TaskKey(111), proto.ImportInto, 1, "", 0, bytes) require.NoError(t, err) importStepMetas := []*importinto.ImportStepMeta{ { @@ -85,7 +85,7 @@ func TestGetTaskImportedRows(t *testing.T) { } bytes, err = json.Marshal(taskMeta) require.NoError(t, err) - taskID, err = manager.CreateTask(ctx, importinto.TaskKey(222), proto.ImportInto, 1, "", bytes) + taskID, err = manager.CreateTask(ctx, importinto.TaskKey(222), proto.ImportInto, 1, "", 0, bytes) require.NoError(t, err) ingestStepMetas := []*importinto.WriteIngestStepMeta{ { diff --git a/pkg/disttask/importinto/scheduler_testkit_test.go b/pkg/disttask/importinto/scheduler_testkit_test.go index 733334986999e..7e6ab15b90bab 100644 --- a/pkg/disttask/importinto/scheduler_testkit_test.go +++ b/pkg/disttask/importinto/scheduler_testkit_test.go @@ -86,7 +86,7 @@ func TestSchedulerExtLocalSort(t *testing.T) { require.NoError(t, err) taskMeta, err := json.Marshal(task) require.NoError(t, err) - taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", taskMeta) + taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", 0, taskMeta) require.NoError(t, err) task.ID = taskID @@ -229,7 +229,7 @@ func TestSchedulerExtGlobalSort(t *testing.T) { require.NoError(t, err) taskMeta, err := json.Marshal(task) require.NoError(t, err) - taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", taskMeta) + taskID, err := manager.CreateTask(ctx, importinto.TaskKey(jobID), proto.ImportInto, 1, "", 0, taskMeta) require.NoError(t, err) task.ID = taskID diff --git a/pkg/executor/importer/BUILD.bazel b/pkg/executor/importer/BUILD.bazel index c8b586e17b615..80f18b3cf30ce 100644 --- a/pkg/executor/importer/BUILD.bazel +++ b/pkg/executor/importer/BUILD.bazel @@ -17,6 +17,7 @@ go_library( "//br/pkg/storage", "//br/pkg/streamhelper", "//pkg/config", + "//pkg/ddl", "//pkg/ddl/util", "//pkg/disttask/framework/handle", "//pkg/disttask/framework/proto", @@ -48,6 +49,7 @@ go_library( "//pkg/planner/util", "//pkg/sessionctx", "//pkg/sessionctx/vardef", + "//pkg/sessionctx/variable", "//pkg/sessiontxn", "//pkg/table", "//pkg/table/tables", diff --git a/pkg/executor/importer/import.go b/pkg/executor/importer/import.go index f89232361809c..17bd7620e45df 100644 --- a/pkg/executor/importer/import.go +++ b/pkg/executor/importer/import.go @@ -32,6 +32,7 @@ import ( "github.com/pingcap/log" "github.com/pingcap/tidb/br/pkg/storage" tidb "github.com/pingcap/tidb/pkg/config" + "github.com/pingcap/tidb/pkg/ddl" "github.com/pingcap/tidb/pkg/ddl/util" "github.com/pingcap/tidb/pkg/disttask/framework/handle" "github.com/pingcap/tidb/pkg/expression" @@ -51,6 +52,7 @@ import ( plannerutil "github.com/pingcap/tidb/pkg/planner/util" "github.com/pingcap/tidb/pkg/sessionctx" "github.com/pingcap/tidb/pkg/sessionctx/vardef" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/table" tidbutil "github.com/pingcap/tidb/pkg/util" "github.com/pingcap/tidb/pkg/util/chunk" @@ -229,6 +231,7 @@ type Plan struct { DiskQuota config.ByteSize Checksum config.PostOpLevel ThreadCnt int + MaxNodeCnt int MaxWriteSpeed config.ByteSize SplitFile bool MaxRecordedErrors int64 @@ -751,6 +754,13 @@ func (p *Plan) initOptions(ctx context.Context, seCtx sessionctx.Context, option p.ForceMergeStep = true } + if sv, ok := seCtx.GetSessionVars().GetSystemVar(vardef.TiDBMaxDistTaskNodes); ok { + p.MaxNodeCnt = variable.TidbOptInt(sv, 0) + if p.MaxNodeCnt == -1 { // -1 means calculate automatically + p.MaxNodeCnt = ddl.GetDXFDefaultMaxNodeCntAuto(seCtx.GetStore()) + } + } + // when split-file is set, data file will be split into chunks of 256 MiB. // skip_rows should be 0 or 1, we add this restriction to simplify skip_rows // logic, so we only need to skip on the first chunk for each data file. diff --git a/pkg/executor/show_ddl_jobs.go b/pkg/executor/show_ddl_jobs.go index c0b3205de06a0..26dc43021a568 100644 --- a/pkg/executor/show_ddl_jobs.go +++ b/pkg/executor/show_ddl_jobs.go @@ -330,6 +330,9 @@ func showCommentsFromJob(job *model.Job) string { if m.TargetScope != "" { labels = append(labels, fmt.Sprintf("service_scope=%s", m.TargetScope)) } + if m.MaxNodeCount != 0 { + labels = append(labels, fmt.Sprintf("max_node_count=%d", m.MaxNodeCount)) + } } return strings.Join(labels, ", ") } diff --git a/pkg/executor/show_ddl_jobs_test.go b/pkg/executor/show_ddl_jobs_test.go index 87183e1d30342..8f6338e83d995 100644 --- a/pkg/executor/show_ddl_jobs_test.go +++ b/pkg/executor/show_ddl_jobs_test.go @@ -64,6 +64,15 @@ func TestShowCommentsFromJob(t *testing.T) { res = showCommentsFromJob(job) require.Equal(t, "ingest, DXF, cloud", res) + job.ReorgMeta = &model.DDLReorgMeta{ + ReorgTp: model.ReorgTypeLitMerge, + IsDistReorg: true, + UseCloudStorage: true, + MaxNodeCount: 5, + } + res = showCommentsFromJob(job) + require.Equal(t, "ingest, DXF, cloud, max_node_count=5", res) + job.ReorgMeta = &model.DDLReorgMeta{ ReorgTp: model.ReorgTypeLitMerge, IsDistReorg: true, diff --git a/pkg/meta/model/reorg.go b/pkg/meta/model/reorg.go index a0b2d0b6f810f..cb931a7d52214 100644 --- a/pkg/meta/model/reorg.go +++ b/pkg/meta/model/reorg.go @@ -73,6 +73,7 @@ type DDLReorgMeta struct { ResourceGroupName string `json:"resource_group_name"` Version int64 `json:"version"` TargetScope string `json:"target_scope"` + MaxNodeCount int `json:"max_node_count"` // These two variables are used to control the concurrency and batch size of the reorganization process. // They can be adjusted dynamically through `admin alter ddl jobs` command. // Note: Don't get or set these two variables directly, use the functions instead. diff --git a/pkg/session/bootstrap.go b/pkg/session/bootstrap.go index 6e37280bd4e51..92438f8618a0a 100644 --- a/pkg/session/bootstrap.go +++ b/pkg/session/bootstrap.go @@ -600,6 +600,7 @@ const ( target_scope VARCHAR(256) DEFAULT "", error BLOB, modify_params json, + max_node_count INT DEFAULT 0, key(state), UNIQUE KEY task_key(task_key) );` @@ -622,6 +623,7 @@ const ( target_scope VARCHAR(256) DEFAULT "", error BLOB, modify_params json, + max_node_count INT DEFAULT 0, key(state), UNIQUE KEY task_key(task_key) );` @@ -1248,11 +1250,14 @@ const ( // version 242 // insert `cluster_id` into the `mysql.tidb` table. version242 = 242 + + // Add max_node_count column to tidb_global_task and tidb_global_task_history. + version243 = 243 ) // currentBootstrapVersion is defined as a variable, so we can modify its value for testing. // please make sure this is the largest version -var currentBootstrapVersion int64 = version242 +var currentBootstrapVersion int64 = version243 // DDL owner key's expired time is ManagerSessionTTL seconds, we should wait the time and give more time to have a chance to finish it. var internalSQLTimeout = owner.ManagerSessionTTL + 15 @@ -1430,6 +1435,7 @@ var ( upgradeToVer240, upgradeToVer241, upgradeToVer242, + upgradeToVer243, } ) @@ -3349,6 +3355,14 @@ func upgradeToVer242(s sessiontypes.Session, ver int64) { writeClusterID(s) } +func upgradeToVer243(s sessiontypes.Session, ver int64) { + if ver >= version243 { + return + } + doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task ADD COLUMN max_node_count INT DEFAULT 0 AFTER `modify_params`;", infoschema.ErrColumnExists) + doReentrantDDL(s, "ALTER TABLE mysql.tidb_global_task_history ADD COLUMN max_node_count INT DEFAULT 0 AFTER `modify_params`;", infoschema.ErrColumnExists) +} + // initGlobalVariableIfNotExists initialize a global variable with specific val if it does not exist. func initGlobalVariableIfNotExists(s sessiontypes.Session, name string, val any) { ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnBootstrap) diff --git a/pkg/sessionctx/vardef/tidb_vars.go b/pkg/sessionctx/vardef/tidb_vars.go index 6efe98cb3ffd7..bae379b47b70d 100644 --- a/pkg/sessionctx/vardef/tidb_vars.go +++ b/pkg/sessionctx/vardef/tidb_vars.go @@ -1054,6 +1054,8 @@ const ( TiDBAutoAnalyzeConcurrency = "tidb_auto_analyze_concurrency" // TiDBEnableDistTask indicates whether to enable the distributed execute background tasks(For example DDL, Import etc). TiDBEnableDistTask = "tidb_enable_dist_task" + // TiDBMaxDistTaskNodes indicates the max node count that could be used by distributed execution framework. + TiDBMaxDistTaskNodes = "tidb_max_dist_task_nodes" // TiDBEnableFastCreateTable indicates whether to enable the fast create table feature. TiDBEnableFastCreateTable = "tidb_enable_fast_create_table" // TiDBGenerateBinaryPlan indicates whether binary plan should be generated in slow log and statements summary. @@ -1446,6 +1448,7 @@ const ( DefTiDBEnableWorkloadBasedLearning = false DefTiDBWorkloadBasedLearningInterval = 24 * time.Hour DefTiDBEnableDistTask = true + DefTiDBMaxDistTaskNodes = -1 DefTiDBEnableFastCreateTable = true DefTiDBSimplifiedMetrics = false DefTiDBEnablePaging = true diff --git a/pkg/sessionctx/variable/sysvar.go b/pkg/sessionctx/variable/sysvar.go index cbdf158594265..ec43def999045 100644 --- a/pkg/sessionctx/variable/sysvar.go +++ b/pkg/sessionctx/variable/sysvar.go @@ -3422,6 +3422,7 @@ var defaultSysVars = []*SysVar{ s.SharedLockPromotion = TiDBOptOn(val) return nil }}, + {Scope: vardef.ScopeGlobal | vardef.ScopeSession, Name: vardef.TiDBMaxDistTaskNodes, Value: strconv.Itoa(vardef.DefTiDBMaxDistTaskNodes), Type: vardef.TypeInt, MinValue: -1, MaxValue: 128}, {Scope: vardef.ScopeGlobal, Name: vardef.TiDBTSOClientRPCMode, Value: vardef.DefTiDBTSOClientRPCMode, Type: vardef.TypeEnum, PossibleValues: []string{vardef.TSOClientRPCModeDefault, vardef.TSOClientRPCModeParallel, vardef.TSOClientRPCModeParallelFast}, SetGlobal: func(_ context.Context, s *SessionVars, val string) error { return (*SetPDClientDynamicOption.Load())(vardef.TiDBTSOClientRPCMode, val)