diff --git a/rest.go b/rest.go index d8aef1b..de5c2f9 100644 --- a/rest.go +++ b/rest.go @@ -11,6 +11,7 @@ import ( // Method contains the supported HTTP verbs. type Method string +// Supported HTTP verbs. const ( Get Method = "GET" Post Method = "POST" @@ -28,6 +29,16 @@ type Request struct { Body []byte } +// DefaultClient is used if no custom HTTP client is defined +var DefaultClient = &Client{HTTPClient: http.DefaultClient} + +// Client allows modification of client headers, redirect policy +// and other settings +// See https://golang.org/pkg/net/http +type Client struct { + HTTPClient *http.Client +} + // Response holds the response from an API call. type Response struct { StatusCode int // e.g. 200 @@ -59,11 +70,7 @@ func BuildRequestObject(request Request) (*http.Request, error) { // MakeRequest makes the API call. func MakeRequest(req *http.Request) (*http.Response, error) { - var Client = &http.Client{ - Transport: http.DefaultTransport, - } - res, err := Client.Do(req) - return res, err + return DefaultClient.HTTPClient.Do(req) } // BuildResponse builds the response struct. @@ -83,6 +90,19 @@ func BuildResponse(res *http.Response) (*Response, error) { // API is the main interface to the API. func API(request Request) (*Response, error) { + return DefaultClient.API(request) +} + +// The following functions enable the ability to define a +// custom HTTP Client + +// MakeRequest makes the API call. +func (c *Client) MakeRequest(req *http.Request) (*http.Response, error) { + return c.HTTPClient.Do(req) +} + +// API is the main interface to the API. +func (c *Client) API(request Request) (*Response, error) { // Add any query parameters to the URL. if len(request.QueryParams) != 0 { request.BaseURL = AddQueryParameters(request.BaseURL, request.QueryParams) @@ -95,7 +115,7 @@ func API(request Request) (*Response, error) { } // Build the HTTP client and make the request. - res, err := MakeRequest(req) + res, err := c.MakeRequest(req) if err != nil { return nil, err } diff --git a/rest_test.go b/rest_test.go index afbec8b..e8d3d03 100644 --- a/rest_test.go +++ b/rest_test.go @@ -4,7 +4,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "strings" "testing" + "time" ) func TestBuildURL(t *testing.T) { @@ -47,6 +49,7 @@ func TestBuildResponse(t *testing.T) { fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "{\"message\": \"success\"}") })) + defer fakeServer.Close() baseURL := fakeServer.URL method := Get request := Request{ @@ -74,6 +77,7 @@ func TestRest(t *testing.T) { fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintln(w, "{\"message\": \"success\"}") })) + defer fakeServer.Close() host := fakeServer.URL endpoint := "/test_endpoint" baseURL := host + endpoint @@ -105,3 +109,27 @@ func TestRest(t *testing.T) { t.Errorf("Rest failed to make a valid API request. Returned error: %v", e) } } + +func TestCustomHTTPClient(t *testing.T) { + fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(time.Millisecond * 20) + fmt.Fprintln(w, "{\"message\": \"success\"}") + })) + defer fakeServer.Close() + host := fakeServer.URL + endpoint := "/test_endpoint" + baseURL := host + endpoint + method := Get + request := Request{ + Method: method, + BaseURL: baseURL, + } + customClient := &Client{&http.Client{Timeout: time.Millisecond * 10}} + _, err := customClient.API(request) + if err == nil { + t.Error("A timeout did not trigger as expected") + } + if strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") == false { + t.Error("We did not receive the Timeout error") + } +}