Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for setting a custom HTTP client #9

Merged
merged 2 commits into from
Jul 28, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 26 additions & 6 deletions rest.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
// Method contains the supported HTTP verbs.
type Method string

// Supported HTTP verbs.
const (
Get Method = "GET"
Post Method = "POST"
Expand All @@ -28,6 +29,16 @@ type Request struct {
Body []byte
}

// DefaultClient is used if no custom HTTP client is defined
var DefaultClient = &Client{HTTPClient: http.DefaultClient}

// Client allows modification of client headers, redirect policy
// and other settings
// See https://golang.org/pkg/net/http
type Client struct {
HTTPClient *http.Client
}

// Response holds the response from an API call.
type Response struct {
StatusCode int // e.g. 200
Expand Down Expand Up @@ -59,11 +70,7 @@ func BuildRequestObject(request Request) (*http.Request, error) {

// MakeRequest makes the API call.
func MakeRequest(req *http.Request) (*http.Response, error) {
var Client = &http.Client{
Transport: http.DefaultTransport,
}
res, err := Client.Do(req)
return res, err
return DefaultClient.HTTPClient.Do(req)
}

// BuildResponse builds the response struct.
Expand All @@ -83,6 +90,19 @@ func BuildResponse(res *http.Response) (*Response, error) {

// API is the main interface to the API.
func API(request Request) (*Response, error) {
return DefaultClient.API(request)
}

// The following functions enable the ability to define a
// custom HTTP Client

// MakeRequest makes the API call.
func (c *Client) MakeRequest(req *http.Request) (*http.Response, error) {
return c.HTTPClient.Do(req)
}

// API is the main interface to the API.
func (c *Client) API(request Request) (*Response, error) {
// Add any query parameters to the URL.
if len(request.QueryParams) != 0 {
request.BaseURL = AddQueryParameters(request.BaseURL, request.QueryParams)
Expand All @@ -95,7 +115,7 @@ func API(request Request) (*Response, error) {
}

// Build the HTTP client and make the request.
res, err := MakeRequest(req)
res, err := c.MakeRequest(req)
if err != nil {
return nil, err
}
Expand Down
28 changes: 28 additions & 0 deletions rest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"fmt"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
)

func TestBuildURL(t *testing.T) {
Expand Down Expand Up @@ -47,6 +49,7 @@ func TestBuildResponse(t *testing.T) {
fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "{\"message\": \"success\"}")
}))
defer fakeServer.Close()
baseURL := fakeServer.URL
method := Get
request := Request{
Expand Down Expand Up @@ -74,6 +77,7 @@ func TestRest(t *testing.T) {
fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintln(w, "{\"message\": \"success\"}")
}))
defer fakeServer.Close()
host := fakeServer.URL
endpoint := "/test_endpoint"
baseURL := host + endpoint
Expand Down Expand Up @@ -105,3 +109,27 @@ func TestRest(t *testing.T) {
t.Errorf("Rest failed to make a valid API request. Returned error: %v", e)
}
}

func TestCustomHTTPClient(t *testing.T) {
fakeServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
time.Sleep(time.Millisecond * 20)
fmt.Fprintln(w, "{\"message\": \"success\"}")
}))
defer fakeServer.Close()
host := fakeServer.URL
endpoint := "/test_endpoint"
baseURL := host + endpoint
method := Get
request := Request{
Method: method,
BaseURL: baseURL,
}
customClient := &Client{&http.Client{Timeout: time.Millisecond * 10}}
_, err := customClient.API(request)
if err == nil {
t.Error("A timeout did not trigger as expected")
}
if strings.Contains(err.Error(), "Client.Timeout exceeded while awaiting headers") == false {
t.Error("We did not receive the Timeout error")
}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't check the type of error returned.

}