diff --git a/response.go b/response.go index 5f3b268..152a9aa 100644 --- a/response.go +++ b/response.go @@ -4,15 +4,73 @@ import ( "bytes" "encoding/json" "encoding/xml" - "errors" "fmt" "io" "net/http" - "runtime" "strconv" "strings" ) +// Responder is a callback that receives and http request and returns +// a mocked response. +type Responder func(*http.Request) (*http.Response, error) + +func (r Responder) times(name string, n int, fn ...func(...interface{})) Responder { + count := 0 + return func(req *http.Request) (*http.Response, error) { + count++ + if count > n { + err := stackTracer{ + err: fmt.Errorf("Responder not found for %s %s (coz %s and already called %d times)", req.Method, req.URL, name, count), + } + if len(fn) > 0 { + err.customFn = fn[0] + } + return nil, err + } + return r(req) + } +} + +// Times returns a Responder callable n times before returning an +// error. If the Responder is called more than n times and fn is +// passed and non-nil, it acts as the fn parameter of +// NewNotFoundResponder, allowing to dump the stack trace to localize +// the origin of the call. +func (r Responder) Times(n int, fn ...func(...interface{})) Responder { + return r.times("Times", n, fn...) +} + +// Once returns a new Responder callable once before returning an +// error. If the Responder is called 2 or more times and fn is passed +// and non-nil, it acts as the fn parameter of NewNotFoundResponder, +// allowing to dump the stack trace to localize the origin of the +// call. +func (r Responder) Once(fn ...func(...interface{})) Responder { + return r.times("Once", 1, fn...) +} + +// Trace returns a new Responder that allow to easily trace the calls +// of the original Responder using fn. It can be used in conjunction +// with the testing package as in the example below with the help of +// (*testing.T).Log method: +// import "testing" +// ... +// func TestMyApp(t *testing.T) { +// ... +// httpmock.RegisterResponder("GET", "/foo/bar", +// httpmock.NewStringResponder(200, "{}").Trace(t.Log), +// ) +func (r Responder) Trace(fn func(...interface{})) Responder { + return func(req *http.Request) (*http.Response, error) { + resp, err := r(req) + return resp, stackTracer{ + customFn: fn, + err: err, + } + } +} + // ResponderFromResponse wraps an *http.Response in a Responder func ResponderFromResponse(resp *http.Response) Responder { return func(req *http.Request) (*http.Response, error) { @@ -54,43 +112,18 @@ func NewErrorResponder(err error) Responder { // httpmock.RegisterNoResponder(httpmock.NewNotFoundResponder(t.Fatal)) // // Will abort the current test and print something like: -// response:69: Responder not found for: GET http://foo.bar/path -// Called from goroutine 20 [running]: -// github.com/jarcoal/httpmock.NewNotFoundResponder.func1(0xc00011f000, 0x0, 0x42dfb1, 0x77ece8) -// /go/src/github.com/jarcoal/httpmock/response.go:67 +0x1c1 -// github.com/jarcoal/httpmock.runCancelable(0xc00004bfc0, 0xc00011f000, 0x7692f8, 0xc, 0xc0001208b0) -// /go/src/github.com/jarcoal/httpmock/transport.go:146 +0x7e -// github.com/jarcoal/httpmock.(*MockTransport).RoundTrip(0xc00005c980, 0xc00011f000, 0xc00005c980, 0x0, 0x0) -// /go/src/github.com/jarcoal/httpmock/transport.go:140 +0x19d -// net/http.send(0xc00011f000, 0x7d3440, 0xc00005c980, 0x0, 0x0, 0x0, 0xc000010400, 0xc000047bd8, 0x1, 0x0) -// /usr/local/go/src/net/http/client.go:250 +0x461 -// net/http.(*Client).send(0x9f6e20, 0xc00011f000, 0x0, 0x0, 0x0, 0xc000010400, 0x0, 0x1, 0x9f7ac0) -// /usr/local/go/src/net/http/client.go:174 +0xfb -// net/http.(*Client).do(0x9f6e20, 0xc00011f000, 0x0, 0x0, 0x0) -// /usr/local/go/src/net/http/client.go:641 +0x279 -// net/http.(*Client).Do(...) -// /usr/local/go/src/net/http/client.go:509 -// net/http.(*Client).Get(0x9f6e20, 0xc00001e420, 0x23, 0xc00012c000, 0xb, 0x600) -// /usr/local/go/src/net/http/client.go:398 +0x9e -// net/http.Get(...) -// /usr/local/go/src/net/http/client.go:370 -// foo.bar/foobar/foobar.TestMyApp(0xc00011e000) -// /go/src/foo.bar/foobar/foobar/my_app_test.go:272 +0xdbb -// testing.tRunner(0xc00011e000, 0x77e3a8) -// /usr/local/go/src/testing/testing.go:865 +0xc0 -// created by testing.(*T).Run -// /usr/local/go/src/testing/testing.go:916 +0x35a +// transport_test.go:735: Called from net/http.Get() +// at /go/src/github.com/jarcoal/httpmock/transport_test.go:714 +// github.com/jarcoal/httpmock.TestCheckStackTracer() +// at /go/src/testing/testing.go:865 +// testing.tRunner() +// at /go/src/runtime/asm_amd64.s:1337 func NewNotFoundResponder(fn func(...interface{})) Responder { return func(req *http.Request) (*http.Response, error) { - mesg := fmt.Sprintf("Responder not found for %s %s", req.Method, req.URL) - if fn != nil { - buf := make([]byte, 4096) - n := runtime.Stack(buf, false) - buf = buf[:n] - fn(mesg + "\nCalled from " + - strings.Replace(strings.TrimSuffix(string(buf), "\n"), "\n", "\n ", -1)) + return nil, stackTracer{ + customFn: fn, + err: fmt.Errorf("Responder not found for %s %s", req.Method, req.URL), } - return nil, errors.New(mesg) } } diff --git a/response_test.go b/response_test.go index a99480a..6a0e907 100644 --- a/response_test.go +++ b/response_test.go @@ -4,10 +4,8 @@ import ( "encoding/json" "encoding/xml" "errors" - "fmt" "io/ioutil" "net/http" - "strings" "testing" ) @@ -51,10 +49,7 @@ func TestResponderFromResponse(t *testing.T) { } func TestNewNotFoundResponder(t *testing.T) { - var mesg string - responder := NewNotFoundResponder(func(args ...interface{}) { - mesg = fmt.Sprint(args[0]) - }) + responder := NewNotFoundResponder(func(args ...interface{}) {}) req, err := http.NewRequest("GET", "http://foo.bar/path", nil) if err != nil { @@ -71,15 +66,11 @@ func TestNewNotFoundResponder(t *testing.T) { t.Error("err should be not nil") } else if err.Error() != title { t.Errorf(`err mismatch, got: "%s", expected: "%s"`, - err.Error(), - "Responder not found for: GET http://foo.bar/path") - } - - if !strings.HasPrefix(mesg, title+"\nCalled from ") { - t.Error(`mesg should begin with "` + title + `\nCalled from ", but it is: "` + mesg + `"`) - } - if strings.HasSuffix(mesg, "\n") { - t.Error(`mesg should not end with \n, but it is: "` + mesg + `"`) + err, "Responder not found for: GET http://foo.bar/path") + } else if ne, ok := err.(stackTracer); !ok { + t.Errorf(`err type mismatch, got %T, expected httpmock.notFound`, err) + } else if ne.customFn == nil { + t.Error(`err customFn mismatch, got: nil, expected: non-nil`) } // nil fn @@ -93,8 +84,11 @@ func TestNewNotFoundResponder(t *testing.T) { t.Error("err should be not nil") } else if err.Error() != title { t.Errorf(`err mismatch, got: "%s", expected: "%s"`, - err.Error(), - "Responder not found for: GET http://foo.bar/path") + err, "Responder not found for: GET http://foo.bar/path") + } else if ne, ok := err.(stackTracer); !ok { + t.Errorf(`err type mismatch, got %T, expected httpmock.notFound`, err) + } else if ne.customFn != nil { + t.Errorf(`err customFn mismatch, got: %p, expected: nil`, ne.customFn) } } @@ -252,3 +246,98 @@ func TestRewindResponse(t *testing.T) { } } } + +func TestResponder(t *testing.T) { + req, err := http.NewRequest(http.MethodGet, "http://foo.bar", nil) + if err != nil { + t.Fatal("Error creating request") + } + resp := &http.Response{} + + chk := func(r Responder, expectedResp *http.Response, expectedErr string) { + //t.Helper // Only available since 1.9 + gotResp, gotErr := r(req) + if gotResp != expectedResp { + t.Errorf(`Response mismatch, expected: %v, got: %v`, expectedResp, gotResp) + } + var gotErrStr string + if gotErr != nil { + gotErrStr = gotErr.Error() + } + if gotErrStr != expectedErr { + t.Errorf(`Error mismatch, expected: %v, got: %v`, expectedErr, gotErrStr) + } + } + called := false + chkNotCalled := func() { + if called { + //t.Helper // Only available since 1.9 + t.Errorf("Original responder should not be called") + called = false + } + } + chkCalled := func() { + if !called { + //t.Helper // Only available since 1.9 + t.Errorf("Original responder should be called") + } + called = false + } + + r := Responder(func(*http.Request) (*http.Response, error) { + called = true + return resp, nil + }) + chk(r, resp, "") + chkCalled() + + // + // Once + ro := r.Once() + chk(ro, resp, "") + chkCalled() + + chk(ro, nil, "Responder not found for GET http://foo.bar (coz Once and already called 2 times)") + chkNotCalled() + + chk(ro, nil, "Responder not found for GET http://foo.bar (coz Once and already called 3 times)") + chkNotCalled() + + ro = r.Once(func(args ...interface{}) {}) + chk(ro, resp, "") + chkCalled() + + chk(ro, nil, "Responder not found for GET http://foo.bar (coz Once and already called 2 times)") + chkNotCalled() + + // + // Times + rt := r.Times(2) + chk(rt, resp, "") + chkCalled() + + chk(rt, resp, "") + chkCalled() + + chk(rt, nil, "Responder not found for GET http://foo.bar (coz Times and already called 3 times)") + chkNotCalled() + + chk(rt, nil, "Responder not found for GET http://foo.bar (coz Times and already called 4 times)") + chkNotCalled() + + rt = r.Times(1, func(args ...interface{}) {}) + chk(rt, resp, "") + chkCalled() + + chk(rt, nil, "Responder not found for GET http://foo.bar (coz Times and already called 2 times)") + chkNotCalled() + + // + // Trace + rt = r.Trace(func(args ...interface{}) {}) + chk(rt, resp, "") + chkCalled() + + chk(rt, resp, "") + chkCalled() +} diff --git a/transport.go b/transport.go index 53fae85..0136a94 100644 --- a/transport.go +++ b/transport.go @@ -6,15 +6,12 @@ import ( "fmt" "net/http" "net/url" + "runtime" "sort" "strings" "sync" ) -// Responder is a callback that receives and http request and returns -// a mocked response. -type Responder func(*http.Request) (*http.Response, error) - // NoResponderFound is returned when no responders are found for a given HTTP method and URL. var NoResponderFound = errors.New("no responder found") // nolint: golint @@ -143,7 +140,8 @@ func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { func runCancelable(responder Responder, req *http.Request) (*http.Response, error) { ctx := req.Context() if req.Cancel == nil && ctx.Done() == nil { // nolint: staticcheck - return responder(req) + resp, err := responder(req) + return resp, checkStackTracer(req, err) } // Set up a goroutine that translates a close(req.Cancel) into a @@ -197,7 +195,83 @@ func runCancelable(responder Responder, req *http.Request) (*http.Response, erro // first goroutine. done <- struct{}{} - return r.response, r.err + return r.response, checkStackTracer(req, r.err) +} + +type stackTracer struct { + customFn func(...interface{}) + err error +} + +func (n stackTracer) Error() string { + if n.err == nil { + return "" + } + return n.err.Error() +} + +// checkStackTracer checks for specific error returned by +// NewNotFoundResponder function or Debug Responder method. +func checkStackTracer(req *http.Request, err error) error { + if nf, ok := err.(stackTracer); ok { + if nf.customFn != nil { + pc := make([]uintptr, 128) + npc := runtime.Callers(2, pc) + pc = pc[:npc] + + var mesg bytes.Buffer + var netHTTPBegin, netHTTPEnd bool + + // Start recording at first net/http call if any... + for { + frames := runtime.CallersFrames(pc) + + var lastFn string + for { + frame, more := frames.Next() + + if !netHTTPEnd { + if netHTTPBegin { + netHTTPEnd = !strings.HasPrefix(frame.Function, "net/http.") + } else { + netHTTPBegin = strings.HasPrefix(frame.Function, "net/http.") + } + } + + if netHTTPEnd { + if lastFn != "" { + if mesg.Len() == 0 { + if nf.err != nil { + mesg.WriteString(nf.err.Error()) + } else { + fmt.Fprintf(&mesg, "%s %s", req.Method, req.URL) + } + mesg.WriteString("\nCalled from ") + } else { + mesg.WriteString("\n ") + } + fmt.Fprintf(&mesg, "%s()\n at %s:%d", lastFn, frame.File, frame.Line) + } + } + lastFn = frame.Function + + if !more { + break + } + } + + // At least one net/http frame found + if mesg.Len() > 0 { + break + } + netHTTPEnd = true // retry without looking at net/http frames + } + + nf.customFn(mesg.String()) + } + err = nf.err + } + return err } // responderForKey returns a responder for a given key. diff --git a/transport_test.go b/transport_test.go index a59cbc4..18cf8e7 100644 --- a/transport_test.go +++ b/transport_test.go @@ -653,3 +653,102 @@ func TestRegisterResponderWithQueryPanic(t *testing.T) { } } } + +func TestCheckStackTracer(t *testing.T) { + req, err := http.NewRequest("GET", "http://foo.bar/", nil) + if err != nil { + t.Fatal(err) + } + + // no error + gotErr := checkStackTracer(req, nil) + if gotErr != nil { + t.Errorf(`checkStackTracer(nil) should return nil, not %v`, gotErr) + } + + // Classic error + err = errors.New("error") + gotErr = checkStackTracer(req, err) + if err != gotErr { + t.Errorf(`checkStackTracer(err) should return %v, not %v`, err, gotErr) + } + + // stackTracer without customFn + origErr := errors.New("foo") + errTracer := stackTracer{ + err: origErr, + } + gotErr = checkStackTracer(req, errTracer) + if gotErr != origErr { + t.Errorf(`Returned error mismatch, expected: %v, got: %v`, origErr, gotErr) + } + + // stackTracer with nil error & without customFn + errTracer = stackTracer{} + gotErr = checkStackTracer(req, errTracer) + if gotErr != nil { + t.Errorf(`Returned error mismatch, expected: nil, got: %v`, gotErr) + } + + // stackTracer + var mesg string + errTracer = stackTracer{ + err: origErr, + customFn: func(args ...interface{}) { + mesg = args[0].(string) + }, + } + gotErr = checkStackTracer(req, errTracer) + if !strings.HasPrefix(mesg, "foo\nCalled from ") || strings.HasSuffix(mesg, "\n") { + t.Errorf(`mesg does not match "^foo\nCalled from .*[^\n]\z", it is "` + mesg + `"`) + } + if gotErr != origErr { + t.Errorf(`Returned error mismatch, expected: %v, got: %v`, origErr, gotErr) + } + + // stackTracer with nil error but customFn + mesg = "" + errTracer = stackTracer{ + customFn: func(args ...interface{}) { + mesg = args[0].(string) + }, + } + gotErr = checkStackTracer(req, errTracer) + if !strings.HasPrefix(mesg, "GET http://foo.bar/\nCalled from ") || strings.HasSuffix(mesg, "\n") { + t.Errorf(`mesg does not match "^foo\nCalled from .*[^\n]\z", it is "` + mesg + `"`) + } + if gotErr != nil { + t.Errorf(`Returned error mismatch, expected: nil, got: %v`, gotErr) + } + + // Full test using Trace() Responder + Activate() + defer Deactivate() + + const url = "https://foo.bar/" + mesg = "" + RegisterResponder("GET", url, + NewStringResponder(200, "{}"). + Trace(func(args ...interface{}) { mesg = args[0].(string) })) + + resp, err := http.Get(url) + if err != nil { + t.Fatal(err) + } + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatal(err) + } + + if string(data) != "{}" { + t.FailNow() + } + + // Check that first frame is the net/http.Get() call + if !strings.HasPrefix(mesg, "GET https://foo.bar/\nCalled from net/http.Get()\n at ") || + strings.HasSuffix(mesg, "\n") { + t.Errorf("Bad mesg: <%v>", mesg) + } +}