From cceee10e70acbdfb5a44be05210dfc9f4ef4f500 Mon Sep 17 00:00:00 2001 From: k1LoW Date: Wed, 31 Jan 2024 06:55:28 +0900 Subject: [PATCH] Support for setting header names to mask --- internal_test.go | 31 +++++++++++++++++--- rc.go | 76 ++++++++++++++++++++++++++++-------------------- 2 files changed, 72 insertions(+), 35 deletions(-) diff --git a/internal_test.go b/internal_test.go index f4e0b3f..2e7994b 100644 --- a/internal_test.go +++ b/internal_test.go @@ -54,14 +54,16 @@ func TestDuplicateRequest(t *testing.T) { func TestMaskHeader(t *testing.T) { tests := []struct { - h http.Header - want http.Header + h http.Header + headerNamesToMask []string + want http.Header }{ - {nil, nil}, + {nil, nil, nil}, { http.Header{ "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, }, + defaultHeaderNamesToMask, http.Header{ "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, }, @@ -71,6 +73,7 @@ func TestMaskHeader(t *testing.T) { "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, "Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"}, }, + defaultHeaderNamesToMask, http.Header{ "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, "Set-Cookie": []string{"*****"}, @@ -84,6 +87,7 @@ func TestMaskHeader(t *testing.T) { "session=5678; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly", }, }, + defaultHeaderNamesToMask, http.Header{ "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, "Set-Cookie": []string{"*****"}, @@ -95,6 +99,7 @@ func TestMaskHeader(t *testing.T) { "Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"}, "Authorization": []string{"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."}, }, + defaultHeaderNamesToMask, http.Header{ "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, "Set-Cookie": []string{"*****"}, @@ -110,16 +115,34 @@ func TestMaskHeader(t *testing.T) { "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.", }, }, + defaultHeaderNamesToMask, http.Header{ "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, "Set-Cookie": []string{"*****"}, "Authorization": []string{"*****"}, }, }, + { + http.Header{ + "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, + "Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"}, + "Authorization": []string{"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."}, + }, + nil, + http.Header{ + "Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"}, + "Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"}, + "Authorization": []string{"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."}, + }, + }, } for i, tt := range tests { t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { - got := maskHeader(tt.h) + m := &cacheMw{ + headerNamesToMask: defaultHeaderNamesToMask, + } + HeaderNamesToMask(tt.headerNamesToMask)(m) + got := m.maskHeader(tt.h) if !reflect.DeepEqual(got, tt.want) { t.Errorf("got %v want %v", got, tt.want) } diff --git a/rc.go b/rc.go index 56e7b9f..258312f 100644 --- a/rc.go +++ b/rc.go @@ -12,6 +12,12 @@ import ( "github.com/2manymws/rc/rfc9111" ) +var defaultHeaderNamesToMask = []string{ + "Authorization", + "Cookie", + "Set-Cookie", +} + type Cacher interface { // Load loads the request/response cache. // If the cache is not found, it returns ErrCacheNotFound. @@ -56,15 +62,17 @@ func newCacher(c Cacher) *cacher { } type cacheMw struct { - cacher *cacher - useRequestBody bool - logger *slog.Logger + cacher *cacher + useRequestBody bool + logger *slog.Logger + headerNamesToMask []string } func newCacheMw(c Cacher, opts ...Option) *cacheMw { cc := newCacher(c) m := &cacheMw{ - cacher: cc, + cacher: cc, + headerNamesToMask: defaultHeaderNamesToMask, } for _, opt := range opts { opt(m) @@ -86,18 +94,18 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler { if err != nil { switch { case errors.Is(err, ErrCacheNotFound): - m.logger.Debug("cache not found", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header))) + m.logger.Debug("cache not found", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header))) case errors.Is(err, ErrCacheExpired): - m.logger.Debug("cache expired", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header))) + m.logger.Debug("cache expired", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header))) case errors.Is(err, ErrShouldNotUseCache): - m.logger.Debug("should not use cache", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header))) + m.logger.Debug("should not use cache", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header))) default: - m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header))) + m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header))) } } cacheUsed, res, err := m.cacher.Handle(req, cachedReq, cachedRes, HandlerToRequester(next), now) //nostyle:handlerrors if err != nil { - m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header))) + m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header))) } // Response @@ -115,27 +123,27 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler { w.WriteHeader(res.StatusCode) body, err := io.ReadAll(res.Body) if err != nil { - m.logger.Error("failed to read response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header))) + m.logger.Error("failed to read response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header))) } else { if _, err := w.Write(body); err != nil { if errors.Is(err, http.ErrBodyNotAllowed) { - m.logger.Warn("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header))) + m.logger.Warn("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header))) } else { - m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header))) + m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header))) } } } if err := res.Body.Close(); err != nil { - m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header))) + m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header))) } if cacheUsed { - m.logger.Debug("cache used", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode)) + m.logger.Debug("cache used", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode)) return } ok, expires := m.cacher.Storable(preq, res, now) if !ok { - m.logger.Debug("cache not storable", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header))) + m.logger.Debug("cache not storable", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header))) return } // Restore response body @@ -143,9 +151,9 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler { // Store response as cache if err := m.cacher.Store(preq, res, expires); err != nil { - m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode)) + m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode)) } - m.logger.Debug("cache stored", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode)) + m.logger.Debug("cache stored", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode)) }) } @@ -158,16 +166,27 @@ func (m *cacheMw) duplicateRequest(req *http.Request) (*http.Request, *http.Requ } b, err := io.ReadAll(copy.Body) if err != nil { - m.logger.Error("failed to read request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", maskHeader(copy.Header)), slog.String("url", copy.URL.String())) + m.logger.Error("failed to read request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", m.maskHeader(copy.Header)), slog.String("url", copy.URL.String())) } if err := copy.Body.Close(); err != nil { - m.logger.Error("failed to close request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", maskHeader(copy.Header)), slog.String("url", copy.URL.String())) + m.logger.Error("failed to close request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", m.maskHeader(copy.Header)), slog.String("url", copy.URL.String())) } req.Body = io.NopCloser(bytes.NewReader(b)) copy.Body = io.NopCloser(bytes.NewReader(b)) return copy, req } +func (m *cacheMw) maskHeader(h http.Header) http.Header { + const masked = "*****" + c := h.Clone() + for _, n := range m.headerNamesToMask { + if c.Get(n) != "" { + c.Set(n, masked) + } + } + return c +} + type Option func(*cacheMw) // WithLogger sets logger (slog.Logger). @@ -184,6 +203,13 @@ func UseRequestBody() Option { } } +// HeaderNamesToMask sets header names to mask. +func HeaderNamesToMask(names []string) Option { + return func(m *cacheMw) { + m.headerNamesToMask = names + } +} + // New returns a new response cache middleware. func New(cacher Cacher, opts ...Option) func(next http.Handler) http.Handler { rl := newCacheMw(cacher, opts...) @@ -200,15 +226,3 @@ func HandlerToRequester(h http.Handler) func(*http.Request) (*http.Response, err return res, nil } } - -func maskHeader(h http.Header) http.Header { - const masked = "*****" - c := h.Clone() - if c.Get("Set-Cookie") != "" { - c.Set("Set-Cookie", masked) - } - if c.Get("Authorization") != "" { - c.Set("Authorization", masked) - } - return c -}