Skip to content

Commit

Permalink
Set request body before every retry
Browse files Browse the repository at this point in the history
This patch changes the interfaces of `NewRequest` and `Do` around a
little so that we can set a new request body with every request.

In the era of HTTP (1), it was safe to reuse a `Request` object, but
with the addition of HTTP/2, it's now only sometimes safe. Reusing a
`Request` with a body will break.

See some more information here:

golang/go#19653 (comment)

Fixes #642.
  • Loading branch information
brandur committed Aug 4, 2018
1 parent a3c54f6 commit 591e6a1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 22 deletions.
59 changes: 48 additions & 11 deletions stripe.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ func (s *BackendConfiguration) Call(method, path, key string, params ParamsConta
func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string, body io.Reader, params *Params, v interface{}) error {
contentType := "multipart/form-data; boundary=" + boundary

req, err := s.NewRequest(method, path, key, contentType, body, params)
req, err := s.NewRequest(method, path, key, contentType, params)
if err != nil {
return err
}

if err := s.Do(req, v); err != nil {
if err := s.Do(req, body, v); err != nil {
return err
}

Expand All @@ -198,22 +198,24 @@ func (s *BackendConfiguration) CallMultipart(method, path, key, boundary string,

// CallRaw is the implementation for invoking Stripe APIs internally without a backend.
func (s *BackendConfiguration) CallRaw(method, path, key string, form *form.Values, params *Params, v interface{}) error {
var body io.Reader
var data string
if form != nil && !form.Empty() {
data := form.Encode()
data = form.Encode()

// On `GET`, move the payload into the URL
if method == http.MethodGet {
path += "?" + data
} else {
body = bytes.NewBufferString(data)
data = ""
}
}
dataBuffer := bytes.NewBufferString(data)

req, err := s.NewRequest(method, path, key, "application/x-www-form-urlencoded", body, params)
req, err := s.NewRequest(method, path, key, "application/x-www-form-urlencoded", params)
if err != nil {
return err
}

if err := s.Do(req, v); err != nil {
if err := s.Do(req, dataBuffer, v); err != nil {
return err
}

Expand All @@ -222,14 +224,15 @@ func (s *BackendConfiguration) CallRaw(method, path, key string, form *form.Valu

// NewRequest is used by Call to generate an http.Request. It handles encoding
// parameters and attaching the appropriate headers.
func (s *BackendConfiguration) NewRequest(method, path, key, contentType string, body io.Reader, params *Params) (*http.Request, error) {
func (s *BackendConfiguration) NewRequest(method, path, key, contentType string, params *Params) (*http.Request, error) {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}

path = s.URL + path

req, err := http.NewRequest(method, path, body)
// Body is set later by `Do`.
req, err := http.NewRequest(method, path, nil)
if err != nil {
if s.LogLevel > 0 {
s.Logger.Printf("Cannot create Stripe request: %v\n", err)
Expand Down Expand Up @@ -278,7 +281,7 @@ func (s *BackendConfiguration) NewRequest(method, path, key, contentType string,
// Do is used by Call to execute an API request and parse the response. It uses
// the backend's HTTP client to execute the request and unmarshals the response
// into v. It also handles unmarshaling errors returned by the API.
func (s *BackendConfiguration) Do(req *http.Request, v interface{}) error {
func (s *BackendConfiguration) Do(req *http.Request, body io.Reader, v interface{}) error {
if s.LogLevel > 1 {
s.Logger.Printf("Requesting %v %v%v\n", req.Method, req.URL.Host, req.URL.Path)
}
Expand All @@ -288,6 +291,30 @@ func (s *BackendConfiguration) Do(req *http.Request, v interface{}) error {
for retry := 0; ; {
start := time.Now()

// This might look a little strange, but we set the request's body
// outside of `NewRequest` so that we can get a fresh version every
// time.
//
// The background is that back in the era of old style HTTP, it was
// safe to reuse `Request` objects, but with the addition of HTTP/2,
// it's now only sometimes safe. Reusing a `Request` with a body will
// break.
//
// See some details here:
//
// https://github.com/golang/go/issues/19653#issuecomment-341539160
//
// And our original bug report here:
//
// https://github.com/stripe/stripe-go/issues/642
//
// To workaround the problem, we put a fresh `Body` onto the `Request`
// every time we execute it, and this seems to empirically resolve the
// problem.
if body != nil {
req.Body = nopReadCloser{body}
}

res, err = s.HTTPClient.Do(req)

if s.LogLevel > 2 {
Expand Down Expand Up @@ -728,6 +755,16 @@ const uploadsURL = "https://uploads.stripe.com"
// Private types
//

// nopReadCloser's sole purpose is to give us a way to turn an `io.Reader` into
// an `io.ReadCloser` by adding a no-op implementation of the `Closer`
// interface. We need this because `http.Request`'s `Body` takes an
// `io.ReadCloser` instead of a `io.Reader`.
type nopReadCloser struct {
io.Reader
}

func (nopReadCloser) Close() error { return nil }

// stripeClientUserAgent contains information about the current runtime which
// is serialized and sent in the `X-Stripe-Client-User-Agent` as additional
// debugging information.
Expand Down
23 changes: 12 additions & 11 deletions stripe_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package stripe_test
import (
"context"
"encoding/json"
"io"
"net/http"
"regexp"
"runtime"
Expand All @@ -18,7 +19,7 @@ func TestBearerAuth(t *testing.T) {
c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)
key := "apiKey"

req, err := c.NewRequest("", "", key, "", nil, nil)
req, err := c.NewRequest("", "", key, "", nil)
assert.NoError(t, err)

assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization"))
Expand All @@ -28,7 +29,7 @@ func TestContext(t *testing.T) {
c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)
p := &stripe.Params{Context: context.Background()}

req, err := c.NewRequest("", "", "", "", nil, p)
req, err := c.NewRequest("", "", "", "", p)
assert.NoError(t, err)

assert.Equal(t, p.Context, req.Context())
Expand All @@ -39,7 +40,7 @@ func TestContext_Cancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
p := &stripe.Params{Context: ctx}

req, err := c.NewRequest("", "", "", "", nil, p)
req, err := c.NewRequest("", "", "", "", p)
assert.NoError(t, err)

assert.Equal(t, ctx, req.Context())
Expand All @@ -50,7 +51,7 @@ func TestContext_Cancel(t *testing.T) {
cancel()

var v interface{}
err = c.Do(req, &v)
err = c.Do(req, func() io.Reader { return nil }, &v)

// Go 1.7 will produce an error message like:
//
Expand Down Expand Up @@ -107,7 +108,7 @@ func TestMultipleAPICalls(t *testing.T) {
c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)
key := "apiKey"

req, err := c.NewRequest("", "", key, "", nil, nil)
req, err := c.NewRequest("", "", key, "", nil)
assert.NoError(t, err)

assert.Equal(t, "Bearer "+key, req.Header.Get("Authorization"))
Expand All @@ -120,7 +121,7 @@ func TestIdempotencyKey(t *testing.T) {
c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)
p := &stripe.Params{IdempotencyKey: stripe.String("idempotency-key")}

req, err := c.NewRequest("", "", "", "", nil, p)
req, err := c.NewRequest("", "", "", "", p)
assert.NoError(t, err)

assert.Equal(t, "idempotency-key", req.Header.Get("Idempotency-Key"))
Expand All @@ -138,7 +139,7 @@ func TestStripeAccount(t *testing.T) {
p := &stripe.Params{}
p.SetStripeAccount(TestMerchantID)

req, err := c.NewRequest("", "", "", "", nil, p)
req, err := c.NewRequest("", "", "", "", p)
assert.NoError(t, err)

assert.Equal(t, TestMerchantID, req.Header.Get("Stripe-Account"))
Expand All @@ -147,7 +148,7 @@ func TestStripeAccount(t *testing.T) {
func TestUserAgent(t *testing.T) {
c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)

req, err := c.NewRequest("", "", "", "", nil, nil)
req, err := c.NewRequest("", "", "", "", nil)
assert.NoError(t, err)

// We keep out version constant private to the package, so use a regexp
Expand All @@ -170,7 +171,7 @@ func TestUserAgentWithAppInfo(t *testing.T) {

c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)

req, err := c.NewRequest("", "", "", "", nil, nil)
req, err := c.NewRequest("", "", "", "", nil)
assert.NoError(t, err)

//
Expand Down Expand Up @@ -206,7 +207,7 @@ func TestUserAgentWithAppInfo(t *testing.T) {
func TestStripeClientUserAgent(t *testing.T) {
c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)

req, err := c.NewRequest("", "", "", "", nil, nil)
req, err := c.NewRequest("", "", "", "", nil)
assert.NoError(t, err)

encodedUserAgent := req.Header.Get("X-Stripe-Client-User-Agent")
Expand Down Expand Up @@ -240,7 +241,7 @@ func TestStripeClientUserAgentWithAppInfo(t *testing.T) {

c := stripe.GetBackend(stripe.APIBackend).(*stripe.BackendConfiguration)

req, err := c.NewRequest("", "", "", "", nil, nil)
req, err := c.NewRequest("", "", "", "", nil)
assert.NoError(t, err)

encodedUserAgent := req.Header.Get("X-Stripe-Client-User-Agent")
Expand Down

0 comments on commit 591e6a1

Please sign in to comment.