diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml index 9b0baa67..6bfe523d 100644 --- a/.github/workflows/integration.yaml +++ b/.github/workflows/integration.yaml @@ -6,14 +6,14 @@ jobs: name: integration_test strategy: matrix: - zk-version: [3.5.8, 3.6.1] + zk-version: [3.5.8, 3.6.3] go-version: ['oldstable', 'stable'] runs-on: ubuntu-latest steps: - - name: Go ${{ matrix.go }} setup + - name: Go ${{ matrix.go-version }} setup uses: actions/setup-go@v4 with: - go-version: ${{ matrix.go-version }} + go-version: ${{ matrix.go-version }} - name: Setup Java 14 uses: actions/setup-java@v3 @@ -22,7 +22,7 @@ jobs: java-version: 14 - name: Checkout code - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Test code run: make test ZK_VERSION=${{ matrix.zk-version }} diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index e9a9f69d..a347d9fa 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -1,14 +1,21 @@ - name: lint -on: [pull_request] +on: [push, pull_request] + jobs: lint: name: lint + strategy: + matrix: + go-version: ["1.21"] runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v1 - - name: Lint code - uses: reviewdog/action-golangci-lint@v1 + uses: actions/checkout@v4 + + - name: Install go + uses: actions/setup-go@v4 with: - github_token: ${{ secrets.github_token }} + go-version: ${{ matrix.go-version }} + + - name: Lint code + uses: golangci/golangci-lint-action@v3 diff --git a/.github/workflows/unittest.yaml b/.github/workflows/unittest.yaml index 0d5d7f9f..42d009cd 100644 --- a/.github/workflows/unittest.yaml +++ b/.github/workflows/unittest.yaml @@ -9,13 +9,13 @@ jobs: go-version: ['oldstable', 'stable'] runs-on: ubuntu-latest steps: - - name: Go ${{ matrix.go }} setup + - name: Go ${{ matrix.go-version }} setup uses: actions/setup-go@v4 with: - go-version: ${{ matrix.go-version }} + go-version: ${{ matrix.go-version }} - name: Checkout code uses: actions/checkout@v3 - name: Run unittest ${{ matrix.go }} - run: make unittest + run: make unittest diff --git a/Makefile b/Makefile index 5492ffa0..de1eaf77 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ # make file to hold the logic of build and test setup -ZK_VERSION ?= 3.5.6 +export ZK_VERSION ?= 3.6.3 # Apache changed the name of the archive in version 3.5.x and seperated out # src and binary packages @@ -20,10 +20,12 @@ $(ZK): tar -zxf $(ZK).tar.gz rm $(ZK).tar.gz +.PHONY: zookeeper zookeeper: $(ZK) # we link to a standard directory path so then the tests dont need to find based on version # in the test code. this allows backward compatable testing. - ln -s $(ZK) zookeeper + rm -f $@ + ln -s $(ZK) $@ .PHONY: setup setup: zookeeper diff --git a/cluster_test.go b/cluster_test.go index c48b1685..ca83d3e6 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1,6 +1,7 @@ package zk import ( + "log/slog" "sync" "testing" "time" @@ -112,10 +113,10 @@ func TestIntegration_NoQuorum(t *testing.T) { t.Fatalf("Failed to connect and get session") } initialSessionID := zk.sessionID - DefaultLogger.Printf(" Session established: id=%d, timeout=%d", zk.sessionID, zk.sessionTimeoutMs) + t.Logf("Session established: id=%d, timeout=%d", zk.sessionID, zk.sessionTimeout) // Kill the ZooKeeper leader and wait for the session to reconnect. - DefaultLogger.Printf(" Kill the leader") + t.Logf("Kill the leader") disconnectWatcher1 := sl.NewWatcher(sessionStateMatcher(StateDisconnected)) hasSessionWatcher2 := sl.NewWatcher(sessionStateMatcher(StateHasSession)) tc.StopServer(hasSessionEvent1.Server) @@ -135,7 +136,7 @@ func TestIntegration_NoQuorum(t *testing.T) { } // Kill the ZooKeeper leader leaving the cluster without quorum. - DefaultLogger.Printf(" Kill the leader") + t.Logf("Kill the leader") disconnectWatcher2 := sl.NewWatcher(sessionStateMatcher(StateDisconnected)) tc.StopServer(hasSessionEvent2.Server) @@ -151,7 +152,7 @@ func TestIntegration_NoQuorum(t *testing.T) { // Make sure that we keep retrying connecting to the only remaining // ZooKeeper server, but the attempts are being dropped because there is // no quorum. - DefaultLogger.Printf(" Retrying no luck...") + t.Logf("Retrying no luck...") var firstDisconnect *Event begin := time.Now() for time.Now().Sub(begin) < 6*time.Second { @@ -269,7 +270,7 @@ func NewStateLogger(eventCh <-chan Event) *EventLogger { sw.matchCh <- event } } - DefaultLogger.Printf(" event received: %v\n", event) + slog.Info("event received", "event", event) el.events = append(el.events, event) el.lock.Unlock() } diff --git a/conn.go b/conn.go index 9afd2d27..62ff4325 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "strings" "sync" @@ -31,9 +32,6 @@ var ErrNoServer = errors.New("zk: could not connect to a server") // an invalid path. (e.g. empty path). var ErrInvalidPath = errors.New("zk: invalid path") -// DefaultLogger uses the stdlib log package for logging. -var DefaultLogger Logger = defaultLogger{} - const ( bufferSize = 1536 * 1024 eventChanSize = 6 @@ -47,21 +45,19 @@ const ( watchTypeData watchType = iota watchTypeExist watchTypeChild + watchTypePersistent + watchTypePersistentRecursive ) +func (w watchType) isPersistent() bool { + return w == watchTypePersistent || w == watchTypePersistentRecursive +} + type watchPathType struct { path string wType watchType } -// Dialer is a function to be used to establish a connection to a single host. -type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) - -// Logger is an interface that can be implemented to provide custom log output. -type Logger interface { - Printf(string, ...interface{}) -} - type authCreds struct { scheme string auth []byte @@ -69,12 +65,12 @@ type authCreds struct { // Conn is the client connection and tracks all details for communication with the server. type Conn struct { - lastZxid int64 - sessionID int64 - state State // must be 32-bit aligned - xid uint32 - sessionTimeoutMs int32 // session timeout in milliseconds - passwd []byte + lastZxid int64 + sessionID int64 + state State // must be 32-bit aligned + xid uint32 + sessionTimeout time.Duration + passwd []byte dialer Dialer hostProvider HostProvider @@ -89,6 +85,7 @@ type Conn struct { recvTimeout time.Duration connectTimeout time.Duration maxBufferSize int + metricReceiver MetricReceiver creds []authCreds credsMu sync.Mutex // protects server @@ -96,7 +93,7 @@ type Conn struct { sendChan chan *request requests map[int32]*request // Xid -> pending request requestsLock sync.Mutex - watchers map[watchPathType][]chan Event + watchers map[watchPathType][]EventQueue watchersLock sync.Mutex closeChan chan struct{} // channel to tell send loop stop @@ -109,14 +106,19 @@ type Conn struct { debugCloseRecvLoop bool resendZkAuthFn func(context.Context, *Conn) error - logger Logger - logInfo bool // true if information messages are logged; false if only errors are logged - buf []byte } -// connOption represents a connection option. -type connOption func(c *Conn) +// ConnOption represents a connection option. +type ConnOption interface { + apply(*Conn) +} + +type connOption func(*Conn) + +func (c connOption) apply(conn *Conn) { + c(conn) +} type request struct { xid int32 @@ -148,6 +150,16 @@ type Event struct { Path string // For non-session events, the path of the watched node. Err error Server string // For connection events + // For watch events, the zxid that caused the change (starting with ZK 3.9.0). For ping events, the zxid that the + // server last processed. Note that the last processed zxid is only updated once the watch events have been + // triggered. Since ZK operates over one connection, the watch events are therefore queued up before the ping. This + // means watch events should always be received before pings, and receiving a ping with a given zxid means any watch + // event for a lower zxid have already been received (if any). + Zxid int64 + // This is the time at which the event was received by the client. Useful to understand lag across the entire + // system. Note that this is NOT the time at which the event fired in the quorum. Only set for watch events and + // pings. + Timestamp time.Time } // HostProvider is used to represent a set of hosts a ZooKeeper client should connect to. @@ -156,29 +168,21 @@ type Event struct { type HostProvider interface { // Init is called first, with the servers specified in the connection string. Init(servers []string) error - // Len returns the number of servers. - Len() int - // Next returns the next server to connect to. retryStart will be true if we've looped through - // all known servers without Connected() being called. + // Next returns the next server to connect to. retryStart should be true if this call to Next + // exhausted the list of known servers without Connected being called. If connecting to this final + // host fails, the connect loop will back off before invoking Next again for a fresh server. Next() (server string, retryStart bool) - // Notify the HostProvider of a successful connection. + // Connected notifies the HostProvider of a successful connection. Connected() } -// ConnectWithDialer establishes a new connection to a pool of zookeeper servers -// using a custom Dialer. See Connect for further information about session timeout. -// This method is deprecated and provided for compatibility: use the WithDialer option instead. -func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { - return Connect(servers, sessionTimeout, WithDialer(dialer)) -} - // Connect establishes a new connection to a pool of zookeeper // servers. The provided session timeout sets the amount of time for which // a session is considered valid after losing connection to a server. Within // the session timeout it's possible to reestablish a connection to a different // server and keep the same session. This is means any ephemeral nodes and // watches are maintained. -func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { +func Connect(servers []string, sessionTimeout time.Duration, options ...ConnOption) (*Conn, <-chan Event, error) { if len(servers) == 0 { return nil, nil, errors.New("zk: server list must not be empty") } @@ -186,12 +190,12 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti srvs := FormatServers(servers) // Randomize the order of the servers to avoid creating hotspots - stringShuffle(srvs) + shuffleSlice(srvs) ec := make(chan Event, eventChanSize) conn := &Conn{ - dialer: net.DialTimeout, - hostProvider: &DNSHostProvider{}, + dialer: new(net.Dialer), + hostProvider: new(StaticHostProvider), conn: nil, state: StateDisconnected, eventChan: ec, @@ -199,24 +203,23 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), requests: make(map[int32]*request), - watchers: make(map[watchPathType][]chan Event), + watchers: make(map[watchPathType][]EventQueue), passwd: emptyPassword, - logger: DefaultLogger, - logInfo: true, // default is true for backwards compatability buf: make([]byte, bufferSize), resendZkAuthFn: resendZkAuth, + metricReceiver: UnimplementedMetricReceiver{}, } // Set provided options. for _, option := range options { - option(conn) + option.apply(conn) } if err := conn.hostProvider.Init(srvs); err != nil { return nil, nil, err } - conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) + conn.setTimeouts(sessionTimeout) // TODO: This context should be passed in by the caller to be the connection lifecycle context. ctx := context.Background() @@ -229,33 +232,25 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti return conn, ec, nil } -// WithDialer returns a connection option specifying a non-default Dialer. -func WithDialer(dialer Dialer) connOption { - return func(c *Conn) { +// Dialer is an interface implemented by the standard [net.Dialer] but also by [crypto/tls.Dialer]. +type Dialer interface { + // DialContext will be invoked when connecting to a ZK host. + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// WithDialer returns a connection option specifying a non-default Dialer. This can be used, for +// example, to enable TLS when connecting to the ZK server by passing in a [crypto/tls.Dialer]. +func WithDialer(dialer Dialer) ConnOption { + return connOption(func(c *Conn) { c.dialer = dialer - } + }) } // WithHostProvider returns a connection option specifying a non-default HostProvider. -func WithHostProvider(hostProvider HostProvider) connOption { - return func(c *Conn) { +func WithHostProvider(hostProvider HostProvider) ConnOption { + return connOption(func(c *Conn) { c.hostProvider = hostProvider - } -} - -// WithLogger returns a connection option specifying a non-default Logger. -func WithLogger(logger Logger) connOption { - return func(c *Conn) { - c.logger = logger - } -} - -// WithLogInfo returns a connection option specifying whether or not information messages -// should be logged. -func WithLogInfo(logInfo bool) connOption { - return func(c *Conn) { - c.logInfo = logInfo - } + }) } // EventCallback is a function that is called when an Event occurs. @@ -264,10 +259,10 @@ type EventCallback func(Event) // WithEventCallback returns a connection option that specifies an event // callback. // The callback must not block - doing so would delay the ZK go routines. -func WithEventCallback(cb EventCallback) connOption { - return func(c *Conn) { +func WithEventCallback(cb EventCallback) ConnOption { + return connOption(func(c *Conn) { c.eventCallback = cb - } + }) } // WithMaxBufferSize sets the maximum buffer size used to read and decode @@ -294,20 +289,26 @@ func WithEventCallback(cb EventCallback) connOption { // the child names without an increased buffer size in the client, but they work // by inspecting the servers' transaction logs to enumerate children instead of // sending an online request to a server. -func WithMaxBufferSize(maxBufferSize int) connOption { - return func(c *Conn) { +func WithMaxBufferSize(maxBufferSize int) ConnOption { + return connOption(func(c *Conn) { c.maxBufferSize = maxBufferSize - } + }) } // WithMaxConnBufferSize sets maximum buffer size used to send and encode // packets to Zookeeper server. The standard Zookeeper client for java defaults // to a limit of 1mb. This option should be used for non-standard server setup // where znode is bigger than default 1mb. -func WithMaxConnBufferSize(maxBufferSize int) connOption { - return func(c *Conn) { +func WithMaxConnBufferSize(maxBufferSize int) ConnOption { + return connOption(func(c *Conn) { c.buf = make([]byte, maxBufferSize) - } + }) +} + +func WithMetricReceiver(mr MetricReceiver) ConnOption { + return connOption(func(c *Conn) { + c.metricReceiver = mr + }) } // Close will submit a close request with ZK and signal the connection to stop @@ -333,15 +334,8 @@ func (c *Conn) SessionID() int64 { return atomic.LoadInt64(&c.sessionID) } -// SetLogger sets the logger to be used for printing errors. -// Logger is an interface provided by this package. -func (c *Conn) SetLogger(l Logger) { - c.logger = l -} - -func (c *Conn) setTimeouts(sessionTimeoutMs int32) { - c.sessionTimeoutMs = sessionTimeoutMs - sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond +func (c *Conn) setTimeouts(sessionTimeout time.Duration) { + c.sessionTimeout = sessionTimeout c.recvTimeout = sessionTimeout * 2 / 3 c.pingInterval = c.recvTimeout / 2 } @@ -372,6 +366,23 @@ func (c *Conn) connect() error { c.setState(StateConnecting) + dial := func() (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), c.connectTimeout) + defer cancel() + return c.dialer.DialContext(ctx, "tcp", c.Server()) + } + + slog.Info("Dialing ZK server", "server", c.Server()) + zkConn, err := dial() + if err == nil { + c.conn = zkConn + c.setState(StateConnected) + slog.Info("Connection established", "server", c.Server(), "addr", zkConn.RemoteAddr()) + return nil + } + + slog.Warn("Failed to connect to ZK server", "server", c.Server(), "err", err) + if retryStart { c.flushUnsentRequests(ErrNoServer) select { @@ -383,18 +394,6 @@ func (c *Conn) connect() error { return ErrClosing } } - - zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) - if err == nil { - c.conn = zkConn - c.setState(StateConnected) - if c.logInfo { - c.logger.Printf("connected to %s", c.Server()) - } - return nil - } - - c.logger.Printf("failed to connect to %s: %v", c.Server(), err) } } @@ -430,18 +429,18 @@ func (c *Conn) loop(ctx context.Context) { return } - err := c.authenticate() + var sendLoopErr, recvLoopErr error + + authErr := c.authenticate() switch { - case err == ErrSessionExpired: - c.logger.Printf("authentication failed: %s", err) - c.invalidateWatches(err) - case err != nil && c.conn != nil: - c.logger.Printf("authentication failed: %s", err) + case errors.Is(authErr, ErrSessionExpired): + slog.Warn("authentication failed", "err", authErr) + c.invalidateWatches(authErr) + case authErr != nil && c.conn != nil: + slog.Warn("authentication failed", "err", authErr) c.conn.Close() - case err == nil: - if c.logInfo { - c.logger.Printf("authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) - } + case authErr == nil: + slog.Info("authenticated", "sessionId", c.SessionID(), "timeout", c.sessionTimeout) c.hostProvider.Connected() // mark success c.closeChan = make(chan struct{}) // channel to tell send loop stop @@ -452,13 +451,15 @@ func (c *Conn) loop(ctx context.Context) { defer c.conn.Close() // causes recv loop to EOF/exit defer wg.Done() - if err := c.resendZkAuthFn(ctx, c); err != nil { - c.logger.Printf("error in resending auth creds: %v", err) + if sendLoopErr = c.resendZkAuthFn(ctx, c); sendLoopErr != nil { + slog.Warn("error in resending auth creds", "err", sendLoopErr) return } - if err := c.sendLoop(); err != nil || c.logInfo { - c.logger.Printf("send loop terminated: %v", err) + if sendLoopErr = c.sendLoop(); sendLoopErr != nil { + slog.Warn("Send loop terminated with error", "err", sendLoopErr) + } else { + slog.Info("Send loop terminated") } }() @@ -467,16 +468,22 @@ func (c *Conn) loop(ctx context.Context) { defer close(c.closeChan) // tell send loop to exit defer wg.Done() - var err error if c.debugCloseRecvLoop { - err = errors.New("DEBUG: close recv loop") + recvLoopErr = errors.New("DEBUG: close recv loop") } else { - err = c.recvLoop(c.conn) + recvLoopErr = c.recvLoop(c.conn) } - if err != io.EOF || c.logInfo { - c.logger.Printf("recv loop terminated: %v", err) + + switch { + case errors.Is(recvLoopErr, io.EOF): + slog.Info("recv loop terminated") + + } + if recvLoopErr != io.EOF { + slog.Warn("recv loop terminated with error", "err", recvLoopErr) + } else { } - if err == nil { + if recvLoopErr == nil { panic("zk: recvLoop should never return nil error") } }() @@ -494,8 +501,14 @@ func (c *Conn) loop(ctx context.Context) { default: } - if err != ErrSessionExpired { - err = ErrConnectionClosed + // Surface an error that contains all errors that could have caused the connection to get closed + err := errors.Join(authErr, sendLoopErr, recvLoopErr) + if !errors.Is(err, ErrSessionExpired) || errors.Is(err, ErrConnectionClosed) { + // Always default to ErrConnectionClosed for any error that doesn't already have it or + // ErrSessionExpired as a cause, makes error handling more straightforward. Note that by definition, + // reaching this point in the code means the connection was closed, hence using it as the default + // value. + err = errors.Join(ErrConnectionClosed, err) } c.flushRequests(err) @@ -530,29 +543,55 @@ func (c *Conn) flushRequests(err error) { c.requestsLock.Unlock() } +var eventWatchTypes = map[EventType][]watchType{ + EventNodeCreated: {watchTypeExist, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeDataChanged: {watchTypeExist, watchTypeData, watchTypePersistent, watchTypePersistentRecursive}, + EventNodeChildrenChanged: {watchTypeChild, watchTypePersistent}, + EventNodeDeleted: {watchTypeExist, watchTypeData, watchTypeChild, watchTypePersistent, watchTypePersistentRecursive}, + EventPingReceived: nil, +} +var persistentWatchTypes = []watchType{watchTypePersistent, watchTypePersistentRecursive} + // Send event to all interested watchers func (c *Conn) notifyWatches(ev Event) { - var wTypes []watchType - switch ev.Type { - case EventNodeCreated: - wTypes = []watchType{watchTypeExist} - case EventNodeDataChanged: - wTypes = []watchType{watchTypeExist, watchTypeData} - case EventNodeChildrenChanged: - wTypes = []watchType{watchTypeChild} - case EventNodeDeleted: - wTypes = []watchType{watchTypeExist, watchTypeData, watchTypeChild} + wTypes, ok := eventWatchTypes[ev.Type] + if !ok { + return } + c.watchersLock.Lock() defer c.watchersLock.Unlock() - for _, t := range wTypes { - wpt := watchPathType{ev.Path, t} - if watchers := c.watchers[wpt]; len(watchers) > 0 { - for _, ch := range watchers { - ch <- ev - close(ch) + + if ev.Type == EventPingReceived { + for wpt, watchers := range c.watchers { + if wpt.wType.isPersistent() { + for _, ch := range watchers { + ch.Push(ev) + } + } + } + } else { + broadcast := func(wpt watchPathType) { + for _, ch := range c.watchers[wpt] { + ch.Push(ev) + if !wpt.wType.isPersistent() { + ch.Close() + delete(c.watchers, wpt) + } + } + } + + for _, t := range wTypes { + if t == watchTypePersistentRecursive { + for p := ev.Path; ; p, _ = SplitPath(p) { + broadcast(watchPathType{p, t}) + if p == "/" { + break + } + } + } else { + broadcast(watchPathType{ev.Path, t}) } - delete(c.watchers, wpt) } } } @@ -562,16 +601,23 @@ func (c *Conn) invalidateWatches(err error) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - if len(c.watchers) >= 0 { + if len(c.watchers) > 0 { for pathType, watchers := range c.watchers { + if errors.Is(err, ErrSessionExpired) && pathType.wType.isPersistent() { + // Ignore ErrSessionExpired for persistent watchers as the client will either automatically reconnect, + // or this is a shutdown-worthy error in which case there will be a followup invocation of this method + // with ErrClosing + continue + } + ev := Event{Type: EventNotWatching, State: StateDisconnected, Path: pathType.path, Err: err} c.sendEvent(ev) // also publish globally for _, ch := range watchers { - ch <- ev - close(ch) + ch.Push(ev) + ch.Close() } + delete(c.watchers, pathType) } - c.watchers = make(map[watchPathType][]chan Event) } } @@ -583,6 +629,8 @@ func (c *Conn) sendSetWatches() { return } + slog.Info("Resetting watches after reconnect", "watchCount", len(c.watchers)) + // NB: A ZK server, by default, rejects packets >1mb. So, if we have too // many watches to reset, we need to break this up into multiple packets // to avoid hitting that limit. Mirroring the Java client behavior: we are @@ -610,12 +658,7 @@ func (c *Conn) sendSetWatches() { reqs = append(reqs, req) } sizeSoFar = 28 // fixed overhead of a set-watches packet - req = &setWatchesRequest{ - RelativeZxid: c.lastZxid, - DataWatches: make([]string, 0), - ExistWatches: make([]string, 0), - ChildWatches: make([]string, 0), - } + req = &setWatchesRequest{RelativeZxid: c.lastZxid} } sizeSoFar += addlLen switch pathType.wType { @@ -625,6 +668,10 @@ func (c *Conn) sendSetWatches() { req.ExistWatches = append(req.ExistWatches, pathType.path) case watchTypeChild: req.ChildWatches = append(req.ChildWatches, pathType.path) + case watchTypePersistent: + req.PersistentWatches = append(req.PersistentWatches, pathType.path) + case watchTypePersistentRecursive: + req.PersistentRecursiveWatches = append(req.PersistentRecursiveWatches, pathType.path) } n++ } @@ -646,9 +693,39 @@ func (c *Conn) sendSetWatches() { // aren't failure modes where a blocking write to the channel of requests // could hang indefinitely and cause this goroutine to leak... for _, req := range reqs { - _, err := c.request(opSetWatches, req, res, nil) + var op int32 = opSetWatches + if len(req.PersistentWatches) > 0 || len(req.PersistentRecursiveWatches) > 0 { + // to maintain compatibility with older servers, only send opSetWatches2 if persistent watches are used + op = opSetWatches2 + } + + _, err := c.request(op, req, res, func(r *request, header *responseHeader, err error) { + if err == nil && op == opSetWatches2 { + // If the setWatches was successful, notify the persistent watchers they've been reconnected. + // Because we process responses in one routine, we know that the following will execute before + // subsequent responses are processed. This means we won't end up in a situation where events are + // sent to watchers before the reconnect event is sent. + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + var paths []string + if wt == watchTypePersistent { + paths = req.PersistentWatches + } else { + paths = req.PersistentRecursiveWatches + } + for _, p := range paths { + e := Event{Type: EventWatching, State: StateConnected, Path: p} + c.sendEvent(e) // also publish globally + for _, ch := range c.watchers[watchPathType{path: p, wType: wt}] { + ch.Push(e) + } + } + } + } + }) if err != nil { - c.logger.Printf("Failed to set previous watches: %v", err) + slog.Warn("Failed to set previous watches", "err", err) break } } @@ -662,9 +739,10 @@ func (c *Conn) authenticate() error { n, err := encodePacket(buf[4:], &connectRequest{ ProtocolVersion: protocolVersion, LastZxidSeen: c.lastZxid, - TimeOut: c.sessionTimeoutMs, - SessionID: c.SessionID(), - Passwd: c.passwd, + // The timeout in the connect request is milliseconds + TimeOut: int32(c.sessionTimeout / time.Millisecond), + SessionID: c.SessionID(), + Passwd: c.passwd, }) if err != nil { return err @@ -711,7 +789,7 @@ func (c *Conn) authenticate() error { } atomic.StoreInt64(&c.sessionID, r.SessionID) - c.setTimeouts(r.TimeOut) + c.setTimeouts(time.Duration(r.TimeOut) * time.Millisecond) c.passwd = r.Passwd c.setState(StateHasSession) @@ -784,6 +862,7 @@ func (c *Conn) sendLoop() error { c.conn.Close() return err } + c.metricReceiver.PingSent() case <-c.closeChan: return nil } @@ -799,17 +878,20 @@ func (c *Conn) recvLoop(conn net.Conn) error { for { // package length if err := conn.SetReadDeadline(time.Now().Add(c.recvTimeout)); err != nil { - c.logger.Printf("failed to set connection deadline: %v", err) + slog.Warn("failed to set connection deadline", "err", err) } _, err := io.ReadFull(conn, buf[:4]) if err != nil { - return fmt.Errorf("failed to read from connection: %v", err) + return fmt.Errorf("failed to read from connection: %w", err) } blen := int(binary.BigEndian.Uint32(buf[:4])) if cap(buf) < blen { if c.maxBufferSize > 0 && blen > c.maxBufferSize { - return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize) + return fmt.Errorf( + "%w (packet size %d, max buffer size: %d)", + ErrResponseBufferSizeExceeded, blen, c.maxBufferSize, + ) } buf = make([]byte, blen) } @@ -827,23 +909,31 @@ func (c *Conn) recvLoop(conn net.Conn) error { } if res.Xid == -1 { - res := &watcherEvent{} - _, err = decodePacket(buf[16:blen], res) + we := &watcherEvent{} + _, err = decodePacket(buf[16:blen], we) if err != nil { return err } ev := Event{ - Type: res.Type, - State: res.State, - Path: res.Path, - Err: nil, + Type: we.Type, + State: we.State, + Path: we.Path, + Err: nil, + Timestamp: time.Now(), } c.sendEvent(ev) c.notifyWatches(ev) } else if res.Xid == -2 { // Ping response. Ignore. + c.metricReceiver.PongReceived() + c.notifyWatches(Event{ + Type: EventPingReceived, + State: StateHasSession, + Zxid: res.Zxid, + Timestamp: time.Now(), + }) } else if res.Xid < 0 { - c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) + slog.Warn("Xid < 0 but not ping or watcher event", "Xid", res.Xid) } else { if res.Zxid > 0 { c.lastZxid = res.Zxid @@ -857,7 +947,7 @@ func (c *Conn) recvLoop(conn net.Conn) error { c.requestsLock.Unlock() if !ok { - c.logger.Printf("Response for unknown request with xid %d", res.Xid) + slog.Warn("Ignoring response for unknown request", "xid", res.Xid) } else { if res.Err != 0 { err = res.Err.toError() @@ -880,14 +970,15 @@ func (c *Conn) nextXid() int32 { return int32(atomic.AddUint32(&c.xid, 1) & 0x7fffffff) } -func (c *Conn) addWatcher(path string, watchType watchType) <-chan Event { +func (c *Conn) addWatcher(path string, watchType watchType, ch EventQueue) { c.watchersLock.Lock() defer c.watchersLock.Unlock() - ch := make(chan Event, 1) wpt := watchPathType{path, watchType} c.watchers[wpt] = append(c.watchers[wpt], ch) - return ch + if watchType.isPersistent() { + ch.Push(Event{Type: EventWatching, State: StateConnected, Path: path}) + } } func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) <-chan response { @@ -906,7 +997,7 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv select { case c.sendChan <- rq: case <-time.After(c.connectTimeout * 2): - c.logger.Printf("gave up trying to send opClose to server") + slog.Warn("gave up trying to send opClose to server") rq.recvChan <- response{-1, ErrConnectionClosed} } default: @@ -928,7 +1019,12 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv return rq.recvChan } -func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (int64, error) { +func (c *Conn) request(opcode int32, req interface{}, res interface{}, recvFunc func(*request, *responseHeader, error)) (_ int64, err error) { + start := time.Now() + defer func() { + c.metricReceiver.RequestCompleted(time.Now().Sub(start), err) + }() + recv := c.queueRequest(opcode, req, res, recvFunc) select { case r := <-recv: @@ -978,10 +1074,10 @@ func (c *Conn) Children(path string) ([]string, *Stat, error) { res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, nil, err } - return res.Children, &res.Stat, err + return res.Children, res.Stat, err } // ChildrenW returns the children of a znode and sets a watch. @@ -990,17 +1086,18 @@ func (c *Conn) ChildrenW(path string) ([]string, *Stat, <-chan Event, error) { return nil, nil, nil, err } - var ech <-chan Event + var ech ChanQueue[Event] res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { - ech = c.addWatcher(path, watchTypeChild) + ech = newChanEventChannel() + c.addWatcher(path, watchTypeChild, ech) } }) if err != nil { return nil, nil, nil, err } - return res.Children, &res.Stat, ech, err + return res.Children, res.Stat, ech, err } // Get gets the contents of a znode. @@ -1011,10 +1108,10 @@ func (c *Conn) Get(path string) ([]byte, *Stat, error) { res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, nil, err } - return res.Data, &res.Stat, err + return res.Data, res.Stat, err } // GetW returns the contents of a znode and sets a watch @@ -1023,17 +1120,18 @@ func (c *Conn) GetW(path string) ([]byte, *Stat, <-chan Event, error) { return nil, nil, nil, err } - var ech <-chan Event + var ech ChanQueue[Event] res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { if err == nil { - ech = c.addWatcher(path, watchTypeData) + ech = newChanEventChannel() + c.addWatcher(path, watchTypeData, ech) } }) if err != nil { return nil, nil, nil, err } - return res.Data, &res.Stat, ech, err + return res.Data, res.Stat, ech, err } // Set updates the contents of a znode. @@ -1044,13 +1142,13 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { res := &setDataResponse{} _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } - return &res.Stat, err + return res.Stat, err } -// Create creates a znode. +// Create creates a znode. If acl is empty, it uses the global WorldACL with PermAll // The returned path is the new path assigned by the server, it may not be the // same as the input, for example when creating a sequence znode the returned path // will be the input path with a sequence number appended. @@ -1059,14 +1157,36 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, return "", err } + if len(acl) == 0 { + acl = WorldACL(PermAll) + } + res := &createResponse{} _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return "", err } return res.Path, err } +// CreateAndReturnStat is the equivalent of Create, but it also returns the Stat of the created node. +func (c *Conn) CreateAndReturnStat(path string, data []byte, flags int32, acl []ACL) (string, *Stat, error) { + if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { + return "", nil, err + } + + if len(acl) == 0 { + acl = WorldACL(PermAll) + } + + res := &create2Response{} + _, err := c.request(opCreate2, &CreateRequest{path, data, acl, flags}, res, nil) + if errors.Is(err, ErrConnectionClosed) { + return "", nil, err + } + return res.Path, res.Stat, err +} + // CreateContainer creates a container znode and returns the path. func (c *Conn) CreateContainer(path string, data []byte, flags int32, acl []ACL) (string, error) { if err := validatePath(path, flags&FlagSequence == FlagSequence); err != nil { @@ -1119,10 +1239,10 @@ func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl var newPath string for i := 0; i < 3; i++ { newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) - switch err { - case ErrSessionExpired: + switch { + case errors.Is(err, ErrSessionExpired): // No need to search for the node since it can't exist. Just try again. - case ErrConnectionClosed: + case errors.Is(err, ErrConnectionClosed): children, _, err := c.Children(rootPath) if err != nil { return "", err @@ -1135,7 +1255,7 @@ func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl } } } - case nil: + case err == nil: return newPath, nil default: return "", err @@ -1162,15 +1282,15 @@ func (c *Conn) Exists(path string) (bool, *Stat, error) { res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return false, nil, err } exists := true - if err == ErrNoNode { + if errors.Is(err, ErrNoNode) { exists = false err = nil } - return exists, &res.Stat, err + return exists, res.Stat, err } // ExistsW tells the existence of a znode and sets a watch. @@ -1179,24 +1299,25 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { return false, nil, nil, err } - var ech <-chan Event + var ech ChanQueue[Event] res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: true}, res, func(req *request, res *responseHeader, err error) { + ech = newChanEventChannel() if err == nil { - ech = c.addWatcher(path, watchTypeData) - } else if err == ErrNoNode { - ech = c.addWatcher(path, watchTypeExist) + c.addWatcher(path, watchTypeData, ech) + } else if errors.Is(err, ErrNoNode) { + c.addWatcher(path, watchTypeExist, ech) } }) exists := true - if err == ErrNoNode { + if errors.Is(err, ErrNoNode) { exists = false err = nil } if err != nil { return false, nil, nil, err } - return exists, &res.Stat, ech, err + return exists, res.Stat, ech, err } // GetACL gets the ACLs of a znode. @@ -1207,10 +1328,10 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { res := &getAclResponse{} _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, nil, err } - return res.Acl, &res.Stat, err + return res.Acl, res.Stat, err } // SetACL updates the ACLs of a znode. @@ -1221,10 +1342,10 @@ func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { res := &setAclResponse{} _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } - return &res.Stat, err + return res.Stat, err } // Sync flushes the channel between process and the leader of a given znode, @@ -1237,7 +1358,7 @@ func (c *Conn) Sync(path string) (string, error) { res := &syncResponse{} _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return "", err } return res.Path, err @@ -1255,8 +1376,7 @@ type MultiResponse struct { // *CheckVersionRequest. func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { req := &multiRequest{ - Ops: make([]multiRequestOp, 0, len(ops)), - DoneHeader: multiHeader{Type: -1, Done: true, Err: -1}, + Ops: make([]multiRequestOp, 0, len(ops)), } for _, op := range ops { var opCode int32 @@ -1276,7 +1396,7 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { } res := &multiResponse{} _, err := c.request(opMulti, req, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } mr := make([]MultiResponse, len(res.Ops)) @@ -1286,6 +1406,44 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { return mr, err } +// MultiRead executes multiple ZooKeeper read operations at once. The provided ops must be one of GetDataOp or +// GetChildrenOp. Returns an error on network or connectivity errors, not on any op errors such as ErrNoNode. To check +// if any ops failed, check the corresponding MultiReadResponse.Err. +func (c *Conn) MultiRead(ops ...ReadOp) ([]MultiReadResponse, error) { + req := &multiRequest{ + Ops: make([]multiRequestOp, len(ops)), + } + for i, op := range ops { + req.Ops[i] = multiRequestOp{ + Header: multiHeader{op.opCode(), false, -1}, + Op: pathWatchRequest{Path: op.GetPath()}, + } + } + res := &multiReadResponse{} + _, err := c.request(opMultiRead, req, res, nil) + return res.OpResults, err +} + +// GetDataAndChildren executes a multi-read to get the given node's data and its children in one call. +func (c *Conn) GetDataAndChildren(path string) ([]byte, *Stat, []string, error) { + if err := validatePath(path, false); err != nil { + return nil, nil, nil, err + } + + opResults, err := c.MultiRead(GetDataOp(path), GetChildrenOp(path)) + if err != nil { + return nil, nil, nil, err + } + + for _, r := range opResults { + if r.Err != nil { + return nil, nil, nil, r.Err + } + } + + return opResults[0].Data, opResults[0].Stat, opResults[1].Children, nil +} + // IncrementalReconfig is the zookeeper reconfiguration api that allows adding and removing servers // by lists of members. For more info refer to the ZK documentation. // @@ -1321,7 +1479,7 @@ func (c *Conn) Reconfig(members []string, version int64) (*Stat, error) { func (c *Conn) internalReconfig(request *reconfigRequest) (*Stat, error) { response := &reconfigReponse{} _, err := c.request(opReconfig, request, response, nil) - return &response.Stat, err + return response.Stat, err } // Server returns the current or last-connected server name. @@ -1331,6 +1489,97 @@ func (c *Conn) Server() string { return c.server } +func (c *Conn) AddPersistentWatch(path string, mode AddWatchMode) (ch EventQueue, err error) { + if err = validatePath(path, false); err != nil { + return nil, err + } + + res := &addWatchResponse{} + _, err = c.request(opAddWatch, &addWatchRequest{Path: path, Mode: mode}, res, func(r *request, header *responseHeader, err error) { + if err == nil { + var wt watchType + if mode == AddWatchModePersistent { + wt = watchTypePersistent + } else { + wt = watchTypePersistentRecursive + } + + ch = NewUnlimitedQueue[Event]() + c.addWatcher(path, wt, ch) + } + }) + if errors.Is(err, ErrConnectionClosed) { + return nil, err + } + return ch, err +} + +func (c *Conn) RemovePersistentWatch(path string, ch EventQueue) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + deleted := false + + res := &checkWatchesResponse{} + _, err = c.request(opCheckWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for i, w := range c.watchers[wpt] { + if w == ch { + deleted = true + c.watchers[wpt] = append(c.watchers[wpt][:i], c.watchers[wpt][i+1:]...) + w.Push(Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher}) + w.Close() + return + } + } + } + }) + + if err != nil { + return err + } + + if !deleted { + return ErrNoWatcher + } + + return nil +} + +func (c *Conn) RemoveAllPersistentWatches(path string) (err error) { + if err = validatePath(path, false); err != nil { + return err + } + + res := &checkWatchesResponse{} + _, err = c.request(opRemoveWatches, &checkWatchesRequest{Path: path, Type: WatcherTypeAny}, res, func(r *request, header *responseHeader, err error) { + if err != nil { + return + } + + c.watchersLock.Lock() + defer c.watchersLock.Unlock() + for _, wt := range persistentWatchTypes { + wpt := watchPathType{path: path, wType: wt} + for _, ch := range c.watchers[wpt] { + ch.Push(Event{Type: EventNotWatching, State: c.State(), Path: path, Err: ErrNoWatcher}) + ch.Close() + } + delete(c.watchers, wpt) + } + }) + return err +} + func resendZkAuth(ctx context.Context, c *Conn) error { shouldCancel := func() bool { select { @@ -1346,9 +1595,7 @@ func resendZkAuth(ctx context.Context, c *Conn) error { c.credsMu.Lock() defer c.credsMu.Unlock() - if c.logInfo { - c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds)) - } + slog.Info("re-submitting credentials after reconnect", "credentialCount", len(c.creds)) for _, cred := range c.creds { // return early before attempting to send request. @@ -1367,25 +1614,51 @@ func resendZkAuth(ctx context.Context, c *Conn) error { nil, /* recvFunc*/ ) if err != nil { - return fmt.Errorf("failed to send auth request: %v", err) + return fmt.Errorf("failed to send auth request: %w", err) } var res response select { case res = <-resChan: case <-c.closeChan: - c.logger.Printf("recv closed, cancel re-submitting credentials") + slog.Info("recv closed, cancel re-submitting credentials") return nil case <-c.shouldQuit: - c.logger.Printf("should quit, cancel re-submitting credentials") + slog.Warn("Connection closing, cancel re-submitting credentials") return nil case <-ctx.Done(): return ctx.Err() } if res.err != nil { - return fmt.Errorf("failed connection setAuth request: %v", res.err) + return fmt.Errorf("failed connection setAuth request: %w", res.err) } } return nil } + +func JoinPath(parent, child string) string { + if !strings.HasSuffix(parent, "/") { + parent += "/" + } + if strings.HasPrefix(child, "/") { + child = child[1:] + } + return parent + child +} + +func SplitPath(path string) (dir, name string) { + i := strings.LastIndex(path, "/") + if i == 0 { + dir, name = "/", path[1:] + } else { + dir, name = path[:i], path[i+1:] + } + return dir, name +} + +// IsConnClosing checks if the given error returned by a Conn function is due to Close having been +// called on the Conn, which is fatal in that no subsequent Conn operations will succeed. +func IsConnClosing(err error) bool { + return errors.Is(err, ErrClosing) +} diff --git a/conn_test.go b/conn_test.go index f0c4e11e..a6df4e81 100644 --- a/conn_test.go +++ b/conn_test.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io/ioutil" + "strings" "sync" "testing" "time" @@ -92,7 +93,6 @@ func TestDeadlockInClose(t *testing.T) { shouldQuit: make(chan struct{}), connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), - logger: DefaultLogger, } for i := 0; i < sendChanSize; i++ { @@ -113,85 +113,144 @@ func TestDeadlockInClose(t *testing.T) { } func TestNotifyWatches(t *testing.T) { + queueImpls := []struct { + name string + new func() EventQueue + }{ + { + name: "chan", + new: func() EventQueue { return newChanEventChannel() }, + }, + { + name: "unlimited", + new: func() EventQueue { return NewUnlimitedQueue[Event]() }, + }, + } + cases := []struct { eType EventType path string watches map[watchPathType]bool }{ { - EventNodeCreated, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeChild}: false, - {"/", watchTypeData}: false, - }, - }, - { - EventNodeCreated, "/a", - map[watchPathType]bool{ + eType: EventNodeCreated, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, {"/b", watchTypeExist}: false, + + {"/a", watchTypeChild}: false, + + {"/a", watchTypeData}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeDataChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: false, + eType: EventNodeDataChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, { - EventNodeChildrenChanged, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: false, - {"/", watchTypeData}: false, - {"/", watchTypeChild}: true, + eType: EventNodeChildrenChanged, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: false, + {"/a", watchTypeData}: false, + {"/a", watchTypeChild}: true, + {"/a", watchTypePersistent}: true, + {"/a", watchTypePersistentRecursive}: false, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: false, + {"/", watchTypePersistentRecursive}: false, }, }, { - EventNodeDeleted, "/", - map[watchPathType]bool{ - {"/", watchTypeExist}: true, - {"/", watchTypeData}: true, - {"/", watchTypeChild}: true, + eType: EventNodeDeleted, + path: "/a", + watches: map[watchPathType]bool{ + {"/a", watchTypeExist}: true, + {"/a", watchTypeData}: true, + {"/a", watchTypeChild}: true, + + {"/a", watchTypePersistent}: true, + {"/", watchTypePersistent}: false, + + {"/a", watchTypePersistentRecursive}: true, + {"/", watchTypePersistentRecursive}: true, }, }, } - conn := &Conn{watchers: make(map[watchPathType][]chan Event)} - - for idx, c := range cases { - t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { - c := c - - notifications := make([]struct { - path string - notify bool - ch <-chan Event - }, len(c.watches)) - - var idx int - for wpt, expectEvent := range c.watches { - ch := conn.addWatcher(wpt.path, wpt.wType) - notifications[idx].path = wpt.path - notifications[idx].notify = expectEvent - notifications[idx].ch = ch - idx++ - } - ev := Event{Type: c.eType, Path: c.path} - conn.notifyWatches(ev) - - for _, res := range notifications { - select { - case e := <-res.ch: - if !res.notify || e.Path != res.path { - t.Fatal("unexpeted notification received") + for _, impl := range queueImpls { + t.Run(impl.name, func(t *testing.T) { + for idx, c := range cases { + c := c + t.Run(fmt.Sprintf("#%d %s", idx, c.eType), func(t *testing.T) { + notifications := make([]struct { + watchPathType + notify bool + ch EventQueue + }, len(c.watches)) + + conn := &Conn{watchers: make(map[watchPathType][]EventQueue)} + + var idx int + for wpt, expectEvent := range c.watches { + notifications[idx].watchPathType = wpt + notifications[idx].notify = expectEvent + ch := impl.new() + conn.addWatcher(wpt.path, wpt.wType, ch) + notifications[idx].ch = ch + if wpt.wType.isPersistent() { + e, _ := ch.Next(context.Background()) + if e.Type != EventWatching { + t.Fatalf("First event on persistent watcher should always be EventWatching") + } + } + idx++ } - default: - if res.notify { - t.Fatal("expected notification not received") + + conn.notifyWatches(Event{Type: c.eType, Path: c.path}) + + for _, res := range notifications { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + t.Cleanup(cancel) + + e, err := res.ch.Next(ctx) + if err == nil { + isPathCorrect := + (res.wType == watchTypePersistentRecursive && strings.HasPrefix(e.Path, res.path)) || + e.Path == res.path + if !res.notify || !isPathCorrect { + t.Logf("unexpeted notification received by %+v: %+v", res, e) + t.Fail() + } + } else { + if res.notify { + t.Logf("expected notification not received for %+v", res) + t.Fail() + } + } } - } + }) } }) } diff --git a/constants.go b/constants.go index 84455d2b..d2e44e04 100644 --- a/constants.go +++ b/constants.go @@ -26,12 +26,18 @@ const ( opGetChildren2 = 12 opCheck = 13 opMulti = 14 + opCreate2 = 15 opReconfig = 16 + opCheckWatches = 17 + opRemoveWatches = 18 opCreateContainer = 19 opCreateTTL = 21 + opMultiRead = 22 opClose = -11 opSetAuth = 100 opSetWatches = 101 + opSetWatches2 = 105 + opAddWatch = 106 opError = -1 // Not in protocol, used internally opWatcherEvent = -2 @@ -45,8 +51,10 @@ const ( EventNodeChildrenChanged EventType = 4 // EventSession represents a session event. - EventSession EventType = -1 - EventNotWatching EventType = -2 + EventSession EventType = -1 + EventNotWatching EventType = -2 + EventWatching EventType = -3 + EventPingReceived EventType = -4 ) var ( @@ -57,6 +65,8 @@ var ( EventNodeChildrenChanged: "EventNodeChildrenChanged", EventSession: "EventSession", EventNotWatching: "EventNotWatching", + EventWatching: "EventWatching", + EventPingReceived: "EventPingReceived", } ) @@ -111,24 +121,27 @@ type ErrCode int32 var ( // ErrConnectionClosed means the connection has been closed. - ErrConnectionClosed = errors.New("zk: connection closed") - ErrUnknown = errors.New("zk: unknown error") - ErrAPIError = errors.New("zk: api error") - ErrNoNode = errors.New("zk: node does not exist") - ErrNoAuth = errors.New("zk: not authenticated") - ErrBadVersion = errors.New("zk: version conflict") - ErrNoChildrenForEphemerals = errors.New("zk: ephemeral nodes may not have children") - ErrNodeExists = errors.New("zk: node already exists") - ErrNotEmpty = errors.New("zk: node has children") - ErrSessionExpired = errors.New("zk: session has been expired by the server") - ErrInvalidACL = errors.New("zk: invalid ACL specified") - ErrInvalidFlags = errors.New("zk: invalid flags specified") - ErrAuthFailed = errors.New("zk: client authentication failed") - ErrClosing = errors.New("zk: zookeeper is closing") - ErrNothing = errors.New("zk: no server responses to process") - ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") - ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") - ErrBadArguments = errors.New("invalid arguments") + ErrConnectionClosed = errors.New("zk: connection closed") + ErrUnknown = errors.New("zk: unknown error") + ErrAPIError = errors.New("zk: api error") + ErrNoNode = errors.New("zk: node does not exist") + ErrNoAuth = errors.New("zk: not authenticated") + ErrBadVersion = errors.New("zk: version conflict") + ErrNoChildrenForEphemerals = errors.New("zk: ephemeral nodes may not have children") + ErrNodeExists = errors.New("zk: node already exists") + ErrNotEmpty = errors.New("zk: node has children") + ErrSessionExpired = errors.New("zk: session has been expired by the server") + ErrInvalidACL = errors.New("zk: invalid ACL specified") + ErrInvalidFlags = errors.New("zk: invalid flags specified") + ErrAuthFailed = errors.New("zk: client authentication failed") + ErrClosing = errors.New("zk: zookeeper is closing") + ErrNothing = errors.New("zk: no server responses to process") + ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") + ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") + ErrBadArguments = errors.New("invalid arguments") + ErrNoWatcher = errors.New("zk: no such watcher") + ErrUnimplemented = errors.New("zk: Not implemented") + ErrResponseBufferSizeExceeded = errors.New("zk: server response exceeds max buffer size") // ErrInvalidCallback = errors.New("zk: invalid callback specified") errCodeToError = map[ErrCode]error{ @@ -147,8 +160,10 @@ var ( errClosing: ErrClosing, errNothing: ErrNothing, errSessionMoved: ErrSessionMoved, + errNoWatcher: ErrNoWatcher, errZReconfigDisabled: ErrReconfigDisabled, errBadArguments: ErrBadArguments, + errUnimplemented: ErrUnimplemented, } ) @@ -186,6 +201,7 @@ const ( errClosing ErrCode = -116 errNothing ErrCode = -117 errSessionMoved ErrCode = -118 + errNoWatcher ErrCode = -121 // Attempts to perform a reconfiguration operation when reconfiguration feature is disabled errZReconfigDisabled ErrCode = -123 ) @@ -224,6 +240,7 @@ var ( opClose: "close", opSetAuth: "setAuth", opSetWatches: "setWatches", + opAddWatch: "addWatch", opWatcherEvent: "watcherEvent", } @@ -263,3 +280,33 @@ var ( ModeStandalone: "standalone", } ) + +// AddWatchMode asd +type AddWatchMode int32 + +func (m AddWatchMode) String() string { + if name, ok := addWatchModeNames[m]; ok { + return name + } + return "unknown" +} + +const ( + AddWatchModePersistent AddWatchMode = iota + AddWatchModePersistentRecursive AddWatchMode = iota +) + +var ( + addWatchModeNames = map[AddWatchMode]string{ + AddWatchModePersistent: "persistent", + AddWatchModePersistentRecursive: "persistentRecursive", + } +) + +type WatcherType int32 + +const ( + WatcherTypeChildren = WatcherType(1) + WatcherTypeData = WatcherType(2) + WatcherTypeAny = WatcherType(3) +) diff --git a/dnshostprovider.go b/dnshostprovider.go index f4bba8d0..3dd74c87 100644 --- a/dnshostprovider.go +++ b/dnshostprovider.go @@ -6,10 +6,12 @@ import ( "sync" ) -// DNSHostProvider is the default HostProvider. It currently matches -// the Java StaticHostProvider, resolving hosts from DNS once during -// the call to Init. It could be easily extended to re-query DNS -// periodically or if there is trouble connecting. +// DNSHostProvider is a simple implementation of a HostProvider. It resolves the hosts once during +// Init, and iterates through the resolved addresses for every call to Next. Note that if the +// addresses that back the ZK hosts change, those changes will not be reflected. +// +// Deprecated: Because this HostProvider does not attempt to re-read from DNS, it can lead to issues +// if the addresses of the hosts change. It is preserved for backwards compatibility. type DNSHostProvider struct { mu sync.Mutex // Protects everything, so we can add asynchronous updates later. servers []string @@ -30,7 +32,7 @@ func (hp *DNSHostProvider) Init(servers []string) error { lookupHost = net.LookupHost } - found := []string{} + var found []string for _, server := range servers { host, port, err := net.SplitHostPort(server) if err != nil { @@ -46,43 +48,38 @@ func (hp *DNSHostProvider) Init(servers []string) error { } if len(found) == 0 { - return fmt.Errorf("No hosts found for addresses %q", servers) + return fmt.Errorf("zk: no hosts found for addresses %q", servers) } // Randomize the order of the servers to avoid creating hotspots - stringShuffle(found) + shuffleSlice(found) hp.servers = found - hp.curr = -1 - hp.last = -1 + hp.curr = 0 + hp.last = len(hp.servers) - 1 return nil } -// Len returns the number of servers available -func (hp *DNSHostProvider) Len() int { - hp.mu.Lock() - defer hp.mu.Unlock() - return len(hp.servers) -} - -// Next returns the next server to connect to. retryStart will be true -// if we've looped through all known servers without Connected() being -// called. +// Next returns the next server to connect to. retryStart should be true if this call to Next +// exhausted the list of known servers without Connected being called. If connecting to this final +// host fails, the connect loop will back off before invoking Next again for a fresh server. func (hp *DNSHostProvider) Next() (server string, retryStart bool) { hp.mu.Lock() defer hp.mu.Unlock() - hp.curr = (hp.curr + 1) % len(hp.servers) retryStart = hp.curr == hp.last - if hp.last == -1 { - hp.last = 0 - } - return hp.servers[hp.curr], retryStart + server = hp.servers[hp.curr] + hp.curr = (hp.curr + 1) % len(hp.servers) + return server, retryStart } // Connected notifies the HostProvider of a successful connection. func (hp *DNSHostProvider) Connected() { hp.mu.Lock() defer hp.mu.Unlock() - hp.last = hp.curr + if hp.curr == 0 { + hp.last = len(hp.servers) - 1 + } else { + hp.last = hp.curr - 1 + } } diff --git a/dnshostprovider_test.go b/dnshostprovider_test.go index d31d34c3..f010eb4d 100644 --- a/dnshostprovider_test.go +++ b/dnshostprovider_test.go @@ -1,6 +1,7 @@ package zk import ( + "errors" "fmt" "log" "testing" @@ -33,7 +34,7 @@ func TestIntegration_DNSHostProviderCreate(t *testing.T) { path := "/gozk-test" - if err := zk.Delete(path, -1); err != nil && err != ErrNoNode { + if err := zk.Delete(path, -1); err != nil && !errors.Is(err, ErrNoNode) { t.Fatalf("Delete returned error: %+v", err) } if p, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { @@ -68,7 +69,6 @@ func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFac } } -func (lhpf *localHostPortsFacade) Len() int { return lhpf.inner.Len() } func (lhpf *localHostPortsFacade) Connected() { lhpf.inner.Connected() } func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) } func (lhpf *localHostPortsFacade) Next() (string, bool) { @@ -165,60 +165,78 @@ func TestIntegration_DNSHostProviderReconnect(t *testing.T) { } } -// TestDNSHostProviderRetryStart tests the `retryStart` functionality -// of DNSHostProvider. -// It's also probably the clearest visual explanation of exactly how -// it works. -func TestDNSHostProviderRetryStart(t *testing.T) { +// TestHostProvidersRetryStart tests the `retryStart` functionality of DNSHostProvider and +// StaticHostProvider. +// It's also probably the clearest visual explanation of exactly how it works. +func TestHostProvidersRetryStart(t *testing.T) { t.Parallel() - hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) { - return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil - }} - - if err := hp.Init([]string{"foo.example.com:12345"}); err != nil { - t.Fatal(err) - } - - testdata := []struct { - retryStartWant bool - callConnected bool - }{ - // Repeated failures. - {false, false}, - {false, false}, - {false, false}, - {true, false}, - {false, false}, - {false, false}, - {true, true}, - - // One success offsets things. - {false, false}, - {false, true}, - {false, true}, - - // Repeated successes. - {false, true}, - {false, true}, - {false, true}, - {false, true}, - {false, true}, - - // And some more failures. - {false, false}, - {false, false}, - {true, false}, // Looped back to last known good server: all alternates failed. - {false, false}, - } - - for i, td := range testdata { - _, retryStartGot := hp.Next() - if retryStartGot != td.retryStartWant { - t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant) - } - if td.callConnected { - hp.Connected() - } + lookupHost := func(host string) ([]string, error) { + return []string{host}, nil + } + + providers := []HostProvider{ + &DNSHostProvider{ + lookupHost: lookupHost, + }, + &StaticHostProvider{ + lookupHost: lookupHost, + }, + } + + for _, hp := range providers { + t.Run(fmt.Sprintf("%T", hp), func(t *testing.T) { + if err := hp.Init([]string{"foo.com:2121", "bar.com:2121", "baz.com:2121"}); err != nil { + t.Fatal(err) + } + + testdata := []struct { + retryStartWant bool + callConnected bool + }{ + // Repeated failures. + {false, false}, + {false, false}, + {true, false}, + {false, false}, + {false, false}, + {true, false}, + {false, true}, + + // One success offsets things. + {false, false}, + {false, true}, + {false, true}, + + // Repeated successes. + {false, true}, + {false, true}, + {false, true}, + {false, true}, + {false, true}, + + // And some more failures. + {false, false}, + {false, false}, + {true, false}, // Looped back to last known good server: all alternates failed. + {false, false}, + {false, false}, + {true, false}, + {false, false}, + {false, false}, + {true, false}, + {false, false}, + } + + for i, td := range testdata { + _, retryStartGot := hp.Next() + if retryStartGot != td.retryStartWant { + t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant) + } + if td.callConnected { + hp.Connected() + } + } + }) } } diff --git a/go.mod b/go.mod index a2662730..5ca02c98 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/go-zookeeper/zk -go 1.13 +go 1.20 diff --git a/metrics.go b/metrics.go new file mode 100644 index 00000000..9294d3ed --- /dev/null +++ b/metrics.go @@ -0,0 +1,20 @@ +package zk + +import ( + "time" +) + +type MetricReceiver interface { + PingSent() + PongReceived() + RequestCompleted(duration time.Duration, err error) +} + +var _ MetricReceiver = UnimplementedMetricReceiver{} + +type UnimplementedMetricReceiver struct { +} + +func (u UnimplementedMetricReceiver) PingSent() {} +func (u UnimplementedMetricReceiver) PongReceived() {} +func (u UnimplementedMetricReceiver) RequestCompleted(time.Duration, error) {} diff --git a/server_help_test.go b/server_help_test.go index 8ae52d94..80990b4e 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -7,6 +7,8 @@ import ( "math/rand" "os" "path/filepath" + "runtime/debug" + "strconv" "strings" "testing" "time" @@ -34,6 +36,38 @@ type TestCluster struct { Servers []TestServer } +func WithTestCluster(t *testing.T, testTimeout time.Duration, f func(ts *TestCluster, zk *Conn)) { + ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + ts.Stop() + }) + zk, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + t.Cleanup(func() { + zk.Close() + }) + doneChan := make(chan struct{}) + go func() { + defer func() { + close(doneChan) + if r := recover(); r != nil { + t.Error(r, string(debug.Stack())) + } + }() + f(ts, zk) + }() + select { + case <-doneChan: + case <-time.After(testTimeout): + t.Fatalf("Test did not complete within timeout") + } +} + // TODO: pull this into its own package to allow for better isolation of integration tests vs. unit // testing. This should be used on CI systems and local only when needed whereas unit tests should remain // fast and not rely on external dependencies. @@ -55,7 +89,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl } tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") success := false startPort := int(rand.Int31n(6000) + 10000) @@ -69,7 +103,7 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl for serverN := 0; serverN < size; serverN++ { srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv%d", serverN+1)) - requireNoError(t, os.Mkdir(srvPath, 0700), "failed to make server path") + requireNoErrorf(t, os.Mkdir(srvPath, 0700), "failed to make server path") port := startPort + serverN*3 cfg := ServerConfig{ @@ -90,20 +124,20 @@ func StartTestCluster(t *testing.T, size int, stdout, stderr io.Writer) (*TestCl cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintf(fi, "%d\n", serverN+1) fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) srv, err := NewIntegrationTestServer(t, cfgPath, stdout, stderr) - requireNoError(t, err) + requireNoErrorf(t, err) if err := srv.Start(); err != nil { return nil, err @@ -139,7 +173,7 @@ func (tc *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, < return tc.ConnectWithOptions(sessionTimeout) } -func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { +func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...ConnOption) (*Conn, <-chan Event, error) { hosts := make([]string, len(tc.Servers)) for i, srv := range tc.Servers { hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port) @@ -253,9 +287,36 @@ func (tc *TestCluster) StopAllServers() error { return nil } -func requireNoError(t *testing.T, err error, msgAndArgs ...interface{}) { +func requireNoErrorf(t *testing.T, err error, msgAndArgs ...interface{}) { if err != nil { + t.Helper() t.Logf("received unexpected error: %v", err) - t.Fatal(msgAndArgs...) + t.Fatalf(msgAndArgs[0].(string), msgAndArgs[1:]...) + } +} + +func RequireMinimumZkVersion(t *testing.T, minimum string) { + if val, ok := os.LookupEnv("ZK_VERSION"); ok { + split := func(v string) (parts []int) { + for _, s := range strings.Split(v, ".") { + i, err := strconv.Atoi(s) + if err != nil { + t.Fatalf("invalid version segment: %q", s) + } + parts = append(parts, i) + } + return parts + } + + minimumV, actualV := split(minimum), split(val) + for i, p := range minimumV { + if actualV[i] < p { + if !strings.HasPrefix(val, minimum) { + t.Skipf("running with zookeeper that does not support this api (requires at least %s)", minimum) + } + } + } + } else { + t.Skip("did not detect zk_version from env. skipping test") } } diff --git a/staticdnshostprovider.go b/staticdnshostprovider.go new file mode 100644 index 00000000..3a227fd5 --- /dev/null +++ b/staticdnshostprovider.go @@ -0,0 +1,115 @@ +package zk + +import ( + "fmt" + "log/slog" + "math/rand" + "net" + "sync" +) + +type hostPort struct { + host, port string +} + +func (hp *hostPort) String() string { + return hp.host + ":" + hp.port +} + +// StaticHostProvider is the default HostProvider, and replaces the now deprecated DNSHostProvider. +// It will iterate through the ZK hosts on every call to Next, and return a random address selected +// from the resolved addresses of the ZK host (if the host is already an IP, it will return that +// directly). It is important to manually resolve and shuffle the addresses because the DNS record +// that backs a host may rarely (or never) change, so repeated calls to connect to this host may +// always connect to the same IP. This mode is the default mode, and matches the Java client's +// implementation. Note that if the host cannot be resolved, Next will return it directly, instead of +// an error. This will cause Dial to fail and the loop will move on to a new host. It is implemented +// as a pound-for-pound copy of the standard Java client's equivalent: +// https://github.com/linkedin/zookeeper/blob/629518b5ea2b26d88a9ec53d5a422afe9b12e452/zookeeper-server/src/main/java/org/apache/zookeeper/client/StaticHostProvider.java#L368 +type StaticHostProvider struct { + mu sync.Mutex // Protects everything, so we can add asynchronous updates later. + servers []hostPort + // nextServer is the index (in servers) of the next server that will be returned by Next. + nextServer int + // lastConnectedServer is the index (in servers) of the last server to which a successful connection + // was established. Used to track whether Next iterated through all available servers without + // successfully connecting. + lastConnectedServer int + lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. +} + +func (shp *StaticHostProvider) Init(servers []string) error { + shp.mu.Lock() + defer shp.mu.Unlock() + + if shp.lookupHost == nil { + shp.lookupHost = net.LookupHost + } + + var found []hostPort + for _, server := range servers { + host, port, err := net.SplitHostPort(server) + if err != nil { + return err + } + // Perform the lookup to validate the initial set of hosts, but discard the results as the addresses + // will be resolved dynamically when Next is called. + _, err = shp.lookupHost(host) + if err != nil { + return err + } + + found = append(found, hostPort{host, port}) + } + + if len(found) == 0 { + return fmt.Errorf("zk: no hosts found for addresses %q", servers) + } + + // Randomize the order of the servers to avoid creating hotspots + shuffleSlice(found) + + shp.servers = found + shp.nextServer = 0 + shp.lastConnectedServer = len(shp.servers) - 1 + + return nil +} + +// Next returns the next server to connect to. retryStart should be true if this call to Next +// exhausted the list of known servers without Connected being called. If connecting to this final +// host fails, the connect loop will back off before invoking Next again for a fresh server. +func (shp *StaticHostProvider) Next() (server string, retryStart bool) { + shp.mu.Lock() + defer shp.mu.Unlock() + retryStart = shp.nextServer == shp.lastConnectedServer + + next := shp.servers[shp.nextServer] + addrs, err := shp.lookupHost(next.host) + if len(addrs) == 0 { + if err == nil { + // If for whatever reason lookupHosts returned an empty list of addresses but a nil error, use a + // default error + err = fmt.Errorf("zk: no hosts resolved by lookup for %q", next.host) + } + slog.Warn("Could not resolve ZK host", "host", next.host, "err", err) + server = next.String() + } else { + server = net.JoinHostPort(addrs[rand.Intn(len(addrs))], next.port) + } + + shp.nextServer = (shp.nextServer + 1) % len(shp.servers) + + return server, retryStart +} + +// Connected notifies the HostProvider of a successful connection. +func (shp *StaticHostProvider) Connected() { + shp.mu.Lock() + defer shp.mu.Unlock() + if shp.nextServer == 0 { + shp.lastConnectedServer = len(shp.servers) - 1 + } else { + shp.lastConnectedServer = shp.nextServer - 1 + } +} diff --git a/staticdnshostprovider_test.go b/staticdnshostprovider_test.go new file mode 100644 index 00000000..7cd2ae86 --- /dev/null +++ b/staticdnshostprovider_test.go @@ -0,0 +1,71 @@ +package zk + +import "testing" + +// The test in TestHostProvidersRetryStart checks that the semantics of StaticHostProvider's +// implementation of Next are correct, this test only checks that the provider correctly interacts +// with the resolver. +func TestStaticHostProvider(t *testing.T) { + const fooPort, barPort = "2121", "6464" + const fooHost, barHost = "foo.com", "bar.com" + hostToPort := map[string]string{ + fooHost: fooPort, + barHost: barPort, + } + hostToAddrs := map[string][]string{ + fooHost: {"0.0.0.1", "0.0.0.2", "0.0.0.3"}, + barHost: {"0.0.0.4", "0.0.0.5", "0.0.0.6"}, + } + addrToHost := map[string]string{} + for host, addrs := range hostToAddrs { + for _, addr := range addrs { + addrToHost[addr+":"+hostToPort[host]] = host + } + } + + hp := &StaticHostProvider{ + lookupHost: func(host string) ([]string, error) { + addrs, ok := hostToAddrs[host] + if !ok { + t.Fatalf("Unexpected argument to lookupHost %q", host) + } + return addrs, nil + }, + } + + err := hp.Init([]string{fooHost + ":" + fooPort, barHost + ":" + barPort}) + if err != nil { + t.Fatalf("Unexpected err from Init %v", err) + } + + addr1, retryStart := hp.Next() + if retryStart { + t.Fatalf("retryStart should be false") + } + addr2, retryStart := hp.Next() + if !retryStart { + t.Fatalf("retryStart should be true") + } + host1, host2 := addrToHost[addr1], addrToHost[addr2] + if host1 == host2 { + t.Fatalf("Next yielded addresses from same host (%q)", host1) + } + + // Final sanity check that it is shuffling the addresses + seenAddresses := map[string]map[string]bool{ + fooHost: {}, + barHost: {}, + } + for i := 0; i < 10_000; i++ { + addr, _ := hp.Next() + seenAddresses[addrToHost[addr]][addr] = true + } + + for host, addrs := range hostToAddrs { + for _, addr := range addrs { + if !seenAddresses[host][addr+":"+hostToPort[host]] { + t.Fatalf("expected addr %q for host %q not seen (seen: %v)", addr, host, seenAddresses) + } + } + } +} diff --git a/structs.go b/structs.go index 8eb41e39..523050a4 100644 --- a/structs.go +++ b/structs.go @@ -3,7 +3,6 @@ package zk import ( "encoding/binary" "errors" - "log" "reflect" "runtime" "strings" @@ -16,12 +15,6 @@ var ( ErrShortBuffer = errors.New("zk: buffer too small") ) -type defaultLogger struct{} - -func (defaultLogger) Printf(format string, a ...interface{}) { - log.Printf(format, a...) -} - type ACL struct { Perms int32 Scheme string @@ -135,7 +128,7 @@ type pathResponse struct { } type statResponse struct { - Stat Stat + Stat *Stat } // @@ -177,6 +170,10 @@ type CreateTTLRequest struct { } type createResponse pathResponse +type create2Response struct { + Path string + Stat *Stat +} type DeleteRequest PathVersionRequest type deleteResponse struct{} @@ -190,7 +187,7 @@ type getAclRequest pathRequest type getAclResponse struct { Acl []ACL - Stat Stat + Stat *Stat } type getChildrenRequest pathRequest @@ -199,18 +196,59 @@ type getChildrenResponse struct { Children []string } +type ReadOp interface { + GetPath() string + IsGetData() bool + IsGetChildren() bool + opCode() int32 +} + type getChildren2Request pathWatchRequest +type GetChildrenOp string + +func (g GetChildrenOp) IsGetData() bool { + return false +} + +func (g GetChildrenOp) IsGetChildren() bool { + return true +} + +func (g GetChildrenOp) GetPath() string { + return string(g) +} + +func (g GetChildrenOp) opCode() int32 { + return opGetChildren +} type getChildren2Response struct { Children []string - Stat Stat + Stat *Stat } type getDataRequest pathWatchRequest +type GetDataOp string + +func (g GetDataOp) IsGetData() bool { + return true +} + +func (g GetDataOp) IsGetChildren() bool { + return false +} + +func (g GetDataOp) GetPath() string { + return string(g) +} + +func (g GetDataOp) opCode() int32 { + return opGetData +} type getDataResponse struct { Data []byte - Stat Stat + Stat *Stat } type getMaxChildrenRequest pathRequest @@ -256,10 +294,12 @@ type setSaslResponse struct { } type setWatchesRequest struct { - RelativeZxid int64 - DataWatches []string - ExistWatches []string - ChildWatches []string + RelativeZxid int64 + DataWatches []string + ExistWatches []string + ChildWatches []string + PersistentWatches []string + PersistentRecursiveWatches []string } type setWatchesResponse struct{} @@ -275,8 +315,7 @@ type multiRequestOp struct { Op interface{} } type multiRequest struct { - Ops []multiRequestOp - DoneHeader multiHeader + Ops []multiRequestOp } type multiResponseOp struct { Header multiHeader @@ -285,8 +324,15 @@ type multiResponseOp struct { Err ErrCode } type multiResponse struct { - Ops []multiResponseOp - DoneHeader multiHeader + Ops []multiResponseOp +} +type MultiReadResponse struct { + getDataResponse + getChildrenResponse + Err error +} +type multiReadResponse struct { + OpResults []MultiReadResponse } // zk version 3.5 reconfig API @@ -301,6 +347,20 @@ type reconfigRequest struct { type reconfigReponse getDataResponse +type addWatchRequest struct { + Path string + Mode AddWatchMode +} + +type addWatchResponse struct{} + +type checkWatchesRequest struct { + Path string + Type WatcherType +} + +type checkWatchesResponse struct{} + func (r *multiRequest) Encode(buf []byte) (int, error) { total := 0 for _, op := range r.Ops { @@ -311,8 +371,7 @@ func (r *multiRequest) Encode(buf []byte) (int, error) { } total += n } - r.DoneHeader.Done = true - n, err := encodePacketValue(buf[total:], reflect.ValueOf(r.DoneHeader)) + n, err := encodePacketValue(buf[total:], reflect.ValueOf(multiHeader{Type: -1, Done: true, Err: -1})) if err != nil { return total, err } @@ -323,7 +382,6 @@ func (r *multiRequest) Encode(buf []byte) (int, error) { func (r *multiRequest) Decode(buf []byte) (int, error) { r.Ops = make([]multiRequestOp, 0) - r.DoneHeader = multiHeader{-1, true, -1} total := 0 for { header := &multiHeader{} @@ -333,7 +391,6 @@ func (r *multiRequest) Decode(buf []byte) (int, error) { } total += n if header.Done { - r.DoneHeader = *header break } @@ -355,7 +412,6 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { var multiErr error r.Ops = make([]multiResponseOp, 0) - r.DoneHeader = multiHeader{-1, true, -1} total := 0 for { header := &multiHeader{} @@ -365,7 +421,6 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { } total += n if header.Done { - r.DoneHeader = *header break } @@ -399,6 +454,48 @@ func (r *multiResponse) Decode(buf []byte) (int, error) { return total, multiErr } +func (r *multiReadResponse) Decode(buf []byte) (total int, multiErr error) { + for { + header := &multiHeader{} + n, err := decodePacketValue(buf[total:], reflect.ValueOf(header)) + if err != nil { + return total, err + } + total += n + if header.Done { + break + } + + var res MultiReadResponse + var errCode ErrCode + var w reflect.Value + switch header.Type { + case opGetData: + w = reflect.ValueOf(&res.getDataResponse) + case opGetChildren: + w = reflect.ValueOf(&res.getChildrenResponse) + case opError: + w = reflect.ValueOf(&errCode) + default: + return total, ErrAPIError + } + + n, err = decodePacketValue(buf[total:], w) + if err != nil { + return total, err + } + total += n + + if errCode != errOk { + res.Err = errCode.toError() + } + + r.OpResults = append(r.OpResults, res) + } + + return total, nil +} + type watcherEvent struct { Type EventType State State @@ -598,7 +695,7 @@ func requestStructForOp(op int32) interface{} { switch op { case opClose: return &closeRequest{} - case opCreate: + case opCreate, opCreate2: return &CreateRequest{} case opCreateContainer: return &CreateContainerRequest{} @@ -622,7 +719,7 @@ func requestStructForOp(op int32) interface{} { return &setAclRequest{} case opSetData: return &SetDataRequest{} - case opSetWatches: + case opSetWatches, opSetWatches2: return &setWatchesRequest{} case opSync: return &syncRequest{} @@ -634,6 +731,8 @@ func requestStructForOp(op int32) interface{} { return &multiRequest{} case opReconfig: return &reconfigRequest{} + case opAddWatch: + return &addWatchRequest{} } return nil } diff --git a/structs_test.go b/structs_test.go index 3a38ab45..9c4e9024 100644 --- a/structs_test.go +++ b/structs_test.go @@ -10,7 +10,7 @@ func TestEncodeDecodePacket(t *testing.T) { encodeDecodeTest(t, &requestHeader{-2, 5}) encodeDecodeTest(t, &connectResponse{1, 2, 3, nil}) encodeDecodeTest(t, &connectResponse{1, 2, 3, []byte{4, 5, 6}}) - encodeDecodeTest(t, &getAclResponse{[]ACL{{12, "s", "anyone"}}, Stat{}}) + encodeDecodeTest(t, &getAclResponse{[]ACL{{12, "s", "anyone"}}, &Stat{}}) encodeDecodeTest(t, &getChildrenResponse{[]string{"foo", "bar"}}) encodeDecodeTest(t, &pathWatchRequest{"path", true}) encodeDecodeTest(t, &pathWatchRequest{"path", false}) diff --git a/tcp_server_test.go b/tcp_server_test.go index 09254948..72bbd09c 100644 --- a/tcp_server_test.go +++ b/tcp_server_test.go @@ -1,17 +1,13 @@ package zk import ( - "fmt" - "math/rand" "net" "testing" "time" ) func WithListenServer(t *testing.T, test func(server string)) { - startPort := int(rand.Int31n(6000) + 10000) - server := fmt.Sprintf("localhost:%d", startPort) - l, err := net.Listen("tcp", server) + l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to start listen server: %v", err) } @@ -26,7 +22,7 @@ func WithListenServer(t *testing.T, test func(server string)) { handleRequest(conn) }() - test(server) + test(l.Addr().String()) } // Handles incoming requests. diff --git a/unlimited_channel.go b/unlimited_channel.go new file mode 100644 index 00000000..fda556b5 --- /dev/null +++ b/unlimited_channel.go @@ -0,0 +1,121 @@ +package zk + +import ( + "context" + "errors" + "sync" +) + +var ErrEventQueueClosed = errors.New("zk: event queue closed") + +type Queue[T any] interface { + // Next waits for a new element to be received until the context expires or the queue is closed. + Next(ctx context.Context) (T, error) + // Push adds the given element to the queue and notifies any in-flight calls to Next that a new element is + // available. + Push(e T) + // Close functions like closing a channel. Subsequent calls to Next will drain whatever elements remain in the + // buffer while subsequent calls to Push will panic. Once remaining elements are drained, Next will return + // ErrEventQueueClosed. + Close() +} + +// EventQueue is added to preserve the old EventQueue type which had an equivalent interface, for backwards +// compatibility in method signatures. +type EventQueue = Queue[Event] + +type ChanQueue[T any] chan T + +func (c ChanQueue[T]) Next(ctx context.Context) (T, error) { + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case e, ok := <-c: + if !ok { + var t T + return t, ErrEventQueueClosed + } else { + return e, nil + } + } +} + +func (c ChanQueue[T]) Push(e T) { + c <- e +} + +func (c ChanQueue[T]) Close() { + close(c) +} + +func newChanEventChannel() ChanQueue[Event] { + return make(chan Event, 1) +} + +type unlimitedEventQueue[T any] struct { + lock sync.Mutex + newElement chan struct{} + elements []T +} + +func NewUnlimitedQueue[T any]() Queue[T] { + return &unlimitedEventQueue[T]{ + newElement: make(chan struct{}), + } +} + +func (q *unlimitedEventQueue[T]) Push(e T) { + q.lock.Lock() + defer q.lock.Unlock() + + if q.newElement == nil { + // Panic like a closed channel + panic("send on closed unlimited channel") + } + + q.elements = append(q.elements, e) + close(q.newElement) + q.newElement = make(chan struct{}) +} + +func (q *unlimitedEventQueue[T]) Close() { + q.lock.Lock() + defer q.lock.Unlock() + + if q.newElement == nil { + // Panic like a closed channel + panic("close of closed Queue") + } + + close(q.newElement) + q.newElement = nil +} + +func (q *unlimitedEventQueue[T]) Next(ctx context.Context) (T, error) { + for { + q.lock.Lock() + if len(q.elements) > 0 { + e := q.elements[0] + q.elements = q.elements[1:] + q.lock.Unlock() + return e, nil + } + + ch := q.newElement + if ch == nil { + q.lock.Unlock() + var t T + return t, ErrEventQueueClosed + } + q.lock.Unlock() + + select { + case <-ctx.Done(): + var t T + return t, ctx.Err() + case <-ch: + continue + } + } +} diff --git a/unlimited_channel_test.go b/unlimited_channel_test.go new file mode 100644 index 00000000..4535518e --- /dev/null +++ b/unlimited_channel_test.go @@ -0,0 +1,117 @@ +//go:build go1.18 + +package zk + +import ( + "context" + "errors" + "fmt" + "reflect" + "sync" + "testing" + "time" +) + +func newEvent(i int) Event { + return Event{Path: fmt.Sprintf("/%d", i)} +} + +func TestUnlimitedChannel(t *testing.T) { + names := []string{"notClosedAfterPushes", "closeAfterPushes"} + for i, closeAfterPushes := range []bool{false, true} { + t.Run(names[i], func(t *testing.T) { + ch := NewUnlimitedQueue[Event]() + const eventCount = 10 + + // check that elements can be pushed without consumers + for i := 0; i < eventCount; i++ { + ch.Push(newEvent(i)) + } + if closeAfterPushes { + ch.Close() + } + + for events := 0; events < eventCount; events++ { + actual, err := ch.Next(context.Background()) + if err != nil { + t.Fatalf("Unexpected error returned from Next (events %d): %+v", events, err) + } + expected := newEvent(events) + if !reflect.DeepEqual(actual, expected) { + t.Fatalf("Did not receive expected event from queue: actual %+v expected %+v", actual, expected) + } + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + t.Cleanup(cancel) + + _, err := ch.Next(ctx) + if closeAfterPushes { + if err != ErrEventQueueClosed { + t.Fatalf("Did not receive expected error (%v) from Next: %v", ErrEventQueueClosed, err) + } + } else { + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Next did not exit with cancelled context: %+v", err) + } + } + }) + } + t.Run("interleaving", func(t *testing.T) { + ch := NewUnlimitedQueue[Event]() + + for i := 0; i < 10; i++ { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*10) + t.Cleanup(cancel) + + expected := newEvent(i) + + ctx = &customContext{ + Context: ctx, + f: func() { + ch.Push(expected) + }, + } + + actual, err := ch.Next(ctx) + if err != nil { + t.Fatalf("Received unexpected error from Next: %+v", err) + } + if !reflect.DeepEqual(expected, actual) { + t.Fatalf("Unexpected event received from Next (expected %+v, actual %+v", expected, actual) + } + } + }) + t.Run("multiple consumers", func(t *testing.T) { + ch := NewUnlimitedQueue[Event]() + for i := 0; i < 20; i++ { + ch.Push(newEvent(i)) + } + ch.Close() + var wg sync.WaitGroup + wg.Add(20) + for i := 0; i < 5; i++ { + go func() { + for { + _, err := ch.Next(context.Background()) + if errors.Is(err, ErrEventQueueClosed) { + return + } + requireNoErrorf(t, err) + wg.Done() + } + }() + } + wg.Wait() + }) +} + +type customContext struct { + context.Context + f func() +} + +func (c *customContext) Done() <-chan struct{} { + c.f() + return c.Context.Done() +} diff --git a/util.go b/util.go index 5a92b66b..9244a0bb 100644 --- a/util.go +++ b/util.go @@ -49,12 +49,11 @@ func FormatServers(servers []string) []string { return srvs } -// stringShuffle performs a Fisher-Yates shuffle on a slice of strings -func stringShuffle(s []string) { - for i := len(s) - 1; i > 0; i-- { - j := rand.Intn(i + 1) +// shuffleSlice invokes rand.Shuffle on the given slice. +func shuffleSlice[T any](s []T) { + rand.Shuffle(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] - } + }) } // validatePath will make sure a path is valid before sending the request diff --git a/zk_test.go b/zk_test.go index ceaacda4..8756cdf2 100644 --- a/zk_test.go +++ b/zk_test.go @@ -1,8 +1,9 @@ package zk import ( + "bytes" "context" - "encoding/hex" + "errors" "fmt" "io" "io/ioutil" @@ -11,10 +12,8 @@ import ( "os" "path/filepath" "reflect" - "regexp" "sort" "strings" - "sync" "sync/atomic" "testing" "time" @@ -187,27 +186,22 @@ func TestIntegration_CreateContainer(t *testing.T) { } func TestIntegration_IncrementalReconfig(t *testing.T) { - if val, ok := os.LookupEnv("zk_version"); ok { - if !strings.HasPrefix(val, "3.5") { - t.Skip("running with zookeeper that does not support this api") - } - } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") - } + RequireMinimumZkVersion(t, "3.5") + ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() // start and add a new server. tmpPath, err := ioutil.TempDir("", "gozk") - requireNoError(t, err, "failed to create tmp dir for test server setup") + requireNoErrorf(t, err, "failed to create tmp dir for test server setup") defer os.RemoveAll(tmpPath) startPort := int(rand.Int31n(6000) + 10000) srvPath := filepath.Join(tmpPath, fmt.Sprintf("srv4")) if err := os.Mkdir(srvPath, 0700); err != nil { - requireNoError(t, err, "failed to make server path") + requireNoErrorf(t, err, "failed to make server path") } testSrvConfig := ServerConfigServer{ ID: 4, @@ -224,35 +218,35 @@ func TestIntegration_IncrementalReconfig(t *testing.T) { // TODO: clean all this server creating up to a better helper method cfgPath := filepath.Join(srvPath, _testConfigName) fi, err := os.Create(cfgPath) - requireNoError(t, err) + requireNoErrorf(t, err) - requireNoError(t, cfg.Marshall(fi)) + requireNoErrorf(t, cfg.Marshall(fi)) fi.Close() fi, err = os.Create(filepath.Join(srvPath, _testMyIDFileName)) - requireNoError(t, err) + requireNoErrorf(t, err) _, err = fmt.Fprintln(fi, "4") fi.Close() - requireNoError(t, err) + requireNoErrorf(t, err) testServer, err := NewIntegrationTestServer(t, cfgPath, nil, nil) - requireNoError(t, err) - requireNoError(t, testServer.Start()) + requireNoErrorf(t, err) + requireNoErrorf(t, testServer.Start()) defer testServer.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -265,48 +259,42 @@ func TestIntegration_IncrementalReconfig(t *testing.T) { // remove node 3. _, err = zk.IncrementalReconfig(nil, []string{"3"}, -1) - if err != nil && err == ErrConnectionClosed { + if err != nil && errors.Is(err, ErrConnectionClosed) { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to remove node from cluster") + requireNoErrorf(t, err, "failed to remove node from cluster") } // add node a new 4th node server := fmt.Sprintf("server.%d=%s:%d:%d;%d", testSrvConfig.ID, testSrvConfig.Host, testSrvConfig.PeerPort, testSrvConfig.LeaderElectionPort, cfg.ClientPort) _, err = zk.IncrementalReconfig([]string{server}, nil, -1) - if err != nil && err == ErrConnectionClosed { + if err != nil && errors.Is(err, ErrConnectionClosed) { t.Log("conneciton closed is fine since the cluster re-elects and we dont reconnect") } else { - requireNoError(t, err, "failed to add new server to cluster") + requireNoErrorf(t, err, "failed to add new server to cluster") } } func TestIntegration_Reconfig(t *testing.T) { - if val, ok := os.LookupEnv("zk_version"); ok { - if !strings.HasPrefix(val, "3.5") { - t.Skip("running with zookeeper that does not support this api") - } - } else { - t.Skip("did not detect zk_version from env. skipping reconfig test") - } + RequireMinimumZkVersion(t, "3.5") // This test enures we can do an non-incremental reconfig ts, err := StartTestCluster(t, 3, nil, logWriter{t: t, p: "[ZKERR] "}) - requireNoError(t, err, "failed to setup test cluster") + requireNoErrorf(t, err, "failed to setup test cluster") defer ts.Stop() zk, events, err := ts.ConnectAll() - requireNoError(t, err, "failed to connect to cluster") + requireNoErrorf(t, err, "failed to connect to cluster") defer zk.Close() err = zk.AddAuth("digest", []byte("super:test")) - requireNoError(t, err, "failed to auth to cluster") + requireNoErrorf(t, err, "failed to auth to cluster") waitCtx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() err = waitForSession(waitCtx, events) - requireNoError(t, err, "failed to wail for session") + requireNoErrorf(t, err, "failed to wail for session") _, _, err = zk.Get("/zookeeper/config") if err != nil { @@ -320,7 +308,7 @@ func TestIntegration_Reconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") // reconfig to all the hosts again s = []string{} @@ -329,7 +317,7 @@ func TestIntegration_Reconfig(t *testing.T) { } _, err = zk.Reconfig(s, -1) - requireNoError(t, err, "failed to reconfig cluster") + requireNoErrorf(t, err, "failed to reconfig cluster") } func TestIntegration_OpsAfterCloseDontDeadlock(t *testing.T) { @@ -400,6 +388,136 @@ func TestIntegration_Multi(t *testing.T) { } } +func TestIntegration_MultiRead(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + nodeChildren := map[string][]string{} + nodeData := map[string][]byte{} + var ops []ReadOp + + create := func(path string, data []byte) { + if _, err := zk.Create(path, data, 0, nil); err != nil { + requireNoErrorf(t, err, "create returned an error") + } else { + dir, name := SplitPath(path) + nodeChildren[dir] = append(nodeChildren[dir], name) + nodeData[path] = data + ops = append(ops, GetDataOp(path), GetChildrenOp(path)) + } + } + + root := "/gozk-test" + create(root, nil) + + for i := byte(0); i < 10; i++ { + child := JoinPath(root, fmt.Sprint(i)) + create(child, []byte{i}) + } + + const foo = "foo" + create(JoinPath(JoinPath(root, "0"), foo), []byte(foo)) + + opResults, err := zk.MultiRead(ops...) + if err != nil { + t.Fatalf("MultiRead returned error: %+v", err) + } else if len(opResults) != len(ops) { + t.Fatalf("Expected %d responses got %d", len(ops), len(opResults)) + } + + nodeStats := map[string]*Stat{} + for k := range nodeData { + _, nodeStats[k], err = zk.Exists(k) + requireNoErrorf(t, err, "exists returned an error") + } + + for i, res := range opResults { + opPath := ops[i].GetPath() + switch op := ops[i].(type) { + case GetDataOp: + if res.Err != nil { + t.Fatalf("GetDataOp(%q) returned an error: %+v", op, res.Err) + } + if !bytes.Equal(res.Data, nodeData[opPath]) { + t.Fatalf("GetDataOp(%q).Data did not return %+v, got %+v", op, nodeData[opPath], res.Data) + } + if !reflect.DeepEqual(res.Stat, nodeStats[opPath]) { + t.Fatalf("GetDataOp(%q).Stat did not return %+v, got %+v", op, nodeStats[opPath], res.Stat) + } + case GetChildrenOp: + if res.Err != nil { + t.Fatalf("GetChildrenOp(%q) returned an error: %+v", opPath, res.Err) + } + // Cannot use DeepEqual here because it fails for []string{} == nil, even though in practice they are + // the same. + actual, expected := res.Children, nodeChildren[opPath] + if len(actual) != len(expected) { + t.Fatalf("GetChildrenOp(%q) did not return %+v, got %+v", opPath, expected, actual) + } + sort.Strings(actual) + sort.Strings(expected) + for i, c := range expected { + if actual[i] != c { + t.Fatalf("GetChildrenOp(%q) did not return %+v, got %+v", opPath, expected, actual) + } + } + } + } + + opResults, err = zk.MultiRead(GetDataOp("/invalid"), GetDataOp(root)) + requireNoErrorf(t, err, "MultiRead returned error") + + if opResults[0].Err != ErrNoNode { + t.Fatalf("MultiRead on invalid node did not return error") + } + if opResults[1].Err != nil { + t.Fatalf("MultiRead on valid node did not return error") + } + if !reflect.DeepEqual(opResults[1].Data, nodeData[root]) { + t.Fatalf("MultiRead on valid node did not return correct data") + } + }) +} + +func TestIntegration_GetDataAndChildren(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + + const path = "/test" + _, _, _, err := zk.GetDataAndChildren(path) + if err != ErrNoNode { + t.Fatalf("GetDataAndChildren(%q) did not return an error", path) + } + + create := func(path string, data []byte) { + if _, err := zk.Create(path, data, 0, nil); err != nil { + requireNoErrorf(t, err, "create returned an error") + } + } + expectedData := []byte{1, 2, 3, 4} + create(path, expectedData) + var expectedChildren []string + for i := 0; i < 10; i++ { + child := fmt.Sprint(i) + create(JoinPath(path, child), nil) + expectedChildren = append(expectedChildren, child) + } + + data, _, children, err := zk.GetDataAndChildren(path) + requireNoErrorf(t, err, "GetDataAndChildren return an error") + + if !bytes.Equal(data, expectedData) { + t.Fatalf("GetDataAndChildren(%q) did not return expected data (expected %v): %v", path, expectedData, data) + } + sort.Strings(children) + if !reflect.DeepEqual(children, expectedChildren) { + t.Fatalf("GetDataAndChildren(%q) did not return expected children (expected %v): %v", + path, expectedChildren, children) + } + }) +} + func TestIntegration_IfAuthdataSurvivesReconnect(t *testing.T) { // This test case ensures authentication data is being resubmited after // reconnect. @@ -452,6 +570,113 @@ func TestIntegration_IfAuthdataSurvivesReconnect(t *testing.T) { } } +func TestIntegration_PersistentWatchOnReconnect(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + + WithTestCluster(t, 10*time.Second, func(ts *TestCluster, zk *Conn) { + zk.reconnectLatch = make(chan struct{}) + + zk2, _, err := ts.ConnectAll() + if err != nil { + t.Fatalf("Connect returned error: %+v", err) + } + defer zk2.Close() + + const testNode = "/gozk-test" + + if err := zk.Delete(testNode, -1); err != nil && err != ErrNoNode { + t.Fatalf("Delete returned error: %+v", err) + } + + watchEventsQueue, err := zk.AddPersistentWatch(testNode, AddWatchModePersistent) + if err != nil { + t.Fatalf("AddPersistentWatch returned error: %+v", err) + } + + // wait for the initial EventWatching + waitForEvent(t, time.Second, watchEventsQueue, EventWatching) + + _, err = zk2.Create(testNode, []byte{1}, 0, WorldACL(PermAll)) + if err != nil { + t.Fatalf("Create returned an error: %+v", err) + } + + e := waitForEvent(t, time.Second, watchEventsQueue, EventNodeCreated) + if e.Path != testNode { + t.Fatalf("Event on persistent watch did not fire for expected node %q, got %q", testNode, e.Path) + } + + // Simulate network error by brutally closing the network connection. + zk.conn.Close() + + _, err = zk2.Set(testNode, []byte{2}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + t.Cleanup(cancel) + // zk should still be waiting to reconnect, so none of the watches should have been triggered + e, err = watchEventsQueue.Next(ctx) + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Persistent watcher for %q should not have triggered yet (%+v)", testNode, e) + } + + // now we let the reconnect occur and make sure it resets watches + close(zk.reconnectLatch) + + // wait for reconnect event + waitForEvent(t, 5*time.Second, watchEventsQueue, EventWatching) + + _, err = zk2.Set(testNode, []byte{3}, -1) + if err != nil { + t.Fatalf("Set returned error: %+v", err) + } + + waitForEvent(t, 1*time.Second, watchEventsQueue, EventNodeDataChanged) + }) +} + +func waitForEvent(t *testing.T, timeout time.Duration, ch EventQueue, expectedType EventType) Event { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + t.Cleanup(cancel) + e, err := ch.Next(ctx) + requireNoErrorf(t, err) + + if e.Type != expectedType { + t.Fatalf("Did not receive event of type %s, got %s instead", expectedType, e.Type) + } + + return e +} + +func TestIntegration_PersistentWatchOnClose(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + + WithTestCluster(t, 10*time.Second, func(_ *TestCluster, zk *Conn) { + ch, err := zk.AddPersistentWatch("/", AddWatchModePersistent) + requireNoErrorf(t, err, "could not add persistent watch") + + waitForEvent(t, 2*time.Second, ch, EventWatching) + zk.Close() + waitForEvent(t, 2*time.Second, ch, EventNotWatching) + }) +} + +func TestIntegration_PersistentWatchGetsPinged(t *testing.T) { + RequireMinimumZkVersion(t, "3.6") + + WithTestCluster(t, 60*time.Second, func(_ *TestCluster, zk *Conn) { + ch, err := zk.AddPersistentWatch("/", AddWatchModePersistent) + if err != nil { + t.Fatalf("Could not add persistent watch: %+v", err) + } + + waitForEvent(t, time.Minute, ch, EventWatching) + waitForEvent(t, time.Minute, ch, EventPingReceived) + }) +} + func TestIntegration_MultiFailures(t *testing.T) { // This test case ensures that we return the errors associated with each // opeThis in the event a call to Multi() fails. @@ -604,6 +829,7 @@ func TestIntegration_Auth(t *testing.T) { } } +// Tests that we correctly handle a response larger than the default buffer size func TestIntegration_Children(t *testing.T) { ts, err := StartTestCluster(t, 1, nil, logWriter{t: t, p: "[ZKERR] "}) if err != nil { @@ -622,38 +848,44 @@ func TestIntegration_Children(t *testing.T) { } } - deleteNode("/gozk-test-big") + testNode := "/gozk-test-big" + deleteNode(testNode) - if path, err := zk.Create("/gozk-test-big", []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { + if _, err := zk.Create(testNode, nil, 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create returned error: %+v", err) - } else if path != "/gozk-test-big" { - t.Fatalf("Create returned different path '%s' != '/gozk-test-big'", path) } - rb := make([]byte, 1000) - hb := make([]byte, 2000) - prefix := []byte("/gozk-test-big/") - for i := 0; i < 10000; i++ { - _, err := rand.Read(rb) - if err != nil { - t.Fatal("Cannot create random znode name") - } - hex.Encode(hb, rb) + const ( + nodesToCreate = 100 + // By creating many nodes with long names, the response from the Children call should be significantly longer + // than the buffer size, forcing recvLoop to allocate a bigger buffer + nameLength = 2 * bufferSize / nodesToCreate + ) - expect := string(append(prefix, hb...)) - if path, err := zk.Create(expect, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { + format := fmt.Sprintf("%%0%dd", nameLength) + if name := fmt.Sprintf(format, 0); len(name) != nameLength { + // Sanity check that the generated format string creates strings of the right length + t.Fatalf("Length of generated name was not %d, got %d", nameLength, len(name)) + } + + var createdNodes []string + for i := 0; i < nodesToCreate; i++ { + name := fmt.Sprintf(format, i) + createdNodes = append(createdNodes, name) + path := testNode + "/" + name + if _, err := zk.Create(path, nil, 0, WorldACL(PermAll)); err != nil { t.Fatalf("Create returned error: %+v", err) - } else if path != expect { - t.Fatalf("Create returned different path '%s' != '%s'", path, expect) } - defer deleteNode(string(expect)) + defer deleteNode(path) } - children, _, err := zk.Children("/gozk-test-big") + children, _, err := zk.Children(testNode) if err != nil { t.Fatalf("Children returned error: %+v", err) - } else if len(children) != 10000 { - t.Fatal("Children returned wrong number of nodes") + } + sort.Strings(children) + if !reflect.DeepEqual(children, createdNodes) { + t.Fatal("Children did not return expected nodes") } } @@ -765,10 +997,16 @@ func TestIntegration_SetWatchers(t *testing.T) { } }() - // we create lots of paths to watch, to make sure a "set watches" request - // on re-create will be too big and be required to span multiple packets - for i := 0; i < 1000; i++ { - testPath, err := zk.Create(fmt.Sprintf("/gozk-test-%d", i), []byte{}, 0, WorldACL(PermAll)) + // we create lots of long paths to watch, to make sure a "set watches" request on will be too big and be broken + // into multiple packets. The size is chosen such that each packet can hold exactly 2 watches, meaning we should + // see half as many packets as there are watches. + const ( + watches = 50 + watchedNodeNameFormat = "/gozk-test-%0450d" + ) + + for i := 0; i < watches; i++ { + testPath, err := zk.Create(fmt.Sprintf(watchedNodeNameFormat, i), []byte{}, 0, WorldACL(PermAll)) if err != nil { t.Fatalf("Create returned: %+v", err) } @@ -852,9 +1090,9 @@ func TestIntegration_SetWatchers(t *testing.T) { buf := make([]byte, bufferSize) totalWatches := 0 actualReqs := setWatchReqs.Load().([]*setWatchesRequest) - if len(actualReqs) < 12 { - // sanity check: we should have generated *at least* 12 requests to reset watches - t.Fatalf("too few setWatchesRequest messages: %d", len(actualReqs)) + if len(actualReqs) != watches/2 { + // sanity check: we should have generated exactly 25 requests to reset watches + t.Fatalf("Did not send exactly %d setWatches requests, got %d instead", watches/2, len(actualReqs)) } for _, r := range actualReqs { totalWatches += len(r.ChildWatches) + len(r.DataWatches) + len(r.ExistWatches) @@ -1014,15 +1252,13 @@ func TestIntegration_MaxBufferSize(t *testing.T) { defer ts.Stop() // no buffer size zk, _, err := ts.ConnectWithOptions(15 * time.Second) - var l testLogger + //var l testLogger if err != nil { t.Fatalf("Connect returned error: %+v", err) } defer zk.Close() // 1k buffer size, logs to custom test logger - zkLimited, _, err := ts.ConnectWithOptions(15*time.Second, WithMaxBufferSize(1024), func(conn *Conn) { - conn.SetLogger(&l) - }) + zkLimited, _, err := ts.ConnectWithOptions(15*time.Second, WithMaxBufferSize(1024)) if err != nil { t.Fatalf("Connect returned error: %+v", err) } @@ -1071,11 +1307,8 @@ func TestIntegration_MaxBufferSize(t *testing.T) { t.Fatalf("Create returned error: %+v", err) } _, _, err = zkLimited.Get("/bar") - // NB: Sadly, without actually de-serializing the too-large response packet, we can't send the - // right error to the corresponding outstanding request. So the request just sees ErrConnectionClosed - // while the log will see the actual reason the connection was closed. expectErr(t, err, ErrConnectionClosed) - expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024") + expectErr(t, err, ErrResponseBufferSizeExceeded) // Or with large number of children... totalLen := 0 @@ -1092,7 +1325,7 @@ func TestIntegration_MaxBufferSize(t *testing.T) { sort.Strings(children) _, _, err = zkLimited.Children("/bar") expectErr(t, err, ErrConnectionClosed) - expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024") + expectErr(t, err, ErrResponseBufferSizeExceeded) // Other client (without buffer size limit) can successfully query the node and its children, of course resultData, _, err = zk.Get("/bar") @@ -1174,47 +1407,7 @@ func expectErr(t *testing.T, err error, expected error) { if err == nil { t.Fatalf("Get for node that is too large should have returned error!") } - if err != expected { + if !errors.Is(err, expected) { t.Fatalf("Get returned wrong error; expecting ErrClosing, got %+v", err) } } - -func expectLogMessage(t *testing.T, logger *testLogger, pattern string) { - re := regexp.MustCompile(pattern) - events := logger.Reset() - if len(events) == 0 { - t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern) - } - var found []string - for _, e := range events { - if re.Match([]byte(e)) { - found = append(found, e) - } - } - if len(found) == 0 { - t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern) - } else if len(found) > 1 { - t.Fatalf("Logged error redundantly %d times:\n%+v", len(found), found) - } -} - -type testLogger struct { - mu sync.Mutex - events []string -} - -func (l *testLogger) Printf(msgFormat string, args ...interface{}) { - msg := fmt.Sprintf(msgFormat, args...) - fmt.Println(msg) - l.mu.Lock() - defer l.mu.Unlock() - l.events = append(l.events, msg) -} - -func (l *testLogger) Reset() []string { - l.mu.Lock() - defer l.mu.Unlock() - ret := l.events - l.events = nil - return ret -}