From 58408c717e69c5331a117033e8ecd4b407c31be5 Mon Sep 17 00:00:00 2001 From: chaunceyjiang Date: Thu, 29 Dec 2022 18:38:38 +0800 Subject: [PATCH] built-in interpreter support statefulset Signed-off-by: chaunceyjiang --- .../defaultinterpreter/replica.go | 17 +++++ .../defaultinterpreter/revisereplica.go | 8 ++ .../defaultinterpreter/revisereplica_test.go | 74 +++++++++++++++++++ 3 files changed, 99 insertions(+) diff --git a/pkg/resourceinterpreter/defaultinterpreter/replica.go b/pkg/resourceinterpreter/defaultinterpreter/replica.go index d52c18c5fb6d..a22401245454 100644 --- a/pkg/resourceinterpreter/defaultinterpreter/replica.go +++ b/pkg/resourceinterpreter/defaultinterpreter/replica.go @@ -18,6 +18,7 @@ type replicaInterpreter func(object *unstructured.Unstructured) (int32, *workv1a func getAllDefaultReplicaInterpreter() map[schema.GroupVersionKind]replicaInterpreter { s := make(map[schema.GroupVersionKind]replicaInterpreter) s[appsv1.SchemeGroupVersion.WithKind(util.DeploymentKind)] = deployReplica + s[appsv1.SchemeGroupVersion.WithKind(util.StatefulSetKind)] = statefulSetReplica s[batchv1.SchemeGroupVersion.WithKind(util.JobKind)] = jobReplica return s } @@ -38,6 +39,22 @@ func deployReplica(object *unstructured.Unstructured) (int32, *workv1alpha2.Repl return replica, requirement, nil } +func statefulSetReplica(object *unstructured.Unstructured) (int32, *workv1alpha2.ReplicaRequirements, error) { + sts := &appsv1.StatefulSet{} + if err := helper.ConvertToTypedObject(object, sts); err != nil { + klog.Errorf("Failed to convert object(%s), err", object.GroupVersionKind().String(), err) + return 0, nil, err + } + + var replica int32 + if sts.Spec.Replicas != nil { + replica = *sts.Spec.Replicas + } + requirement := helper.GenerateReplicaRequirements(&sts.Spec.Template) + + return replica, requirement, nil +} + func jobReplica(object *unstructured.Unstructured) (int32, *workv1alpha2.ReplicaRequirements, error) { job := &batchv1.Job{} err := helper.ConvertToTypedObject(object, job) diff --git a/pkg/resourceinterpreter/defaultinterpreter/revisereplica.go b/pkg/resourceinterpreter/defaultinterpreter/revisereplica.go index e6770f09d54d..9de55b07f30c 100644 --- a/pkg/resourceinterpreter/defaultinterpreter/revisereplica.go +++ b/pkg/resourceinterpreter/defaultinterpreter/revisereplica.go @@ -15,6 +15,7 @@ type reviseReplicaInterpreter func(object *unstructured.Unstructured, replica in func getAllDefaultReviseReplicaInterpreter() map[schema.GroupVersionKind]reviseReplicaInterpreter { s := make(map[schema.GroupVersionKind]reviseReplicaInterpreter) s[appsv1.SchemeGroupVersion.WithKind(util.DeploymentKind)] = reviseDeploymentReplica + s[appsv1.SchemeGroupVersion.WithKind(util.StatefulSetKind)] = reviseStatefulSetReplica s[batchv1.SchemeGroupVersion.WithKind(util.JobKind)] = reviseJobReplica return s } @@ -26,6 +27,13 @@ func reviseDeploymentReplica(object *unstructured.Unstructured, replica int64) ( return object, nil } +func reviseStatefulSetReplica(object *unstructured.Unstructured, replica int64) (*unstructured.Unstructured, error) { + if err := helper.ApplyReplica(object, replica, util.ReplicasField); err != nil { + return nil, err + } + return object, nil +} + func reviseJobReplica(object *unstructured.Unstructured, replica int64) (*unstructured.Unstructured, error) { if err := helper.ApplyReplica(object, replica, util.ParallelismField); err != nil { return nil, err diff --git a/pkg/resourceinterpreter/defaultinterpreter/revisereplica_test.go b/pkg/resourceinterpreter/defaultinterpreter/revisereplica_test.go index c7499ff11c6d..a6445023c4dc 100644 --- a/pkg/resourceinterpreter/defaultinterpreter/revisereplica_test.go +++ b/pkg/resourceinterpreter/defaultinterpreter/revisereplica_test.go @@ -154,3 +154,77 @@ func TestReviseJobReplica(t *testing.T) { }) } } + +func TestReviseStatefulSetReplica(t *testing.T) { + tests := []struct { + name string + object *unstructured.Unstructured + replica int64 + expected *unstructured.Unstructured + expectError bool + }{ + { + name: "StatefulSet .spec.replicas accessor error, expected int64", + object: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "apps/v1", + "kind": "StatefulSet", + "metadata": map[string]interface{}{ + "name": "fake-statefulset", + }, + "spec": map[string]interface{}{ + "replicas": 1, + }, + }, + }, + replica: 3, + expectError: true, + }, + { + name: "revise statefulset replica", + object: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "apps/v1", + "kind": "StatefulSet", + "metadata": map[string]interface{}{ + "name": "fake-statefulset", + }, + "spec": map[string]interface{}{ + "replicas": int64(1), + }, + }, + }, + replica: 3, + expected: &unstructured.Unstructured{ + Object: map[string]interface{}{ + "apiVersion": "apps/v1", + "kind": "StatefulSet", + "metadata": map[string]interface{}{ + "name": "fake-statefulset", + }, + "spec": map[string]interface{}{ + "replicas": int64(3), + }, + }, + }, + expectError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + res, err := reviseStatefulSetReplica(tt.object, tt.replica) + if err == nil && tt.expectError == true { + t.Fatal("expect an error but got none") + } + if err != nil && tt.expectError != true { + t.Fatalf("expect no error but got: %v", err) + } + if err == nil && tt.expectError == false { + if !reflect.DeepEqual(res, tt.expected) { + t.Errorf("reviseStatefulSetReplica() = %v, want %v", res, tt.expected) + } + } + }) + } +}