-
Notifications
You must be signed in to change notification settings - Fork 588
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Setup body wrapper in otelmux #6650
base: main
Are you sure you want to change the base?
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
// Copyright The OpenTelemetry Authors | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package request // import "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux/internal/request" | ||
|
||
import ( | ||
"io" | ||
"sync" | ||
) | ||
|
||
var _ io.ReadCloser = &BodyWrapper{} | ||
|
||
// BodyWrapper wraps a http.Request.Body (an io.ReadCloser) to track the number | ||
// of bytes read and the last error. | ||
type BodyWrapper struct { | ||
io.ReadCloser | ||
OnRead func(n int64) // must not be nil | ||
|
||
mu sync.Mutex | ||
read int64 | ||
err error | ||
} | ||
|
||
// NewBodyWrapper creates a new BodyWrapper. | ||
// | ||
// The onRead attribute is a callback that will be called every time the data | ||
// is read, with the number of bytes being read. | ||
func NewBodyWrapper(body io.ReadCloser, onRead func(int64)) *BodyWrapper { | ||
return &BodyWrapper{ | ||
ReadCloser: body, | ||
OnRead: onRead, | ||
} | ||
} | ||
|
||
// Read reads the data from the io.ReadCloser, and stores the number of bytes | ||
// read and the error. | ||
func (w *BodyWrapper) Read(b []byte) (int, error) { | ||
n, err := w.ReadCloser.Read(b) | ||
n1 := int64(n) | ||
|
||
w.updateReadData(n1, err) | ||
w.OnRead(n1) | ||
return n, err | ||
} | ||
|
||
func (w *BodyWrapper) updateReadData(n int64, err error) { | ||
w.mu.Lock() | ||
defer w.mu.Unlock() | ||
|
||
w.read += n | ||
if err != nil { | ||
w.err = err | ||
} | ||
} | ||
|
||
// Closes closes the io.ReadCloser. | ||
func (w *BodyWrapper) Close() error { | ||
return w.ReadCloser.Close() | ||
Check warning on line 58 in instrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper.go Codecov / codecov/patchinstrumentation/github.com/gorilla/mux/otelmux/internal/request/body_wrapper.go#L57-L58
|
||
} | ||
|
||
// BytesRead returns the number of bytes read up to this point. | ||
func (w *BodyWrapper) BytesRead() int64 { | ||
w.mu.Lock() | ||
defer w.mu.Unlock() | ||
|
||
return w.read | ||
} | ||
|
||
// Error returns the last error. | ||
func (w *BodyWrapper) Error() error { | ||
w.mu.Lock() | ||
defer w.mu.Unlock() | ||
|
||
return w.err | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
// Copyright The OpenTelemetry Authors | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package request | ||
|
||
import ( | ||
"errors" | ||
"io" | ||
"strings" | ||
"testing" | ||
"time" | ||
|
||
"github.com/stretchr/testify/assert" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
var errFirstCall = errors.New("first call") | ||
|
||
func TestBodyWrapper(t *testing.T) { | ||
bw := NewBodyWrapper(io.NopCloser(strings.NewReader("hello world")), func(int64) {}) | ||
|
||
data, err := io.ReadAll(bw) | ||
require.NoError(t, err) | ||
assert.Equal(t, "hello world", string(data)) | ||
|
||
assert.Equal(t, int64(11), bw.BytesRead()) | ||
assert.Equal(t, io.EOF, bw.Error()) | ||
} | ||
|
||
type multipleErrorsReader struct { | ||
calls int | ||
} | ||
|
||
type errorWrapper struct{} | ||
|
||
func (errorWrapper) Error() string { | ||
return "subsequent calls" | ||
} | ||
|
||
func (mer *multipleErrorsReader) Read([]byte) (int, error) { | ||
mer.calls = mer.calls + 1 | ||
if mer.calls == 1 { | ||
return 0, errFirstCall | ||
} | ||
|
||
return 0, errorWrapper{} | ||
} | ||
|
||
func TestBodyWrapperWithErrors(t *testing.T) { | ||
bw := NewBodyWrapper(io.NopCloser(&multipleErrorsReader{}), func(int64) {}) | ||
|
||
data, err := io.ReadAll(bw) | ||
require.Equal(t, errFirstCall, err) | ||
assert.Equal(t, "", string(data)) | ||
require.Equal(t, errFirstCall, bw.Error()) | ||
|
||
data, err = io.ReadAll(bw) | ||
require.Equal(t, errorWrapper{}, err) | ||
assert.Equal(t, "", string(data)) | ||
require.Equal(t, errorWrapper{}, bw.Error()) | ||
} | ||
|
||
func TestConcurrentBodyWrapper(t *testing.T) { | ||
bw := NewBodyWrapper(io.NopCloser(strings.NewReader("hello world")), func(int64) {}) | ||
|
||
go func() { | ||
_, _ = io.ReadAll(bw) | ||
}() | ||
|
||
assert.NotNil(t, bw.BytesRead()) | ||
assert.Eventually(t, func() bool { | ||
return errors.Is(bw.Error(), io.EOF) | ||
}, time.Second, 10*time.Millisecond) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
// Copyright The OpenTelemetry Authors | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package request // import "go.opentelemetry.io/contrib/instrumentation/github.com/gorilla/mux/otelmux/internal/request" | ||
|
||
import ( | ||
"net/http" | ||
"sync" | ||
) | ||
|
||
var _ http.ResponseWriter = &RespWriterWrapper{} | ||
|
||
// RespWriterWrapper wraps a http.ResponseWriter in order to track the number of | ||
// bytes written, the last error, and to catch the first written statusCode. | ||
// TODO: The wrapped http.ResponseWriter doesn't implement any of the optional | ||
// types (http.Hijacker, http.Pusher, http.CloseNotifier, etc) | ||
// that may be useful when using it in real life situations. | ||
type RespWriterWrapper struct { | ||
http.ResponseWriter | ||
OnWrite func(n int64) // must not be nil | ||
|
||
mu sync.RWMutex | ||
written int64 | ||
statusCode int | ||
err error | ||
wroteHeader bool | ||
} | ||
|
||
// NewRespWriterWrapper creates a new RespWriterWrapper. | ||
// | ||
// The onWrite attribute is a callback that will be called every time the data | ||
// is written, with the number of bytes that were written. | ||
func NewRespWriterWrapper(w http.ResponseWriter, onWrite func(int64)) *RespWriterWrapper { | ||
return &RespWriterWrapper{ | ||
ResponseWriter: w, | ||
OnWrite: onWrite, | ||
statusCode: http.StatusOK, // default status code in case the Handler doesn't write anything | ||
} | ||
} | ||
|
||
// Write writes the bytes array into the [ResponseWriter], and tracks the | ||
// number of bytes written and last error. | ||
func (w *RespWriterWrapper) Write(p []byte) (int, error) { | ||
w.mu.Lock() | ||
defer w.mu.Unlock() | ||
|
||
if !w.wroteHeader { | ||
w.writeHeader(http.StatusOK) | ||
} | ||
|
||
n, err := w.ResponseWriter.Write(p) | ||
Check warning Code scanning / CodeQL Reflected cross-site scripting Medium
Cross-site scripting vulnerability due to
user-provided value Error loading related location Loading |
||
n1 := int64(n) | ||
w.OnWrite(n1) | ||
w.written += n1 | ||
w.err = err | ||
return n, err | ||
} | ||
|
||
// WriteHeader persists initial statusCode for span attribution. | ||
// All calls to WriteHeader will be propagated to the underlying ResponseWriter | ||
// and will persist the statusCode from the first call. | ||
// Blocking consecutive calls to WriteHeader alters expected behavior and will | ||
// remove warning logs from net/http where developers will notice incorrect handler implementations. | ||
func (w *RespWriterWrapper) WriteHeader(statusCode int) { | ||
w.mu.Lock() | ||
defer w.mu.Unlock() | ||
|
||
w.writeHeader(statusCode) | ||
} | ||
|
||
// writeHeader persists the status code for span attribution, and propagates | ||
// the call to the underlying ResponseWriter. | ||
// It does not acquire a lock, and therefore assumes that is being handled by a | ||
// parent method. | ||
func (w *RespWriterWrapper) writeHeader(statusCode int) { | ||
if !w.wroteHeader { | ||
w.wroteHeader = true | ||
w.statusCode = statusCode | ||
} | ||
w.ResponseWriter.WriteHeader(statusCode) | ||
} | ||
|
||
// Flush implements [http.Flusher]. | ||
func (w *RespWriterWrapper) Flush() { | ||
w.mu.Lock() | ||
defer w.mu.Unlock() | ||
|
||
if !w.wroteHeader { | ||
w.writeHeader(http.StatusOK) | ||
} | ||
|
||
if f, ok := w.ResponseWriter.(http.Flusher); ok { | ||
f.Flush() | ||
} | ||
} | ||
|
||
// BytesWritten returns the number of bytes written. | ||
func (w *RespWriterWrapper) BytesWritten() int64 { | ||
w.mu.RLock() | ||
defer w.mu.RUnlock() | ||
|
||
return w.written | ||
} | ||
|
||
// BytesWritten returns the HTTP status code that was sent. | ||
func (w *RespWriterWrapper) StatusCode() int { | ||
w.mu.RLock() | ||
defer w.mu.RUnlock() | ||
|
||
return w.statusCode | ||
} | ||
|
||
// Error returns the last error. | ||
func (w *RespWriterWrapper) Error() error { | ||
w.mu.RLock() | ||
defer w.mu.RUnlock() | ||
|
||
return w.err | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
// Copyright The OpenTelemetry Authors | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
package request | ||
|
||
import ( | ||
"net/http" | ||
"net/http/httptest" | ||
"testing" | ||
|
||
"github.com/stretchr/testify/assert" | ||
) | ||
|
||
func TestRespWriterWriteHeader(t *testing.T) { | ||
rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {}) | ||
|
||
rw.WriteHeader(http.StatusTeapot) | ||
assert.Equal(t, http.StatusTeapot, rw.statusCode) | ||
assert.True(t, rw.wroteHeader) | ||
|
||
rw.WriteHeader(http.StatusGone) | ||
assert.Equal(t, http.StatusTeapot, rw.statusCode) | ||
} | ||
|
||
func TestRespWriterFlush(t *testing.T) { | ||
rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {}) | ||
|
||
rw.Flush() | ||
assert.Equal(t, http.StatusOK, rw.statusCode) | ||
assert.True(t, rw.wroteHeader) | ||
} | ||
|
||
type nonFlushableResponseWriter struct{} | ||
|
||
func (_ nonFlushableResponseWriter) Header() http.Header { | ||
return http.Header{} | ||
} | ||
|
||
func (_ nonFlushableResponseWriter) Write([]byte) (int, error) { | ||
return 0, nil | ||
} | ||
|
||
func (_ nonFlushableResponseWriter) WriteHeader(int) {} | ||
|
||
func TestRespWriterFlushNoFlusher(t *testing.T) { | ||
rw := NewRespWriterWrapper(nonFlushableResponseWriter{}, func(int64) {}) | ||
|
||
rw.Flush() | ||
assert.Equal(t, http.StatusOK, rw.statusCode) | ||
assert.True(t, rw.wroteHeader) | ||
} | ||
|
||
func TestConcurrentRespWriterWrapper(t *testing.T) { | ||
rw := NewRespWriterWrapper(&httptest.ResponseRecorder{}, func(int64) {}) | ||
|
||
go func() { | ||
_, _ = rw.Write([]byte("hello world")) | ||
}() | ||
|
||
assert.NotNil(t, rw.BytesWritten()) | ||
assert.NotNil(t, rw.StatusCode()) | ||
assert.NoError(t, rw.Error()) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this still apply, since the way you wrap the default writer preserves the original implemented interfaces. Why would you want to implement them in RespWriterWrapper directly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this package is a copy from the otelhttp one.
https://github.com/open-telemetry/opentelemetry-go-contrib/tree/main/instrumentation/net/http/otelhttp/internal/request
Which makes me think maybe it should be templatized instead, to ensure things remain in sync.