From ac3450cc7f1d45e976e149007177d1b2f470e7f4 Mon Sep 17 00:00:00 2001 From: Ryan Leung Date: Wed, 15 Apr 2020 17:11:53 +0800 Subject: [PATCH] address comments Signed-off-by: Ryan Leung --- pkg/component/manager.go | 55 ++++++++++++++++++++++++++++++----- pkg/component/manager_test.go | 15 ++++++++-- server/api/component.go | 24 +++++++++++++-- server/api/component_test.go | 38 ++++++++++++++++++++---- server/api/router.go | 3 +- 5 files changed, 117 insertions(+), 18 deletions(-) diff --git a/pkg/component/manager.go b/pkg/component/manager.go index 168107dcb7a2..faa6e30eaefb 100644 --- a/pkg/component/manager.go +++ b/pkg/component/manager.go @@ -41,7 +41,7 @@ func NewManager() *Manager { func (c *Manager) GetComponentAddrs(component string) []string { c.RLock() defer c.RUnlock() - var addresses []string + addresses := []string{} if ca, ok := c.Addresses[component]; ok { addresses = append(addresses, ca...) } @@ -52,7 +52,13 @@ func (c *Manager) GetComponentAddrs(component string) []string { func (c *Manager) GetAllComponentAddrs() map[string][]string { c.RLock() defer c.RUnlock() - return c.Addresses + n := make(map[string][]string) + for k, v := range c.Addresses { + b := make([]string, len(v)) + copy(b, v) + n[k] = b + } + return n } // GetComponent returns the component from a given component ID. @@ -60,7 +66,7 @@ func (c *Manager) GetComponent(addr string) string { c.RLock() defer c.RUnlock() for component, ca := range c.Addresses { - if contains(ca, addr) { + if exist, _ := contains(ca, addr); exist { return component } } @@ -81,7 +87,7 @@ func (c *Manager) Register(component, addr string) error { } ca, ok := c.Addresses[component] - if ok && contains(ca, addr) { + if exist, _ := contains(ca, addr); ok && exist { log.Info("address has already been registered", zap.String("component", component), zap.String("address", addr)) return fmt.Errorf("component %s address %s has already been registered", component, addr) } @@ -92,12 +98,45 @@ func (c *Manager) Register(component, addr string) error { return nil } -func contains(slice []string, item string) bool { - for _, s := range slice { +// UnRegister is used for unregistering a component with an address from PD. +func (c *Manager) UnRegister(component, addr string) error { + c.Lock() + defer c.Unlock() + + str := strings.Split(addr, ":") + if len(str) != 0 { + ip := net.ParseIP(str[0]) + if ip == nil { + return fmt.Errorf("failed to parse address %s of component %s", addr, component) + } + } + + ca, ok := c.Addresses[component] + if !ok { + return fmt.Errorf("component %s not found", component) + } + + if exist, idx := contains(ca, addr); exist { + ca = append(ca[:idx], ca[idx+1:]...) + log.Info("address has successfully been unregistered", zap.String("component", component), zap.String("address", addr)) + if len(ca) == 0 { + delete(c.Addresses, component) + return nil + } + + c.Addresses[component] = ca + return nil + } + + return fmt.Errorf("address %s not found", addr) +} + +func contains(slice []string, item string) (bool, int) { + for i, s := range slice { if s == item { - return true + return true, i } } - return false + return false, 0 } diff --git a/pkg/component/manager_test.go b/pkg/component/manager_test.go index e638aa7ca3db..840a85f1e921 100644 --- a/pkg/component/manager_test.go +++ b/pkg/component/manager_test.go @@ -36,17 +36,28 @@ func (s *testManagerSuite) TestManager(c *C) { // register repeatedly c.Assert(strings.Contains(m.Register("c1", "127.0.0.1:2").Error(), "already"), IsTrue) c.Assert(m.Register("c2", "127.0.0.1:3"), IsNil) - // register illegal address - c.Assert(m.Register("c2", "abcde"), NotNil) + // get all addresses all := map[string][]string{ "c1": {"127.0.0.1:1", "127.0.0.1:2"}, "c2": {"127.0.0.1:3"}, } c.Assert(m.GetAllComponentAddrs(), DeepEquals, all) + + // get the specific component addresses c.Assert(m.GetComponentAddrs("c1"), DeepEquals, all["c1"]) c.Assert(m.GetComponentAddrs("c2"), DeepEquals, all["c2"]) + + // get the component from the address c.Assert(m.GetComponent("127.0.0.1:1"), Equals, "c1") c.Assert(m.GetComponent("127.0.0.1:2"), Equals, "c1") c.Assert(m.GetComponent("127.0.0.1:3"), Equals, "c2") + + // unregister address + c.Assert(m.UnRegister("c1", "127.0.0.1:1"), IsNil) + c.Assert(m.GetComponentAddrs("c1"), DeepEquals, []string{"127.0.0.1:2"}) + c.Assert(m.UnRegister("c1", "127.0.0.1:2"), IsNil) + c.Assert(m.GetComponentAddrs("c1"), DeepEquals, []string{}) + all = map[string][]string{"c2": {"127.0.0.1:3"}} + c.Assert(m.GetAllComponentAddrs(), DeepEquals, all) } diff --git a/server/api/component.go b/server/api/component.go index c331bbcb09be..c599ace64e2f 100644 --- a/server/api/component.go +++ b/server/api/component.go @@ -43,8 +43,9 @@ func newComponentHandler(svr *server.Server, rd *render.Render) *componentHandle // @Summary Register component address. // @Produce json // @Success 200 {string} string -// @Failure 400 {string} string "PD server failed to proceed the request." -// @Router /component/register [post] +// @Failure 400 {string} string "The input is invalid." +// @Failure 500 {string} string "PD server failed to proceed the request." +// @Router /component [post] func (h *componentHandler) Register(w http.ResponseWriter, r *http.Request) { input := make(map[string]string) if err := apiutil.ReadJSONRespondError(h.rd, w, r.Body, &input); err != nil { @@ -69,6 +70,25 @@ func (h *componentHandler) Register(w http.ResponseWriter, r *http.Request) { h.rd.JSON(w, http.StatusOK, nil) } +// @Tags component +// @Summary Unregister component address. +// @Produce json +// @Success 200 {string} string +// @Failure 400 {string} string "The input is invalid." +// @Router /component [delete] +func (h *componentHandler) UnRegister(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + component := vars["component"] + addr := vars["addr"] + m := h.svr.GetComponentManager() + err := m.UnRegister(component, addr) + if err != nil { + h.rd.JSON(w, http.StatusBadRequest, err.Error()) + return + } + h.rd.JSON(w, http.StatusOK, nil) +} + // @Tags component // @Summary List all component addresses // @Produce json diff --git a/server/api/component_test.go b/server/api/component_test.go index 9d0819c2306b..3fe6de3d8ac6 100644 --- a/server/api/component_test.go +++ b/server/api/component_test.go @@ -42,7 +42,8 @@ func (s *testComponentSuite) TearDownSuite(c *C) { s.cleanup() } -func (s *testComponentSuite) TestRegister(c *C) { +func (s *testComponentSuite) TestComponent(c *C) { + // register not happen addr := fmt.Sprintf("%s/component", s.urlPrefix) output := make(map[string][]string) err := readJSON(addr, &output) @@ -55,7 +56,7 @@ func (s *testComponentSuite) TestRegister(c *C) { c.Assert(strings.Contains(err.Error(), "404"), IsTrue) c.Assert(len(output1), Equals, 0) - addr2 := fmt.Sprintf("%s/component/register", s.urlPrefix) + // register 2 c1 and 1 c2 reqs := []map[string]string{ {"component": "c1", "addr": "127.0.0.1:1"}, {"component": "c1", "addr": "127.0.0.1:2"}, @@ -64,10 +65,11 @@ func (s *testComponentSuite) TestRegister(c *C) { for _, req := range reqs { postData, err := json.Marshal(req) c.Assert(err, IsNil) - err = postJSON(addr2, postData) + err = postJSON(addr, postData) c.Assert(err, IsNil) } + // get all addresses expected := map[string][]string{ "c1": {"127.0.0.1:1", "127.0.0.1:2"}, "c2": {"127.0.0.1:3"}, @@ -78,16 +80,42 @@ func (s *testComponentSuite) TestRegister(c *C) { c.Assert(err, IsNil) c.Assert(output, DeepEquals, expected) + // get the specific component addresses expected1 := []string{"127.0.0.1:1", "127.0.0.1:2"} var output2 []string err = readJSON(addr1, &output2) c.Assert(err, IsNil) c.Assert(output2, DeepEquals, expected1) - addr3 := fmt.Sprintf("%s/component/c2", s.urlPrefix) + addr2 := fmt.Sprintf("%s/component/c2", s.urlPrefix) expected2 := []string{"127.0.0.1:3"} var output3 []string - err = readJSON(addr3, &output3) + err = readJSON(addr2, &output3) c.Assert(err, IsNil) c.Assert(output3, DeepEquals, expected2) + + // unregister address + addr3 := fmt.Sprintf("%s/component/c1/127.0.0.1:1", s.urlPrefix) + res, err := doDelete(addr3) + c.Assert(err, IsNil) + c.Assert(res.StatusCode, Equals, 200) + + expected3 := map[string][]string{ + "c1": {"127.0.0.1:2"}, + "c2": {"127.0.0.1:3"}, + } + output = make(map[string][]string) + err = readJSON(addr, &output) + c.Assert(err, IsNil) + c.Assert(output, DeepEquals, expected3) + + addr4 := fmt.Sprintf("%s/component/c1/127.0.0.1:2", s.urlPrefix) + res, err = doDelete(addr4) + c.Assert(err, IsNil) + c.Assert(res.StatusCode, Equals, 200) + expected4 := map[string][]string{"c2": {"127.0.0.1:3"}} + output = make(map[string][]string) + err = readJSON(addr, &output) + c.Assert(err, IsNil) + c.Assert(output, DeepEquals, expected4) } diff --git a/server/api/router.go b/server/api/router.go index 53f417e91a5e..9c08fe4f2629 100644 --- a/server/api/router.go +++ b/server/api/router.go @@ -178,7 +178,8 @@ func createRouter(ctx context.Context, prefix string, svr *server.Server) *mux.R clusterRouter.HandleFunc("/replication_mode/status", replicationModeHandler.GetStatus) componentHandler := newComponentHandler(svr, rd) - apiRouter.HandleFunc("/component/register", componentHandler.Register).Methods("POST") + apiRouter.HandleFunc("/component", componentHandler.Register).Methods("POST") + apiRouter.HandleFunc("/component/{component}/{addr}", componentHandler.UnRegister).Methods("DELETE") apiRouter.HandleFunc("/component", componentHandler.GetAllAddress).Methods("GET") apiRouter.HandleFunc("/component/{type}", componentHandler.GetAddress).Methods("GET")