Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: s3/transfermanager (v2): round-robin DNS and multi-NIC #2975

Merged
merged 4 commits into from
Jan 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions feature/s3/transfermanager/dns_cache.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package transfermanager

import (
"sync"
"time"

"github.com/aws/smithy-go/container/private/cache"
"github.com/aws/smithy-go/container/private/cache/lru"
)

// dnsCache implements an LRU cache of DNS query results by host.
//
// Cache retrievals will automatically rotate between IP addresses for
// multi-value query results.
type dnsCache struct {
mu sync.Mutex
addrs cache.Cache
}

// newDNSCache returns an initialized dnsCache with given capacity.
func newDNSCache(cap int) *dnsCache {
return &dnsCache{
addrs: lru.New(cap),
}
}

// GetAddr returns the next IP address for the given host if present in the
// cache.
func (c *dnsCache) GetAddr(host string) (string, bool) {
c.mu.Lock()
defer c.mu.Unlock()

v, ok := c.addrs.Get(host)
if !ok {
return "", false
}

record := v.(*dnsCacheEntry)
if timeNow().After(record.expires) {
return "", false
}

addr := record.addrs[record.index]
record.index = (record.index + 1) % len(record.addrs)
return addr, true
}

// PutAddrs stores a DNS query result in the cache, overwriting any present
// entry for the host if it exists.
func (c *dnsCache) PutAddrs(host string, addrs []string, expires time.Time) {
c.mu.Lock()
defer c.mu.Unlock()

c.addrs.Put(host, &dnsCacheEntry{addrs, expires, 0})
}

type dnsCacheEntry struct {
addrs []string
expires time.Time
index int
}
160 changes: 160 additions & 0 deletions feature/s3/transfermanager/rrdns.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
package transfermanager

import (
"context"
"fmt"
"net"
"net/http"
"sync"
"time"

"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight"
)

var timeNow = time.Now

// WithRoundRobinDNS configures an http.Transport to spread HTTP connections
// across multiple IP addresses for a given host.
//
// This is recommended by the [S3 performance guide] in high-concurrency
// application environments.
//
// WithRoundRobinDNS wraps the underlying DialContext hook on http.Transport.
// Future modifications to this hook MUST preserve said wrapping in order for
// round-robin DNS to operate.
//
// [S3 performance guide]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html
func WithRoundRobinDNS(opts ...func(*RoundRobinDNSOptions)) func(*http.Transport) {
options := &RoundRobinDNSOptions{
TTL: 30 * time.Second,
MaxHosts: 100,
}
for _, opt := range opts {
opt(options)
}

return func(t *http.Transport) {
rr := &rrDNS{
cache: newDNSCache(options.MaxHosts),
expiry: options.TTL,
resolver: &net.Resolver{},
dialContext: t.DialContext,
}
t.DialContext = rr.DialContext
}
}

// RoundRobinDNSOptions configures use of round-robin DNS.
type RoundRobinDNSOptions struct {
// The length of time for which the results of a DNS query are valid.
TTL time.Duration

// A limit to the number of DNS query results, cached by hostname, which are
// stored. Round-robin DNS uses an LRU cache.
MaxHosts int
}

type resolver interface {
LookupHost(context.Context, string) ([]string, error)
}

type rrDNS struct {
sf singleflight.Group
cache *dnsCache

expiry time.Duration
resolver resolver

dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}

// DialContext implements the DialContext hook used by http.Transport,
// pre-caching IP addresses for a given host and distributing them evenly
// across new connections.
func (r *rrDNS) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
host, port, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("rrdns split host/port: %w", err)
}

ipaddr, err := r.getAddr(ctx, host)
if err != nil {
return nil, fmt.Errorf("rrdns lookup host: %w", err)
}

return r.dialContext(ctx, network, net.JoinHostPort(ipaddr, port))
}

func (r *rrDNS) getAddr(ctx context.Context, host string) (string, error) {
addr, ok := r.cache.GetAddr(host)
if ok {
return addr, nil
}
return r.lookupHost(ctx, host)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when we have a cache miss we'll refresh the whole cache? is this what we want?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the whole cache? We just re-query for that one host, or am I missing something?

Copy link
Contributor Author

@lucix-aws lucix-aws Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache is hostname -> []<ip address>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, my thinking was that we were storing things like

hostname -> {
  ip1/expiry1
  ip2/expiry2
  ip3/expiry3
}

But we're storing all results of the lookup with the same expiry (since we don't get TTL), so the entries are

hostname -> {
  ip1/expiry1
  ip2/expiry1
  ip3/expiry1
}

Then it doesn't really matter if one expires since all entries from the lookup will expire at the same time. And anyway, anytime we do a lookup for the same hostname, we replace existing values so they effectively always have the same expiry

}

func (r *rrDNS) lookupHost(ctx context.Context, host string) (string, error) {
ch := r.sf.DoChan(host, func() (interface{}, error) {
addrs, err := r.resolver.LookupHost(ctx, host)
if err != nil {
return nil, err
}

expires := timeNow().Add(r.expiry)
r.cache.PutAddrs(host, addrs, expires)
return nil, nil
})

select {
case result := <-ch:
if result.Err != nil {
return "", result.Err
}

addr, _ := r.cache.GetAddr(host)
return addr, nil
case <-ctx.Done():
return "", ctx.Err()
}
}

// WithRotoDialer configures an http.Transport to cycle through multiple local
// network addresses when creating new HTTP connections.
//
// WithRotoDialer REPLACES the root DialContext hook on the underlying
// Transport, thereby destroying any previously-applied wrappings around it. If
// the caller needs to apply additional decorations to the DialContext hook,
// they must do so after applying WithRotoDialer.
func WithRotoDialer(addrs []net.Addr) func(*http.Transport) {
return func(t *http.Transport) {
var dialers []*net.Dialer
for _, addr := range addrs {
dialers = append(dialers, &net.Dialer{
LocalAddr: addr,
})
}

t.DialContext = (&rotoDialer{
dialers: dialers,
}).DialContext
}
}

type rotoDialer struct {
mu sync.Mutex
dialers []*net.Dialer
index int
}

func (r *rotoDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
return r.next().DialContext(ctx, network, addr)
}

func (r *rotoDialer) next() *net.Dialer {
r.mu.Lock()
defer r.mu.Unlock()

d := r.dialers[r.index]
r.index = (r.index + 1) % len(r.dialers)
return d
}
166 changes: 166 additions & 0 deletions feature/s3/transfermanager/rrdns_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
package transfermanager

import (
"context"
"errors"
"net"
"testing"
"time"
)

// these tests also cover the cache impl (cycling+expiry+evict)

type mockNow struct {
now time.Time
}

func (m *mockNow) Now() time.Time {
return m.now
}

func (m *mockNow) Add(d time.Duration) {
m.now = m.now.Add(d)
}

func useMockNow(m *mockNow) func() {
timeNow = m.Now
return func() {
timeNow = time.Now
}
}

var errDialContextOK = errors.New("dial context ok")

type mockResolver struct {
addrs map[string][]string
err error
}

func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
return m.addrs[host], m.err
}

type mockDialContext struct {
calledWith string
}

func (m *mockDialContext) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
m.calledWith = addr
return nil, errDialContextOK
}

func TestRoundRobinDNS_CycleIPs(t *testing.T) {
restore := useMockNow(&mockNow{})
defer restore()

addrs := []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}
r := &mockResolver{
addrs: map[string][]string{
"s3.us-east-1.amazonaws.com": addrs,
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(1),
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0])
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[1])
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[2])
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0])
}

func TestRoundRobinDNS_MultiIP(t *testing.T) {
restore := useMockNow(&mockNow{})
defer restore()

r := &mockResolver{
addrs: map[string][]string{
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
"host2.com": []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"},
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(2),
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0])
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][1])
}

func TestRoundRobinDNS_MaxHosts(t *testing.T) {
restore := useMockNow(&mockNow{})
defer restore()

r := &mockResolver{
addrs: map[string][]string{
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
"host2.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(1),
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) // evicts host1
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) // evicts host2
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0])
}

func TestRoundRobinDNS_Expires(t *testing.T) {
now := &mockNow{time.Unix(0, 0)}
restore := useMockNow(now)
defer restore()

r := &mockResolver{
addrs: map[string][]string{
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"},
},
}
dc := &mockDialContext{}

rr := &rrDNS{
cache: newDNSCache(2),
expiry: 30,
resolver: r,
dialContext: dc.DialContext,
}

expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
now.Add(16) // hasn't expired
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1])
now.Add(16) // expired, starts over
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0])
}

func expectDialContext(t *testing.T, rr *rrDNS, dc *mockDialContext, host, expect string) {
const port = "443"

t.Helper()
_, err := rr.DialContext(context.Background(), "", net.JoinHostPort(host, port))
if err != errDialContextOK {
t.Errorf("expect sentinel err, got %v", err)
}
actual, _, err := net.SplitHostPort(dc.calledWith)
if err != nil {
t.Fatal(err)
}
if expect != actual {
t.Errorf("expect addr %s, got %s", expect, actual)
}
}
Loading