Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

built-in interpreter support statefulset #3009

Merged
merged 1 commit into from
Jan 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions pkg/resourceinterpreter/defaultinterpreter/replica.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type replicaInterpreter func(object *unstructured.Unstructured) (int32, *workv1a
func getAllDefaultReplicaInterpreter() map[schema.GroupVersionKind]replicaInterpreter {
s := make(map[schema.GroupVersionKind]replicaInterpreter)
s[appsv1.SchemeGroupVersion.WithKind(util.DeploymentKind)] = deployReplica
s[appsv1.SchemeGroupVersion.WithKind(util.StatefulSetKind)] = statefulSetReplica
s[batchv1.SchemeGroupVersion.WithKind(util.JobKind)] = jobReplica
return s
}
Expand All @@ -38,6 +39,22 @@ func deployReplica(object *unstructured.Unstructured) (int32, *workv1alpha2.Repl
return replica, requirement, nil
}

func statefulSetReplica(object *unstructured.Unstructured) (int32, *workv1alpha2.ReplicaRequirements, error) {
sts := &appsv1.StatefulSet{}
if err := helper.ConvertToTypedObject(object, sts); err != nil {
klog.Errorf("Failed to convert object(%s), err", object.GroupVersionKind().String(), err)
return 0, nil, err
}

var replica int32
if sts.Spec.Replicas != nil {
replica = *sts.Spec.Replicas
}
requirement := helper.GenerateReplicaRequirements(&sts.Spec.Template)

return replica, requirement, nil
}

func jobReplica(object *unstructured.Unstructured) (int32, *workv1alpha2.ReplicaRequirements, error) {
job := &batchv1.Job{}
err := helper.ConvertToTypedObject(object, job)
Expand Down
8 changes: 8 additions & 0 deletions pkg/resourceinterpreter/defaultinterpreter/revisereplica.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type reviseReplicaInterpreter func(object *unstructured.Unstructured, replica in
func getAllDefaultReviseReplicaInterpreter() map[schema.GroupVersionKind]reviseReplicaInterpreter {
s := make(map[schema.GroupVersionKind]reviseReplicaInterpreter)
s[appsv1.SchemeGroupVersion.WithKind(util.DeploymentKind)] = reviseDeploymentReplica
s[appsv1.SchemeGroupVersion.WithKind(util.StatefulSetKind)] = reviseStatefulSetReplica
s[batchv1.SchemeGroupVersion.WithKind(util.JobKind)] = reviseJobReplica
return s
}
Expand All @@ -26,6 +27,13 @@ func reviseDeploymentReplica(object *unstructured.Unstructured, replica int64) (
return object, nil
}

func reviseStatefulSetReplica(object *unstructured.Unstructured, replica int64) (*unstructured.Unstructured, error) {
if err := helper.ApplyReplica(object, replica, util.ReplicasField); err != nil {
return nil, err
}
return object, nil
}

func reviseJobReplica(object *unstructured.Unstructured, replica int64) (*unstructured.Unstructured, error) {
if err := helper.ApplyReplica(object, replica, util.ParallelismField); err != nil {
return nil, err
Expand Down
74 changes: 74 additions & 0 deletions pkg/resourceinterpreter/defaultinterpreter/revisereplica_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,77 @@ func TestReviseJobReplica(t *testing.T) {
})
}
}

func TestReviseStatefulSetReplica(t *testing.T) {
tests := []struct {
name string
object *unstructured.Unstructured
replica int64
expected *unstructured.Unstructured
expectError bool
}{
{
name: "StatefulSet .spec.replicas accessor error, expected int64",
object: &unstructured.Unstructured{
Object: map[string]interface{}{
"apiVersion": "apps/v1",
"kind": "StatefulSet",
"metadata": map[string]interface{}{
"name": "fake-statefulset",
},
"spec": map[string]interface{}{
"replicas": 1,
},
},
},
replica: 3,
expectError: true,
},
{
name: "revise statefulset replica",
object: &unstructured.Unstructured{
Object: map[string]interface{}{
"apiVersion": "apps/v1",
"kind": "StatefulSet",
"metadata": map[string]interface{}{
"name": "fake-statefulset",
},
"spec": map[string]interface{}{
"replicas": int64(1),
},
},
},
replica: 3,
expected: &unstructured.Unstructured{
Object: map[string]interface{}{
"apiVersion": "apps/v1",
"kind": "StatefulSet",
"metadata": map[string]interface{}{
"name": "fake-statefulset",
},
"spec": map[string]interface{}{
"replicas": int64(3),
},
},
},
expectError: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
res, err := reviseStatefulSetReplica(tt.object, tt.replica)
if err == nil && tt.expectError == true {
t.Fatal("expect an error but got none")
}
if err != nil && tt.expectError != true {
t.Fatalf("expect no error but got: %v", err)
}
if err == nil && tt.expectError == false {
if !reflect.DeepEqual(res, tt.expected) {
t.Errorf("reviseStatefulSetReplica() = %v, want %v", res, tt.expected)
}
}
})
}
}