From eccc74aafec4e5711bdcec4773aab02365a37e19 Mon Sep 17 00:00:00 2001 From: Peter Bourgon Date: Tue, 22 Sep 2015 16:02:29 +0200 Subject: [PATCH] Make the resolver concurrent --- probe/resolver.go | 140 +++++++++++++++++++++++------------------ probe/resolver_test.go | 76 ++++++++++++++++------ 2 files changed, 136 insertions(+), 80 deletions(-) diff --git a/probe/resolver.go b/probe/resolver.go index 6feae5d419..1d99e41c0b 100644 --- a/probe/resolver.go +++ b/probe/resolver.go @@ -15,95 +15,115 @@ var ( lookupIP = net.LookupIP ) +const maxConcurrentLookup = 10 + type staticResolver struct { - quit chan struct{} - set func(string, []string) - peers []peer + set func(string, []string) + targets []target + sema semaphore + quit chan struct{} } -type peer struct { - hostname string - port string -} +type target struct{ host, port string } + +func (t target) String() string { return net.JoinHostPort(t.host, t.port) } -// NewResolver starts a new resolver that periodically -// tries to resolve peers and the calls add() with all the -// resolved IPs. It explictiy supports hostnames which -// resolve to multiple IPs; it will repeatedly call -// add with the same IP, expecting the target to dedupe. -func newStaticResolver(peers []string, set func(target string, endpoints []string)) staticResolver { +// newStaticResolver periodically resolves the targets, and calls the set +// function with all the resolved IPs. It explictiy supports targets which +// resolve to multiple IPs. +func newStaticResolver(targets []string, set func(target string, endpoints []string)) staticResolver { r := staticResolver{ - quit: make(chan struct{}), - set: set, - peers: prepareNames(peers), + targets: prepare(targets), + set: set, + sema: newSemaphore(maxConcurrentLookup), + quit: make(chan struct{}), } go r.loop() return r } -func prepareNames(strs []string) []peer { - var results []peer - for _, s := range strs { - var ( - hostname string - port string - ) - - if strings.Contains(s, ":") { - var err error - hostname, port, err = net.SplitHostPort(s) - if err != nil { - log.Printf("invalid address %s: %v", s, err) - continue - } - } else { - hostname, port = s, strconv.Itoa(xfer.AppPort) - } - - results = append(results, peer{hostname, port}) - } - return results -} - func (r staticResolver) loop() { - r.resolveHosts() + r.resolve() t := tick(time.Minute) for { select { case <-t: - r.resolveHosts() - + r.resolve() case <-r.quit: return } } } -func (r staticResolver) resolveHosts() { - for _, peer := range r.peers { - var addrs []net.IP - if addr := net.ParseIP(peer.hostname); addr != nil { - addrs = []net.IP{addr} - } else { +func (r staticResolver) Stop() { + close(r.quit) +} + +func prepare(strs []string) []target { + var targets []target + for _, s := range strs { + var host, port string + if strings.Contains(s, ":") { var err error - addrs, err = lookupIP(peer.hostname) + host, port, err = net.SplitHostPort(s) if err != nil { + log.Printf("invalid address %s: %v", s, err) continue } + } else { + host, port = s, strconv.Itoa(xfer.AppPort) } + targets = append(targets, target{host, port}) + } + return targets +} - endpoints := make([]string, 0, len(addrs)) - for _, addr := range addrs { - // For now, ignore IPv6 - if addr.To4() == nil { - continue - } - endpoints = append(endpoints, net.JoinHostPort(addr.String(), peer.port)) +func (r staticResolver) resolve() { + for t, endpoints := range resolveMany(r.sema, r.targets) { + r.set(t.String(), endpoints) + } +} + +func resolveMany(s semaphore, targets []target) map[target][]string { + result := map[target][]string{} + for _, t := range targets { + c := make(chan []string) + go func(t target) { s.p(); defer s.v(); c <- resolveOne(t) }(t) + result[t] = <-c + } + return result +} + +func resolveOne(t target) []string { + var addrs []net.IP + if addr := net.ParseIP(t.host); addr != nil { + addrs = []net.IP{addr} + } else { + var err error + addrs, err = lookupIP(t.host) + if err != nil { + return []string{} + } + } + endpoints := make([]string, 0, len(addrs)) + for _, addr := range addrs { + // For now, ignore IPv6 + if addr.To4() == nil { + continue } - r.set(peer.hostname, endpoints) + endpoints = append(endpoints, net.JoinHostPort(addr.String(), t.port)) } + return endpoints } -func (r staticResolver) Stop() { - close(r.quit) +type semaphore chan struct{} + +func newSemaphore(n int) semaphore { + c := make(chan struct{}, n) + for i := 0; i < n; i++ { + c <- struct{}{} + } + return semaphore(c) } +func (s semaphore) p() { <-s } +func (s semaphore) v() { s <- struct{}{} } diff --git a/probe/resolver_test.go b/probe/resolver_test.go index 63826380c8..88e5807aae 100644 --- a/probe/resolver_test.go +++ b/probe/resolver_test.go @@ -48,42 +48,44 @@ func TestResolver(t *testing.T) { r := newStaticResolver([]string{"symbolic.name" + port, "namewithnoport", ip1 + port, ip2}, set) - assertAdd := func(want string) { + assertAdd := func(want ...string) { + remaining := map[string]struct{}{} + for _, s := range want { + remaining[s] = struct{}{} + } _, _, line, _ := runtime.Caller(1) - select { - case have := <-sets: - if want != have { - t.Errorf("line %d: want %q, have %q", line, want, have) + for len(remaining) > 0 { + select { + case s := <-sets: + if _, ok := remaining[s]; ok { + t.Logf("line %d: got %q OK", line, s) + delete(remaining, s) + } else { + t.Errorf("line %d: got unexpected %q", line, s) + } + case <-time.After(100 * time.Millisecond): + t.Fatalf("line %d: didn't get the adds in time", line) } - case <-time.After(100 * time.Millisecond): - t.Fatalf("line %d: didn't get add in time", line) } } // Initial resolve should just give us IPs - assertAdd(ip1 + port) - assertAdd(fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) + assertAdd(ip1+port, fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) // Trigger another resolve with a tick; again, // just want ips. c <- time.Now() - assertAdd(ip1 + port) - assertAdd(fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) + assertAdd(ip1+port, fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) ip3 := "1.2.3.4" updateIPs("symbolic.name", makeIPs(ip3)) - c <- time.Now() // trigger a resolve - assertAdd(ip3 + port) // we want 1 add - assertAdd(ip1 + port) - assertAdd(fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) + c <- time.Now() // trigger a resolve + assertAdd(ip3+port, ip1+port, fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) ip4 := "10.10.10.10" updateIPs("symbolic.name", makeIPs(ip3, ip4)) - c <- time.Now() // trigger another resolve, this time with 2 adds - assertAdd(ip3 + port) // first add - assertAdd(ip4 + port) // second add - assertAdd(ip1 + port) - assertAdd(fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) + c <- time.Now() // trigger another resolve, this time with 2 adds + assertAdd(ip3+port, ip4+port, ip1+port, fmt.Sprintf("%s:%d", ip2, xfer.AppPort)) done := make(chan struct{}) go func() { r.Stop(); close(done) }() @@ -94,6 +96,40 @@ func TestResolver(t *testing.T) { } } +func TestSemaphore(t *testing.T) { + n := 3 + s := newSemaphore(n) + + // First n should be fine + for i := 0; i < n; i++ { + ok := make(chan struct{}) + go func() { s.p(); close(ok) }() + select { + case <-ok: + case <-time.After(10 * time.Millisecond): + t.Errorf("p (%d) failed", i+1) + } + } + + // This should block + ok := make(chan struct{}) + go func() { s.p(); close(ok) }() + select { + case <-ok: + t.Errorf("%dth p OK, but should block", n+1) + case <-time.After(10 * time.Millisecond): + //t.Logf("%dth p blocks, as expected", n+1) + } + + s.v() + + select { + case <-ok: + case <-time.After(10 * time.Millisecond): + t.Errorf("%dth p didn't resolve in time", n+1) + } +} + func makeIPs(addrs ...string) []net.IP { var ips []net.IP for _, addr := range addrs {