diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index ab319561cf..3d51ac183d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -972,113 +972,34 @@ func TestBuildResourcePytorchV1WithRunPolicy(t *testing.T) { } } -func TestBuildResourcePytorchV1WithOnlyWorkerSpec(t *testing.T) { +func TestBuildResourcePytorchV1WithDifferentWorkersNumber(t *testing.T) { taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ { + // Test case 1: Zero workers - should only have master WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 100, + Replicas: 0, + }, + MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Image: testImageMaster, Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, + {Name: core.Resources_CPU, Value: "250m"}, + {Name: core.Resources_MEMORY, Value: "250Mi"}, }, Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, + {Name: core.Resources_CPU, Value: "500m"}, + {Name: core.Resources_MEMORY, Value: "500Mi"}, }, }, }, }, { + // Test case 2: One worker - should have both master and worker WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Common: &kfplugins.CommonReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - }, - }, - }, - }, - }, - } - - for _, taskConfig := range taskConfigs { - // Master Replica should use resource from task override if not set - taskOverrideResourceRequirements := &corev1.ResourceRequirements{ - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1000m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("100m"), - corev1.ResourceMemory: resource.MustParse("512Mi"), - flytek8s.ResourceNvidiaGPU: resource.MustParse("1"), - }, - } - - workerResourceRequirements := &corev1.ResourceRequirements{ - Requests: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("1024m"), - corev1.ResourceMemory: resource.MustParse("1Gi"), - }, - Limits: corev1.ResourceList{ - corev1.ResourceCPU: resource.MustParse("2048m"), - corev1.ResourceMemory: resource.MustParse("2Gi"), + Replicas: 1, }, - } - - pytorchResourceHandler := pytorchOperatorResourceHandler{} - - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 - - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) - assert.NoError(t, err) - assert.NotNil(t, res) - - pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) - - assert.Equal(t, int32(100), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.Equal(t, int32(1), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Replicas) - - assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Image) - assert.Equal(t, testImage, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Image) - - assert.Equal(t, *taskOverrideResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Containers[0].Resources) - assert.Equal(t, *workerResourceRequirements, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers[0].Resources) - - assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].RestartPolicy) - assert.Equal(t, commonOp.RestartPolicyNever, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].RestartPolicy) - - assert.Nil(t, pytorchJob.Spec.ElasticPolicy) - } -} - -func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { - gpuToleration := corev1.Toleration{ - Key: "nvidia.com/gpu", - Value: "present", - Operator: corev1.TolerationOpEqual, - Effect: corev1.TaintEffectNoSchedule, - } - assert.NoError(t, flytek8sConfig.SetK8sPluginConfig(&flytek8sConfig.K8sPluginConfig{ - GpuResourceName: flytek8s.ResourceNvidiaGPU, - ResourceTolerations: map[corev1.ResourceName][]corev1.Toleration{ - flytek8s.ResourceNvidiaGPU: {gpuToleration}, - }, - })) - - taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ - { MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ + Image: testImageMaster, Resources: &core.Resources{ Requests: []*core.Resources_ResourceEntry{ {Name: core.Resources_CPU, Value: "250m"}, @@ -1090,149 +1011,49 @@ func TestBuildResourcePytorchV1ResourceTolerations(t *testing.T) { }, }, }, - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, - }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, - }, - }, - }, - }, - { - MasterReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Common: &kfplugins.CommonReplicaSpec{ - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "250m"}, - {Name: core.Resources_MEMORY, Value: "250Mi"}, - }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "500m"}, - {Name: core.Resources_MEMORY, Value: "500Mi"}, - }, - }, - }, - }, - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Common: &kfplugins.CommonReplicaSpec{ - Replicas: 100, - Resources: &core.Resources{ - Requests: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "1024m"}, - {Name: core.Resources_MEMORY, Value: "1Gi"}, - {Name: core.Resources_GPU, Value: "1"}, - }, - Limits: []*core.Resources_ResourceEntry{ - {Name: core.Resources_CPU, Value: "2048m"}, - {Name: core.Resources_MEMORY, Value: "2Gi"}, - {Name: core.Resources_GPU, Value: "1"}, - }, - }, - }, - }, }, } - for _, taskConfig := range taskConfigs { - pytorchResourceHandler := pytorchOperatorResourceHandler{} - - taskTemplate := dummyPytorchTaskTemplate("job4", taskConfig) - taskTemplate.TaskTypeVersion = 1 - - res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) - assert.NoError(t, err) - assert.NotNil(t, res) - - pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) - - assert.NotContains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster].Template.Spec.Tolerations, gpuToleration) - assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Tolerations, gpuToleration) - } -} - -func TestBuildResourcePytorchV1WithElastic(t *testing.T) { - taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ - { - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 2, - }, - ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}, - }, - { - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Common: &kfplugins.CommonReplicaSpec{ - Replicas: 2, - }, - }, - ElasticConfig: &kfplugins.ElasticConfig{MinReplicas: 1, MaxReplicas: 2, NprocPerNode: 4, RdzvBackend: "c10d"}, - }, - } - - for _, taskConfig := range taskConfigs { - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 - - pytorchResourceHandler := pytorchOperatorResourceHandler{} - resource, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) - assert.NoError(t, err) - assert.NotNil(t, resource) + for i, taskConfig := range taskConfigs { + t.Run(fmt.Sprintf("Case %d", i+1), func(t *testing.T) { + pytorchResourceHandler := pytorchOperatorResourceHandler{} - pytorchJob, ok := resource.(*kubeflowv1.PyTorchJob) - assert.True(t, ok) - assert.Equal(t, int32(2), *pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Replicas) - assert.NotNil(t, pytorchJob.Spec.ElasticPolicy) - assert.Equal(t, int32(1), *pytorchJob.Spec.ElasticPolicy.MinReplicas) - assert.Equal(t, int32(2), *pytorchJob.Spec.ElasticPolicy.MaxReplicas) - assert.Equal(t, int32(4), *pytorchJob.Spec.ElasticPolicy.NProcPerNode) - assert.Equal(t, kubeflowv1.RDZVBackend("c10d"), *pytorchJob.Spec.ElasticPolicy.RDZVBackend) + taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) + taskTemplate.TaskTypeVersion = 1 - assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs)) - assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) + res, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) + assert.NoError(t, err) + assert.NotNil(t, res) - var hasContainerWithDefaultPytorchName = false + pytorchJob, ok := res.(*kubeflowv1.PyTorchJob) + assert.True(t, ok) - for _, container := range pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker].Template.Spec.Containers { - if container.Name == kubeflowv1.PytorchJobDefaultContainerName { - hasContainerWithDefaultPytorchName = true + if taskConfig.WorkerReplicas.Replicas == 0 { + // Should only contain master spec + assert.Equal(t, 1, len(pytorchJob.Spec.PyTorchReplicaSpecs)) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeMaster) + assert.NotContains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) + + // Verify master spec details + masterSpec := pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster] + assert.Equal(t, int32(1), *masterSpec.Replicas) + assert.Equal(t, testImageMaster, masterSpec.Template.Spec.Containers[0].Image) + } else { + // Should contain both master and worker specs + assert.Equal(t, 2, len(pytorchJob.Spec.PyTorchReplicaSpecs)) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeMaster) + assert.Contains(t, pytorchJob.Spec.PyTorchReplicaSpecs, kubeflowv1.PyTorchJobReplicaTypeWorker) + + // Verify master spec details + masterSpec := pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeMaster] + assert.Equal(t, int32(1), *masterSpec.Replicas) + assert.Equal(t, testImageMaster, masterSpec.Template.Spec.Containers[0].Image) + + // Verify worker spec details + workerSpec := pytorchJob.Spec.PyTorchReplicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker] + assert.Equal(t, int32(1), *workerSpec.Replicas) } - } - - assert.True(t, hasContainerWithDefaultPytorchName) - } -} - -func TestBuildResourcePytorchV1WithZeroWorker(t *testing.T) { - taskConfigs := []*kfplugins.DistributedPyTorchTrainingTask{ - { - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Replicas: 0, - }, - }, - { - WorkerReplicas: &kfplugins.DistributedPyTorchTrainingReplicaSpec{ - Common: &kfplugins.CommonReplicaSpec{ - Replicas: 0, - }, - }, - }, - } - - for _, taskConfig := range taskConfigs { - pytorchResourceHandler := pytorchOperatorResourceHandler{} - - taskTemplate := dummyPytorchTaskTemplate("job5", taskConfig) - taskTemplate.TaskTypeVersion = 1 - _, err := pytorchResourceHandler.BuildResource(context.TODO(), dummyPytorchTaskContext(taskTemplate, resourceRequirements, nil, "", k8s.PluginState{})) - assert.Error(t, err) + }) } }