diff --git a/README.md b/README.md index 86531f5..5afd5cc 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ err := hx.Get(ctx, "https://api.example.com/contents/1", ```go func init() { - defaultTransport := hx.CloneTransport(http.DefaultTransport.(*http.Transport)) + defaultTransport := hxutil.CloneTransport(http.DefaultTransport.(*http.Transport)) // Tweak keep-alive configuration defaultTransport.MaxIdleConns = 500 diff --git a/client_test.go b/client_test.go index c1f5866..79d1476 100644 --- a/client_test.go +++ b/client_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/izumin5210/hx" + "github.com/izumin5210/hx/hxutil" ) func TestClient(t *testing.T) { @@ -321,8 +322,8 @@ func TestClient(t *testing.T) { }) t.Run("with Transport", func(t *testing.T) { - transport := &fakeTransport{ - RoundTripFunc: func(rt http.RoundTripper, req *http.Request) (*http.Response, error) { + transport := &hxutil.RoundTripperWrapper{ + Func: func(req *http.Request, rt http.RoundTripper) (*http.Response, error) { req.SetBasicAuth("foo", "bar") return rt.RoundTrip(req) }, @@ -339,9 +340,8 @@ func TestClient(t *testing.T) { t.Run("with TransportFrom", func(t *testing.T) { err := hx.Get(context.Background(), ts.URL+"/basic_auth", hx.TransportFrom(func(base http.RoundTripper) http.RoundTripper { - return &fakeTransport{ - Base: base, - RoundTripFunc: func(rt http.RoundTripper, req *http.Request) (*http.Response, error) { + return &hxutil.RoundTripperWrapper{ + Func: func(req *http.Request, rt http.RoundTripper) (*http.Response, error) { req.SetBasicAuth("foo", "bar") return rt.RoundTrip(req) }, @@ -353,6 +353,19 @@ func TestClient(t *testing.T) { t.Errorf("returned %v, want nil", err) } }) + + t.Run("with TransportFunc", func(t *testing.T) { + err := hx.Get(context.Background(), ts.URL+"/basic_auth", + hx.TransportFunc(func(r *http.Request, next http.RoundTripper) (*http.Response, error) { + r.SetBasicAuth("foo", "bar") + return next.RoundTrip(r) + }), + hx.WhenFailure(hx.AsError()), + ) + if err != nil { + t.Errorf("returned %v, want nil", err) + } + }) } type fakeError struct { @@ -360,16 +373,3 @@ type fakeError struct { } func (e fakeError) Error() string { return e.Message } - -type fakeTransport struct { - Base http.RoundTripper - RoundTripFunc func(http.RoundTripper, *http.Request) (*http.Response, error) -} - -func (t *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { - base := t.Base - if base == nil { - base = http.DefaultTransport - } - return t.RoundTripFunc(base, req) -} diff --git a/helper.go b/helper.go index 547e48c..db13b82 100644 --- a/helper.go +++ b/helper.go @@ -1,49 +1,12 @@ package hx import ( - "bytes" "fmt" - "io/ioutil" - "net/http" "net/url" "path" - "reflect" "strings" ) -func DrainResponseBody(r *http.Response) error { - var buf bytes.Buffer - _, err := buf.ReadFrom(r.Body) - if err != nil { - return err - } - err = r.Body.Close() - if err != nil { - return err - } - r.Body = ioutil.NopCloser(&buf) - return nil -} - -// CloneTransport creates a new *http.Transport object that has copied attributes from a given one. -func CloneTransport(in *http.Transport) *http.Transport { - out := new(http.Transport) - outRv := reflect.ValueOf(out).Elem() - - rv := reflect.ValueOf(in).Elem() - rt := rv.Type() - - n := rt.NumField() - for i := 0; i < n; i++ { - src, dst := rv.Field(i), outRv.Field(i) - if src.Type().AssignableTo(dst.Type()) && dst.CanSet() { - dst.Set(src) - } - } - - return out -} - func Path(elem ...interface{}) string { chunks := make([]string, len(elem)) for i, e := range elem { diff --git a/helper_test.go b/helper_test.go index f2ad3ba..deffb13 100644 --- a/helper_test.go +++ b/helper_test.go @@ -1,70 +1,11 @@ package hx_test import ( - "net" - "net/http" "testing" - "time" "github.com/izumin5210/hx" ) -func TestCloneTransport(t *testing.T) { - // https://github.com/golang/go/blob/go1.13.4/src/net/http/transport.go#L42-L54 - base := &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: (&net.Dialer{ - Timeout: 30 * time.Second, - KeepAlive: 30 * time.Second, - DualStack: true, - }).DialContext, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - - cloned := hx.CloneTransport(base) - cloned.MaxIdleConns = 500 - cloned.MaxIdleConnsPerHost = 100 - - if cloned.Proxy == nil { - t.Errorf("Proxy should be copied") - } - - if cloned.DialContext == nil { - t.Errorf("DialContext should be copied") - } - - if got, want := cloned.IdleConnTimeout, base.IdleConnTimeout; got != want { - t.Errorf("cloned IdleConnTimeout is %s, want %s", got, want) - } - - if got, want := cloned.TLSHandshakeTimeout, base.TLSHandshakeTimeout; got != want { - t.Errorf("cloned TLSHandshakeTimeout is %s, want %s", got, want) - } - - if got, want := cloned.ExpectContinueTimeout, base.ExpectContinueTimeout; got != want { - t.Errorf("cloned ExpectContinueTimeout is %s, want %s", got, want) - } - - if got, want := base.MaxIdleConns, 100; got != want { - t.Errorf("base MaxIdleConns is %d, want %d", got, want) - } - - if got, want := cloned.MaxIdleConns, 500; got != want { - t.Errorf("cloned MaxIdleConns is %d, want %d", got, want) - } - - if got, want := base.MaxIdleConnsPerHost, 0; got != want { - t.Errorf("base MaxIdleConnsPerHost is %d, want %d", got, want) - } - - if got, want := cloned.MaxIdleConnsPerHost, 100; got != want { - t.Errorf("cloned MaxIdleConnsPerHost is %d, want %d", got, want) - } -} - func TestPath(t *testing.T) { cases := []struct { test string diff --git a/hxutil/drain.go b/hxutil/drain.go new file mode 100644 index 0000000..7d876a7 --- /dev/null +++ b/hxutil/drain.go @@ -0,0 +1,21 @@ +package hxutil + +import ( + "bytes" + "io/ioutil" + "net/http" +) + +func DrainResponseBody(r *http.Response) error { + var buf bytes.Buffer + _, err := buf.ReadFrom(r.Body) + if err != nil { + return err + } + err = r.Body.Close() + if err != nil { + return err + } + r.Body = ioutil.NopCloser(&buf) + return nil +} diff --git a/hxutil/drain_test.go b/hxutil/drain_test.go new file mode 100644 index 0000000..d31c709 --- /dev/null +++ b/hxutil/drain_test.go @@ -0,0 +1,53 @@ +package hxutil + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" +) + +func TestDrainResponseBody(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodGet && r.URL.Path == "/ping": + w.Write([]byte("pong")) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + + t.Run("success", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/ping") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + err = DrainResponseBody(resp) + if err != nil { + t.Errorf("returned %v, want nil", err) + } + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Errorf("returned %v, want nil", err) + } else if got, want := string(data), "pong"; got != want { + t.Errorf("returned %q, want %q", got, want) + } + }) + + t.Run("failure", func(t *testing.T) { + resp, err := http.Get(ts.URL + "/ping") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + resp.Body.Close() + + err = DrainResponseBody(resp) + if err == nil { + t.Errorf("returned nil, want an error") + } + }) +} diff --git a/hxutil/round_tripper.go b/hxutil/round_tripper.go new file mode 100644 index 0000000..04b855c --- /dev/null +++ b/hxutil/round_tripper.go @@ -0,0 +1,22 @@ +package hxutil + +import "net/http" + +type RoundTripperFunc func(*http.Request, http.RoundTripper) (*http.Response, error) + +func (f RoundTripperFunc) Wrap(rt http.RoundTripper) http.RoundTripper { + return &RoundTripperWrapper{Next: rt, Func: f} +} + +type RoundTripperWrapper struct { + Next http.RoundTripper + Func func(*http.Request, http.RoundTripper) (*http.Response, error) +} + +func (w *RoundTripperWrapper) RoundTrip(r *http.Request) (*http.Response, error) { + next := w.Next + if next == nil { + next = http.DefaultTransport + } + return w.Func(r, next) +} diff --git a/hxutil/round_tripper_test.go b/hxutil/round_tripper_test.go new file mode 100644 index 0000000..9b286b4 --- /dev/null +++ b/hxutil/round_tripper_test.go @@ -0,0 +1,71 @@ +package hxutil_test + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "testing" + + "github.com/izumin5210/hx/hxutil" +) + +func TestRoundTripperFunc(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch { + case r.Method == http.MethodPost && r.URL.Path == "/echo": + cnt, _ := strconv.Atoi(r.Header.Get("Count")) + if cnt == 0 { + cnt = 1 + } + var buf bytes.Buffer + io.Copy(&buf, r.Body) + w.Write([]byte(strings.Repeat(buf.String(), cnt))) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer ts.Close() + + cases := []struct { + test string + base http.RoundTripper + }{ + {test: "no base"}, + {test: "specify base", base: http.DefaultTransport}, + } + + for _, tc := range cases { + t.Run(tc.test, func(t *testing.T) { + cli := &http.Client{ + Transport: hxutil.RoundTripperFunc(func(r *http.Request, rt http.RoundTripper) (*http.Response, error) { + r.Header.Set("Count", "3") + return rt.RoundTrip(r) + }).Wrap(tc.base), + } + + req, err := http.NewRequest(http.MethodPost, ts.URL+"/echo", bytes.NewBufferString("test")) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + resp, err := cli.Do(req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer resp.Body.Close() + + var buf bytes.Buffer + _, err = io.Copy(&buf, resp.Body) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if got, want := buf.String(), "testtesttest"; got != want { + t.Errorf("returned %q, want %q", got, want) + } + }) + } +} diff --git a/hxutil/transport.go b/hxutil/transport.go new file mode 100644 index 0000000..dc6ab9a --- /dev/null +++ b/hxutil/transport.go @@ -0,0 +1,25 @@ +package hxutil + +import ( + "net/http" + "reflect" +) + +// CloneTransport creates a new *http.Transport object that has copied attributes from a given one. +func CloneTransport(in *http.Transport) *http.Transport { + out := new(http.Transport) + outRv := reflect.ValueOf(out).Elem() + + rv := reflect.ValueOf(in).Elem() + rt := rv.Type() + + n := rt.NumField() + for i := 0; i < n; i++ { + src, dst := rv.Field(i), outRv.Field(i) + if src.Type().AssignableTo(dst.Type()) && dst.CanSet() { + dst.Set(src) + } + } + + return out +} diff --git a/hxutil/transport_test.go b/hxutil/transport_test.go new file mode 100644 index 0000000..2061746 --- /dev/null +++ b/hxutil/transport_test.go @@ -0,0 +1,66 @@ +package hxutil_test + +import ( + "net" + "net/http" + "testing" + "time" + + "github.com/izumin5210/hx/hxutil" +) + +func TestCloneTransport(t *testing.T) { + // https://github.com/golang/go/blob/go1.13.4/src/net/http/transport.go#L42-L54 + base := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + DualStack: true, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + cloned := hxutil.CloneTransport(base) + cloned.MaxIdleConns = 500 + cloned.MaxIdleConnsPerHost = 100 + + if cloned.Proxy == nil { + t.Errorf("Proxy should be copied") + } + + if cloned.DialContext == nil { + t.Errorf("DialContext should be copied") + } + + if got, want := cloned.IdleConnTimeout, base.IdleConnTimeout; got != want { + t.Errorf("cloned IdleConnTimeout is %s, want %s", got, want) + } + + if got, want := cloned.TLSHandshakeTimeout, base.TLSHandshakeTimeout; got != want { + t.Errorf("cloned TLSHandshakeTimeout is %s, want %s", got, want) + } + + if got, want := cloned.ExpectContinueTimeout, base.ExpectContinueTimeout; got != want { + t.Errorf("cloned ExpectContinueTimeout is %s, want %s", got, want) + } + + if got, want := base.MaxIdleConns, 100; got != want { + t.Errorf("base MaxIdleConns is %d, want %d", got, want) + } + + if got, want := cloned.MaxIdleConns, 500; got != want { + t.Errorf("cloned MaxIdleConns is %d, want %d", got, want) + } + + if got, want := base.MaxIdleConnsPerHost, 0; got != want { + t.Errorf("base MaxIdleConnsPerHost is %d, want %d", got, want) + } + + if got, want := cloned.MaxIdleConnsPerHost, 100; got != want { + t.Errorf("cloned MaxIdleConnsPerHost is %d, want %d", got, want) + } +} diff --git a/request_handler.go b/request_handler.go index 627d749..89f7806 100644 --- a/request_handler.go +++ b/request_handler.go @@ -3,6 +3,8 @@ package hx import ( "net/http" "time" + + "github.com/izumin5210/hx/hxutil" ) type RequestHandler func(*http.Client, *http.Request) (*http.Client, *http.Request, error) @@ -34,6 +36,10 @@ func TransportFrom(f func(http.RoundTripper) http.RoundTripper) Option { }) } +func TransportFunc(f func(*http.Request, http.RoundTripper) (*http.Response, error)) Option { + return TransportFrom(hxutil.RoundTripperFunc(f).Wrap) +} + // Timeout sets the max duration for http request(s). func Timeout(t time.Duration) Option { return RequestHandler(func(c *http.Client, r *http.Request) (*http.Client, *http.Request, error) { diff --git a/response_handler.go b/response_handler.go index 877328a..e52d834 100644 --- a/response_handler.go +++ b/response_handler.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "net/http" + + "github.com/izumin5210/hx/hxutil" ) type ResponseHandler func(*http.Response, error) (*http.Response, error) @@ -51,7 +53,7 @@ func AsError() ResponseHandler { if r == nil || err != nil { return r, err } - err = DrainResponseBody(r) + err = hxutil.DrainResponseBody(r) if err != nil { return nil, &ResponseError{Response: r, Err: err} }