From f5f41825eb28c804a850caad9ba3357a3e1d41d7 Mon Sep 17 00:00:00 2001 From: ByronHsu Date: Mon, 17 Apr 2023 08:36:46 -0700 Subject: [PATCH] Override primary container name instead of flyte generated name (#340) Signed-off-by: byhsu Co-authored-by: byhsu --- .../pluginmachinery/flytek8s/pod_helper.go | 8 ++++---- .../flytek8s/pod_helper_test.go | 18 +++++++++--------- .../k8s/kfoperators/common/common_operator.go | 6 ++---- go/tasks/plugins/k8s/kfoperators/mpi/mpi.go | 4 ++-- .../plugins/k8s/kfoperators/pytorch/pytorch.go | 4 ++-- .../k8s/kfoperators/tensorflow/tensorflow.go | 4 ++-- 6 files changed, 21 insertions(+), 23 deletions(-) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/go/tasks/pluginmachinery/flytek8s/pod_helper.go index ee26ce4dc..2e4493ccc 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -255,20 +255,20 @@ func ApplyFlytePodConfiguration(ctx context.Context, tCtx pluginsCore.TaskExecut // ToK8sPodSpec builds a PodSpec and ObjectMeta based on the definition passed by the TaskExecutionContext. This // involves parsing the raw PodSpec definition and applying all Flyte configuration options. -func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, error) { +func ToK8sPodSpec(ctx context.Context, tCtx pluginsCore.TaskExecutionContext) (*v1.PodSpec, *metav1.ObjectMeta, string, error) { // build raw PodSpec and ObjectMeta podSpec, objectMeta, primaryContainerName, err := BuildRawPod(ctx, tCtx) if err != nil { - return nil, nil, err + return nil, nil, "", err } // add flyte configuration podSpec, objectMeta, err = ApplyFlytePodConfiguration(ctx, tCtx, podSpec, objectMeta, primaryContainerName) if err != nil { - return nil, nil, err + return nil, nil, "", err } - return podSpec, objectMeta, nil + return podSpec, objectMeta, primaryContainerName, nil } // getBasePodTemplate attempts to retrieve the PodTemplate to use as the base for k8s Pod configuration. This value can diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index 58a5e012a..bb99bafb5 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -324,7 +324,7 @@ func toK8sPodInterruptible(t *testing.T) { }, }) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Len(t, p.Tolerations, 2) assert.Equal(t, "x/flyte", p.Tolerations[1].Key) @@ -391,7 +391,7 @@ func TestToK8sPod(t *testing.T) { }, }) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, len(p.Tolerations), 1) }) @@ -408,7 +408,7 @@ func TestToK8sPod(t *testing.T) { }, }) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, len(p.Tolerations), 0) assert.Equal(t, "some-acceptable-name", p.Containers[0].Name) @@ -435,7 +435,7 @@ func TestToK8sPod(t *testing.T) { DefaultMemoryRequest: resource.MustParse("1024Mi"), })) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.Equal(t, 1, len(p.NodeSelector)) assert.Equal(t, "myScheduler", p.SchedulerName) @@ -452,7 +452,7 @@ func TestToK8sPod(t *testing.T) { })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.NotNil(t, p.SecurityContext) assert.Equal(t, *p.SecurityContext.RunAsGroup, v) @@ -464,7 +464,7 @@ func TestToK8sPod(t *testing.T) { EnableHostNetworkingPod: &enabled, })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.True(t, p.HostNetwork) }) @@ -475,7 +475,7 @@ func TestToK8sPod(t *testing.T) { EnableHostNetworkingPod: &enabled, })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.False(t, p.HostNetwork) }) @@ -483,7 +483,7 @@ func TestToK8sPod(t *testing.T) { t.Run("skipSettingHostNetwork", func(t *testing.T) { assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{})) x := dummyExecContext(&v1.ResourceRequirements{}) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.False(t, p.HostNetwork) }) @@ -517,7 +517,7 @@ func TestToK8sPod(t *testing.T) { })) x := dummyExecContext(&v1.ResourceRequirements{}) - p, _, err := ToK8sPodSpec(ctx, x) + p, _, _, err := ToK8sPodSpec(ctx, x) assert.NoError(t, err) assert.NotNil(t, p.DNSConfig) assert.Equal(t, []string{"8.8.8.8", "8.8.4.4"}, p.DNSConfig.Nameservers) diff --git a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go index 41eb89644..88419b64c 100644 --- a/go/tasks/plugins/k8s/kfoperators/common/common_operator.go +++ b/go/tasks/plugins/k8s/kfoperators/common/common_operator.go @@ -166,17 +166,15 @@ func GetLogs(taskType string, name string, namespace string, return taskLogs, nil } -func OverrideDefaultContainerName(taskCtx pluginsCore.TaskExecutionContext, podSpec *v1.PodSpec, - defaultContainerName string) { +func OverridePrimaryContainerName(podSpec *v1.PodSpec, primaryContainerName string, defaultContainerName string) { // Pytorch operator forces pod to have container named 'pytorch' // https://github.com/kubeflow/pytorch-operator/blob/037cd1b18eb77f657f2a4bc8a8334f2a06324b57/pkg/apis/pytorch/validation/validation.go#L54-L62 // Tensorflow operator forces pod to have container named 'tensorflow' // https://github.com/kubeflow/tf-operator/blob/984adc287e6fe82841e4ca282dc9a2cbb71e2d4a/pkg/apis/tensorflow/validation/validation.go#L55-L63 // hence we have to override the name set here // https://github.com/flyteorg/flyteplugins/blob/209c52d002b4e6a39be5d175bc1046b7e631c153/go/tasks/pluginmachinery/flytek8s/container_helper.go#L116 - flyteDefaultContainerName := taskCtx.TaskExecutionMetadata().GetTaskExecutionID().GetGeneratedName() for idx, c := range podSpec.Containers { - if c.Name == flyteDefaultContainerName { + if c.Name == primaryContainerName { podSpec.Containers[idx].Name = defaultContainerName return } diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go index f605088e5..eaf6c3780 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi.go @@ -62,11 +62,11 @@ func (mpiOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx plu launcherReplicas := mpiTaskExtraArgs.GetNumLauncherReplicas() slots := mpiTaskExtraArgs.GetSlots() - podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.MPIJobDefaultContainerName) + common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.MPIJobDefaultContainerName) // workersPodSpec is deepCopy of podSpec submitted by flyte // WorkerPodSpec doesn't need any Argument & command. It will be trigger from launcher pod diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 287590ed5..71550c4bc 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -62,11 +62,11 @@ func (pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskCtx return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.PytorchJobDefaultContainerName) + common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.PytorchJobDefaultContainerName) workers := pytorchTaskExtraArgs.GetWorkers() if workers == 0 { diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go index f03bbab5e..b5a5a675f 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow.go @@ -62,11 +62,11 @@ func (tensorflowOperatorResourceHandler) BuildResource(ctx context.Context, task return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "invalid TaskSpecification [%v], Err: [%v]", taskTemplate.GetCustom(), err.Error()) } - podSpec, objectMeta, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) + podSpec, objectMeta, primaryContainerName, err := flytek8s.ToK8sPodSpec(ctx, taskCtx) if err != nil { return nil, flyteerr.Errorf(flyteerr.BadTaskSpecification, "Unable to create pod spec: [%v]", err.Error()) } - common.OverrideDefaultContainerName(taskCtx, podSpec, kubeflowv1.TFJobDefaultContainerName) + common.OverridePrimaryContainerName(podSpec, primaryContainerName, kubeflowv1.TFJobDefaultContainerName) workers := tensorflowTaskExtraArgs.GetWorkers() psReplicas := tensorflowTaskExtraArgs.GetPsReplicas()