diff --git a/pkg/logging/http.go b/pkg/logging/http.go index 2e6d6a9d0df..2409fb6b158 100644 --- a/pkg/logging/http.go +++ b/pkg/logging/http.go @@ -39,7 +39,11 @@ func (m *HTTPServerMiddleware) HTTPMiddleware(name string, next http.Handler) ht return func(w http.ResponseWriter, r *http.Request) { wrapped := httputil.WrapResponseWriterWithStatus(w) start := time.Now() - _, port, err := net.SplitHostPort(r.Host) + hostPort := r.Host + if hostPort == "" { + hostPort = r.URL.Host + } + _, port, err := net.SplitHostPort(hostPort) if err != nil { level.Error(m.logger).Log("msg", "failed to parse host port for http log decision", "err", err) next.ServeHTTP(w, r) diff --git a/pkg/logging/http_test.go b/pkg/logging/http_test.go index 73e7b0225d3..152c08dc090 100644 --- a/pkg/logging/http_test.go +++ b/pkg/logging/http_test.go @@ -4,10 +4,13 @@ package logging import ( + "bytes" "io" "io/ioutil" "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/go-kit/kit/log" @@ -15,23 +18,50 @@ import ( ) func TestHTTPServerMiddleware(t *testing.T) { - m := NewHTTPServerMiddleware(log.NewNopLogger()) + b := bytes.Buffer{} + + m := NewHTTPServerMiddleware(log.NewLogfmtLogger(io.Writer(&b))) handler := func(w http.ResponseWriter, r *http.Request) { _, err := io.WriteString(w, "Test Works") if err != nil { - t.Log(err) + testutil.Ok(t, err) } } hm := m.HTTPMiddleware("test", http.HandlerFunc(handler)) - req := httptest.NewRequest("GET", "http://example.com/foo", nil) + // Cortex way: + u, err := url.Parse("http://example.com:5555/foo") + testutil.Ok(t, err) + req := &http.Request{ + Method: "GET", + URL: u, + Body: nil, + } + w := httptest.NewRecorder() hm(w, req) resp := w.Result() - body, _ := ioutil.ReadAll(resp.Body) + body, err := ioutil.ReadAll(resp.Body) + testutil.Ok(t, err) + + testutil.Equals(t, 200, resp.StatusCode) + testutil.Equals(t, "Test Works", string(body)) + testutil.Assert(t, !strings.Contains(b.String(), "err=")) + + // Typical way: + req = httptest.NewRequest("GET", "http://example.com:5555/foo", nil) + b.Reset() + + w = httptest.NewRecorder() + hm(w, req) + + resp = w.Result() + body, err = ioutil.ReadAll(resp.Body) + testutil.Ok(t, err) testutil.Equals(t, 200, resp.StatusCode) testutil.Equals(t, "Test Works", string(body)) + testutil.Assert(t, !strings.Contains(b.String(), "err=")) }