Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TransportFunc and move utilities into under hxutil #12

Merged
merged 3 commits into from
Nov 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ err := hx.Get(ctx, "https://api.example.com/contents/1",

```go
func init() {
defaultTransport := hx.CloneTransport(http.DefaultTransport.(*http.Transport))
defaultTransport := hxutil.CloneTransport(http.DefaultTransport.(*http.Transport))

// Tweak keep-alive configuration
defaultTransport.MaxIdleConns = 500
Expand Down
36 changes: 18 additions & 18 deletions client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"time"

"github.com/izumin5210/hx"
"github.com/izumin5210/hx/hxutil"
)

func TestClient(t *testing.T) {
Expand Down Expand Up @@ -321,8 +322,8 @@ func TestClient(t *testing.T) {
})

t.Run("with Transport", func(t *testing.T) {
transport := &fakeTransport{
RoundTripFunc: func(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
transport := &hxutil.RoundTripperWrapper{
Func: func(req *http.Request, rt http.RoundTripper) (*http.Response, error) {
req.SetBasicAuth("foo", "bar")
return rt.RoundTrip(req)
},
Expand All @@ -339,9 +340,8 @@ func TestClient(t *testing.T) {
t.Run("with TransportFrom", func(t *testing.T) {
err := hx.Get(context.Background(), ts.URL+"/basic_auth",
hx.TransportFrom(func(base http.RoundTripper) http.RoundTripper {
return &fakeTransport{
Base: base,
RoundTripFunc: func(rt http.RoundTripper, req *http.Request) (*http.Response, error) {
return &hxutil.RoundTripperWrapper{
Func: func(req *http.Request, rt http.RoundTripper) (*http.Response, error) {
req.SetBasicAuth("foo", "bar")
return rt.RoundTrip(req)
},
Expand All @@ -353,23 +353,23 @@ func TestClient(t *testing.T) {
t.Errorf("returned %v, want nil", err)
}
})

t.Run("with TransportFunc", func(t *testing.T) {
err := hx.Get(context.Background(), ts.URL+"/basic_auth",
hx.TransportFunc(func(r *http.Request, next http.RoundTripper) (*http.Response, error) {
r.SetBasicAuth("foo", "bar")
return next.RoundTrip(r)
}),
hx.WhenFailure(hx.AsError()),
)
if err != nil {
t.Errorf("returned %v, want nil", err)
}
})
}

type fakeError struct {
Message string `json:"message"`
}

func (e fakeError) Error() string { return e.Message }

type fakeTransport struct {
Base http.RoundTripper
RoundTripFunc func(http.RoundTripper, *http.Request) (*http.Response, error)
}

func (t *fakeTransport) RoundTrip(req *http.Request) (*http.Response, error) {
base := t.Base
if base == nil {
base = http.DefaultTransport
}
return t.RoundTripFunc(base, req)
}
37 changes: 0 additions & 37 deletions helper.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,12 @@
package hx

import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
"reflect"
"strings"
)

func DrainResponseBody(r *http.Response) error {
var buf bytes.Buffer
_, err := buf.ReadFrom(r.Body)
if err != nil {
return err
}
err = r.Body.Close()
if err != nil {
return err
}
r.Body = ioutil.NopCloser(&buf)
return nil
}

// CloneTransport creates a new *http.Transport object that has copied attributes from a given one.
func CloneTransport(in *http.Transport) *http.Transport {
out := new(http.Transport)
outRv := reflect.ValueOf(out).Elem()

rv := reflect.ValueOf(in).Elem()
rt := rv.Type()

n := rt.NumField()
for i := 0; i < n; i++ {
src, dst := rv.Field(i), outRv.Field(i)
if src.Type().AssignableTo(dst.Type()) && dst.CanSet() {
dst.Set(src)
}
}

return out
}

func Path(elem ...interface{}) string {
chunks := make([]string, len(elem))
for i, e := range elem {
Expand Down
59 changes: 0 additions & 59 deletions helper_test.go
Original file line number Diff line number Diff line change
@@ -1,70 +1,11 @@
package hx_test

import (
"net"
"net/http"
"testing"
"time"

"github.com/izumin5210/hx"
)

func TestCloneTransport(t *testing.T) {
// https://github.com/golang/go/blob/go1.13.4/src/net/http/transport.go#L42-L54
base := &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
DualStack: true,
}).DialContext,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}

cloned := hx.CloneTransport(base)
cloned.MaxIdleConns = 500
cloned.MaxIdleConnsPerHost = 100

if cloned.Proxy == nil {
t.Errorf("Proxy should be copied")
}

if cloned.DialContext == nil {
t.Errorf("DialContext should be copied")
}

if got, want := cloned.IdleConnTimeout, base.IdleConnTimeout; got != want {
t.Errorf("cloned IdleConnTimeout is %s, want %s", got, want)
}

if got, want := cloned.TLSHandshakeTimeout, base.TLSHandshakeTimeout; got != want {
t.Errorf("cloned TLSHandshakeTimeout is %s, want %s", got, want)
}

if got, want := cloned.ExpectContinueTimeout, base.ExpectContinueTimeout; got != want {
t.Errorf("cloned ExpectContinueTimeout is %s, want %s", got, want)
}

if got, want := base.MaxIdleConns, 100; got != want {
t.Errorf("base MaxIdleConns is %d, want %d", got, want)
}

if got, want := cloned.MaxIdleConns, 500; got != want {
t.Errorf("cloned MaxIdleConns is %d, want %d", got, want)
}

if got, want := base.MaxIdleConnsPerHost, 0; got != want {
t.Errorf("base MaxIdleConnsPerHost is %d, want %d", got, want)
}

if got, want := cloned.MaxIdleConnsPerHost, 100; got != want {
t.Errorf("cloned MaxIdleConnsPerHost is %d, want %d", got, want)
}
}

func TestPath(t *testing.T) {
cases := []struct {
test string
Expand Down
21 changes: 21 additions & 0 deletions hxutil/drain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package hxutil

import (
"bytes"
"io/ioutil"
"net/http"
)

func DrainResponseBody(r *http.Response) error {
var buf bytes.Buffer
_, err := buf.ReadFrom(r.Body)
if err != nil {
return err
}
err = r.Body.Close()
if err != nil {
return err
}
r.Body = ioutil.NopCloser(&buf)
return nil
}
53 changes: 53 additions & 0 deletions hxutil/drain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package hxutil

import (
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
)

func TestDrainResponseBody(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodGet && r.URL.Path == "/ping":
w.Write([]byte("pong"))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer ts.Close()

t.Run("success", func(t *testing.T) {
resp, err := http.Get(ts.URL + "/ping")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()

err = DrainResponseBody(resp)
if err != nil {
t.Errorf("returned %v, want nil", err)
}

data, err := ioutil.ReadAll(resp.Body)
if err != nil {
t.Errorf("returned %v, want nil", err)
} else if got, want := string(data), "pong"; got != want {
t.Errorf("returned %q, want %q", got, want)
}
})

t.Run("failure", func(t *testing.T) {
resp, err := http.Get(ts.URL + "/ping")
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
resp.Body.Close()

err = DrainResponseBody(resp)
if err == nil {
t.Errorf("returned nil, want an error")
}
})
}
22 changes: 22 additions & 0 deletions hxutil/round_tripper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package hxutil

import "net/http"

type RoundTripperFunc func(*http.Request, http.RoundTripper) (*http.Response, error)

func (f RoundTripperFunc) Wrap(rt http.RoundTripper) http.RoundTripper {
return &RoundTripperWrapper{Next: rt, Func: f}
}

type RoundTripperWrapper struct {
Next http.RoundTripper
Func func(*http.Request, http.RoundTripper) (*http.Response, error)
}

func (w *RoundTripperWrapper) RoundTrip(r *http.Request) (*http.Response, error) {
next := w.Next
if next == nil {
next = http.DefaultTransport
}
return w.Func(r, next)
}
71 changes: 71 additions & 0 deletions hxutil/round_tripper_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package hxutil_test

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"

"github.com/izumin5210/hx/hxutil"
)

func TestRoundTripperFunc(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch {
case r.Method == http.MethodPost && r.URL.Path == "/echo":
cnt, _ := strconv.Atoi(r.Header.Get("Count"))
if cnt == 0 {
cnt = 1
}
var buf bytes.Buffer
io.Copy(&buf, r.Body)
w.Write([]byte(strings.Repeat(buf.String(), cnt)))
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer ts.Close()

cases := []struct {
test string
base http.RoundTripper
}{
{test: "no base"},
{test: "specify base", base: http.DefaultTransport},
}

for _, tc := range cases {
t.Run(tc.test, func(t *testing.T) {
cli := &http.Client{
Transport: hxutil.RoundTripperFunc(func(r *http.Request, rt http.RoundTripper) (*http.Response, error) {
r.Header.Set("Count", "3")
return rt.RoundTrip(r)
}).Wrap(tc.base),
}

req, err := http.NewRequest(http.MethodPost, ts.URL+"/echo", bytes.NewBufferString("test"))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

resp, err := cli.Do(req)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer resp.Body.Close()

var buf bytes.Buffer
_, err = io.Copy(&buf, resp.Body)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if got, want := buf.String(), "testtesttest"; got != want {
t.Errorf("returned %q, want %q", got, want)
}
})
}
}
25 changes: 25 additions & 0 deletions hxutil/transport.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package hxutil

import (
"net/http"
"reflect"
)

// CloneTransport creates a new *http.Transport object that has copied attributes from a given one.
func CloneTransport(in *http.Transport) *http.Transport {
out := new(http.Transport)
outRv := reflect.ValueOf(out).Elem()

rv := reflect.ValueOf(in).Elem()
rt := rv.Type()

n := rt.NumField()
for i := 0; i < n; i++ {
src, dst := rv.Field(i), outRv.Field(i)
if src.Type().AssignableTo(dst.Type()) && dst.CanSet() {
dst.Set(src)
}
}

return out
}
Loading