Skip to content

Commit

Permalink
Merge pull request #64 from 2manymws/header-names-to-mask
Browse files Browse the repository at this point in the history
Support for setting header names to mask
  • Loading branch information
k1LoW authored Jan 30, 2024
2 parents 56f4ada + cceee10 commit 46d88ac
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 35 deletions.
31 changes: 27 additions & 4 deletions internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,16 @@ func TestDuplicateRequest(t *testing.T) {

func TestMaskHeader(t *testing.T) {
tests := []struct {
h http.Header
want http.Header
h http.Header
headerNamesToMask []string
want http.Header
}{
{nil, nil},
{nil, nil, nil},
{
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
},
defaultHeaderNamesToMask,
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
},
Expand All @@ -71,6 +73,7 @@ func TestMaskHeader(t *testing.T) {
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"},
},
defaultHeaderNamesToMask,
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"*****"},
Expand All @@ -84,6 +87,7 @@ func TestMaskHeader(t *testing.T) {
"session=5678; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly",
},
},
defaultHeaderNamesToMask,
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"*****"},
Expand All @@ -95,6 +99,7 @@ func TestMaskHeader(t *testing.T) {
"Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"},
"Authorization": []string{"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."},
},
defaultHeaderNamesToMask,
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"*****"},
Expand All @@ -110,16 +115,34 @@ func TestMaskHeader(t *testing.T) {
"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.",
},
},
defaultHeaderNamesToMask,
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"*****"},
"Authorization": []string{"*****"},
},
},
{
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"},
"Authorization": []string{"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."},
},
nil,
http.Header{
"Date": []string{"Mon, 02 Jan 2006 15:04:05 GMT"},
"Set-Cookie": []string{"session=1234; Path=/; Expires=Wed, 09 Jun 2021 10:18:14 GMT; HttpOnly"},
"Authorization": []string{"Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9."},
},
},
}
for i, tt := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
got := maskHeader(tt.h)
m := &cacheMw{
headerNamesToMask: defaultHeaderNamesToMask,
}
HeaderNamesToMask(tt.headerNamesToMask)(m)
got := m.maskHeader(tt.h)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("got %v want %v", got, tt.want)
}
Expand Down
76 changes: 45 additions & 31 deletions rc.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (
"github.com/2manymws/rc/rfc9111"
)

var defaultHeaderNamesToMask = []string{
"Authorization",
"Cookie",
"Set-Cookie",
}

type Cacher interface {
// Load loads the request/response cache.
// If the cache is not found, it returns ErrCacheNotFound.
Expand Down Expand Up @@ -56,15 +62,17 @@ func newCacher(c Cacher) *cacher {
}

type cacheMw struct {
cacher *cacher
useRequestBody bool
logger *slog.Logger
cacher *cacher
useRequestBody bool
logger *slog.Logger
headerNamesToMask []string
}

func newCacheMw(c Cacher, opts ...Option) *cacheMw {
cc := newCacher(c)
m := &cacheMw{
cacher: cc,
cacher: cc,
headerNamesToMask: defaultHeaderNamesToMask,
}
for _, opt := range opts {
opt(m)
Expand All @@ -86,18 +94,18 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler {
if err != nil {
switch {
case errors.Is(err, ErrCacheNotFound):
m.logger.Debug("cache not found", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)))
m.logger.Debug("cache not found", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
case errors.Is(err, ErrCacheExpired):
m.logger.Debug("cache expired", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)))
m.logger.Debug("cache expired", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
case errors.Is(err, ErrShouldNotUseCache):
m.logger.Debug("should not use cache", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)))
m.logger.Debug("should not use cache", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
default:
m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)))
m.logger.Error("failed to load cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
}
}
cacheUsed, res, err := m.cacher.Handle(req, cachedReq, cachedRes, HandlerToRequester(next), now) //nostyle:handlerrors
if err != nil {
m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)))
m.logger.Error("failed to handle cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)))
}

// Response
Expand All @@ -115,37 +123,37 @@ func (m *cacheMw) Handler(next http.Handler) http.Handler {
w.WriteHeader(res.StatusCode)
body, err := io.ReadAll(res.Body)
if err != nil {
m.logger.Error("failed to read response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header)))
m.logger.Error("failed to read response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
} else {
if _, err := w.Write(body); err != nil {
if errors.Is(err, http.ErrBodyNotAllowed) {
m.logger.Warn("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header)))
m.logger.Warn("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
} else {
m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header)))
m.logger.Error("failed to write response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}
}
}
if err := res.Body.Close(); err != nil {
m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header)))
m.logger.Error("failed to close response body", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
}

if cacheUsed {
m.logger.Debug("cache used", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
m.logger.Debug("cache used", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
return
}
ok, expires := m.cacher.Storable(preq, res, now)
if !ok {
m.logger.Debug("cache not storable", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", maskHeader(res.Header)))
m.logger.Debug("cache not storable", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode), slog.Any("response_headers", m.maskHeader(res.Header)))
return
}
// Restore response body
res.Body = io.NopCloser(bytes.NewReader(body))

// Store response as cache
if err := m.cacher.Store(preq, res, expires); err != nil {
m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
m.logger.Error("failed to store cache", slog.String("error", err.Error()), slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
}
m.logger.Debug("cache stored", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
m.logger.Debug("cache stored", slog.String("host", preq.Host), slog.String("method", preq.Method), slog.String("url", preq.URL.String()), slog.Any("headers", m.maskHeader(preq.Header)), slog.Int("status", res.StatusCode))
})
}

Expand All @@ -158,16 +166,27 @@ func (m *cacheMw) duplicateRequest(req *http.Request) (*http.Request, *http.Requ
}
b, err := io.ReadAll(copy.Body)
if err != nil {
m.logger.Error("failed to read request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", maskHeader(copy.Header)), slog.String("url", copy.URL.String()))
m.logger.Error("failed to read request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", m.maskHeader(copy.Header)), slog.String("url", copy.URL.String()))
}
if err := copy.Body.Close(); err != nil {
m.logger.Error("failed to close request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", maskHeader(copy.Header)), slog.String("url", copy.URL.String()))
m.logger.Error("failed to close request body", slog.String("error", err.Error()), slog.String("host", copy.Host), slog.String("method", copy.Method), slog.Any("headers", m.maskHeader(copy.Header)), slog.String("url", copy.URL.String()))
}
req.Body = io.NopCloser(bytes.NewReader(b))
copy.Body = io.NopCloser(bytes.NewReader(b))
return copy, req
}

func (m *cacheMw) maskHeader(h http.Header) http.Header {
const masked = "*****"
c := h.Clone()
for _, n := range m.headerNamesToMask {
if c.Get(n) != "" {
c.Set(n, masked)
}
}
return c
}

type Option func(*cacheMw)

// WithLogger sets logger (slog.Logger).
Expand All @@ -184,6 +203,13 @@ func UseRequestBody() Option {
}
}

// HeaderNamesToMask sets header names to mask.
func HeaderNamesToMask(names []string) Option {
return func(m *cacheMw) {
m.headerNamesToMask = names
}
}

// New returns a new response cache middleware.
func New(cacher Cacher, opts ...Option) func(next http.Handler) http.Handler {
rl := newCacheMw(cacher, opts...)
Expand All @@ -200,15 +226,3 @@ func HandlerToRequester(h http.Handler) func(*http.Request) (*http.Response, err
return res, nil
}
}

func maskHeader(h http.Header) http.Header {
const masked = "*****"
c := h.Clone()
if c.Get("Set-Cookie") != "" {
c.Set("Set-Cookie", masked)
}
if c.Get("Authorization") != "" {
c.Set("Authorization", masked)
}
return c
}

0 comments on commit 46d88ac

Please sign in to comment.