diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go index 4ba21a4620..ce1aca9ee6 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch.go @@ -146,17 +146,16 @@ func (p pytorchOperatorResourceHandler) BuildResource(ctx context.Context, taskC } jobSpec := kubeflowv1.PyTorchJobSpec{} - if *workerReplicaSpec.Replicas <= 0 { - replicaSpecs := map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ - kubeflowv1.PyTorchJobReplicaTypeMaster: masterReplicaSpec, - } - if workerReplicaSpec != nil { - replicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker] = workerReplicaSpec - } - jobSpec = kubeflowv1.PyTorchJobSpec{ - PyTorchReplicaSpecs: replicaSpecs, - RunPolicy: runPolicy, - } + replicaSpecs := map[commonOp.ReplicaType]*commonOp.ReplicaSpec{ + kubeflowv1.PyTorchJobReplicaTypeMaster: masterReplicaSpec, + } + if *workerReplicaSpec.Replicas > 0 { + replicaSpecs[kubeflowv1.PyTorchJobReplicaTypeWorker] = workerReplicaSpec + } + + jobSpec = kubeflowv1.PyTorchJobSpec{ + PyTorchReplicaSpecs: replicaSpecs, + RunPolicy: runPolicy, } if elasticPolicy != nil {