diff --git a/api/conf/conf.go b/api/conf/conf.go index b34c339db6..6ea9af02de 100644 --- a/api/conf/conf.go +++ b/api/conf/conf.go @@ -101,7 +101,11 @@ type Config struct { Authentication Authentication } +// TODO: it is just for integration tests, we should call "InitLog" explicitly when remove all handler's integration tests func init() { + InitConf() +} +func InitConf() { //go test if workDir := os.Getenv("APISIX_API_WORKDIR"); workDir != "" { WorkDir = workDir diff --git a/api/filter/logging_test.go b/api/filter/logging_test.go index 087d8b0e85..688befa525 100644 --- a/api/filter/logging_test.go +++ b/api/filter/logging_test.go @@ -17,30 +17,30 @@ package filter import ( - "net/http" - "net/http/httptest" - "testing" + "net/http" + "net/http/httptest" + "testing" - "github.com/gin-gonic/gin" - "github.com/stretchr/testify/assert" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/assert" - "github.com/apisix/manager-api/log" + "github.com/apisix/manager-api/log" ) func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder { - req := httptest.NewRequest(method, path, nil) - w := httptest.NewRecorder() - r.ServeHTTP(w, req) - return w + req := httptest.NewRequest(method, path, nil) + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + return w } func TestRequestLogHandler(t *testing.T) { - r := gin.New() - logger := log.GetLogger(log.AccessLog) - r.Use(RequestLogHandler(logger)) - r.GET("/", func(c *gin.Context) { - }) + r := gin.New() + logger := log.GetLogger(log.AccessLog) + r.Use(RequestLogHandler(logger)) + r.GET("/", func(c *gin.Context) { + }) - w := performRequest(r, "GET", "/") - assert.Equal(t, 200, w.Code) + w := performRequest(r, "GET", "/") + assert.Equal(t, 200, w.Code) } diff --git a/api/internal/core/store/store.go b/api/internal/core/store/store.go index d9c7e36915..ee3fa5038e 100644 --- a/api/internal/core/store/store.go +++ b/api/internal/core/store/store.go @@ -38,7 +38,7 @@ type Interface interface { Get(key string) (interface{}, error) List(input ListInput) (*ListOutput, error) Create(ctx context.Context, obj interface{}) error - Update(ctx context.Context, obj interface{}, createOnFail bool) error + Update(ctx context.Context, obj interface{}, createIfNotExist bool) error BatchDelete(ctx context.Context, keys []string) error } diff --git a/api/internal/core/store/store_mock.go b/api/internal/core/store/store_mock.go new file mode 100644 index 0000000000..258f083970 --- /dev/null +++ b/api/internal/core/store/store_mock.go @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package store + +import ( + "context" + "github.com/stretchr/testify/mock" +) + +type MockInterface struct { + mock.Mock +} + +func (m *MockInterface) Get(key string) (interface{}, error) { + ret := m.Mock.Called(key) + return ret.Get(0), ret.Error(1) +} + +func (m *MockInterface) List(input ListInput) (*ListOutput, error) { + ret := m.Called(input) + + var ( + r0 *ListOutput + r1 error + ) + + if rf, ok := ret.Get(0).(func(ListInput) *ListOutput); ok { + r0 = rf(input) + } else { + r0 = ret.Get(0).(*ListOutput) + } + r1 = ret.Error(1) + + return r0, r1 +} + +func (m *MockInterface) Create(ctx context.Context, obj interface{}) error { + ret := m.Mock.Called(ctx, obj) + return ret.Error(0) +} + +func (m *MockInterface) Update(ctx context.Context, obj interface{}, createOnFail bool) error { + ret := m.Mock.Called(ctx, obj, createOnFail) + return ret.Error(0) +} + +func (m *MockInterface) BatchDelete(ctx context.Context, keys []string) error { + ret := m.Mock.Called(ctx, keys) + return ret.Error(0) +} diff --git a/api/internal/handler/consumer/consumer.go b/api/internal/handler/consumer/consumer.go index a5854c583c..31b7af37f2 100644 --- a/api/internal/handler/consumer/consumer.go +++ b/api/internal/handler/consumer/consumer.go @@ -17,21 +17,17 @@ package consumer import ( - "fmt" - "net/http" "reflect" "strings" "github.com/gin-gonic/gin" "github.com/shiningrush/droplet" - "github.com/shiningrush/droplet/data" "github.com/shiningrush/droplet/wrapper" wgin "github.com/shiningrush/droplet/wrapper/gin" "github.com/apisix/manager-api/internal/core/entity" "github.com/apisix/manager-api/internal/core/store" "github.com/apisix/manager-api/internal/handler" - "github.com/apisix/manager-api/internal/utils" ) type Handler struct { @@ -56,7 +52,7 @@ func (h *Handler) ApplyRoute(r *gin.Engine) { r.PUT("/apisix/admin/consumers", wgin.Wraps(h.Update, wrapper.InputType(reflect.TypeOf(UpdateInput{})))) r.DELETE("/apisix/admin/consumers/:usernames", wgin.Wraps(h.BatchDelete, - wrapper.InputType(reflect.TypeOf(BatchDelete{})))) + wrapper.InputType(reflect.TypeOf(BatchDeleteInput{})))) } type GetInput struct { @@ -134,19 +130,9 @@ func (h *Handler) List(c droplet.Context) (interface{}, error) { func (h *Handler) Create(c droplet.Context) (interface{}, error) { input := c.Input().(*entity.Consumer) - if input.ID != nil && utils.InterfaceToString(input.ID) != input.Username { - return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, - fmt.Errorf("consumer's id and username must be a same value") - } input.ID = input.Username - if _, ok := input.Plugins["jwt-auth"]; ok { - jwt := input.Plugins["jwt-auth"].(map[string]interface{}) - jwt["exp"] = 86400 - - input.Plugins["jwt-auth"] = jwt - } - + ensurePluginsDefValue(input.Plugins) if err := h.consumerStore.Create(c.Context(), input); err != nil { return handler.SpecCodeResponse(err), err } @@ -161,42 +147,34 @@ type UpdateInput struct { func (h *Handler) Update(c droplet.Context) (interface{}, error) { input := c.Input().(*UpdateInput) - if input.ID != nil && utils.InterfaceToString(input.ID) != input.Username { - return &data.SpecCodeResponse{StatusCode: http.StatusBadRequest}, - fmt.Errorf("consumer's id and username must be a same value") - } if input.Username != "" { input.Consumer.Username = input.Username } input.Consumer.ID = input.Consumer.Username - - if _, ok := input.Consumer.Plugins["jwt-auth"]; ok { - jwt := input.Consumer.Plugins["jwt-auth"].(map[string]interface{}) - jwt["exp"] = 86400 - - input.Consumer.Plugins["jwt-auth"] = jwt - } + ensurePluginsDefValue(input.Plugins) if err := h.consumerStore.Update(c.Context(), &input.Consumer, true); err != nil { - //if not exists, create - if err.Error() == fmt.Sprintf("key: %s is not found", input.Username) { - if err := h.consumerStore.Create(c.Context(), &input.Consumer); err != nil { - return handler.SpecCodeResponse(err), err - } - } else { - return handler.SpecCodeResponse(err), err - } + return handler.SpecCodeResponse(err), err } return nil, nil } -type BatchDelete struct { +func ensurePluginsDefValue(plugins map[string]interface{}) { + if plugins["jwt-auth"] != nil { + jwtAuth, ok := plugins["jwt-auth"].(map[string]interface{}) + if ok && jwtAuth["exp"] == nil { + jwtAuth["exp"] = 86400 + } + } +} + +type BatchDeleteInput struct { UserNames string `auto_read:"usernames,path"` } func (h *Handler) BatchDelete(c droplet.Context) (interface{}, error) { - input := c.Input().(*BatchDelete) + input := c.Input().(*BatchDeleteInput) if err := h.consumerStore.BatchDelete(c.Context(), strings.Split(input.UserNames, ",")); err != nil { return handler.SpecCodeResponse(err), err diff --git a/api/internal/handler/consumer/consumer_test.go b/api/internal/handler/consumer/consumer_test.go index ba02f28f8f..6a8215453a 100644 --- a/api/internal/handler/consumer/consumer_test.go +++ b/api/internal/handler/consumer/consumer_test.go @@ -18,232 +18,391 @@ package consumer import ( - "encoding/json" - - "testing" - "time" - - "github.com/shiningrush/droplet" - "github.com/stretchr/testify/assert" - - "github.com/apisix/manager-api/conf" + "context" + "fmt" "github.com/apisix/manager-api/internal/core/entity" - "github.com/apisix/manager-api/internal/core/storage" "github.com/apisix/manager-api/internal/core/store" + "github.com/shiningrush/droplet" + "github.com/shiningrush/droplet/data" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "net/http" + "testing" ) -func TestConsumer(t *testing.T) { - // init - err := storage.InitETCDClient(conf.ETCDConfig) - assert.Nil(t, err) - err = store.InitStores() - assert.Nil(t, err) +func TestHandler_Get(t *testing.T) { + tests := []struct { + caseDesc string + giveInput *GetInput + giveRet interface{} + giveErr error + wantErr error + wantGetKey string + wantRet interface{} + }{ + { + caseDesc: "normal", + giveInput: &GetInput{Username: "test"}, + wantGetKey: "test", + giveRet: "hello", + wantRet: "hello", + }, + { + caseDesc: "store get failed", + giveInput: &GetInput{Username: "failed key"}, + wantGetKey: "failed key", + giveErr: fmt.Errorf("get failed"), + wantErr: fmt.Errorf("get failed"), + wantRet: &data.SpecCodeResponse{ + StatusCode: http.StatusInternalServerError, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.caseDesc, func(t *testing.T) { + getCalled := true + mStore := &store.MockInterface{} + mStore.On("Get", mock.Anything).Run(func(args mock.Arguments) { + getCalled = true + assert.Equal(t, tc.wantGetKey, args.Get(0)) + }).Return(tc.giveRet, tc.giveErr) + + h := Handler{consumerStore: mStore} + ctx := droplet.NewContext() + ctx.SetInput(tc.giveInput) + ret, err := h.Get(ctx) + assert.True(t, getCalled) + assert.Equal(t, tc.wantRet, ret) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestHandler_List(t *testing.T) { + tests := []struct { + caseDesc string + giveInput *ListInput + giveData []*entity.Consumer + giveErr error + wantErr error + wantInput store.ListInput + wantRet interface{} + }{ + { + caseDesc: "list all condition", + giveInput: &ListInput{ + Username: "testUser", + Pagination: store.Pagination{ + PageSize: 10, + PageNumber: 10, + }, + }, + wantInput: store.ListInput{ + PageSize: 10, + PageNumber: 10, + }, + giveData: []*entity.Consumer{ + {Username: "user1"}, + {Username: "testUser"}, + {Username: "iam-testUser"}, + {Username: "testUser-is-me"}, + }, + wantRet: &store.ListOutput{ + Rows: []interface{}{ + &entity.Consumer{Username: "testUser"}, + &entity.Consumer{Username: "iam-testUser"}, + &entity.Consumer{Username: "testUser-is-me"}, + }, + TotalSize: 3, + }, + }, + { + caseDesc: "store list failed", + giveInput: &ListInput{ + Username: "testUser", + Pagination: store.Pagination{ + PageSize: 10, + PageNumber: 10, + }, + }, + wantInput: store.ListInput{ + PageSize: 10, + PageNumber: 10, + }, + giveData: []*entity.Consumer{}, + giveErr: fmt.Errorf("list failed"), + wantErr: fmt.Errorf("list failed"), + }, + } + + for _, tc := range tests { + t.Run(tc.caseDesc, func(t *testing.T) { + getCalled := true + mStore := &store.MockInterface{} + mStore.On("List", mock.Anything).Run(func(args mock.Arguments) { + getCalled = true + input := args.Get(0).(store.ListInput) + assert.Equal(t, tc.wantInput.PageSize, input.PageSize) + assert.Equal(t, tc.wantInput.PageNumber, input.PageNumber) + }).Return(func(input store.ListInput) *store.ListOutput { + var returnData []interface{} + for _, c := range tc.giveData { + if input.Predicate(c) { + returnData = append(returnData, c) + } + } + return &store.ListOutput{ + Rows: returnData, + TotalSize: len(returnData), + } + }, tc.giveErr) - handler := &Handler{ - consumerStore: store.GetStore(store.HubKeyConsumer), + h := Handler{consumerStore: mStore} + ctx := droplet.NewContext() + ctx.SetInput(tc.giveInput) + ret, err := h.List(ctx) + assert.True(t, getCalled) + assert.Equal(t, tc.wantRet, ret) + assert.Equal(t, tc.wantErr, err) + }) } - assert.NotNil(t, handler) - - //create consumer - ctx := droplet.NewContext() - consumer := &entity.Consumer{} - reqBody := `{ - "username": "jack", - "plugins": { - "limit-count": { - "count": 2, - "time_window": 60, - "rejected_code": 503, - "key": "remote_addr" - } - }, - "desc": "test description" - }` - err = json.Unmarshal([]byte(reqBody), consumer) - assert.Nil(t, err) - ctx.SetInput(consumer) - _, err = handler.Create(ctx) - assert.Nil(t, err) - - //create consumer 2 - consumer2 := &entity.Consumer{} - reqBody = `{ - "username": "pony", - "plugins": { - "limit-count": { - "count": 2, - "time_window": 60, - "rejected_code": 503, - "key": "remote_addr" - } +} + +func TestHandler_Create(t *testing.T) { + tests := []struct { + caseDesc string + giveInput *entity.Consumer + giveCtx context.Context + giveErr error + wantErr error + wantInput *entity.Consumer + wantRet interface{} + wantCalled bool + }{ + { + caseDesc: "normal", + giveInput: &entity.Consumer{ + Username: "name", + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{}, + }, + }, + giveCtx: context.WithValue(context.Background(), "test", "value"), + wantInput: &entity.Consumer{ + BaseInfo: entity.BaseInfo{ + ID: "name", + }, + Username: "name", + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{ + "exp": 86400, + }, + }, + }, + wantRet: nil, + wantCalled: true, }, - "desc": "test description" - }` - err = json.Unmarshal([]byte(reqBody), consumer2) - assert.Nil(t, err) - ctx.SetInput(consumer2) - _, err = handler.Create(ctx) - assert.Nil(t, err) - - //sleep - time.Sleep(time.Duration(100) * time.Millisecond) - - //get consumer - input := &GetInput{} - reqBody = `{"username": "jack"}` - err = json.Unmarshal([]byte(reqBody), input) - assert.Nil(t, err) - ctx.SetInput(input) - ret, err := handler.Get(ctx) - stored := ret.(*entity.Consumer) - assert.Nil(t, err) - assert.Equal(t, stored.ID, consumer.ID) - assert.Equal(t, stored.Username, consumer.Username) - - //update consumer - consumer3 := &UpdateInput{} - consumer3.Username = "pony" - reqBody = `{ - "username": "pony", - "plugins": { - "limit-count": { - "count": 2, - "time_window": 60, - "rejected_code": 503, - "key": "remote_addr" - } + { + caseDesc: "store create failed", + giveInput: &entity.Consumer{ + Username: "name", + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{ + "exp": 5000, + }, + }, + }, + giveErr: fmt.Errorf("create failed"), + wantInput: &entity.Consumer{ + BaseInfo: entity.BaseInfo{ + ID: "name", + }, + Username: "name", + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{ + "exp": 5000, + }, + }, + }, + wantErr: fmt.Errorf("create failed"), + wantRet: &data.SpecCodeResponse{ + StatusCode: http.StatusInternalServerError, + }, + wantCalled: true, }, - "desc": "test description2" - }` - err = json.Unmarshal([]byte(reqBody), consumer3) - assert.Nil(t, err) - ctx.SetInput(consumer3) - _, err = handler.Update(ctx) - assert.Nil(t, err) - - //sleep - time.Sleep(time.Duration(100) * time.Millisecond) - - //check update - input3 := &GetInput{} - reqBody = `{"username": "pony"}` - err = json.Unmarshal([]byte(reqBody), input3) - assert.Nil(t, err) - ctx.SetInput(input3) - ret3, err := handler.Get(ctx) - stored3 := ret3.(*entity.Consumer) - assert.Nil(t, err) - assert.Equal(t, stored3.Desc, "test description2") //consumer3.Desc) - assert.Equal(t, stored3.Username, consumer3.Username) - - //list page 1 - listInput := &ListInput{} - reqBody = `{"page_size": 1, "page": 1}` - err = json.Unmarshal([]byte(reqBody), listInput) - assert.Nil(t, err) - ctx.SetInput(listInput) - retPage1, err := handler.List(ctx) - assert.Nil(t, err) - dataPage1 := retPage1.(*store.ListOutput) - assert.Equal(t, len(dataPage1.Rows), 1) - - //list page 2 - listInput2 := &ListInput{} - reqBody = `{"page_size": 1, "page": 2}` - err = json.Unmarshal([]byte(reqBody), listInput2) - assert.Nil(t, err) - ctx.SetInput(listInput2) - retPage2, err := handler.List(ctx) - assert.Nil(t, err) - dataPage2 := retPage2.(*store.ListOutput) - assert.Equal(t, len(dataPage2.Rows), 1) - - //list search match - listInput3 := &ListInput{} - reqBody = `{"page_size": 1, "page": 1, "username": "pony"}` - err = json.Unmarshal([]byte(reqBody), listInput3) - assert.Nil(t, err) - ctx.SetInput(listInput3) - retPage, err := handler.List(ctx) - assert.Nil(t, err) - dataPage := retPage.(*store.ListOutput) - assert.Equal(t, len(dataPage.Rows), 1) - - //list search not match - listInput4 := &ListInput{} - reqBody = `{"page_size": 1, "page": 1, "username": "not-exists"}` - err = json.Unmarshal([]byte(reqBody), listInput4) - assert.Nil(t, err) - ctx.SetInput(listInput4) - retPage, err = handler.List(ctx) - assert.Nil(t, err) - dataPage = retPage.(*store.ListOutput) - assert.Equal(t, len(dataPage.Rows), 0) - - //delete consumer - inputDel := &BatchDelete{} - reqBody = `{"usernames": "jack"}` - err = json.Unmarshal([]byte(reqBody), inputDel) - assert.Nil(t, err) - ctx.SetInput(inputDel) - _, err = handler.BatchDelete(ctx) - assert.Nil(t, err) - - reqBody = `{"usernames": "pony"}` - err = json.Unmarshal([]byte(reqBody), inputDel) - assert.Nil(t, err) - ctx.SetInput(inputDel) - _, err = handler.BatchDelete(ctx) - assert.Nil(t, err) - - //create consumer fail - consumer_fail := &entity.Consumer{} - reqBody = `{ - "plugins": { - "limit-count": { - "count": 2, - "time_window": 60, - "rejected_code": 503, - "key": "remote_addr" - } - }, - "desc": "test description" - }` - err = json.Unmarshal([]byte(reqBody), consumer_fail) - assert.Nil(t, err) - ctx.SetInput(consumer_fail) - _, err = handler.Create(ctx) - assert.NotNil(t, err) - - //create consumer using Update - consumer6 := &UpdateInput{} - reqBody = `{ - "username": "nnn", - "plugins": { - "limit-count": { - "count": 2, - "time_window": 60, - "rejected_code": 503, - "key": "remote_addr" - } - }, - "desc": "test description" - }` - err = json.Unmarshal([]byte(reqBody), consumer6) - assert.Nil(t, err) - ctx.SetInput(consumer6) - _, err = handler.Update(ctx) - assert.Nil(t, err) - - //sleep - time.Sleep(time.Duration(100) * time.Millisecond) - - //delete consumer - reqBody = `{"usernames": "nnn"}` - err = json.Unmarshal([]byte(reqBody), inputDel) - assert.Nil(t, err) - ctx.SetInput(inputDel) - _, err = handler.BatchDelete(ctx) - assert.Nil(t, err) + } + + for _, tc := range tests { + t.Run(tc.caseDesc, func(t *testing.T) { + methodCalled := true + mStore := &store.MockInterface{} + mStore.On("Create", mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + methodCalled = true + assert.Equal(t, tc.giveCtx, args.Get(0)) + assert.Equal(t, tc.wantInput, args.Get(1)) + }).Return(tc.giveErr) + + h := Handler{consumerStore: mStore} + ctx := droplet.NewContext() + ctx.SetInput(tc.giveInput) + ctx.SetContext(tc.giveCtx) + ret, err := h.Create(ctx) + assert.Equal(t, tc.wantCalled, methodCalled) + assert.Equal(t, tc.wantRet, ret) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestHandler_Update(t *testing.T) { + tests := []struct { + caseDesc string + giveInput *UpdateInput + giveCtx context.Context + giveErr error + wantErr error + wantInput *entity.Consumer + wantRet interface{} + wantCalled bool + }{ + { + caseDesc: "normal", + giveInput: &UpdateInput{ + Username: "name", + Consumer: entity.Consumer{ + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{ + "exp": 500, + }, + }, + }, + }, + giveCtx: context.WithValue(context.Background(), "test", "value"), + wantInput: &entity.Consumer{ + BaseInfo: entity.BaseInfo{ + ID: "name", + }, + Username: "name", + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{ + "exp": 500, + }, + }, + }, + wantRet: nil, + wantCalled: true, + }, + { + caseDesc: "store update failed", + giveInput: &UpdateInput{ + Username: "name", + Consumer: entity.Consumer{ + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{}, + }, + }, + }, + giveErr: fmt.Errorf("create failed"), + wantInput: &entity.Consumer{ + BaseInfo: entity.BaseInfo{ + ID: "name", + }, + Username: "name", + Plugins: map[string]interface{}{ + "jwt-auth": map[string]interface{}{ + "exp": 86400, + }, + }, + }, + wantErr: fmt.Errorf("create failed"), + wantRet: &data.SpecCodeResponse{ + StatusCode: http.StatusInternalServerError, + }, + wantCalled: true, + }, + } + for _, tc := range tests { + t.Run(tc.caseDesc, func(t *testing.T) { + methodCalled := true + mStore := &store.MockInterface{} + mStore.On("Update", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + methodCalled = true + assert.Equal(t, tc.giveCtx, args.Get(0)) + assert.Equal(t, tc.wantInput, args.Get(1)) + assert.True(t, args.Bool(2)) + }).Return(tc.giveErr) + + h := Handler{consumerStore: mStore} + ctx := droplet.NewContext() + ctx.SetInput(tc.giveInput) + ctx.SetContext(tc.giveCtx) + ret, err := h.Update(ctx) + assert.Equal(t, tc.wantCalled, methodCalled) + assert.Equal(t, tc.wantRet, ret) + assert.Equal(t, tc.wantErr, err) + }) + } +} + +func TestHandler_BatchDelete(t *testing.T) { + tests := []struct { + caseDesc string + giveInput *BatchDeleteInput + giveCtx context.Context + giveErr error + wantErr error + wantInput []string + wantRet interface{} + }{ + { + caseDesc: "normal", + giveInput: &BatchDeleteInput{ + UserNames: "user1,user2", + }, + giveCtx: context.WithValue(context.Background(), "test", "value"), + wantInput: []string{ + "user1", + "user2", + }, + }, + { + caseDesc: "store delete failed", + giveInput: &BatchDeleteInput{ + UserNames: "user1,user2", + }, + giveCtx: context.WithValue(context.Background(), "test", "value"), + giveErr: fmt.Errorf("delete failed"), + wantInput: []string{ + "user1", + "user2", + }, + wantErr: fmt.Errorf("delete failed"), + wantRet: &data.SpecCodeResponse{ + StatusCode: http.StatusInternalServerError, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.caseDesc, func(t *testing.T) { + methodCalled := true + mStore := &store.MockInterface{} + mStore.On("BatchDelete", mock.Anything, mock.Anything, mock.Anything).Run(func(args mock.Arguments) { + methodCalled = true + assert.Equal(t, tc.giveCtx, args.Get(0)) + assert.Equal(t, tc.wantInput, args.Get(1)) + }).Return(tc.giveErr) + + h := Handler{consumerStore: mStore} + ctx := droplet.NewContext() + ctx.SetInput(tc.giveInput) + ctx.SetContext(tc.giveCtx) + ret, err := h.BatchDelete(ctx) + assert.True(t, methodCalled) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantRet, ret) + }) + } } diff --git a/api/log/zap.go b/api/log/zap.go index 66379d84f6..cd36b5323e 100644 --- a/api/log/zap.go +++ b/api/log/zap.go @@ -27,10 +27,13 @@ import ( var logger *zap.SugaredLogger +// TODO: it is just for integration tests, we should call "InitLog" explicitly when remove all handler's integration tests func init() { + InitLogger() +} +func InitLogger() { logger = GetLogger(ErrorLog) } - func GetLogger(logType Type) *zap.SugaredLogger { writeSyncer := fileWriter(logType) encoder := getEncoder(logType)