diff --git a/conn.go b/conn.go index 9afd2d27..931416aa 100644 --- a/conn.go +++ b/conn.go @@ -890,6 +890,29 @@ func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { return ch } +func (c *Conn) removeWatcher(path string, chnl <-chan Event, watchType watchType) bool { + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + wpt := watchPathType{path, watchType} + watchers := c.watchers[wpt] + for ind, ch := range watchers { + if ch == chnl { + close(ch) + + // Remove the entry at index ind, by swapping it with the last entry + // and creating a slice without the last entry. + watchers[ind], watchers[len(watchers)-1] = + watchers[len(watchers)-1], watchers[ind] + + c.watchers[wpt] = watchers[:len(watchers)-1] + return true + } + } + + return false +} + func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { rq := &request{ xid: c.nextXid(), @@ -1003,6 +1026,13 @@ func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { return res.Children, &res.Stat, ech, err } +// RemoveChildW removes the channel mentioned from the list +// of watches. Returns true, if the channel was found and +// removed. +func (c *Conn) RemoveChildW(path string, ch <-chan Event) bool { + return c.removeWatcher(path, ch, watchTypeChild) +} + // Get gets the contents of a znode. func (c *Conn) Get(path string) ([]byte, *Stat, error) { if err := validatePath(path, false); err != nil { @@ -1036,6 +1066,13 @@ func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { return res.Data, &res.Stat, ech, err } +// RemoveGetW removes the channel mentioned from the list +// of watches. Returns true, if the channel was found and +// removed. +func (c *Conn) RemoveGetW(path string, ch <-chan Event) bool { + return c.removeWatcher(path, ch, watchTypeData) +} + // Set updates the contents of a znode. func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { if err := validatePath(path, false); err != nil { @@ -1199,6 +1236,20 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { return exists, &res.Stat, ech, err } +// RemoveExistsW removes the channel mentioned from the list +// of watches. Returns true, if the channel was found and +// removed. +func (c *Conn) RemoveExistsW(path string, ch <-chan Event) bool { + removedExistsWatch := c.removeWatcher(path, ch, watchTypeExist) + + // A data watch would have been created for this, + // had the zk node existed when the watch was initiated. + if !removedExistsWatch { + return c.removeWatcher(path, ch, watchTypeData) + } + return removedExistsWatch +} + // GetACL gets the ACLs of a znode. func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { if err := validatePath(path, false); err != nil { diff --git a/conn_test.go b/conn_test.go index 96299280..ca7202fe 100644 --- a/conn_test.go +++ b/conn_test.go @@ -196,3 +196,39 @@ func TestNotifyWatches(t *testing.T) { }) } } + +func TestTemoveGetWatches(t *testing.T) { + ch := make(chan Event, 1) + zkPath := "/a" + conn := &Conn{watchers: make(map[watchPathType][]chan Event)} + watcherInfo := watchPathType{zkPath, watchTypeData} + conn.watchers[watcherInfo] = append(conn.watchers[watcherInfo], ch) + + // Assert that the map has the required number of watchers + if len(conn.watchers[watcherInfo]) != 1 { + t.Fatalf("Failed to add a data watcher for path %s", zkPath) + } + conn.RemoveGetW(zkPath, ch) + + // Assert that the channel is closed and removed from the map + var closed bool + select { + case _, ok := <-ch: + closed = !ok + default: + closed = false + } + + if !closed { + t.Fatalf("Channel used for notifying data watch changes was not closed on removal") + } + + if len(conn.watchers[watcherInfo]) != 0 { + t.Fatalf("Failed to remove channel used to notify data watch changes") + } + + // Try to remove the same channel and expect failure. + if removed := conn.RemoveGetW(zkPath, ch); removed { + t.Fatalf("Removed the same channel twice") + } +}