diff --git a/rc.go b/rc.go index 6271729..cc0140d 100644 --- a/rc.go +++ b/rc.go @@ -6,7 +6,6 @@ import ( "io" "log/slog" "net/http" - "net/http/httptest" "os" "strings" "syscall" @@ -241,7 +240,7 @@ func New(cacher Cacher, opts ...Option) func(next http.Handler) http.Handler { // HandlerToRequester converts http.Handler to func(*http.Request) (*http.Response, error). func HandlerToRequester(h http.Handler) func(*http.Request) (*http.Response, error) { return func(req *http.Request) (*http.Response, error) { - rec := httptest.NewRecorder() + rec := newRecorder() h.ServeHTTP(rec, req) res := rec.Result() res.Header = rec.Header() @@ -249,6 +248,43 @@ func HandlerToRequester(h http.Handler) func(*http.Request) (*http.Response, err } } +type recorder struct { + statusCode int + header http.Header + buf *bytes.Buffer +} + +var _ http.ResponseWriter = (*recorder)(nil) + +func newRecorder() *recorder { + return &recorder{ + buf: new(bytes.Buffer), + header: make(http.Header), + } +} + +func (r *recorder) Header() http.Header { + return r.header +} + +func (r *recorder) Write(b []byte) (int, error) { + return r.buf.Write(b) +} + +func (r *recorder) WriteHeader(statusCode int) { + r.statusCode = statusCode +} + +func (r *recorder) Result() *http.Response { + return &http.Response{ + Status: http.StatusText(r.statusCode), + StatusCode: r.statusCode, + Header: r.header.Clone(), + Body: io.NopCloser(r.buf), + ContentLength: int64(r.buf.Len()), + } +} + func contains(s []string, e string) bool { for _, v := range s { if e == v {