Skip to content

Commit

Permalink
Merge pull request #32 from tri-adam/error-is
Browse files Browse the repository at this point in the history
Implement Is() for Error type
  • Loading branch information
tri-adam authored Jul 20, 2020
2 parents 4e0832d + 9fda8f8 commit 4933ba4
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 54 deletions.
34 changes: 23 additions & 11 deletions json_response.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,37 @@ import (
"net/http"
)

// Error describes an error condition.
type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}

func (e *Error) Error() string {
if e.Message != "" {
return fmt.Sprintf("%v (%v %v)", e.Message, e.Code, http.StatusText(e.Code))
}
return fmt.Sprintf("%v %v", e.Code, http.StatusText(e.Code))
}

// Is compares e against target. If target is an Error and matches the non-zero fields of e, true
// is returned.
func (e *Error) Is(target error) bool {
t, ok := target.(*Error)
if !ok {
return false
}
return ((e.Code == t.Code) || t.Code == 0) &&
((e.Message == t.Message) || t.Message == "")
}

// PageDetails specifies paging information.
type PageDetails struct {
Prev string `json:"prev,omitempty"`
Next string `json:"next,omitempty"`
TotalSize int `json:"totalSize,omitempty"`
}

// Error describes an error condition.
type Error struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}

var (
// JSONErrorUnauthorized is a generic 401 unauthorized response
JSONErrorUnauthorized = &Error{
Expand Down Expand Up @@ -74,18 +85,19 @@ func encodeResponse(w http.ResponseWriter, jr Response, code int) error {
return nil
}

// WriteError encodes the supplied error in a response, and writes to w.
func WriteError(w http.ResponseWriter, error string, code int) error {
// WriteError writes a status code and JSON response containing the supplied error message and
// status code to w.
func WriteError(w http.ResponseWriter, message string, code int) error {
jr := Response{
Error: &Error{
Code: code,
Message: error,
Message: message,
},
}
return encodeResponse(w, jr, code)
}

// WriteResponsePage encodes the supplied data in a paged JSON response, and writes to w.
// WriteResponsePage writes a status code and JSON response containing data and pd to w.
func WriteResponsePage(w http.ResponseWriter, data interface{}, pd *PageDetails, code int) error {
jr := Response{
Data: data,
Expand All @@ -94,7 +106,7 @@ func WriteResponsePage(w http.ResponseWriter, data interface{}, pd *PageDetails,
return encodeResponse(w, jr, code)
}

// WriteResponse encodes the supplied data in a response, and writes to w.
// WriteResponse writes a status code and JSON response containing data to w.
func WriteResponse(w http.ResponseWriter, data interface{}, code int) error {
return WriteResponsePage(w, data, nil, code)
}
Expand Down
78 changes: 35 additions & 43 deletions json_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package jsonresp
import (
"bytes"
"encoding/json"
"errors"
"io"
"log"
"net/http"
Expand All @@ -21,10 +22,21 @@ func TestError(t *testing.T) {
name string
code int
message string
wantErr error
wantErrString string
}{
{"NoMessage", http.StatusNotFound, "", "404 Not Found"},
{"Message", http.StatusNotFound, "blah", "blah (404 Not Found)"},
{
name: "NoMessage",
code: http.StatusNotFound,
wantErr: &Error{Code: http.StatusNotFound},
wantErrString: "404 Not Found",
},
{
name: "Message",
code: http.StatusNotFound,
message: "blah",
wantErr: &Error{Code: http.StatusNotFound, Message: "blah"},
wantErrString: "blah (404 Not Found)"},
}

for _, tt := range tests {
Expand All @@ -36,6 +48,9 @@ func TestError(t *testing.T) {
if je.Message != tt.message {
t.Errorf("got message %v, want %v", je.Message, tt.message)
}
if !errors.Is(je, tt.wantErr) {
t.Errorf("got error %v, want %v", je, tt.wantErr)
}
if s := je.Error(); s != tt.wantErrString {
t.Errorf("got string %v, want %v", s, tt.wantErrString)
}
Expand All @@ -45,14 +60,13 @@ func TestError(t *testing.T) {

func TestWriteError(t *testing.T) {
tests := []struct {
name string
error string
code int
wantMessage string
wantCode int
name string
error string
code int
wantErr error
}{
{"NoMessage", "", http.StatusNotFound, "", http.StatusNotFound},
{"NoMessage", "blah", http.StatusNotFound, "blah", http.StatusNotFound},
{"NoMessage", "", http.StatusNotFound, &Error{Code: http.StatusNotFound}},
{"NoMessage", "blah", http.StatusNotFound, &Error{Code: http.StatusNotFound, Message: "blah"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -62,22 +76,16 @@ func TestWriteError(t *testing.T) {
t.Fatalf("failed to write error: %v", err)
}

if rr.Code != tt.wantCode {
t.Errorf("got code %v, want %v", rr.Code, tt.wantCode)
if rr.Code != tt.code {
t.Errorf("got code %v, want %v", rr.Code, tt.code)
}

var jr Response
if err := json.NewDecoder(rr.Body).Decode(&jr); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if jr.Error == nil {
t.Fatalf("nil error received")
}
if jr.Error.Message != tt.wantMessage {
t.Errorf("got message %v, want %v", jr.Error.Message, tt.wantMessage)
}
if jr.Error.Code != tt.wantCode {
t.Errorf("got code %v, want %v", jr.Error.Code, tt.wantCode)
if got, want := jr.Error, tt.wantErr; !errors.Is(got, want) {
t.Errorf("got error %v, want %v", got, want)
}
})
}
Expand Down Expand Up @@ -309,34 +317,18 @@ func TestReadError(t *testing.T) {
}

tests := []struct {
name string
r io.Reader
wantErr bool
wantMessage string
wantCode int
name string
r io.Reader
wantErr error
}{
{"Empty", bytes.NewReader(nil), false, "", 0},
{"Response", getResponseBody(TestStruct{"blah"}), false, "", 0},
{"Error", getErrorBody(), true, "blah", http.StatusNotFound},
{"Empty", bytes.NewReader(nil), nil},
{"Response", getResponseBody(TestStruct{"blah"}), nil},
{"Error", getErrorBody(), &Error{Code: http.StatusNotFound, Message: "blah"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ReadError(tt.r)
if (err != nil) != tt.wantErr {
t.Errorf("ReadError() error = %v, wantErr %v", err, tt.wantErr)
}

if err != nil {
err, ok := err.(*Error)
if !ok {
t.Fatal("invalid error type")
}
if got, want := err.Message, tt.wantMessage; got != want {
t.Errorf("got message %v, want %v", got, want)
}
if got, want := err.Code, tt.wantCode; got != want {
t.Errorf("got code %v, want %v", got, want)
}
if got, want := ReadError(tt.r), tt.wantErr; !errors.Is(got, want) {
t.Errorf("got error %v, want %v", got, want)
}
})
}
Expand Down

0 comments on commit 4933ba4

Please sign in to comment.