From 4dc080847d90ad889a31036f7fd804aa3827b43e Mon Sep 17 00:00:00 2001 From: Paul Chesnais Date: Fri, 2 Feb 2024 18:40:39 -0500 Subject: [PATCH] Upgrade HostProvider (#6) * Upgrade HostProvider This change fixes the behavior of DNSHostProvider where it does not refresh its cached IP addresses that it resolves once on startup for the configured ZK servers. This new behavior more closely matches the Java client's behavior by randomly selecting an address after resolving the host. It slightly changes the semantics of `HostProvider` with an off-by-one, otherwise the `connect` loop could end up in a situation where it attempts to connect to a stale address. This is fixed by moving the backoff to _before_ getting the address, rather than _after_. * Bump linter version to support generics * Fix linter and integration test actions * Add docs --- conn.go | 39 ++++++----- dnshostprovider.go | 47 ++++++------- dnshostprovider_test.go | 125 +++++++++++++++++++--------------- staticdnshostprovider.go | 115 +++++++++++++++++++++++++++++++ staticdnshostprovider_test.go | 71 +++++++++++++++++++ tcp_server_test.go | 8 +-- util.go | 9 ++- 7 files changed, 304 insertions(+), 110 deletions(-) create mode 100644 staticdnshostprovider.go create mode 100644 staticdnshostprovider_test.go diff --git a/conn.go b/conn.go index b3e52d6d..9d880e36 100644 --- a/conn.go +++ b/conn.go @@ -173,12 +173,11 @@ type Event struct { type HostProvider interface { // Init is called first, with the servers specified in the connection string. Init(servers []string) error - // Len returns the number of servers. - Len() int - // Next returns the next server to connect to. retryStart will be true if we've looped through - // all known servers without Connected() being called. + // Next returns the next server to connect to. retryStart should be true if this call to Next + // exhausted the list of known servers without Connected being called. If connecting to this final + // host fails, the connect loop will back off before invoking Next again for a fresh server. Next() (server string, retryStart bool) - // Notify the HostProvider of a successful connection. + // Connected notifies the HostProvider of a successful connection. Connected() } @@ -203,12 +202,12 @@ func Connect(servers []string, sessionTimeout time.Duration, options ...connOpti srvs := FormatServers(servers) // Randomize the order of the servers to avoid creating hotspots - stringShuffle(srvs) + shuffleSlice(srvs) ec := make(chan Event, eventChanSize) conn := &Conn{ dialer: net.DialTimeout, - hostProvider: &DNSHostProvider{}, + hostProvider: new(StaticHostProvider), conn: nil, state: StateDisconnected, eventChan: ec, @@ -387,7 +386,7 @@ func (c *Conn) sendEvent(evt Event) { } } -func (c *Conn) connect() error { +func (c *Conn) connect() (err error) { var retryStart bool for { c.serverMu.Lock() @@ -396,18 +395,6 @@ func (c *Conn) connect() error { c.setState(StateConnecting) - if retryStart { - c.flushUnsentRequests(ErrNoServer) - select { - case <-time.After(time.Second): - // pass - case <-c.shouldQuit: - c.setState(StateDisconnected) - c.flushUnsentRequests(ErrClosing) - return ErrClosing - } - } - zkConn, err := c.dialer("tcp", c.Server(), c.connectTimeout) if err == nil { c.conn = zkConn @@ -419,6 +406,18 @@ func (c *Conn) connect() error { } c.logger.Printf("failed to connect to %s: %v", c.Server(), err) + + if retryStart { + c.flushUnsentRequests(ErrNoServer) + select { + case <-time.After(time.Second): + // pass + case <-c.shouldQuit: + c.setState(StateDisconnected) + c.flushUnsentRequests(ErrClosing) + return ErrClosing + } + } } } diff --git a/dnshostprovider.go b/dnshostprovider.go index f4bba8d0..3dd74c87 100644 --- a/dnshostprovider.go +++ b/dnshostprovider.go @@ -6,10 +6,12 @@ import ( "sync" ) -// DNSHostProvider is the default HostProvider. It currently matches -// the Java StaticHostProvider, resolving hosts from DNS once during -// the call to Init. It could be easily extended to re-query DNS -// periodically or if there is trouble connecting. +// DNSHostProvider is a simple implementation of a HostProvider. It resolves the hosts once during +// Init, and iterates through the resolved addresses for every call to Next. Note that if the +// addresses that back the ZK hosts change, those changes will not be reflected. +// +// Deprecated: Because this HostProvider does not attempt to re-read from DNS, it can lead to issues +// if the addresses of the hosts change. It is preserved for backwards compatibility. type DNSHostProvider struct { mu sync.Mutex // Protects everything, so we can add asynchronous updates later. servers []string @@ -30,7 +32,7 @@ func (hp *DNSHostProvider) Init(servers []string) error { lookupHost = net.LookupHost } - found := []string{} + var found []string for _, server := range servers { host, port, err := net.SplitHostPort(server) if err != nil { @@ -46,43 +48,38 @@ func (hp *DNSHostProvider) Init(servers []string) error { } if len(found) == 0 { - return fmt.Errorf("No hosts found for addresses %q", servers) + return fmt.Errorf("zk: no hosts found for addresses %q", servers) } // Randomize the order of the servers to avoid creating hotspots - stringShuffle(found) + shuffleSlice(found) hp.servers = found - hp.curr = -1 - hp.last = -1 + hp.curr = 0 + hp.last = len(hp.servers) - 1 return nil } -// Len returns the number of servers available -func (hp *DNSHostProvider) Len() int { - hp.mu.Lock() - defer hp.mu.Unlock() - return len(hp.servers) -} - -// Next returns the next server to connect to. retryStart will be true -// if we've looped through all known servers without Connected() being -// called. +// Next returns the next server to connect to. retryStart should be true if this call to Next +// exhausted the list of known servers without Connected being called. If connecting to this final +// host fails, the connect loop will back off before invoking Next again for a fresh server. func (hp *DNSHostProvider) Next() (server string, retryStart bool) { hp.mu.Lock() defer hp.mu.Unlock() - hp.curr = (hp.curr + 1) % len(hp.servers) retryStart = hp.curr == hp.last - if hp.last == -1 { - hp.last = 0 - } - return hp.servers[hp.curr], retryStart + server = hp.servers[hp.curr] + hp.curr = (hp.curr + 1) % len(hp.servers) + return server, retryStart } // Connected notifies the HostProvider of a successful connection. func (hp *DNSHostProvider) Connected() { hp.mu.Lock() defer hp.mu.Unlock() - hp.last = hp.curr + if hp.curr == 0 { + hp.last = len(hp.servers) - 1 + } else { + hp.last = hp.curr - 1 + } } diff --git a/dnshostprovider_test.go b/dnshostprovider_test.go index 48000a5f..00bdea80 100644 --- a/dnshostprovider_test.go +++ b/dnshostprovider_test.go @@ -68,7 +68,6 @@ func newLocalHostPortsFacade(inner HostProvider, ports []int) *localHostPortsFac } } -func (lhpf *localHostPortsFacade) Len() int { return lhpf.inner.Len() } func (lhpf *localHostPortsFacade) Connected() { lhpf.inner.Connected() } func (lhpf *localHostPortsFacade) Init(servers []string) error { return lhpf.inner.Init(servers) } func (lhpf *localHostPortsFacade) Next() (string, bool) { @@ -165,60 +164,78 @@ func TestDNSHostProviderReconnect(t *testing.T) { } } -// TestDNSHostProviderRetryStart tests the `retryStart` functionality -// of DNSHostProvider. -// It's also probably the clearest visual explanation of exactly how -// it works. -func TestDNSHostProviderRetryStart(t *testing.T) { +// TestHostProvidersRetryStart tests the `retryStart` functionality of DNSHostProvider and +// StaticHostProvider. +// It's also probably the clearest visual explanation of exactly how it works. +func TestHostProvidersRetryStart(t *testing.T) { t.Parallel() - hp := &DNSHostProvider{lookupHost: func(host string) ([]string, error) { - return []string{"192.0.2.1", "192.0.2.2", "192.0.2.3"}, nil - }} - - if err := hp.Init([]string{"foo.example.com:12345"}); err != nil { - t.Fatal(err) - } - - testdata := []struct { - retryStartWant bool - callConnected bool - }{ - // Repeated failures. - {false, false}, - {false, false}, - {false, false}, - {true, false}, - {false, false}, - {false, false}, - {true, true}, - - // One success offsets things. - {false, false}, - {false, true}, - {false, true}, - - // Repeated successes. - {false, true}, - {false, true}, - {false, true}, - {false, true}, - {false, true}, - - // And some more failures. - {false, false}, - {false, false}, - {true, false}, // Looped back to last known good server: all alternates failed. - {false, false}, - } - - for i, td := range testdata { - _, retryStartGot := hp.Next() - if retryStartGot != td.retryStartWant { - t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant) - } - if td.callConnected { - hp.Connected() - } + lookupHost := func(host string) ([]string, error) { + return []string{host}, nil + } + + providers := []HostProvider{ + &DNSHostProvider{ + lookupHost: lookupHost, + }, + &StaticHostProvider{ + lookupHost: lookupHost, + }, + } + + for _, hp := range providers { + t.Run(fmt.Sprintf("%T", hp), func(t *testing.T) { + if err := hp.Init([]string{"foo.com:2121", "bar.com:2121", "baz.com:2121"}); err != nil { + t.Fatal(err) + } + + testdata := []struct { + retryStartWant bool + callConnected bool + }{ + // Repeated failures. + {false, false}, + {false, false}, + {true, false}, + {false, false}, + {false, false}, + {true, false}, + {false, true}, + + // One success offsets things. + {false, false}, + {false, true}, + {false, true}, + + // Repeated successes. + {false, true}, + {false, true}, + {false, true}, + {false, true}, + {false, true}, + + // And some more failures. + {false, false}, + {false, false}, + {true, false}, // Looped back to last known good server: all alternates failed. + {false, false}, + {false, false}, + {true, false}, + {false, false}, + {false, false}, + {true, false}, + {false, false}, + } + + for i, td := range testdata { + _, retryStartGot := hp.Next() + if retryStartGot != td.retryStartWant { + t.Errorf("%d: retryStart=%v; want %v", i, retryStartGot, td.retryStartWant) + } + if td.callConnected { + hp.Connected() + } + } + }) } } diff --git a/staticdnshostprovider.go b/staticdnshostprovider.go new file mode 100644 index 00000000..cb298ce3 --- /dev/null +++ b/staticdnshostprovider.go @@ -0,0 +1,115 @@ +package zk + +import ( + "fmt" + "log/slog" + "math/rand" + "net" + "sync" +) + +type hostPort struct { + host, port string +} + +func (hp *hostPort) String() string { + return hp.host + ":" + hp.port +} + +// StaticHostProvider is the default HostProvider, and replaces the now deprecated DNSHostProvider. +// It will iterate through the ZK hosts on every call to Next, and return a random address selected +// from the resolved addresses of the ZK host (if the host is already an IP, it will return that +// directly). It is important to manually resolve and shuffle the addresses because the DNS record +// that backs a host may rarely (or never) change, so repeated calls to connect to this host may +// always connect to the same IP. This mode is the default mode, and matches the Java client's +// implementation. Note that if the host cannot be resolved, Next will return it directly, instead of +// an error. This will cause Dial to fail and the loop will move on to a new host. It is implemented +// as a pound-for-pound copy of the standard Java client's equivalent: +// https://github.com/linkedin/zookeeper/blob/629518b5ea2b26d88a9ec53d5a422afe9b12e452/zookeeper-server/src/main/java/org/apache/zookeeper/client/StaticHostProvider.java#L368 +type StaticHostProvider struct { + mu sync.Mutex // Protects everything, so we can add asynchronous updates later. + servers []hostPort + // nextServer is the index (in servers) of the next server that will be returned by Next. + nextServer int + // lastConnectedServer is the index (in servers) of the last server to which a successful connection + // was established. Used to track whether Next iterated through all available servers without + // successfully connecting. + lastConnectedServer int + lookupHost func(string) ([]string, error) // Override of net.LookupHost, for testing. +} + +func (shp *StaticHostProvider) Init(servers []string) error { + shp.mu.Lock() + defer shp.mu.Unlock() + + if shp.lookupHost == nil { + shp.lookupHost = net.LookupHost + } + + var found []hostPort + for _, server := range servers { + host, port, err := net.SplitHostPort(server) + if err != nil { + return err + } + // Perform the lookup to validate the initial set of hosts, but discard the results as the addresses + // will be resolved dynamically when Next is called. + _, err = shp.lookupHost(host) + if err != nil { + return err + } + + found = append(found, hostPort{host, port}) + } + + if len(found) == 0 { + return fmt.Errorf("zk: no hosts found for addresses %q", servers) + } + + // Randomize the order of the servers to avoid creating hotspots + shuffleSlice(found) + + shp.servers = found + shp.nextServer = 0 + shp.lastConnectedServer = len(shp.servers) - 1 + + return nil +} + +// Next returns the next server to connect to. retryStart should be true if this call to Next +// exhausted the list of known servers without Connected being called. If connecting to this final +// host fails, the connect loop will back off before invoking Next again for a fresh server. +func (shp *StaticHostProvider) Next() (server string, retryStart bool) { + shp.mu.Lock() + defer shp.mu.Unlock() + retryStart = shp.nextServer == shp.lastConnectedServer + + next := shp.servers[shp.nextServer] + addrs, err := shp.lookupHost(next.host) + if len(addrs) == 0 { + if err == nil { + // If for whatever reason lookupHosts returned an empty list of addresses but a nil error, use a + // default error + err = fmt.Errorf("zk: no hosts resolved by lookup for %q", next.host) + } + slog.Warn("Could not resolve ZK host", "host", next.host, "err", err) + server = next.String() + } else { + server = addrs[rand.Intn(len(addrs))] + ":" + next.port + } + + shp.nextServer = (shp.nextServer + 1) % len(shp.servers) + + return server, retryStart +} + +// Connected notifies the HostProvider of a successful connection. +func (shp *StaticHostProvider) Connected() { + shp.mu.Lock() + defer shp.mu.Unlock() + if shp.nextServer == 0 { + shp.lastConnectedServer = len(shp.servers) - 1 + } else { + shp.lastConnectedServer = shp.nextServer - 1 + } +} diff --git a/staticdnshostprovider_test.go b/staticdnshostprovider_test.go new file mode 100644 index 00000000..7cd2ae86 --- /dev/null +++ b/staticdnshostprovider_test.go @@ -0,0 +1,71 @@ +package zk + +import "testing" + +// The test in TestHostProvidersRetryStart checks that the semantics of StaticHostProvider's +// implementation of Next are correct, this test only checks that the provider correctly interacts +// with the resolver. +func TestStaticHostProvider(t *testing.T) { + const fooPort, barPort = "2121", "6464" + const fooHost, barHost = "foo.com", "bar.com" + hostToPort := map[string]string{ + fooHost: fooPort, + barHost: barPort, + } + hostToAddrs := map[string][]string{ + fooHost: {"0.0.0.1", "0.0.0.2", "0.0.0.3"}, + barHost: {"0.0.0.4", "0.0.0.5", "0.0.0.6"}, + } + addrToHost := map[string]string{} + for host, addrs := range hostToAddrs { + for _, addr := range addrs { + addrToHost[addr+":"+hostToPort[host]] = host + } + } + + hp := &StaticHostProvider{ + lookupHost: func(host string) ([]string, error) { + addrs, ok := hostToAddrs[host] + if !ok { + t.Fatalf("Unexpected argument to lookupHost %q", host) + } + return addrs, nil + }, + } + + err := hp.Init([]string{fooHost + ":" + fooPort, barHost + ":" + barPort}) + if err != nil { + t.Fatalf("Unexpected err from Init %v", err) + } + + addr1, retryStart := hp.Next() + if retryStart { + t.Fatalf("retryStart should be false") + } + addr2, retryStart := hp.Next() + if !retryStart { + t.Fatalf("retryStart should be true") + } + host1, host2 := addrToHost[addr1], addrToHost[addr2] + if host1 == host2 { + t.Fatalf("Next yielded addresses from same host (%q)", host1) + } + + // Final sanity check that it is shuffling the addresses + seenAddresses := map[string]map[string]bool{ + fooHost: {}, + barHost: {}, + } + for i := 0; i < 10_000; i++ { + addr, _ := hp.Next() + seenAddresses[addrToHost[addr]][addr] = true + } + + for host, addrs := range hostToAddrs { + for _, addr := range addrs { + if !seenAddresses[host][addr+":"+hostToPort[host]] { + t.Fatalf("expected addr %q for host %q not seen (seen: %v)", addr, host, seenAddresses) + } + } + } +} diff --git a/tcp_server_test.go b/tcp_server_test.go index 09254948..72bbd09c 100644 --- a/tcp_server_test.go +++ b/tcp_server_test.go @@ -1,17 +1,13 @@ package zk import ( - "fmt" - "math/rand" "net" "testing" "time" ) func WithListenServer(t *testing.T, test func(server string)) { - startPort := int(rand.Int31n(6000) + 10000) - server := fmt.Sprintf("localhost:%d", startPort) - l, err := net.Listen("tcp", server) + l, err := net.Listen("tcp", "localhost:0") if err != nil { t.Fatalf("Failed to start listen server: %v", err) } @@ -26,7 +22,7 @@ func WithListenServer(t *testing.T, test func(server string)) { handleRequest(conn) }() - test(server) + test(l.Addr().String()) } // Handles incoming requests. diff --git a/util.go b/util.go index 5a92b66b..9244a0bb 100644 --- a/util.go +++ b/util.go @@ -49,12 +49,11 @@ func FormatServers(servers []string) []string { return srvs } -// stringShuffle performs a Fisher-Yates shuffle on a slice of strings -func stringShuffle(s []string) { - for i := len(s) - 1; i > 0; i-- { - j := rand.Intn(i + 1) +// shuffleSlice invokes rand.Shuffle on the given slice. +func shuffleSlice[T any](s []T) { + rand.Shuffle(len(s), func(i, j int) { s[i], s[j] = s[j], s[i] - } + }) } // validatePath will make sure a path is valid before sending the request