diff --git a/middleware/cache/README.md b/middleware/cache/README.md index 10972e2f01..a5f3921a7a 100644 --- a/middleware/cache/README.md +++ b/middleware/cache/README.md @@ -125,6 +125,12 @@ type Config struct { // // Default: 0 MaxBytes uint + + // You can specify HTTP methods to cache. + // The middleware just caches the routes of its methods in this slice. + // + // Default: []string{fiber.MethodGet, fiber.MethodHead} + Methods []string } ``` @@ -144,5 +150,6 @@ var ConfigDefault = Config{ StoreResponseHeaders: false, Storage: nil, MaxBytes: 0, + Methods: []string{fiber.MethodGet, fiber.MethodHead}, } ``` diff --git a/middleware/cache/cache.go b/middleware/cache/cache.go index e761ff1731..a8485f450f 100644 --- a/middleware/cache/cache.go +++ b/middleware/cache/cache.go @@ -83,8 +83,15 @@ func New(config ...Config) fiber.Handler { // Return new handler return func(c *fiber.Ctx) error { - // Only cache GET and HEAD methods - if c.Method() != fiber.MethodGet && c.Method() != fiber.MethodHead { + // Only cache selected methods + var isExists bool + for _, method := range cfg.Methods { + if c.Method() == method { + isExists = true + } + } + + if !isExists { c.Set(cfg.CacheHeader, cacheUnreachable) return c.Next() } diff --git a/middleware/cache/cache_test.go b/middleware/cache/cache_test.go index f991d24b0a..78aeab4f71 100644 --- a/middleware/cache/cache_test.go +++ b/middleware/cache/cache_test.go @@ -173,7 +173,7 @@ func Test_Cache_Invalid_Expiration(t *testing.T) { utils.AssertEqual(t, cachedBody, body) } -func Test_Cache_Invalid_Method(t *testing.T) { +func Test_Cache_Get(t *testing.T) { t.Parallel() app := fiber.New() @@ -213,6 +213,48 @@ func Test_Cache_Invalid_Method(t *testing.T) { utils.AssertEqual(t, "123", string(body)) } +func Test_Cache_Post(t *testing.T) { + t.Parallel() + + app := fiber.New() + + app.Use(New(Config{ + Methods: []string{fiber.MethodPost}, + })) + + app.Post("/", func(c *fiber.Ctx) error { + return c.SendString(c.Query("cache")) + }) + + app.Get("/get", func(c *fiber.Ctx) error { + return c.SendString(c.Query("cache")) + }) + + resp, err := app.Test(httptest.NewRequest("POST", "/?cache=123", nil)) + utils.AssertEqual(t, nil, err) + body, err := ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "123", string(body)) + + resp, err = app.Test(httptest.NewRequest("POST", "/?cache=12345", nil)) + utils.AssertEqual(t, nil, err) + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "123", string(body)) + + resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=123", nil)) + utils.AssertEqual(t, nil, err) + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "123", string(body)) + + resp, err = app.Test(httptest.NewRequest("GET", "/get?cache=12345", nil)) + utils.AssertEqual(t, nil, err) + body, err = ioutil.ReadAll(resp.Body) + utils.AssertEqual(t, nil, err) + utils.AssertEqual(t, "12345", string(body)) +} + func Test_Cache_NothingToCache(t *testing.T) { t.Parallel() @@ -428,10 +470,12 @@ func Test_Cache_WithHead(t *testing.T) { req := httptest.NewRequest("HEAD", "/", nil) resp, err := app.Test(req) + utils.AssertEqual(t, nil, err) utils.AssertEqual(t, cacheMiss, resp.Header.Get("X-Cache")) cachedReq := httptest.NewRequest("HEAD", "/", nil) cachedResp, err := app.Test(cachedReq) + utils.AssertEqual(t, nil, err) utils.AssertEqual(t, cacheHit, cachedResp.Header.Get("X-Cache")) body, err := ioutil.ReadAll(resp.Body) diff --git a/middleware/cache/config.go b/middleware/cache/config.go index 625d1c478b..12f81e2ae8 100644 --- a/middleware/cache/config.go +++ b/middleware/cache/config.go @@ -66,6 +66,12 @@ type Config struct { // // Default: 0 MaxBytes uint + + // You can specify HTTP methods to cache. + // The middleware just caches the routes of its methods in this slice. + // + // Default: []string{fiber.MethodGet, fiber.MethodHead} + Methods []string } // ConfigDefault is the default config @@ -81,6 +87,7 @@ var ConfigDefault = Config{ StoreResponseHeaders: false, Storage: nil, MaxBytes: 0, + Methods: []string{fiber.MethodGet, fiber.MethodHead}, } // Helper function to set default values @@ -114,5 +121,8 @@ func configDefault(config ...Config) Config { if cfg.KeyGenerator == nil { cfg.KeyGenerator = ConfigDefault.KeyGenerator } + if len(cfg.Methods) == 0 { + cfg.Methods = ConfigDefault.Methods + } return cfg }