diff --git a/cluster_test.go b/cluster_test.go index f54883e5..20ec1898 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -1,6 +1,7 @@ package zk import ( + "log/slog" "sync" "testing" "time" @@ -112,10 +113,10 @@ func TestNoQuorum(t *testing.T) { t.Fatalf("Failed to connect and get session") } initialSessionID := zk.sessionID - DefaultLogger.Printf(" Session established: id=%d, timeout=%d", zk.sessionID, zk.sessionTimeoutMs) + t.Logf("Session established: id=%d, timeout=%d", zk.sessionID, zk.sessionTimeout) // Kill the ZooKeeper leader and wait for the session to reconnect. - DefaultLogger.Printf(" Kill the leader") + t.Logf("Kill the leader") disconnectWatcher1 := sl.NewWatcher(sessionStateMatcher(StateDisconnected)) hasSessionWatcher2 := sl.NewWatcher(sessionStateMatcher(StateHasSession)) tc.StopServer(hasSessionEvent1.Server) @@ -135,7 +136,7 @@ func TestNoQuorum(t *testing.T) { } // Kill the ZooKeeper leader leaving the cluster without quorum. - DefaultLogger.Printf(" Kill the leader") + t.Logf("Kill the leader") disconnectWatcher2 := sl.NewWatcher(sessionStateMatcher(StateDisconnected)) tc.StopServer(hasSessionEvent2.Server) @@ -151,7 +152,7 @@ func TestNoQuorum(t *testing.T) { // Make sure that we keep retrying connecting to the only remaining // ZooKeeper server, but the attempts are being dropped because there is // no quorum. - DefaultLogger.Printf(" Retrying no luck...") + t.Logf("Retrying no luck...") var firstDisconnect *Event begin := time.Now() for time.Now().Sub(begin) < 6*time.Second { @@ -269,7 +270,7 @@ func NewStateLogger(eventCh <-chan Event) *EventLogger { sw.matchCh <- event } } - DefaultLogger.Printf(" event received: %v\n", event) + slog.Info("event received", "event", event) el.events = append(el.events, event) el.lock.Unlock() } diff --git a/conn.go b/conn.go index 9d880e36..e0c7df3d 100644 --- a/conn.go +++ b/conn.go @@ -16,6 +16,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "strings" "sync" @@ -31,9 +32,6 @@ var ErrNoServer = errors.New("zk: could not connect to a server") // an invalid path. (e.g. empty path). var ErrInvalidPath = errors.New("zk: invalid path") -// DefaultLogger uses the stdlib log package for logging. -var DefaultLogger Logger = defaultLogger{} - const ( bufferSize = 1536 * 1024 eventChanSize = 6 @@ -60,14 +58,6 @@ type watchPathType struct { wType watchType } -// Dialer is a function to be used to establish a connection to a single host. -type Dialer func(network, address string, timeout time.Duration) (net.Conn, error) - -// Logger is an interface that can be implemented to provide custom log output. -type Logger interface { - Printf(string, ...interface{}) -} - type authCreds struct { scheme string auth []byte @@ -75,12 +65,12 @@ type authCreds struct { // Conn is the client connection and tracks all details for communication with the server. type Conn struct { - lastZxid int64 - sessionID int64 - state State // must be 32-bit aligned - xid uint32 - sessionTimeoutMs int32 // session timeout in milliseconds - passwd []byte + lastZxid int64 + sessionID int64 + state State // must be 32-bit aligned + xid uint32 + sessionTimeout time.Duration + passwd []byte dialer Dialer hostProvider HostProvider @@ -116,14 +106,19 @@ type Conn struct { debugCloseRecvLoop bool resendZkAuthFn func(context.Context, *Conn) error - logger Logger - logInfo bool // true if information messages are logged; false if only errors are logged - buf []byte } -// connOption represents a connection option. -type connOption func(c *Conn) +// ConnOption represents a connection option. +type ConnOption interface { + apply(*Conn) +} + +type connOption func(*Conn) + +func (c connOption) apply(conn *Conn) { + c(conn) +} type request struct { xid int32 @@ -181,20 +176,13 @@ type HostProvider interface { Connected() } -// ConnectWithDialer establishes a new connection to a pool of zookeeper servers -// using a custom Dialer. See Connect for further information about session timeout. -// This method is deprecated and provided for compatibility: use the WithDialer option instead. -func ConnectWithDialer(servers []string, sessionTimeout time.Duration, dialer Dialer) (*Conn, <-chan Event, error) { - return Connect(servers, sessionTimeout, WithDialer(dialer)) -} - // Connect establishes a new connection to a pool of zookeeper // servers. The provided session timeout sets the amount of time for which // a session is considered valid after losing connection to a server. Within // the session timeout it's possible to reestablish a connection to a different // server and keep the same session. This is means any ephemeral nodes and // watches are maintained. -func Connect(servers []string, sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { +func Connect(servers []string, sessionTimeout time.Duration, options ...ConnOption) (*Conn, <-chan Event, error) { if len(servers) == 0 { return nil, nil, errors.New("zk: server list must not be empty") } @@ -206,7 +194,7 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti ec := make(chan Event, eventChanSize) conn := &Conn{ - dialer: net.DialTimeout, + dialer: new(net.Dialer), hostProvider: new(StaticHostProvider), conn: nil, state: StateDisconnected, @@ -217,8 +205,6 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti requests: make(map[int32]*request), watchers: make(map[watchPathType][]EventQueue), passwd: emptyPassword, - logger: DefaultLogger, - logInfo: true, // default is true for backwards compatability buf: make([]byte, bufferSize), resendZkAuthFn: resendZkAuth, metricReceiver: UnimplementedMetricReceiver{}, @@ -226,14 +212,14 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti // Set provided options. for _, option := range options { - option(conn) + option.apply(conn) } if err := conn.hostProvider.Init(srvs); err != nil { return nil, nil, err } - conn.setTimeouts(int32(sessionTimeout / time.Millisecond)) + conn.setTimeouts(sessionTimeout) // TODO: This context should be passed in by the caller to be the connection lifecycle context. ctx := context.Background() @@ -246,33 +232,25 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti return conn, ec, nil } -// WithDialer returns a connection option specifying a non-default Dialer. -func WithDialer(dialer Dialer) connOption { - return func(c *Conn) { +// Dialer is an interface implemented by the standard [net.Dialer] but also by [crypto/tls.Dialer]. +type Dialer interface { + // DialContext will be invoked when connecting to a ZK host. + DialContext(ctx context.Context, network, address string) (net.Conn, error) +} + +// WithDialer returns a connection option specifying a non-default Dialer. This can be used, for +// example, to enable TLS when connecting to the ZK server by passing in a [crypto/tls.Dialer]. +func WithDialer(dialer Dialer) ConnOption { + return connOption(func(c *Conn) { c.dialer = dialer - } + }) } // WithHostProvider returns a connection option specifying a non-default HostProvider. -func WithHostProvider(hostProvider HostProvider) connOption { - return func(c *Conn) { +func WithHostProvider(hostProvider HostProvider) ConnOption { + return connOption(func(c *Conn) { c.hostProvider = hostProvider - } -} - -// WithLogger returns a connection option specifying a non-default Logger. -func WithLogger(logger Logger) connOption { - return func(c *Conn) { - c.logger = logger - } -} - -// WithLogInfo returns a connection option specifying whether or not information messages -// should be logged. -func WithLogInfo(logInfo bool) connOption { - return func(c *Conn) { - c.logInfo = logInfo - } + }) } // EventCallback is a function that is called when an Event occurs. @@ -281,10 +259,10 @@ type EventCallback func(Event) // WithEventCallback returns a connection option that specifies an event // callback. // The callback must not block - doing so would delay the ZK go routines. -func WithEventCallback(cb EventCallback) connOption { - return func(c *Conn) { +func WithEventCallback(cb EventCallback) ConnOption { + return connOption(func(c *Conn) { c.eventCallback = cb - } + }) } // WithMaxBufferSize sets the maximum buffer size used to read and decode @@ -311,26 +289,26 @@ func WithEventCallback(cb EventCallback) connOption { // the child names without an increased buffer size in the client, but they work // by inspecting the servers' transaction logs to enumerate children instead of // sending an online request to a server. -func WithMaxBufferSize(maxBufferSize int) connOption { - return func(c *Conn) { +func WithMaxBufferSize(maxBufferSize int) ConnOption { + return connOption(func(c *Conn) { c.maxBufferSize = maxBufferSize - } + }) } // WithMaxConnBufferSize sets maximum buffer size used to send and encode // packets to Zookeeper server. The standard Zookeeper client for java defaults // to a limit of 1mb. This option should be used for non-standard server setup // where znode is bigger than default 1mb. -func WithMaxConnBufferSize(maxBufferSize int) connOption { - return func(c *Conn) { +func WithMaxConnBufferSize(maxBufferSize int) ConnOption { + return connOption(func(c *Conn) { c.buf = make([]byte, maxBufferSize) - } + }) } -func WithMetricReceiver(mr MetricReceiver) connOption { - return func(c *Conn) { +func WithMetricReceiver(mr MetricReceiver) ConnOption { + return connOption(func(c *Conn) { c.metricReceiver = mr - } + }) } // Close will submit a close request with ZK and signal the connection to stop @@ -356,15 +334,8 @@ func (c *Conn) SessionID() int64 { return atomic.LoadInt64(&c.sessionID) } -// SetLogger sets the logger to be used for printing errors. -// Logger is an interface provided by this package. -func (c *Conn) SetLogger(l Logger) { - c.logger = l -} - -func (c *Conn) setTimeouts(sessionTimeoutMs int32) { - c.sessionTimeoutMs = sessionTimeoutMs - sessionTimeout := time.Duration(sessionTimeoutMs) * time.Millisecond +func (c *Conn) setTimeouts(sessionTimeout time.Duration) { + c.sessionTimeout = sessionTimeout c.recvTimeout = sessionTimeout * 2 / 3 c.pingInterval = c.recvTimeout / 2 } @@ -395,17 +366,22 @@ func (c *Conn) connect() (err error) { c.setState(StateConnecting) - zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) + dial := func() (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), c.connectTimeout) + defer cancel() + return c.dialer.DialContext(ctx, "tcp", c.Server()) + } + + slog.Info("Dialing ZK server", "server", c.Server()) + zkConn, err := dial() if err == nil { c.conn = zkConn c.setState(StateConnected) - if c.logInfo { - c.logger.Printf("connected to %s", c.Server()) - } + slog.Info("Connection established", "server", c.Server(), "addr", zkConn.RemoteAddr()) return nil } - c.logger.Printf("failed to connect to %s: %v", c.Server(), err) + slog.Warn("Failed to connect to ZK server", "server", c.Server(), "err", err) if retryStart { c.flushUnsentRequests(ErrNoServer) @@ -453,18 +429,18 @@ func (c *Conn) loop(ctx context.Context) { return } - err := c.authenticate() + var sendLoopErr, recvLoopErr error + + authErr := c.authenticate() switch { - case err == ErrSessionExpired: - c.logger.Printf("authentication failed: %s", err) - c.invalidateWatches(err) - case err != nil && c.conn != nil: - c.logger.Printf("authentication failed: %s", err) + case errors.Is(authErr, ErrSessionExpired): + slog.Warn("authentication failed", "err", authErr) + c.invalidateWatches(authErr) + case authErr != nil && c.conn != nil: + slog.Warn("authentication failed", "err", authErr) c.conn.Close() - case err == nil: - if c.logInfo { - c.logger.Printf("authenticated: id=%d, timeout=%d", c.SessionID(), c.sessionTimeoutMs) - } + case authErr == nil: + slog.Info("authenticated", "sessionId", c.SessionID(), "timeout", c.sessionTimeout) c.hostProvider.Connected() // mark success c.closeChan = make(chan struct{}) // channel to tell send loop stop @@ -475,13 +451,15 @@ func (c *Conn) loop(ctx context.Context) { defer c.conn.Close() // causes recv loop to EOF/exit defer wg.Done() - if err := c.resendZkAuthFn(ctx, c); err != nil { - c.logger.Printf("error in resending auth creds: %v", err) + if sendLoopErr = c.resendZkAuthFn(ctx, c); sendLoopErr != nil { + slog.Warn("error in resending auth creds", "err", sendLoopErr) return } - if err := c.sendLoop(); err != nil || c.logInfo { - c.logger.Printf("send loop terminated: %v", err) + if sendLoopErr = c.sendLoop(); sendLoopErr != nil { + slog.Warn("Send loop terminated with error", "err", sendLoopErr) + } else { + slog.Info("Send loop terminated") } }() @@ -490,16 +468,22 @@ func (c *Conn) loop(ctx context.Context) { defer close(c.closeChan) // tell send loop to exit defer wg.Done() - var err error if c.debugCloseRecvLoop { - err = errors.New("DEBUG: close recv loop") + recvLoopErr = errors.New("DEBUG: close recv loop") } else { - err = c.recvLoop(c.conn) + recvLoopErr = c.recvLoop(c.conn) + } + + switch { + case errors.Is(recvLoopErr, io.EOF): + slog.Info("recv loop terminated") + } - if err != io.EOF || c.logInfo { - c.logger.Printf("recv loop terminated: %v", err) + if recvLoopErr != io.EOF { + slog.Warn("recv loop terminated with error", "err", recvLoopErr) + } else { } - if err == nil { + if recvLoopErr == nil { panic("zk: recvLoop should never return nil error") } }() @@ -517,8 +501,14 @@ func (c *Conn) loop(ctx context.Context) { default: } - if err != ErrSessionExpired { - err = ErrConnectionClosed + // Surface an error that contains all errors that could have caused the connection to get closed + err := errors.Join(authErr, sendLoopErr, recvLoopErr) + if !errors.Is(err, ErrSessionExpired) || errors.Is(err, ErrConnectionClosed) { + // Always default to ErrConnectionClosed for any error that doesn't already have it or + // ErrSessionExpired as a cause, makes error handling more straightforward. Note that by definition, + // reaching this point in the code means the connection was closed, hence using it as the default + // value. + err = errors.Join(ErrConnectionClosed, err) } c.flushRequests(err) @@ -613,7 +603,7 @@ func (c *Conn) invalidateWatches(err error) { if len(c.watchers) > 0 { for pathType, watchers := range c.watchers { - if err == ErrSessionExpired && pathType.wType.isPersistent() { + if errors.Is(err, ErrSessionExpired) && pathType.wType.isPersistent() { // Ignore ErrSessionExpired for persistent watchers as the client will either automatically reconnect, // or this is a shutdown-worthy error in which case there will be a followup invocation of this method // with ErrClosing @@ -639,6 +629,8 @@ func (c *Conn) sendSetWatches() { return } + slog.Info("Resetting watches after reconnect", "watchCount", len(c.watchers)) + // NB: A ZK server, by default, rejects packets >1mb. So, if we have too // many watches to reset, we need to break this up into multiple packets // to avoid hitting that limit. Mirroring the Java client behavior: we are @@ -733,7 +725,7 @@ func (c *Conn) sendSetWatches() { } }) if err != nil { - c.logger.Printf("Failed to set previous watches: %v", err) + slog.Warn("Failed to set previous watches", "err", err) break } } @@ -747,9 +739,10 @@ func (c *Conn) authenticate() error { n, err := encodePacket(buf[4:], &connectRequest{ ProtocolVersion: protocolVersion, LastZxidSeen: c.lastZxid, - TimeOut: c.sessionTimeoutMs, - SessionID: c.SessionID(), - Passwd: c.passwd, + // The timeout in the connect request is milliseconds + TimeOut: int32(c.sessionTimeout / time.Millisecond), + SessionID: c.SessionID(), + Passwd: c.passwd, }) if err != nil { return err @@ -796,7 +789,7 @@ func (c *Conn) authenticate() error { } atomic.StoreInt64(&c.sessionID, r.SessionID) - c.setTimeouts(r.TimeOut) + c.setTimeouts(time.Duration(r.TimeOut) * time.Millisecond) c.passwd = r.Passwd c.setState(StateHasSession) @@ -885,17 +878,20 @@ func (c *Conn) recvLoop(conn net.Conn) error { for { // package length if err := conn.SetReadDeadline(time.Now().Add(c.recvTimeout)); err != nil { - c.logger.Printf("failed to set connection deadline: %v", err) + slog.Warn("failed to set connection deadline", "err", err) } _, err := io.ReadFull(conn, buf[:4]) if err != nil { - return fmt.Errorf("failed to read from connection: %v", err) + return fmt.Errorf("failed to read from connection: %w", err) } blen := int(binary.BigEndian.Uint32(buf[:4])) if cap(buf) < blen { if c.maxBufferSize > 0 && blen > c.maxBufferSize { - return fmt.Errorf("received packet from server with length %d, which exceeds max buffer size %d", blen, c.maxBufferSize) + return fmt.Errorf( + "%w (packet size %d, max buffer size: %d)", + ErrResponseBufferSizeExceeded, blen, c.maxBufferSize, + ) } buf = make([]byte, blen) } @@ -937,7 +933,7 @@ func (c *Conn) recvLoop(conn net.Conn) error { Timestamp: time.Now(), }) } else if res.Xid < 0 { - c.logger.Printf("Xid < 0 (%d) but not ping or watcher event", res.Xid) + slog.Warn("Xid < 0 but not ping or watcher event", "Xid", res.Xid) } else { if res.Zxid > 0 { c.lastZxid = res.Zxid @@ -951,7 +947,7 @@ func (c *Conn) recvLoop(conn net.Conn) error { c.requestsLock.Unlock() if !ok { - c.logger.Printf("Response for unknown request with xid %d", res.Xid) + slog.Warn("Ignoring response for unknown request", "xid", res.Xid) } else { if res.Err != 0 { err = res.Err.toError() @@ -1001,7 +997,7 @@ func (c *Conn) queueRequest(opcode int32, req interface{}, res interface{}, recv select { case c.sendChan <- rq: case <-time.After(c.connectTimeout * 2): - c.logger.Printf("gave up trying to send opClose to server") + slog.Warn("gave up trying to send opClose to server") rq.recvChan <- response{-1, ErrConnectionClosed} } default: @@ -1078,7 +1074,7 @@ func (c *Conn) Children(path string) ([]string, *Stat, error) { res := &getChildren2Response{} _, err := c.request(opGetChildren2, &getChildren2Request{Path: path, Watch: false}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, nil, err } return res.Children, res.Stat, err @@ -1112,7 +1108,7 @@ func (c *Conn) Get(path string) ([]byte, *Stat, error) { res := &getDataResponse{} _, err := c.request(opGetData, &getDataRequest{Path: path, Watch: false}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, nil, err } return res.Data, res.Stat, err @@ -1146,7 +1142,7 @@ func (c *Conn) Set(path string, data []byte, version int32) (*Stat, error) { res := &setDataResponse{} _, err := c.request(opSetData, &SetDataRequest{path, data, version}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } return res.Stat, err @@ -1167,7 +1163,7 @@ func (c *Conn) Create(path string, data []byte, flags int32, acl []ACL) (string, res := &createResponse{} _, err := c.request(opCreate, &CreateRequest{path, data, acl, flags}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return "", err } return res.Path, err @@ -1185,7 +1181,7 @@ func (c *Conn) CreateAndReturnStat(path string, data []byte, flags int32, acl [] res := &create2Response{} _, err := c.request(opCreate2, &CreateRequest{path, data, acl, flags}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return "", nil, err } return res.Path, res.Stat, err @@ -1243,10 +1239,10 @@ func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl var newPath string for i := 0; i < 3; i++ { newPath, err = c.Create(protectedPath, data, FlagEphemeral|FlagSequence, acl) - switch err { - case ErrSessionExpired: + switch { + case errors.Is(err, ErrSessionExpired): // No need to search for the node since it can't exist. Just try again. - case ErrConnectionClosed: + case errors.Is(err, ErrConnectionClosed): children, _, err := c.Children(rootPath) if err != nil { return "", err @@ -1259,7 +1255,7 @@ func (c *Conn) CreateProtectedEphemeralSequential(path string, data []byte, acl } } } - case nil: + case err == nil: return newPath, nil default: return "", err @@ -1286,11 +1282,11 @@ func (c *Conn) Exists(path string) (bool, *Stat, error) { res := &existsResponse{} _, err := c.request(opExists, &existsRequest{Path: path, Watch: false}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return false, nil, err } exists := true - if err == ErrNoNode { + if errors.Is(err, ErrNoNode) { exists = false err = nil } @@ -1309,12 +1305,12 @@ func (c *Conn) ExistsW(path string) (bool, *Stat, <-chan Event, error) { ech = newChanEventChannel() if err == nil { c.addWatcher(path, watchTypeData, ech) - } else if err == ErrNoNode { + } else if errors.Is(err, ErrNoNode) { c.addWatcher(path, watchTypeExist, ech) } }) exists := true - if err == ErrNoNode { + if errors.Is(err, ErrNoNode) { exists = false err = nil } @@ -1332,7 +1328,7 @@ func (c *Conn) GetACL(path string) ([]ACL, *Stat, error) { res := &getAclResponse{} _, err := c.request(opGetAcl, &getAclRequest{Path: path}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, nil, err } return res.Acl, res.Stat, err @@ -1346,7 +1342,7 @@ func (c *Conn) SetACL(path string, acl []ACL, version int32) (*Stat, error) { res := &setAclResponse{} _, err := c.request(opSetAcl, &setAclRequest{Path: path, Acl: acl, Version: version}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } return res.Stat, err @@ -1362,7 +1358,7 @@ func (c *Conn) Sync(path string) (string, error) { res := &syncResponse{} _, err := c.request(opSync, &syncRequest{Path: path}, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return "", err } return res.Path, err @@ -1400,7 +1396,7 @@ func (c *Conn) Multi(ops ...interface{}) ([]MultiResponse, error) { } res := &multiResponse{} _, err := c.request(opMulti, req, res, nil) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } mr := make([]MultiResponse, len(res.Ops)) @@ -1512,7 +1508,7 @@ func (c *Conn) AddPersistentWatch(path string, mode AddWatchMode) (ch EventQueue c.addWatcher(path, wt, ch) } }) - if err == ErrConnectionClosed { + if errors.Is(err, ErrConnectionClosed) { return nil, err } return ch, err @@ -1599,9 +1595,7 @@ func resendZkAuth(ctx context.Context, c *Conn) error { c.credsMu.Lock() defer c.credsMu.Unlock() - if c.logInfo { - c.logger.Printf("re-submitting `%d` credentials after reconnect", len(c.creds)) - } + slog.Info("re-submitting credentials after reconnect", "credentialCount", len(c.creds)) for _, cred := range c.creds { // return early before attempting to send request. @@ -1620,23 +1614,23 @@ func resendZkAuth(ctx context.Context, c *Conn) error { nil, /* recvFunc*/ ) if err != nil { - return fmt.Errorf("failed to send auth request: %v", err) + return fmt.Errorf("failed to send auth request: %w", err) } var res response select { case res = <-resChan: case <-c.closeChan: - c.logger.Printf("recv closed, cancel re-submitting credentials") + slog.Info("recv closed, cancel re-submitting credentials") return nil case <-c.shouldQuit: - c.logger.Printf("should quit, cancel re-submitting credentials") + slog.Warn("Connection closing, cancel re-submitting credentials") return nil case <-ctx.Done(): return ctx.Err() } if res.err != nil { - return fmt.Errorf("failed connection setAuth request: %v", res.err) + return fmt.Errorf("failed connection setAuth request: %w", res.err) } } @@ -1662,3 +1656,9 @@ func SplitPath(path string) (dir, name string) { } return dir, name } + +// IsConnClosing checks if the given error returned by a Conn function is due to Close having been +// called on the Conn, which is fatal in that no subsequent Conn operations will succeed. +func IsConnClosing(err error) bool { + return errors.Is(err, ErrClosing) +} diff --git a/conn_test.go b/conn_test.go index 6e2a2d95..6b96b9d6 100644 --- a/conn_test.go +++ b/conn_test.go @@ -93,7 +93,6 @@ func TestDeadlockInClose(t *testing.T) { shouldQuit: make(chan struct{}), connectTimeout: 1 * time.Second, sendChan: make(chan *request, sendChanSize), - logger: DefaultLogger, } for i := 0; i < sendChanSize; i++ { diff --git a/constants.go b/constants.go index 7162afbb..d2e44e04 100644 --- a/constants.go +++ b/constants.go @@ -121,26 +121,27 @@ type ErrCode int32 var ( // ErrConnectionClosed means the connection has been closed. - ErrConnectionClosed = errors.New("zk: connection closed") - ErrUnknown = errors.New("zk: unknown error") - ErrAPIError = errors.New("zk: api error") - ErrNoNode = errors.New("zk: node does not exist") - ErrNoAuth = errors.New("zk: not authenticated") - ErrBadVersion = errors.New("zk: version conflict") - ErrNoChildrenForEphemerals = errors.New("zk: ephemeral nodes may not have children") - ErrNodeExists = errors.New("zk: node already exists") - ErrNotEmpty = errors.New("zk: node has children") - ErrSessionExpired = errors.New("zk: session has been expired by the server") - ErrInvalidACL = errors.New("zk: invalid ACL specified") - ErrInvalidFlags = errors.New("zk: invalid flags specified") - ErrAuthFailed = errors.New("zk: client authentication failed") - ErrClosing = errors.New("zk: zookeeper is closing") - ErrNothing = errors.New("zk: no server responses to process") - ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") - ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") - ErrBadArguments = errors.New("invalid arguments") - ErrNoWatcher = errors.New("zk: no such watcher") - ErrUnimplemented = errors.New("zk: Not implemented") + ErrConnectionClosed = errors.New("zk: connection closed") + ErrUnknown = errors.New("zk: unknown error") + ErrAPIError = errors.New("zk: api error") + ErrNoNode = errors.New("zk: node does not exist") + ErrNoAuth = errors.New("zk: not authenticated") + ErrBadVersion = errors.New("zk: version conflict") + ErrNoChildrenForEphemerals = errors.New("zk: ephemeral nodes may not have children") + ErrNodeExists = errors.New("zk: node already exists") + ErrNotEmpty = errors.New("zk: node has children") + ErrSessionExpired = errors.New("zk: session has been expired by the server") + ErrInvalidACL = errors.New("zk: invalid ACL specified") + ErrInvalidFlags = errors.New("zk: invalid flags specified") + ErrAuthFailed = errors.New("zk: client authentication failed") + ErrClosing = errors.New("zk: zookeeper is closing") + ErrNothing = errors.New("zk: no server responses to process") + ErrSessionMoved = errors.New("zk: session moved to another server, so operation is ignored") + ErrReconfigDisabled = errors.New("attempts to perform a reconfiguration operation when reconfiguration feature is disabled") + ErrBadArguments = errors.New("invalid arguments") + ErrNoWatcher = errors.New("zk: no such watcher") + ErrUnimplemented = errors.New("zk: Not implemented") + ErrResponseBufferSizeExceeded = errors.New("zk: server response exceeds max buffer size") // ErrInvalidCallback = errors.New("zk: invalid callback specified") errCodeToError = map[ErrCode]error{ diff --git a/dnshostprovider_test.go b/dnshostprovider_test.go index 00bdea80..61a113a5 100644 --- a/dnshostprovider_test.go +++ b/dnshostprovider_test.go @@ -1,6 +1,7 @@ package zk import ( + "errors" "fmt" "log" "testing" @@ -33,7 +34,7 @@ func TestDNSHostProviderCreate(t *testing.T) { path := "/gozk-test" - if err := zk.Delete(path, -1); err != nil && err != ErrNoNode { + if err := zk.Delete(path, -1); err != nil && !errors.Is(err, ErrNoNode) { t.Fatalf("Delete returned error: %+v", err) } if p, err := zk.Create(path, []byte{1, 2, 3, 4}, 0, WorldACL(PermAll)); err != nil { diff --git a/server_help_test.go b/server_help_test.go index 4ea1e7d0..b2e4adea 100644 --- a/server_help_test.go +++ b/server_help_test.go @@ -171,7 +171,7 @@ func (tc *TestCluster) ConnectAllTimeout(sessionTimeout time.Duration) (*Conn, < return tc.ConnectWithOptions(sessionTimeout) } -func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...connOption) (*Conn, <-chan Event, error) { +func (tc *TestCluster) ConnectWithOptions(sessionTimeout time.Duration, options ...ConnOption) (*Conn, <-chan Event, error) { hosts := make([]string, len(tc.Servers)) for i, srv := range tc.Servers { hosts[i] = fmt.Sprintf("127.0.0.1:%d", srv.Port) diff --git a/structs.go b/structs.go index 66c10463..523050a4 100644 --- a/structs.go +++ b/structs.go @@ -3,8 +3,6 @@ package zk import ( "encoding/binary" "errors" - "fmt" - "log" "reflect" "runtime" "strings" @@ -17,12 +15,6 @@ var ( ErrShortBuffer = errors.New("zk: buffer too small") ) -type defaultLogger struct{} - -func (defaultLogger) Printf(format string, v ...interface{}) { - log.Output(3, fmt.Sprintf(format, v...)) -} - type ACL struct { Perms int32 Scheme string diff --git a/zk_test.go b/zk_test.go index c5918f32..75c69ecb 100644 --- a/zk_test.go +++ b/zk_test.go @@ -12,10 +12,8 @@ import ( "os" "path/filepath" "reflect" - "regexp" "sort" "strings" - "sync" "sync/atomic" "testing" "time" @@ -1249,15 +1247,13 @@ func TestMaxBufferSize(t *testing.T) { defer ts.Stop() // no buffer size zk, _, err := ts.ConnectWithOptions(15 * time.Second) - var l testLogger + //var l testLogger if err != nil { t.Fatalf("Connect returned error: %+v", err) } defer zk.Close() // 1k buffer size, logs to custom test logger - zkLimited, _, err := ts.ConnectWithOptions(15*time.Second, WithMaxBufferSize(1024), func(conn *Conn) { - conn.SetLogger(&l) - }) + zkLimited, _, err := ts.ConnectWithOptions(15*time.Second, WithMaxBufferSize(1024)) if err != nil { t.Fatalf("Connect returned error: %+v", err) } @@ -1306,11 +1302,8 @@ func TestMaxBufferSize(t *testing.T) { t.Fatalf("Create returned error: %+v", err) } _, _, err = zkLimited.Get("/bar") - // NB: Sadly, without actually de-serializing the too-large response packet, we can't send the - // right error to the corresponding outstanding request. So the request just sees ErrConnectionClosed - // while the log will see the actual reason the connection was closed. expectErr(t, err, ErrConnectionClosed) - expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024") + expectErr(t, err, ErrResponseBufferSizeExceeded) // Or with large number of children... totalLen := 0 @@ -1327,7 +1320,7 @@ func TestMaxBufferSize(t *testing.T) { sort.Strings(children) _, _, err = zkLimited.Children("/bar") expectErr(t, err, ErrConnectionClosed) - expectLogMessage(t, &l, "received packet from server with length .*, which exceeds max buffer size 1024") + expectErr(t, err, ErrResponseBufferSizeExceeded) // Other client (without buffer size limit) can successfully query the node and its children, of course resultData, _, err = zk.Get("/bar") @@ -1409,47 +1402,7 @@ func expectErr(t *testing.T, err error, expected error) { if err == nil { t.Fatalf("Get for node that is too large should have returned error!") } - if err != expected { + if !errors.Is(err, expected) { t.Fatalf("Get returned wrong error; expecting ErrClosing, got %+v", err) } } - -func expectLogMessage(t *testing.T, logger *testLogger, pattern string) { - re := regexp.MustCompile(pattern) - events := logger.Reset() - if len(events) == 0 { - t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern) - } - var found []string - for _, e := range events { - if re.Match([]byte(e)) { - found = append(found, e) - } - } - if len(found) == 0 { - t.Fatalf("Failed to log error; expecting message that matches pattern: %s", pattern) - } else if len(found) > 1 { - t.Fatalf("Logged error redundantly %d times:\n%+v", len(found), found) - } -} - -type testLogger struct { - mu sync.Mutex - events []string -} - -func (l *testLogger) Printf(msgFormat string, args ...interface{}) { - msg := fmt.Sprintf(msgFormat, args...) - fmt.Println(msg) - l.mu.Lock() - defer l.mu.Unlock() - l.events = append(l.events, msg) -} - -func (l *testLogger) Reset() []string { - l.mu.Lock() - defer l.mu.Unlock() - ret := l.events - l.events = nil - return ret -}