-
Notifications
You must be signed in to change notification settings - Fork 671
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: s3/transfermanager (v2): round-robin DNS and multi-NIC #2975
Merged
lucix-aws
merged 4 commits into
feat-transfer-manager-v2
from
feat-transfer-manager-v2-rrdns
Jan 21, 2025
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package transfermanager | ||
|
||
import ( | ||
"sync" | ||
"time" | ||
|
||
"github.com/aws/smithy-go/container/private/cache" | ||
"github.com/aws/smithy-go/container/private/cache/lru" | ||
) | ||
|
||
// dnsCache implements an LRU cache of DNS query results by host. | ||
// | ||
// Cache retrievals will automatically rotate between IP addresses for | ||
// multi-value query results. | ||
type dnsCache struct { | ||
mu sync.Mutex | ||
addrs cache.Cache | ||
} | ||
|
||
// newDNSCache returns an initialized dnsCache with given capacity. | ||
func newDNSCache(cap int) *dnsCache { | ||
return &dnsCache{ | ||
addrs: lru.New(cap), | ||
} | ||
} | ||
|
||
// GetAddr returns the next IP address for the given host if present in the | ||
// cache. | ||
func (c *dnsCache) GetAddr(host string) (string, bool) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
v, ok := c.addrs.Get(host) | ||
if !ok { | ||
return "", false | ||
} | ||
|
||
record := v.(*dnsCacheEntry) | ||
if timeNow().After(record.expires) { | ||
return "", false | ||
} | ||
|
||
addr := record.addrs[record.index] | ||
record.index = (record.index + 1) % len(record.addrs) | ||
return addr, true | ||
} | ||
|
||
// PutAddrs stores a DNS query result in the cache, overwriting any present | ||
// entry for the host if it exists. | ||
func (c *dnsCache) PutAddrs(host string, addrs []string, expires time.Time) { | ||
c.mu.Lock() | ||
defer c.mu.Unlock() | ||
|
||
c.addrs.Put(host, &dnsCacheEntry{addrs, expires, 0}) | ||
} | ||
|
||
type dnsCacheEntry struct { | ||
addrs []string | ||
expires time.Time | ||
index int | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
package transfermanager | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"net" | ||
"net/http" | ||
"sync" | ||
"time" | ||
|
||
"github.com/aws/aws-sdk-go-v2/internal/sync/singleflight" | ||
) | ||
|
||
var timeNow = time.Now | ||
|
||
// WithRoundRobinDNS configures an http.Transport to spread HTTP connections | ||
// across multiple IP addresses for a given host. | ||
// | ||
// This is recommended by the [S3 performance guide] in high-concurrency | ||
// application environments. | ||
// | ||
// WithRoundRobinDNS wraps the underlying DialContext hook on http.Transport. | ||
// Future modifications to this hook MUST preserve said wrapping in order for | ||
// round-robin DNS to operate. | ||
// | ||
// [S3 performance guide]: https://docs.aws.amazon.com/AmazonS3/latest/userguide/optimizing-performance-design-patterns.html | ||
func WithRoundRobinDNS(opts ...func(*RoundRobinDNSOptions)) func(*http.Transport) { | ||
options := &RoundRobinDNSOptions{ | ||
TTL: 30 * time.Second, | ||
MaxHosts: 100, | ||
} | ||
for _, opt := range opts { | ||
opt(options) | ||
} | ||
|
||
return func(t *http.Transport) { | ||
rr := &rrDNS{ | ||
cache: newDNSCache(options.MaxHosts), | ||
expiry: options.TTL, | ||
resolver: &net.Resolver{}, | ||
dialContext: t.DialContext, | ||
} | ||
t.DialContext = rr.DialContext | ||
} | ||
} | ||
|
||
// RoundRobinDNSOptions configures use of round-robin DNS. | ||
type RoundRobinDNSOptions struct { | ||
// The length of time for which the results of a DNS query are valid. | ||
TTL time.Duration | ||
|
||
// A limit to the number of DNS query results, cached by hostname, which are | ||
// stored. Round-robin DNS uses an LRU cache. | ||
MaxHosts int | ||
} | ||
|
||
type resolver interface { | ||
LookupHost(context.Context, string) ([]string, error) | ||
} | ||
|
||
type rrDNS struct { | ||
sf singleflight.Group | ||
cache *dnsCache | ||
|
||
expiry time.Duration | ||
resolver resolver | ||
|
||
dialContext func(ctx context.Context, network, addr string) (net.Conn, error) | ||
} | ||
|
||
// DialContext implements the DialContext hook used by http.Transport, | ||
// pre-caching IP addresses for a given host and distributing them evenly | ||
// across new connections. | ||
func (r *rrDNS) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
host, port, err := net.SplitHostPort(addr) | ||
if err != nil { | ||
return nil, fmt.Errorf("rrdns split host/port: %w", err) | ||
} | ||
|
||
ipaddr, err := r.getAddr(ctx, host) | ||
if err != nil { | ||
return nil, fmt.Errorf("rrdns lookup host: %w", err) | ||
} | ||
|
||
return r.dialContext(ctx, network, net.JoinHostPort(ipaddr, port)) | ||
} | ||
|
||
func (r *rrDNS) getAddr(ctx context.Context, host string) (string, error) { | ||
addr, ok := r.cache.GetAddr(host) | ||
if ok { | ||
return addr, nil | ||
} | ||
return r.lookupHost(ctx, host) | ||
} | ||
|
||
func (r *rrDNS) lookupHost(ctx context.Context, host string) (string, error) { | ||
ch := r.sf.DoChan(host, func() (interface{}, error) { | ||
addrs, err := r.resolver.LookupHost(ctx, host) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
expires := timeNow().Add(r.expiry) | ||
r.cache.PutAddrs(host, addrs, expires) | ||
return nil, nil | ||
}) | ||
|
||
select { | ||
case result := <-ch: | ||
if result.Err != nil { | ||
return "", result.Err | ||
} | ||
|
||
addr, _ := r.cache.GetAddr(host) | ||
return addr, nil | ||
case <-ctx.Done(): | ||
return "", ctx.Err() | ||
} | ||
} | ||
|
||
// WithRotoDialer configures an http.Transport to cycle through multiple local | ||
// network addresses when creating new HTTP connections. | ||
// | ||
// WithRotoDialer REPLACES the root DialContext hook on the underlying | ||
// Transport, thereby destroying any previously-applied wrappings around it. If | ||
// the caller needs to apply additional decorations to the DialContext hook, | ||
// they must do so after applying WithRotoDialer. | ||
func WithRotoDialer(addrs []net.Addr) func(*http.Transport) { | ||
return func(t *http.Transport) { | ||
var dialers []*net.Dialer | ||
for _, addr := range addrs { | ||
dialers = append(dialers, &net.Dialer{ | ||
LocalAddr: addr, | ||
}) | ||
} | ||
|
||
t.DialContext = (&rotoDialer{ | ||
dialers: dialers, | ||
}).DialContext | ||
} | ||
} | ||
|
||
type rotoDialer struct { | ||
mu sync.Mutex | ||
dialers []*net.Dialer | ||
index int | ||
} | ||
|
||
func (r *rotoDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
return r.next().DialContext(ctx, network, addr) | ||
} | ||
|
||
func (r *rotoDialer) next() *net.Dialer { | ||
r.mu.Lock() | ||
defer r.mu.Unlock() | ||
|
||
d := r.dialers[r.index] | ||
r.index = (r.index + 1) % len(r.dialers) | ||
return d | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
package transfermanager | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"net" | ||
"testing" | ||
"time" | ||
) | ||
|
||
// these tests also cover the cache impl (cycling+expiry+evict) | ||
|
||
type mockNow struct { | ||
now time.Time | ||
} | ||
|
||
func (m *mockNow) Now() time.Time { | ||
return m.now | ||
} | ||
|
||
func (m *mockNow) Add(d time.Duration) { | ||
m.now = m.now.Add(d) | ||
} | ||
|
||
func useMockNow(m *mockNow) func() { | ||
timeNow = m.Now | ||
return func() { | ||
timeNow = time.Now | ||
} | ||
} | ||
|
||
var errDialContextOK = errors.New("dial context ok") | ||
|
||
type mockResolver struct { | ||
addrs map[string][]string | ||
err error | ||
} | ||
|
||
func (m *mockResolver) LookupHost(ctx context.Context, host string) ([]string, error) { | ||
return m.addrs[host], m.err | ||
} | ||
|
||
type mockDialContext struct { | ||
calledWith string | ||
} | ||
|
||
func (m *mockDialContext) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { | ||
m.calledWith = addr | ||
return nil, errDialContextOK | ||
} | ||
|
||
func TestRoundRobinDNS_CycleIPs(t *testing.T) { | ||
restore := useMockNow(&mockNow{}) | ||
defer restore() | ||
|
||
addrs := []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"} | ||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"s3.us-east-1.amazonaws.com": addrs, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(1), | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0]) | ||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[1]) | ||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[2]) | ||
expectDialContext(t, rr, dc, "s3.us-east-1.amazonaws.com", addrs[0]) | ||
} | ||
|
||
func TestRoundRobinDNS_MultiIP(t *testing.T) { | ||
restore := useMockNow(&mockNow{}) | ||
defer restore() | ||
|
||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
"host2.com": []string{"1.0.0.1", "1.0.0.2", "1.0.0.3"}, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(2), | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1]) | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][1]) | ||
} | ||
|
||
func TestRoundRobinDNS_MaxHosts(t *testing.T) { | ||
restore := useMockNow(&mockNow{}) | ||
defer restore() | ||
|
||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
"host2.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(1), | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1]) | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) // evicts host1 | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) // evicts host2 | ||
expectDialContext(t, rr, dc, "host2.com", r.addrs["host2.com"][0]) | ||
} | ||
|
||
func TestRoundRobinDNS_Expires(t *testing.T) { | ||
now := &mockNow{time.Unix(0, 0)} | ||
restore := useMockNow(now) | ||
defer restore() | ||
|
||
r := &mockResolver{ | ||
addrs: map[string][]string{ | ||
"host1.com": []string{"0.0.0.1", "0.0.0.2", "0.0.0.3"}, | ||
}, | ||
} | ||
dc := &mockDialContext{} | ||
|
||
rr := &rrDNS{ | ||
cache: newDNSCache(2), | ||
expiry: 30, | ||
resolver: r, | ||
dialContext: dc.DialContext, | ||
} | ||
|
||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
now.Add(16) // hasn't expired | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][1]) | ||
now.Add(16) // expired, starts over | ||
expectDialContext(t, rr, dc, "host1.com", r.addrs["host1.com"][0]) | ||
} | ||
|
||
func expectDialContext(t *testing.T, rr *rrDNS, dc *mockDialContext, host, expect string) { | ||
const port = "443" | ||
|
||
t.Helper() | ||
_, err := rr.DialContext(context.Background(), "", net.JoinHostPort(host, port)) | ||
if err != errDialContextOK { | ||
t.Errorf("expect sentinel err, got %v", err) | ||
} | ||
actual, _, err := net.SplitHostPort(dc.calledWith) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
if expect != actual { | ||
t.Errorf("expect addr %s, got %s", expect, actual) | ||
} | ||
} |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So when we have a cache miss we'll refresh the whole cache? is this what we want?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the whole cache? We just re-query for that one host, or am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cache is
hostname -> []<ip address>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, my thinking was that we were storing things like
But we're storing all results of the lookup with the same expiry (since we don't get TTL), so the entries are
Then it doesn't really matter if one expires since all entries from the lookup will expire at the same time. And anyway, anytime we do a lookup for the same hostname, we replace existing values so they effectively always have the same expiry