diff --git a/pkg/drivers/http/driver.go b/pkg/drivers/http/driver.go index 9e86aba8..f36087c4 100644 --- a/pkg/drivers/http/driver.go +++ b/pkg/drivers/http/driver.go @@ -26,18 +26,7 @@ func NewDriver(opts ...Option) *Driver { drv := new(Driver) drv.options = newOptions(opts) - if drv.options.Proxy == "" { - drv.client = pester.New() - } else { - client, err := newClientWithProxy(drv.options) - - if err != nil { - drv.client = pester.New() - } else { - drv.client = pester.NewExtendedClient(client) - } - } - + drv.client = newHTTPClient(drv.options) drv.client.Concurrency = drv.options.Concurrency drv.client.MaxRetries = drv.options.MaxRetries drv.client.Backoff = drv.options.Backoff @@ -45,17 +34,47 @@ func NewDriver(opts ...Option) *Driver { return drv } -func newClientWithProxy(options *Options) (*http.Client, error) { - proxyURL, err := url.Parse(options.Proxy) +func newHTTPClient(options *Options) (httpClient *pester.Client) { + httpClient = pester.New() + + if options.HTTPTransport != nil { + httpClient.Transport = options.HTTPTransport + } + + if options.Proxy == "" { + return + } + + if err := addProxy(httpClient, options.Proxy); err != nil { + return + } + + httpClient = pester.NewExtendedClient(&http.Client{Transport: httpClient.Transport}) + + return +} + +func addProxy(httpClient *pester.Client, proxyStr string) error { + if proxyStr == "" { + return nil + } + proxyURL, err := url.Parse(proxyStr) if err != nil { - return nil, err + return err } proxy := http.ProxyURL(proxyURL) - tr := &http.Transport{Proxy: proxy} - return &http.Client{Transport: tr}, nil + if httpClient.Transport != nil { + httpClient.Transport.(*http.Transport).Proxy = proxy + + return nil + } + + httpClient.Transport = &http.Transport{Proxy: proxy} + + return nil } func (drv *Driver) Name() string { diff --git a/pkg/drivers/http/driver_test.go b/pkg/drivers/http/driver_test.go new file mode 100644 index 00000000..bb117516 --- /dev/null +++ b/pkg/drivers/http/driver_test.go @@ -0,0 +1,101 @@ +package http + +import ( + "crypto/tls" + "net/http" + "reflect" + "testing" + "unsafe" + + "github.com/smartystreets/goconvey/convey" +) + +func Test_newHTTPClientWithTransport(t *testing.T) { + httpTransport := (http.DefaultTransport).(*http.Transport) + httpTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + + type args struct { + options *Options + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "check transport exist with pester.New()", + args: args{options: &Options{ + Proxy: "http://0.0.0.|", + HTTPTransport: httpTransport, + }}, + }, + { + name: "check transport exist with pester.NewExtendedClient()", + args: args{options: &Options{ + Proxy: "http://0.0.0.0", + HTTPTransport: httpTransport, + }}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + convey.Convey(tt.name, t, func() { + var ( + transport *http.Transport + client = newHTTPClient(tt.args.options) + rValue = reflect.ValueOf(client).Elem() + rField = rValue.Field(0) + ) + + rField = reflect.NewAt(rField.Type(), unsafe.Pointer(rField.UnsafeAddr())).Elem() + hc := rField.Interface().(*http.Client) + + if hc != nil { + transport = hc.Transport.(*http.Transport) + } else { + transport = client.Transport.(*http.Transport) + } + + verify := transport.TLSClientConfig.InsecureSkipVerify + + convey.So(verify, convey.ShouldBeTrue) + }) + }) + } +} + +func Test_newHTTPClient(t *testing.T) { + + convey.Convey("pester.New()", t, func() { + var ( + client = newHTTPClient(&Options{ + Proxy: "http://0.0.0.|", + }) + + rValue = reflect.ValueOf(client).Elem() + rField = rValue.Field(0) + ) + + rField = reflect.NewAt(rField.Type(), unsafe.Pointer(rField.UnsafeAddr())).Elem() + hc := rField.Interface().(*http.Client) + + convey.So(hc, convey.ShouldBeNil) + }) + + convey.Convey("pester.NewExtend()", t, func() { + var ( + client = newHTTPClient(&Options{ + Proxy: "http://0.0.0.0", + }) + + rValue = reflect.ValueOf(client).Elem() + rField = rValue.Field(0) + ) + + rField = reflect.NewAt(rField.Type(), unsafe.Pointer(rField.UnsafeAddr())).Elem() + hc := rField.Interface().(*http.Client) + + convey.So(hc, convey.ShouldNotBeNil) + }) + +} diff --git a/pkg/drivers/http/options.go b/pkg/drivers/http/options.go index 4615ee1a..10a3e352 100644 --- a/pkg/drivers/http/options.go +++ b/pkg/drivers/http/options.go @@ -20,6 +20,7 @@ type ( Headers drivers.HTTPHeaders Cookies drivers.HTTPCookies AllowedHTTPCodes map[int]struct{} + HTTPTransport *stdhttp.Transport } ) @@ -143,3 +144,9 @@ func WithAllowedHTTPCodes(httpCodes []int) Option { } } } + +func WithCustomTransport(transport *stdhttp.Transport) Option { + return func(opts *Options) { + opts.HTTPTransport = transport + } +}