Skip to content

Commit

Permalink
Reduce timeout for accessing EC2 instance metadata service
Browse files Browse the repository at this point in the history
Previously we were re-using our shared HTTP client, which has a rather
high timeout (120 seconds) that causes the HTTP client to wait around
for a long time. This is generally intentional (since it includes the
time spent downloading a request body), but is a bad idea when running
into EC2's IDMSv2 service that has a network-hop based limit. If that
hop limit is exceeded, the requests just go to nowhere, causing the
client to wait for a multiple of 120 seconds (~10 minutes were observed).

This instead uses a special client for the EC2 instance metadata service
that has a much lower timeout (1 second, like in the AWS SDK itself), to
avoid the problem.

See also aws/aws-sdk-go#2972
  • Loading branch information
lfittl committed Mar 10, 2021
1 parent 1e7bc4e commit eec7426
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 27 deletions.
12 changes: 10 additions & 2 deletions config/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,8 @@ func CreateHTTPClient(conf ServerConfig) *http.Client {
matchesProxyURL = true
}
}
// Require secure conection for everything except proxies, the EC2 and ECS metadata services
if !matchesProxyURL && !strings.HasSuffix(addr, ":443") && addr != "169.254.169.254:80" && addr != "169.254.170.2:80" {
// Require secure conection for everything except proxies
if !matchesProxyURL && !strings.HasSuffix(addr, ":443") {
return nil, fmt.Errorf("Unencrypted connection is not permitted by pganalyze configuration")
}
return (&net.Dialer{Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, DualStack: true}).DialContext(ctx, network, addr)
Expand All @@ -279,6 +279,14 @@ func CreateHTTPClient(conf ServerConfig) *http.Client {
}
}

// CreateEC2IMDSHTTPClient - Create HTTP client for EC2 instance meta data service (IMDS)
func CreateEC2IMDSHTTPClient(conf ServerConfig) *http.Client {
// Match https://github.com/aws/aws-sdk-go/pull/3066
return &http.Client{
Timeout: 1 * time.Second,
}
}

func writeValueToTempfile(value string) (string, error) {
file, err := ioutil.TempFile("", "")
if err != nil {
Expand Down
69 changes: 44 additions & 25 deletions util/awsutil/amazon.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,86 @@ import (
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/aws/aws-sdk-go/aws/endpoints"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/pganalyze/collector/config"
)

// GetAwsSession - Returns an AWS session for the specified server configuration
func GetAwsSession(config config.ServerConfig) (*session.Session, error) {
var creds *credentials.Credentials

if config.AwsAccessKeyID != "" {
creds = credentials.NewStaticCredentials(config.AwsAccessKeyID, config.AwsSecretAccessKey, "")
}
// GetAwsSession - Returns an AWS session for the specified server cfguration
func GetAwsSession(cfg config.ServerConfig) (*session.Session, error) {
var providers []credentials.Provider

customResolver := func(service, region string, optFns ...func(*endpoints.Options)) (endpoints.ResolvedEndpoint, error) {
if service == endpoints.RdsServiceID && config.AwsEndpointRdsURL != "" {
if service == endpoints.RdsServiceID && cfg.AwsEndpointRdsURL != "" {
return endpoints.ResolvedEndpoint{
URL: config.AwsEndpointRdsURL,
SigningRegion: config.AwsEndpointSigningRegion,
URL: cfg.AwsEndpointRdsURL,
SigningRegion: cfg.AwsEndpointSigningRegion,
}, nil
}
if service == endpoints.Ec2ServiceID && config.AwsEndpointEc2URL != "" {
if service == endpoints.Ec2ServiceID && cfg.AwsEndpointEc2URL != "" {
return endpoints.ResolvedEndpoint{
URL: config.AwsEndpointEc2URL,
SigningRegion: config.AwsEndpointSigningRegion,
URL: cfg.AwsEndpointEc2URL,
SigningRegion: cfg.AwsEndpointSigningRegion,
}, nil
}
if service == endpoints.MonitoringServiceID && config.AwsEndpointCloudwatchURL != "" {
if service == endpoints.MonitoringServiceID && cfg.AwsEndpointCloudwatchURL != "" {
return endpoints.ResolvedEndpoint{
URL: config.AwsEndpointCloudwatchURL,
SigningRegion: config.AwsEndpointSigningRegion,
URL: cfg.AwsEndpointCloudwatchURL,
SigningRegion: cfg.AwsEndpointSigningRegion,
}, nil
}
if service == endpoints.LogsServiceID && config.AwsEndpointCloudwatchLogsURL != "" {
if service == endpoints.LogsServiceID && cfg.AwsEndpointCloudwatchLogsURL != "" {
return endpoints.ResolvedEndpoint{
URL: config.AwsEndpointCloudwatchLogsURL,
SigningRegion: config.AwsEndpointSigningRegion,
URL: cfg.AwsEndpointCloudwatchLogsURL,
SigningRegion: cfg.AwsEndpointSigningRegion,
}, nil
}

return endpoints.DefaultResolver().EndpointFor(service, region, optFns...)
}

if config.AwsAssumeRole != "" {
if cfg.AwsAccessKeyID != "" {
providers = append(providers, &credentials.StaticProvider{
Value: credentials.Value{
AccessKeyID: cfg.AwsAccessKeyID,
SecretAccessKey: cfg.AwsSecretAccessKey,
SessionToken: "",
},
})
}

// add default providers
providers = append(providers, &credentials.EnvProvider{})
providers = append(providers, &credentials.SharedCredentialsProvider{Filename: "", Profile: ""})

// add the metadata service
def := defaults.Get()
def.Config.HTTPClient = config.CreateEC2IMDSHTTPClient(cfg)
def.Config.MaxRetries = aws.Int(2)
providers = append(providers, defaults.RemoteCredProvider(*def.Config, def.Handlers))

creds := credentials.NewChainCredentials(providers)

if cfg.AwsAssumeRole != "" {
sess, err := session.NewSession(&aws.Config{
Credentials: creds,
CredentialsChainVerboseErrors: aws.Bool(true),
Region: aws.String(config.AwsRegion),
HTTPClient: config.HTTPClient,
Region: aws.String(cfg.AwsRegion),
HTTPClient: cfg.HTTPClient,
EndpointResolver: endpoints.ResolverFunc(customResolver),
})
if err != nil {
return nil, err
}
creds = stscreds.NewCredentials(sess, config.AwsAssumeRole)
creds = stscreds.NewCredentials(sess, cfg.AwsAssumeRole)
}

return session.NewSession(&aws.Config{
Credentials: creds,
CredentialsChainVerboseErrors: aws.Bool(true),
Region: aws.String(config.AwsRegion),
HTTPClient: config.HTTPClient,
Region: aws.String(cfg.AwsRegion),
HTTPClient: cfg.HTTPClient,
EndpointResolver: endpoints.ResolverFunc(customResolver),
})
}
Expand Down

0 comments on commit eec7426

Please sign in to comment.