diff --git a/.circleci/config.yml b/.circleci/config.yml deleted file mode 100644 index c58e7ba..0000000 --- a/.circleci/config.yml +++ /dev/null @@ -1,55 +0,0 @@ -version: 2.1 - -orbs: - go: circleci/go@1.1.1 - -references: - environments: - tmp: &TEST_RESULTS_PATH /tmp/test-results - -environment: &ENVIRONMENT - TEST_RESULTS: "/tmp/test-results" - -jobs: - run-tests: - executor: - name: go/default - tag: << parameters.go-version >> - parameters: - go-version: - type: string - environment: - TEST_RESULTS: *TEST_RESULTS_PATH - steps: - - checkout - - run: mkdir -p $TEST_RESULTS/go-retryablyhttp - - go/load-cache - - go/mod-download - - go/save-cache - - run: - name: Run go format - command: | - files=$(go fmt ./...) - if [ -n "$files" ]; then - echo "The following file(s) do not conform to go fmt:" - echo "$files" - exit 1 - fi - - run: - name: Run tests with gotestsum - command: | - PACKAGE_NAMES=$(go list ./...) - gotestsum --format=short-verbose --junitfile $TEST_RESULTS/go-retryablyhttp/gotestsum-report.xml -- $PACKAGE_NAMES - - store_test_results: - path: *TEST_RESULTS_PATH - - store_artifacts: - path: *TEST_RESULTS_PATH - -workflows: - go-retryablehttp: - jobs: - - run-tests: - matrix: - parameters: - go-version: ["1.14.2"] - name: test-go-<< matrix.go-version >> diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..0cb37bd --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +version: 2 + +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "daily" + + - package-ecosystem: "gomod" + directory: "/" + schedule: + interval: "weekly" \ No newline at end of file diff --git a/.github/workflows/actionlint.yml b/.github/workflows/actionlint.yml new file mode 100644 index 0000000..42da307 --- /dev/null +++ b/.github/workflows/actionlint.yml @@ -0,0 +1,19 @@ +name: actionlint + +on: + push: + paths: + - .github/** + +permissions: + contents: read + +jobs: + actionlint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + - name: "Check GitHub workflow files" + uses: docker://docker.mirror.hashicorp.services/rhysd/actionlint:latest + with: + args: -color \ No newline at end of file diff --git a/.github/workflows/pr-gofmt.yaml b/.github/workflows/pr-gofmt.yaml new file mode 100644 index 0000000..dbe3089 --- /dev/null +++ b/.github/workflows/pr-gofmt.yaml @@ -0,0 +1,23 @@ +name: Go format check +on: + pull_request: + types: ['opened', 'synchronize'] + +jobs: + run-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + + - uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1 + with: + go-version-file: ./.go-version + + - name: Run go format + run: |- + files=$(gofmt -s -l .) + if [ -n "$files" ]; then + echo >&2 "The following file(s) are not gofmt compliant:" + echo >&2 "$files" + exit 1 + fi diff --git a/.github/workflows/pr-unit-tests-1.19.yaml b/.github/workflows/pr-unit-tests-1.19.yaml new file mode 100644 index 0000000..28a21c7 --- /dev/null +++ b/.github/workflows/pr-unit-tests-1.19.yaml @@ -0,0 +1,17 @@ +name: Unit tests (Go 1.19) +on: + pull_request: + types: ['opened', 'synchronize'] + +jobs: + run-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + + - uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1 + with: + go-version: 1.19 + + - name: Run unit tests + run: make test diff --git a/.github/workflows/pr-unit-tests-1.20.yaml b/.github/workflows/pr-unit-tests-1.20.yaml new file mode 100644 index 0000000..14fc918 --- /dev/null +++ b/.github/workflows/pr-unit-tests-1.20.yaml @@ -0,0 +1,17 @@ +name: Unit tests (Go 1.20+) +on: + pull_request: + types: ['opened', 'synchronize'] + +jobs: + run-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@a5ac7e51b41094c92402da3b24376905380afc29 # v4.1.6 + + - uses: actions/setup-go@cdcb36043654635271a94b9a6d1392de5bb323a7 # v5.0.1 + with: + go-version: 1.22 + + - name: Run unit tests + run: make test diff --git a/.go-version b/.go-version new file mode 100644 index 0000000..6fee2fe --- /dev/null +++ b/.go-version @@ -0,0 +1 @@ +1.22.2 diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..0c4c7a2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,27 @@ +## 0.7.6 (May 9, 2024) + +ENHANCEMENTS: + +- client: support a `RetryPrepare` function for modifying the request before retrying (#216) +- client: support HTTP-date values for `Retry-After` header value (#138) +- client: avoid reading entire body when the body is a `*bytes.Reader` (#197) + +BUG FIXES: + +- client: fix a broken check for invalid server certificate in go 1.20+ (#210) + +## 0.7.5 (Nov 8, 2023) + +BUG FIXES: + +- client: fixes an issue where the request body is not preserved on temporary redirects or re-established HTTP/2 connections (#207) + +## 0.7.4 (Jun 6, 2023) + +BUG FIXES: + +- client: fixing an issue where the Content-Type header wouldn't be sent with an empty payload when using HTTP/2 (#194) + +## 0.7.3 (May 15, 2023) + +Initial release diff --git a/CODEOWNERS b/CODEOWNERS new file mode 100644 index 0000000..d6dd78a --- /dev/null +++ b/CODEOWNERS @@ -0,0 +1 @@ +* @hashicorp/go-retryablehttp-maintainers diff --git a/LICENSE b/LICENSE index e87a115..f4f97ee 100644 --- a/LICENSE +++ b/LICENSE @@ -1,3 +1,5 @@ +Copyright (c) 2015 HashiCorp, Inc. + Mozilla Public License, version 2.0 1. Definitions diff --git a/Makefile b/Makefile index da17640..5255241 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,7 @@ default: test test: go vet ./... - go test -race ./... + go test -v -race ./... updatedeps: go get -f -t -u ./... diff --git a/README.md b/README.md index 8943bec..145a62f 100644 --- a/README.md +++ b/README.md @@ -59,4 +59,4 @@ standardClient := retryClient.StandardClient() // *http.Client ``` For more usage and examples see the -[godoc](http://godoc.org/github.com/hashicorp/go-retryablehttp). +[pkg.go.dev](https://pkg.go.dev/github.com/hashicorp/go-retryablehttp). diff --git a/cert_error_go119.go b/cert_error_go119.go new file mode 100644 index 0000000..b2b27e8 --- /dev/null +++ b/cert_error_go119.go @@ -0,0 +1,14 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build !go1.20 +// +build !go1.20 + +package retryablehttp + +import "crypto/x509" + +func isCertError(err error) bool { + _, ok := err.(x509.UnknownAuthorityError) + return ok +} diff --git a/cert_error_go120.go b/cert_error_go120.go new file mode 100644 index 0000000..a3cd315 --- /dev/null +++ b/cert_error_go120.go @@ -0,0 +1,14 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +//go:build go1.20 +// +build go1.20 + +package retryablehttp + +import "crypto/tls" + +func isCertError(err error) bool { + _, ok := err.(*tls.CertificateVerificationError) + return ok +} diff --git a/client.go b/client.go index 90a738c..efee53c 100644 --- a/client.go +++ b/client.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + // Package retryablehttp provides a familiar HTTP client interface with // automatic retries and exponential backoff. It is a thin wrapper over the // standard net/http client library and exposes nearly the same public API. @@ -24,10 +27,8 @@ package retryablehttp import ( "bytes" "context" - "crypto/x509" "fmt" "io" - "io/ioutil" "log" "math" "math/rand" @@ -60,6 +61,10 @@ var ( // limit the size we consume to respReadLimit. respReadLimit = int64(4096) + // timeNow sets the function that returns the current time. + // This defaults to time.Now. Changes to this should only be done in tests. + timeNow = time.Now + // A regular expression to match the error returned by net/http when the // configured number of redirects is exhausted. This error isn't typed // specifically so we resort to matching on the error string. @@ -69,11 +74,33 @@ var ( // scheme specified in the URL is invalid. This error isn't typed // specifically so we resort to matching on the error string. schemeErrorRe = regexp.MustCompile(`unsupported protocol scheme`) + + // A regular expression to match the error returned by net/http when a + // request header or value is invalid. This error isn't typed + // specifically so we resort to matching on the error string. + invalidHeaderErrorRe = regexp.MustCompile(`invalid header`) + + // A regular expression to match the error returned by net/http when the + // TLS certificate is not trusted. This error isn't typed + // specifically so we resort to matching on the error string. + notTrustedErrorRe = regexp.MustCompile(`certificate is not trusted`) ) // ReaderFunc is the type of function that can be given natively to NewRequest type ReaderFunc func() (io.Reader, error) +// ResponseHandlerFunc is a type of function that takes in a Response, and does something with it. +// The ResponseHandlerFunc is called when the HTTP client successfully receives a response and the +// CheckRetry function indicates that a retry of the base request is not necessary. +// If an error is returned from this function, the CheckRetry policy will be used to determine +// whether to retry the whole request (including this handler). +// +// Make sure to check status codes! Even if the request was completed it may have a non-2xx status code. +// +// The response body is not automatically closed. It must be closed either by the ResponseHandlerFunc or +// by the caller out-of-band. Failure to do so will result in a memory leak. +type ResponseHandlerFunc func(*http.Response) error + // LenReader is an interface implemented by many in-memory io.Reader's. Used // for automatically sending the right Content-Length header when possible. type LenReader interface { @@ -86,6 +113,8 @@ type Request struct { // used to rewind the request data in between retries. body ReaderFunc + responseHandler ResponseHandlerFunc + // Embed an HTTP request directly. This makes a *Request act exactly // like an *http.Request so that all meta methods are supported. *http.Request @@ -95,11 +124,17 @@ type Request struct { // with its context changed to ctx. The provided ctx must be non-nil. func (r *Request) WithContext(ctx context.Context) *Request { return &Request{ - body: r.body, - Request: r.Request.WithContext(ctx), + body: r.body, + responseHandler: r.responseHandler, + Request: r.Request.WithContext(ctx), } } +// SetResponseHandler allows setting the response handler. +func (r *Request) SetResponseHandler(fn ResponseHandlerFunc) { + r.responseHandler = fn +} + // BodyBytes allows accessing the request body. It is an analogue to // http.Request's Body variable, but it returns a copy of the underlying data // rather than consuming it. @@ -132,6 +167,20 @@ func (r *Request) SetBody(rawBody interface{}) error { } r.body = bodyReader r.ContentLength = contentLength + if bodyReader != nil { + r.GetBody = func() (io.ReadCloser, error) { + body, err := bodyReader() + if err != nil { + return nil, err + } + if rc, ok := body.(io.ReadCloser); ok { + return rc, nil + } + return io.NopCloser(body), nil + } + } else { + r.GetBody = func() (io.ReadCloser, error) { return http.NoBody, nil } + } return nil } @@ -206,21 +255,19 @@ func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, erro // deal with it seeking so want it to match here instead of the // io.ReadSeeker case. case *bytes.Reader: - buf, err := ioutil.ReadAll(body) - if err != nil { - return nil, 0, err - } + snapshot := *body bodyReader = func() (io.Reader, error) { - return bytes.NewReader(buf), nil + r := snapshot + return &r, nil } - contentLength = int64(len(buf)) + contentLength = int64(body.Len()) // Compat case case io.ReadSeeker: raw := body bodyReader = func() (io.Reader, error) { _, err := raw.Seek(0, 0) - return ioutil.NopCloser(raw), err + return io.NopCloser(raw), err } if lr, ok := raw.(LenReader); ok { contentLength = int64(lr.Len()) @@ -228,14 +275,21 @@ func getBodyReaderAndContentLength(rawBody interface{}) (ReaderFunc, int64, erro // Read all in so we can reset case io.Reader: - buf, err := ioutil.ReadAll(body) + buf, err := io.ReadAll(body) if err != nil { return nil, 0, err } - bodyReader = func() (io.Reader, error) { - return bytes.NewReader(buf), nil + if len(buf) == 0 { + bodyReader = func() (io.Reader, error) { + return http.NoBody, nil + } + contentLength = 0 + } else { + bodyReader = func() (io.Reader, error) { + return bytes.NewReader(buf), nil + } + contentLength = int64(len(buf)) } - contentLength = int64(len(buf)) // No body provided, nothing to do case nil: @@ -254,7 +308,7 @@ func FromRequest(r *http.Request) (*Request, error) { return nil, err } // Could assert contentLength == r.ContentLength - return &Request{bodyReader, r}, nil + return &Request{body: bodyReader, Request: r}, nil } // NewRequest creates a new wrapped request. @@ -267,18 +321,19 @@ func NewRequest(method, url string, rawBody interface{}) (*Request, error) { // The context controls the entire lifetime of a request and its response: // obtaining a connection, sending the request, and reading the response headers and body. func NewRequestWithContext(ctx context.Context, method, url string, rawBody interface{}) (*Request, error) { - bodyReader, contentLength, err := getBodyReaderAndContentLength(rawBody) + httpReq, err := http.NewRequestWithContext(ctx, method, url, nil) if err != nil { return nil, err } - httpReq, err := http.NewRequestWithContext(ctx, method, url, nil) - if err != nil { + req := &Request{ + Request: httpReq, + } + if err := req.SetBody(rawBody); err != nil { return nil, err } - httpReq.ContentLength = contentLength - return &Request{bodyReader, httpReq}, nil + return req, nil } // Logger interface allows to use other loggers than @@ -343,6 +398,9 @@ type Backoff func(min, max time.Duration, attemptNum int, resp *http.Response) t // attempted. If overriding this, be sure to close the body if needed. type ErrorHandler func(resp *http.Response, err error, numTries int) (*http.Response, error) +// PrepareRetry is called before retry operation. It can be used for example to re-sign the request +type PrepareRetry func(req *http.Request) error + // Client is used to make HTTP requests. It adds additional functionality // like automatic retries to tolerate minor outages. type Client struct { @@ -371,6 +429,9 @@ type Client struct { // ErrorHandler specifies the custom error handler to use, if any ErrorHandler ErrorHandler + // PrepareRetry can prepare the request for retry operation, for example re-sign it + PrepareRetry PrepareRetry + loggerInit sync.Once clientInit sync.Once } @@ -444,8 +505,16 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) { return false, v } + // Don't retry if the error was due to an invalid header. + if invalidHeaderErrorRe.MatchString(v.Error()) { + return false, v + } + // Don't retry if the error was due to TLS cert verification failure. - if _, ok := v.Err.(x509.UnknownAuthorityError); ok { + if notTrustedErrorRe.MatchString(v.Error()) { + return false, v + } + if isCertError(v.Err) { return false, v } } @@ -482,10 +551,8 @@ func baseRetryPolicy(resp *http.Response, err error) (bool, error) { func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) time.Duration { if resp != nil { if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable { - if s, ok := resp.Header["Retry-After"]; ok { - if sleep, err := strconv.ParseInt(s[0], 10, 64); err == nil { - return time.Second * time.Duration(sleep) - } + if sleep, ok := parseRetryAfterHeader(resp.Header["Retry-After"]); ok { + return sleep } } } @@ -498,6 +565,41 @@ func DefaultBackoff(min, max time.Duration, attemptNum int, resp *http.Response) return sleep } +// parseRetryAfterHeader parses the Retry-After header and returns the +// delay duration according to the spec: https://httpwg.org/specs/rfc7231.html#header.retry-after +// The bool returned will be true if the header was successfully parsed. +// Otherwise, the header was either not present, or was not parseable according to the spec. +// +// Retry-After headers come in two flavors: Seconds or HTTP-Date +// +// Examples: +// * Retry-After: Fri, 31 Dec 1999 23:59:59 GMT +// * Retry-After: 120 +func parseRetryAfterHeader(headers []string) (time.Duration, bool) { + if len(headers) == 0 || headers[0] == "" { + return 0, false + } + header := headers[0] + // Retry-After: 120 + if sleep, err := strconv.ParseInt(header, 10, 64); err == nil { + if sleep < 0 { // a negative sleep doesn't make sense + return 0, false + } + return time.Second * time.Duration(sleep), true + } + + // Retry-After: Fri, 31 Dec 1999 23:59:59 GMT + retryTime, err := time.Parse(time.RFC1123, header) + if err != nil { + return 0, false + } + if until := retryTime.Sub(timeNow()); until > 0 { + return until, true + } + // date is in the past + return 0, true +} + // LinearJitterBackoff provides a callback for Client.Backoff which will // perform linear backoff based on the attempt number and with jitter to // prevent a thundering herd. @@ -525,13 +627,13 @@ func LinearJitterBackoff(min, max time.Duration, attemptNum int, resp *http.Resp } // Seed rand; doing this every time is fine - rand := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) + source := rand.New(rand.NewSource(int64(time.Now().Nanosecond()))) // Pick a random number that lies somewhere between the min and max and // multiply by the attemptNum. attemptNum starts at zero so we always // increment here. We first get a random percentage, then apply that to the // difference between min and max, and add to min. - jitter := rand.Float64() * float64(max-min) + jitter := source.Float64() * float64(max-min) jitterMin := int64(jitter) + int64(min) return time.Duration(jitterMin * int64(attemptNum)) } @@ -565,9 +667,10 @@ func (c *Client) Do(req *Request) (*http.Response, error) { var resp *http.Response var attempt int var shouldRetry bool - var doErr, checkErr error + var doErr, respErr, checkErr, prepareErr error for i := 0; ; i++ { + doErr, respErr, prepareErr = nil, nil, nil attempt++ // Always rewind the request body when non-nil. @@ -580,7 +683,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { if c, ok := body.(io.ReadCloser); ok { req.Body = c } else { - req.Body = ioutil.NopCloser(body) + req.Body = io.NopCloser(body) } } @@ -600,13 +703,21 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // Check if we should continue with retries. shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, doErr) + if !shouldRetry && doErr == nil && req.responseHandler != nil { + respErr = req.responseHandler(resp) + shouldRetry, checkErr = c.CheckRetry(req.Context(), resp, respErr) + } - if doErr != nil { + err := doErr + if respErr != nil { + err = respErr + } + if err != nil { switch v := logger.(type) { case LeveledLogger: - v.Error("request failed", "error", doErr, "method", req.Method, "url", redactURL(req.URL)) + v.Error("request failed", "error", err, "method", req.Method, "url", redactURL(req.URL)) case Logger: - v.Printf("[ERR] %s %s request failed: %v", req.Method, redactURL(req.URL), doErr) + v.Printf("[ERR] %s %s request failed: %v", req.Method, redactURL(req.URL), err) } } else { // Call this here to maintain the behavior of logging all requests, @@ -666,18 +777,31 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // without racing against the closeBody call in persistConn.writeLoop. httpreq := *req.Request req.Request = &httpreq + + if c.PrepareRetry != nil { + if err := c.PrepareRetry(req.Request); err != nil { + prepareErr = err + break + } + } } // this is the closest we have to success criteria - if doErr == nil && checkErr == nil && !shouldRetry { + if doErr == nil && respErr == nil && checkErr == nil && prepareErr == nil && !shouldRetry { return resp, nil } defer c.HTTPClient.CloseIdleConnections() - err := doErr - if checkErr != nil { + var err error + if prepareErr != nil { + err = prepareErr + } else if checkErr != nil { err = checkErr + } else if respErr != nil { + err = respErr + } else { + err = doErr } if c.ErrorHandler != nil { @@ -704,7 +828,7 @@ func (c *Client) Do(req *Request) (*http.Response, error) { // Try to read the response body so we can reuse this connection. func (c *Client) drainBody(body io.ReadCloser) { defer body.Close() - _, err := io.Copy(ioutil.Discard, io.LimitReader(body, respReadLimit)) + _, err := io.Copy(io.Discard, io.LimitReader(body, respReadLimit)) if err != nil { if c.logger() != nil { switch v := c.logger().(type) { diff --git a/client_test.go b/client_test.go index 9438648..cc05e91 100644 --- a/client_test.go +++ b/client_test.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package retryablehttp import ( @@ -6,12 +9,12 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/httptest" "net/http/httputil" "net/url" + "strconv" "strings" "sync/atomic" "testing" @@ -167,13 +170,13 @@ func testClientDo(t *testing.T, body interface{}) { // Send the request var resp *http.Response doneCh := make(chan struct{}) + errCh := make(chan error, 1) go func() { defer close(doneCh) + defer close(errCh) var err error resp, err = client.Do(req) - if err != nil { - t.Fatalf("err: %v", err) - } + errCh <- err }() select { @@ -202,7 +205,7 @@ func testClientDo(t *testing.T, body interface{}) { } // Check the payload - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("err: %s", err) } @@ -247,6 +250,228 @@ func testClientDo(t *testing.T, body interface{}) { if retryCount < 0 { t.Fatal("request log hook was not called") } + + err = <-errCh + if err != nil { + t.Fatalf("err: %v", err) + } +} + +func TestClient_Do_WithResponseHandler(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.RetryMax = 2 + + var checks int + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + checks++ + if err != nil && strings.Contains(err.Error(), "nonretryable") { + return false, nil + } + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + var shouldSucceed bool + tests := []struct { + name string + handler ResponseHandlerFunc + expectedChecks int // often 2x number of attempts since we check twice + err string + }{ + { + name: "nil handler", + handler: nil, + expectedChecks: 1, + }, + { + name: "handler always succeeds", + handler: func(*http.Response) error { + return nil + }, + expectedChecks: 2, + }, + { + name: "handler always fails in a retryable way", + handler: func(*http.Response) error { + return errors.New("retryable failure") + }, + expectedChecks: 6, + }, + { + name: "handler always fails in a nonretryable way", + handler: func(*http.Response) error { + return errors.New("nonretryable failure") + }, + expectedChecks: 2, + }, + { + name: "handler succeeds on second attempt", + handler: func(*http.Response) error { + if shouldSucceed { + return nil + } + shouldSucceed = true + return errors.New("retryable failure") + }, + expectedChecks: 4, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checks = 0 + shouldSucceed = false + // Create the request + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.SetResponseHandler(tt.handler) + + // Send the request. + _, err = client.Do(req) + if err != nil && !strings.Contains(err.Error(), tt.err) { + t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) + } + if err == nil && tt.err != "" { + t.Fatalf("no error, expected: %s", tt.err) + } + + if checks != tt.expectedChecks { + t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) + } + }) + } +} + +func TestClient_Do_WithPrepareRetry(t *testing.T) { + // Create the client. Use short retry windows so we fail faster. + client := NewClient() + client.RetryWaitMin = 10 * time.Millisecond + client.RetryWaitMax = 10 * time.Millisecond + client.RetryMax = 2 + + var checks int + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + checks++ + if err != nil && strings.Contains(err.Error(), "nonretryable") { + return false, nil + } + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + var prepareChecks int + client.PrepareRetry = func(req *http.Request) error { + prepareChecks++ + req.Header.Set("foo", strconv.Itoa(prepareChecks)) + return nil + } + + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + var shouldSucceed bool + tests := []struct { + name string + handler ResponseHandlerFunc + expectedChecks int // often 2x number of attempts since we check twice + expectedPrepareChecks int + err string + }{ + { + name: "nil handler", + handler: nil, + expectedChecks: 1, + expectedPrepareChecks: 0, + }, + { + name: "handler always succeeds", + handler: func(*http.Response) error { + return nil + }, + expectedChecks: 2, + expectedPrepareChecks: 0, + }, + { + name: "handler always fails in a retryable way", + handler: func(*http.Response) error { + return errors.New("retryable failure") + }, + expectedChecks: 6, + expectedPrepareChecks: 2, + }, + { + name: "handler always fails in a nonretryable way", + handler: func(*http.Response) error { + return errors.New("nonretryable failure") + }, + expectedChecks: 2, + expectedPrepareChecks: 0, + }, + { + name: "handler succeeds on second attempt", + handler: func(*http.Response) error { + if shouldSucceed { + return nil + } + shouldSucceed = true + return errors.New("retryable failure") + }, + expectedChecks: 4, + expectedPrepareChecks: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + checks = 0 + prepareChecks = 0 + shouldSucceed = false + // Create the request + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.SetResponseHandler(tt.handler) + + // Send the request. + _, err = client.Do(req) + if err != nil && !strings.Contains(err.Error(), tt.err) { + t.Fatalf("error does not match expectation, expected: %s, got: %s", tt.err, err.Error()) + } + if err == nil && tt.err != "" { + t.Fatalf("no error, expected: %s", tt.err) + } + + if checks != tt.expectedChecks { + t.Fatalf("expected %d attempts, got %d attempts", tt.expectedChecks, checks) + } + + if prepareChecks != tt.expectedPrepareChecks { + t.Fatalf("expected %d attempts of prepare check, got %d attempts", tt.expectedPrepareChecks, prepareChecks) + } + header := req.Request.Header.Get("foo") + if tt.expectedPrepareChecks == 0 && header != "" { + t.Fatalf("expected no changes to request header 'foo', but got '%s'", header) + } + expectedHeader := strconv.Itoa(tt.expectedPrepareChecks) + if tt.expectedPrepareChecks != 0 && header != expectedHeader { + t.Fatalf("expected changes in request header 'foo' '%s', but got '%s'", expectedHeader, header) + } + + }) + } } func TestClient_Do_fails(t *testing.T) { @@ -339,6 +564,12 @@ func TestClient_RequestLogHook(t *testing.T) { t.Run("RequestLogHook successfully called with nil Logger", func(t *testing.T) { testClientRequestLogHook(t, nil) }) + t.Run("RequestLogHook successfully called with nil typed Logger", func(t *testing.T) { + testClientRequestLogHook(t, Logger(nil)) + }) + t.Run("RequestLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { + testClientRequestLogHook(t, LeveledLogger(nil)) + }) } func testClientRequestLogHook(t *testing.T, logger interface{}) { @@ -400,6 +631,14 @@ func TestClient_ResponseLogHook(t *testing.T) { buf := new(bytes.Buffer) testClientResponseLogHook(t, nil, buf) }) + t.Run("ResponseLogHook successfully called with nil typed Logger", func(t *testing.T) { + buf := new(bytes.Buffer) + testClientResponseLogHook(t, Logger(nil), buf) + }) + t.Run("ResponseLogHook successfully called with nil typed LeveledLogger", func(t *testing.T) { + buf := new(bytes.Buffer) + testClientResponseLogHook(t, LeveledLogger(nil), buf) + }) } func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) { @@ -432,7 +671,7 @@ func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) { } } else { // Log the response body when we get a 500 - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("err: %v", err) } @@ -453,7 +692,7 @@ func testClientResponseLogHook(t *testing.T, l interface{}, buf *bytes.Buffer) { // Make sure we can read the response body still, since we did not // read or close it from the response log hook. - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Fatalf("err: %v", err) } @@ -553,12 +792,67 @@ func TestClient_CheckRetry(t *testing.T) { } } +func testStaticTime(t *testing.T) { + timeNow = func() time.Time { + now, err := time.Parse(time.RFC1123, "Fri, 31 Dec 1999 23:59:57 GMT") + if err != nil { + panic(err) + } + return now + } + t.Cleanup(func() { + timeNow = time.Now + }) +} + +func TestParseRetryAfterHeader(t *testing.T) { + testStaticTime(t) + tests := []struct { + name string + headers []string + sleep time.Duration + ok bool + }{ + {"seconds", []string{"2"}, time.Second * 2, true}, + {"date", []string{"Fri, 31 Dec 1999 23:59:59 GMT"}, time.Second * 2, true}, + {"past-date", []string{"Fri, 31 Dec 1999 23:59:00 GMT"}, 0, true}, + {"nil", nil, 0, false}, + {"two-headers", []string{"2", "3"}, time.Second * 2, true}, + {"empty", []string{""}, 0, false}, + {"negative", []string{"-2"}, 0, false}, + {"bad-date", []string{"Fri, 32 Dec 1999 23:59:59 GMT"}, 0, false}, + {"bad-date-format", []string{"badbadbad"}, 0, false}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + sleep, ok := parseRetryAfterHeader(test.headers) + if ok != test.ok { + t.Fatalf("expected ok=%t, got ok=%t", test.ok, ok) + } + if sleep != test.sleep { + t.Fatalf("expected sleep=%v, got sleep=%v", test.sleep, sleep) + } + }) + } +} + func TestClient_DefaultBackoff(t *testing.T) { - for _, code := range []int{http.StatusTooManyRequests, http.StatusServiceUnavailable} { - t.Run(fmt.Sprintf("http_%d", code), func(t *testing.T) { + testStaticTime(t) + tests := []struct { + name string + code int + retryHeader string + }{ + {"http_429_seconds", http.StatusTooManyRequests, "2"}, + {"http_429_date", http.StatusTooManyRequests, "Fri, 31 Dec 1999 23:59:59 GMT"}, + {"http_503_seconds", http.StatusServiceUnavailable, "2"}, + {"http_503_date", http.StatusServiceUnavailable, "Fri, 31 Dec 1999 23:59:59 GMT"}, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Retry-After", "2") - http.Error(w, fmt.Sprintf("test_%d_body", code), code) + w.Header().Set("Retry-After", test.retryHeader) + http.Error(w, fmt.Sprintf("test_%d_body", test.code), test.code) })) defer ts.Close() @@ -613,7 +907,7 @@ func TestClient_DefaultRetryPolicy_TLS(t *testing.T) { func TestClient_DefaultRetryPolicy_redirects(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/", 302) + http.Redirect(w, r, "/", http.StatusFound) })) defer ts.Close() @@ -656,6 +950,60 @@ func TestClient_DefaultRetryPolicy_invalidscheme(t *testing.T) { } } +func TestClient_DefaultRetryPolicy_invalidheadername(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.Header.Set("Header-Name-\033", "header value") + _, err = client.StandardClient().Do(req) + if err == nil { + t.Fatalf("expected header error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + +func TestClient_DefaultRetryPolicy_invalidheadervalue(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(200) + })) + defer ts.Close() + + attempts := 0 + client := NewClient() + client.CheckRetry = func(_ context.Context, resp *http.Response, err error) (bool, error) { + attempts++ + return DefaultRetryPolicy(context.TODO(), resp, err) + } + + req, err := http.NewRequest(http.MethodGet, ts.URL, nil) + if err != nil { + t.Fatalf("err: %v", err) + } + req.Header.Set("Header-Name", "bad header value \033") + _, err = client.StandardClient().Do(req) + if err == nil { + t.Fatalf("expected header value error, got nil") + } + if attempts != 1 { + t.Fatalf("expected 1 attempt, got %d", attempts) + } +} + func TestClient_CheckRetryStop(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { http.Error(w, "test_500_body", http.StatusInternalServerError) @@ -717,7 +1065,7 @@ func TestClient_Post(t *testing.T) { } // Check the payload - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("err: %s", err) } @@ -755,7 +1103,7 @@ func TestClient_PostForm(t *testing.T) { } // Check the payload - body, err := ioutil.ReadAll(r.Body) + body, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("err: %s", err) } @@ -876,3 +1224,62 @@ func TestClient_StandardClient(t *testing.T) { t.Fatalf("expected %v, got %v", client, v) } } + +func TestClient_RedirectWithBody(t *testing.T) { + var redirects int32 + // Mock server which always responds 200. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.RequestURI { + case "/redirect": + w.Header().Set("Location", "/target") + w.WriteHeader(http.StatusTemporaryRedirect) + case "/target": + atomic.AddInt32(&redirects, 1) + w.WriteHeader(http.StatusCreated) + default: + t.Fatalf("bad uri: %s", r.RequestURI) + } + })) + defer ts.Close() + + client := NewClient() + client.RequestLogHook = func(logger Logger, req *http.Request, retryNumber int) { + if _, err := req.GetBody(); err != nil { + t.Fatalf("unexpected error with GetBody: %v", err) + } + } + // create a request with a body + req, err := NewRequest(http.MethodPost, ts.URL+"/redirect", strings.NewReader(`{"foo":"bar"}`)) + if err != nil { + t.Fatalf("err: %v", err) + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected status code 201, got: %d", resp.StatusCode) + } + + // now one without a body + if err := req.SetBody(nil); err != nil { + t.Fatalf("err: %v", err) + } + + resp, err = client.Do(req) + if err != nil { + t.Fatalf("err: %v", err) + } + resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + t.Fatalf("expected status code 201, got: %d", resp.StatusCode) + } + + if atomic.LoadInt32(&redirects) != 2 { + t.Fatalf("Expected the client to be redirected 2 times, got: %d", atomic.LoadInt32(&redirects)) + } +} diff --git a/go.mod b/go.mod index 7cc02b7..12c7872 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,15 @@ module github.com/hashicorp/go-retryablehttp require ( - github.com/hashicorp/go-cleanhttp v0.5.1 - github.com/hashicorp/go-hclog v0.9.2 + github.com/hashicorp/go-cleanhttp v0.5.2 + github.com/hashicorp/go-hclog v1.6.3 ) -go 1.13 +require ( + github.com/fatih/color v1.16.0 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + golang.org/x/sys v0.20.0 // indirect +) + +go 1.19 diff --git a/go.sum b/go.sum index 71afe56..a5da2ce 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,36 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/hashicorp/go-cleanhttp v0.5.1 h1:dH3aiDG9Jvb5r5+bYHsikaOUIpcM0xvgMXVoDkXMzJM= -github.com/hashicorp/go-cleanhttp v0.5.1/go.mod h1:JpRdi6/HCYpAwUzNwuwqhbovhLtngrth3wmdIIUrZ80= -github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= -github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= +github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= +github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= +github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= +github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= +github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= +github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= +github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/testify v1.2.2 h1:bSDNvY7ZPG5RlJ8otE/7V6gMiyenm9RtJ7IUVIAoJ1w= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8s= +github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220503163025-988cb79eb6c6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/roundtripper.go b/roundtripper.go index 8f3ee35..8c407ad 100644 --- a/roundtripper.go +++ b/roundtripper.go @@ -1,3 +1,6 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package retryablehttp import ( diff --git a/roundtripper_test.go b/roundtripper_test.go index 93ff0c5..975d8b1 100644 --- a/roundtripper_test.go +++ b/roundtripper_test.go @@ -1,9 +1,12 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + package retryablehttp import ( "context" "errors" - "io/ioutil" + "io" "net" "net/http" "net/http/httptest" @@ -85,7 +88,7 @@ func TestRoundTripper_RoundTrip(t *testing.T) { if resp.StatusCode != 200 { t.Fatalf("expected 200, got %d", resp.StatusCode) } - if v, err := ioutil.ReadAll(resp.Body); err != nil { + if v, err := io.ReadAll(resp.Body); err != nil { t.Fatal(err) } else if string(v) != "success!" { t.Fatalf("expected %q, got %q", "success!", v) @@ -107,12 +110,12 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) { expectedError := &url.Error{ Op: "Get", - URL: "http://this-url-does-not-exist-ed2fb.com/", + URL: "http://999.999.999.999:999/", Err: &net.OpError{ Op: "dial", Net: "tcp", Err: &net.DNSError{ - Name: "this-url-does-not-exist-ed2fb.com", + Name: "999.999.999.999", Err: "no such host", IsNotFound: true, }, @@ -121,10 +124,10 @@ func TestRoundTripper_TransportFailureErrorHandling(t *testing.T) { // Get the standard client and execute the request. client := retryClient.StandardClient() - _, err := client.Get("http://this-url-does-not-exist-ed2fb.com/") + _, err := client.Get("http://999.999.999.999:999/") // assert expectations - if !reflect.DeepEqual(normalizeError(err), expectedError) { + if !reflect.DeepEqual(expectedError, normalizeError(err)) { t.Fatalf("expected %q, got %q", expectedError, err) } }