Skip to content

Commit

Permalink
✨ feat(cors): Added new 'AllowOriginsFunc' function. (#2394)
Browse files Browse the repository at this point in the history
* ✨ feat(cors): Added new 'AllowOriginsFunc' function.

* feat(cors): Added warning log for when both 'AllowOrigins' and 'AllowOriginsFunc' are set.

* feat(docs): Updated docs to include note about discouraging the use of this function in production workloads.

---------

Co-authored-by: RW <[email protected]>
  • Loading branch information
Jamess-Lucass and ReneWerner87 authored Apr 11, 2023
1 parent fcf708d commit 866d5b7
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 43 deletions.
108 changes: 67 additions & 41 deletions docs/api/middleware/cors.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,52 +35,77 @@ app.Use(cors.New(cors.Config{
}))
```

Using the `AllowOriginsFunc` function. In this example any origin will be allowed via CORS.

For example, if a browser running on `http://localhost:3000` sends a request, this will be accepted and the `access-control-allow-origin` response header will be set to `http://localhost:3000`.

**Note: Using this feature is discouraged in production and it's best practice to explicitly set CORS origins via `AllowOrigins`.**

```go
app.Use(cors.New())

app.Use(cors.New(cors.Config{
AllowOriginsFunc: func(origin string) bool {
return os.Getenv("ENVIRONMENT") == "development"
},
}))
```

## Config

```go
// Config defines the config for middleware.
type Config struct {
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool

// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string

// AllowMethods defines a list of methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string

// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string

// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
//
// Optional. Default value false.
AllowCredentials bool

// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string

// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
//
// Optional. Default value 0.
MaxAge int
// Next defines a function to skip this middleware when returned true.
//
// Optional. Default: nil
Next func(c *fiber.Ctx) bool

// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// response header to the 'origin' request header when returned true.
//
// Note: Using this feature is discouraged in production and it's best practice to explicitly
// set CORS origins via 'AllowOrigins'
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool

// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
AllowOrigins string

// AllowMethods defines a list methods allowed when accessing the resource.
// This is used in response to a preflight request.
//
// Optional. Default value "GET,POST,HEAD,PUT,DELETE,PATCH"
AllowMethods string

// AllowHeaders defines a list of request headers that can be used when
// making the actual request. This is in response to a preflight request.
//
// Optional. Default value "".
AllowHeaders string

// AllowCredentials indicates whether or not the response to the request
// can be exposed when the credentials flag is true. When used as part of
// a response to a preflight request, this indicates whether or not the
// actual request can be made using credentials.
//
// Optional. Default value false.
AllowCredentials bool

// ExposeHeaders defines a whitelist headers that clients are allowed to
// access.
//
// Optional. Default value "".
ExposeHeaders string

// MaxAge indicates how long (in seconds) the results of a preflight request
// can be cached.
//
// Optional. Default value 0.
MaxAge int
}
```

Expand All @@ -89,6 +114,7 @@ type Config struct {
```go
var ConfigDefault = Config{
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
AllowMethods: strings.Join([]string{
fiber.MethodGet,
Expand Down
26 changes: 24 additions & 2 deletions middleware/cors/cors.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cors

import (
"log"
"strconv"
"strings"

Expand All @@ -14,6 +15,12 @@ type Config struct {
// Optional. Default: nil
Next func(c *fiber.Ctx) bool

// AllowOriginsFunc defines a function that will set the 'access-control-allow-origin'
// response header to the 'origin' request header when returned true.
//
// Optional. Default: nil
AllowOriginsFunc func(origin string) bool

// AllowOrigin defines a list of origins that may access the resource.
//
// Optional. Default value "*"
Expand Down Expand Up @@ -54,8 +61,9 @@ type Config struct {

// ConfigDefault is the default config
var ConfigDefault = Config{
Next: nil,
AllowOrigins: "*",
Next: nil,
AllowOriginsFunc: nil,
AllowOrigins: "*",
AllowMethods: strings.Join([]string{
fiber.MethodGet,
fiber.MethodPost,
Expand Down Expand Up @@ -88,6 +96,11 @@ func New(config ...Config) fiber.Handler {
}
}

// Warning logs if both AllowOrigins and AllowOriginsFunc are set
if cfg.AllowOrigins != "" && cfg.AllowOriginsFunc != nil {
log.Printf("[CORS] - [Warning] Both 'AllowOrigins' and 'AllowOriginsFunc' have been defined.\n")
}

// Convert string to slice
allowOrigins := strings.Split(strings.ReplaceAll(cfg.AllowOrigins, " ", ""), ",")

Expand Down Expand Up @@ -126,6 +139,15 @@ func New(config ...Config) fiber.Handler {
}
}

// Run AllowOriginsFunc if the logic for
// handling the value in 'AllowOrigins' does
// not result in allowOrigin being set.
if allowOrigin == "" && cfg.AllowOriginsFunc != nil {
if cfg.AllowOriginsFunc(origin) {
allowOrigin = origin
}
}

// Simple request
if c.Method() != fiber.MethodOptions {
c.Vary(fiber.HeaderOrigin)
Expand Down
52 changes: 52 additions & 0 deletions middleware/cors/cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cors

import (
"net/http/httptest"
"strings"
"testing"

"github.com/gofiber/fiber/v2"
Expand Down Expand Up @@ -242,3 +243,54 @@ func Test_CORS_Next(t *testing.T) {
utils.AssertEqual(t, nil, err)
utils.AssertEqual(t, fiber.StatusNotFound, resp.StatusCode)
}

func Test_CORS_AllowOriginsFunc(t *testing.T) {
t.Parallel()
// New fiber instance
app := fiber.New()
app.Use("/", New(Config{
AllowOrigins: "http://example-1.com",
AllowOriginsFunc: func(origin string) bool {
return strings.Contains(origin, "example-2")
},
}))

// Get handler pointer
handler := app.Handler()

// Make request with disallowed origin
ctx := &fasthttp.RequestCtx{}
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://google.com")

// Perform request
handler(ctx)

// Allow-Origin header should be "" because http://google.com does not satisfy http://*.example.com
utils.AssertEqual(t, "", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))

ctx.Request.Reset()
ctx.Response.Reset()

// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-1.com")

handler(ctx)

utils.AssertEqual(t, "http://example-1.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))

ctx.Request.Reset()
ctx.Response.Reset()

// Make request with allowed origin
ctx.Request.SetRequestURI("/")
ctx.Request.Header.SetMethod(fiber.MethodOptions)
ctx.Request.Header.Set(fiber.HeaderOrigin, "http://example-2.com")

handler(ctx)

utils.AssertEqual(t, "http://example-2.com", string(ctx.Response.Header.Peek(fiber.HeaderAccessControlAllowOrigin)))
}

0 comments on commit 866d5b7

Please sign in to comment.