Skip to content

Commit

Permalink
Only log request body in debug loglevel
Browse files Browse the repository at this point in the history
Until now bramble has logged the entire body of each request. This has
been useful for debugging and seeing it in operation but it has a few
downsides:

- Entries can be huge
- Risk of sensitive information being leaked
- IO delay of creating large log events
  • Loading branch information
pkqk committed Nov 20, 2024
1 parent 63a1c40 commit fa39ca5
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 11 deletions.
35 changes: 32 additions & 3 deletions gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"strings"
Expand Down Expand Up @@ -67,6 +68,34 @@ func TestGatewayQuery(t *testing.T) {
assert.JSONEq(t, `{"data": { "test": "Hello" }}`, rec.Body.String())
}

func TestRequestNoBodyLoggingOnInfo(t *testing.T) {
server := NewGateway(NewExecutableSchema(nil, 50, nil), nil).Router(&Config{})

body := map[string]interface{}{
"foo": "bar",
}
jr, jw := io.Pipe()
go func() {
enc := json.NewEncoder(jw)
enc.Encode(body)
jw.Close()
}()
defer jr.Close()

req := httptest.NewRequest("POST", "/query", jr)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
obj := collectLogEvent(t, &slog.HandlerOptions{Level: slog.LevelInfo}, func() {
server.ServeHTTP(w, req)
})
resp := w.Result()

assert.NotNil(t, obj)
assert.Equal(t, float64(resp.StatusCode), obj["response.status"])
assert.Empty(t, obj["request.content-type"])
assert.Empty(t, obj["request.body"])
}

func TestRequestJSONBodyLogging(t *testing.T) {
server := NewGateway(NewExecutableSchema(nil, 50, nil), nil).Router(&Config{})

Expand All @@ -84,7 +113,7 @@ func TestRequestJSONBodyLogging(t *testing.T) {
req := httptest.NewRequest("POST", "/query", jr)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
obj := collectLogEvent(t, func() {
obj := collectLogEvent(t, &slog.HandlerOptions{Level: slog.LevelDebug}, func() {
server.ServeHTTP(w, req)
})
resp := w.Result()
Expand All @@ -110,7 +139,7 @@ func TestRequestInvalidJSONBodyLogging(t *testing.T) {
req := httptest.NewRequest("POST", "/query", jr)
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
obj := collectLogEvent(t, func() {
obj := collectLogEvent(t, &slog.HandlerOptions{Level: slog.LevelDebug}, func() {
server.ServeHTTP(w, req)
})
w.Result()
Expand All @@ -136,7 +165,7 @@ func TestRequestTextBodyLogging(t *testing.T) {
req := httptest.NewRequest("POST", "/query", jr)
req.Header.Set("Content-Type", "text/plain")
w := httptest.NewRecorder()
obj := collectLogEvent(t, func() {
obj := collectLogEvent(t, &slog.HandlerOptions{Level: slog.LevelDebug}, func() {
server.ServeHTTP(w, req)
})
w.Result()
Expand Down
4 changes: 4 additions & 0 deletions instrumentation.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ func (e *event) addFields(fields EventFields) {
e.fieldLock.Unlock()
}

func (e *event) debugEnabled() bool {
return log.Default().Enabled(context.Background(), log.LevelDebug)
}

func (e *event) finish() {
e.writeLock.Do(func() {
attrs := make([]any, 0, len(e.fields))
Expand Down
46 changes: 39 additions & 7 deletions instrumentation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ import (
// can only run one test at a time that takes over the log output
var logLock = sync.Mutex{}

func collectLogEvent(t *testing.T, f func()) map[string]interface{} {
func collectLogEvent(t *testing.T, o *log.HandlerOptions, f func()) map[string]interface{} {
t.Helper()
r, w := io.Pipe()
defer r.Close()
prevlogger := log.Default()
logLock.Lock()
defer logLock.Unlock()
log.SetDefault(log.New(log.NewJSONHandler(w, nil)))
log.SetDefault(log.New(log.NewJSONHandler(w, o)))
t.Cleanup(func() {
logLock.Lock()
defer logLock.Unlock()
Expand All @@ -41,9 +41,9 @@ func collectLogEvent(t *testing.T, f func()) map[string]interface{} {
return obj
}

func collectEventFromContext(ctx context.Context, t *testing.T, f func(*event)) map[string]interface{} {
func collectEventFromContext(ctx context.Context, t *testing.T, o *log.HandlerOptions, f func(*event)) map[string]interface{} {
t.Helper()
return collectLogEvent(t, func() {
return collectLogEvent(t, o, func() {
e := getEvent(ctx)
f(e)
if e != nil {
Expand All @@ -63,7 +63,7 @@ func TestDropsField(t *testing.T) {

func TestEventLogOnFinish(t *testing.T) {
ctx, _ := startEvent(context.TODO(), testEventName)
output := collectEventFromContext(ctx, t, func(*event) {
output := collectEventFromContext(ctx, t, nil, func(*event) {
AddField(ctx, "val", "test")
})

Expand All @@ -72,7 +72,7 @@ func TestEventLogOnFinish(t *testing.T) {

func TestAddMultipleToEventOnContext(t *testing.T) {
ctx, _ := startEvent(context.TODO(), testEventName)
output := collectEventFromContext(ctx, t, func(*event) {
output := collectEventFromContext(ctx, t, nil, func(*event) {
AddFields(ctx, EventFields{
"gizmo": "foo",
"gimmick": "bar",
Expand All @@ -86,7 +86,7 @@ func TestAddMultipleToEventOnContext(t *testing.T) {
func TestEventMeasurement(t *testing.T) {
start := time.Now()
ctx, _ := startEvent(context.TODO(), testEventName)
output := collectEventFromContext(ctx, t, func(*event) {
output := collectEventFromContext(ctx, t, nil, func(*event) {
time.Sleep(time.Microsecond)
})

Expand All @@ -105,3 +105,35 @@ func TestEventMeasurement(t *testing.T) {
assert.Fail(t, "missing duration")
}
}

func TestDebugDisabled(t *testing.T) {
ctx, _ := startEvent(context.TODO(), testEventName)

o := &log.HandlerOptions{
Level: log.LevelInfo,
}

output := collectEventFromContext(ctx, t, o, func(e *event) {
if e.debugEnabled() {
AddField(ctx, "val", "test")
}
})

assert.Empty(t, output["val"])
}

func TestDebugEnabled(t *testing.T) {
ctx, _ := startEvent(context.TODO(), testEventName)

o := &log.HandlerOptions{
Level: log.LevelDebug,
}

output := collectEventFromContext(ctx, t, o, func(e *event) {
if e.debugEnabled() {
AddField(ctx, "val", "test")
}
})

assert.Equal(t, "test", output["val"])
}
4 changes: 3 additions & 1 deletion middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ func monitoringMiddleware(h http.Handler) http.Handler {

r = r.WithContext(ctx)

addRequestBody(event, r, buf)
if event.debugEnabled() {
addRequestBody(event, r, buf)
}

m := httpsnoop.CaptureMetrics(h, w, r)

Expand Down

0 comments on commit fa39ca5

Please sign in to comment.