Skip to content

Commit

Permalink
distinguish between explicit and implicit star (#118)
Browse files Browse the repository at this point in the history
  • Loading branch information
cainlevy authored and elithrar committed Apr 16, 2018
1 parent 9066371 commit c5874fa
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 28 deletions.
24 changes: 16 additions & 8 deletions cors.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,17 @@ func (ch *cors) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

returnOrigin := origin
for _, o := range ch.allowedOrigins {
// A configuration of * is different than explicitly setting an allowed
// origin. Returning arbitrary origin headers an an access control allow
// origin header is unsafe and is not required by any use case.
if o == corsOriginMatchAll {
returnOrigin = "*"
break
if ch.allowedOriginValidator == nil && len(ch.allowedOrigins) == 0 {
returnOrigin = "*"
} else {
for _, o := range ch.allowedOrigins {
// A configuration of * is different than explicitly setting an allowed
// origin. Returning arbitrary origin headers an an access control allow
// origin header is unsafe and is not required by any use case.
if o == corsOriginMatchAll {
returnOrigin = "*"
break
}
}
}
w.Header().Set(corsAllowOriginHeader, returnOrigin)
Expand Down Expand Up @@ -159,7 +163,7 @@ func parseCORSOptions(opts ...CORSOption) *cors {
ch := &cors{
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{corsOriginMatchAll},
allowedOrigins: []string{},
}

for _, option := range opts {
Expand Down Expand Up @@ -307,6 +311,10 @@ func (ch *cors) isOriginAllowed(origin string) bool {
return ch.allowedOriginValidator(origin)
}

if len(ch.allowedOrigins) == 0 {
return true
}

for _, allowedOrigin := range ch.allowedOrigins {
if allowedOrigin == origin || allowedOrigin == corsOriginMatchAll {
return true
Expand Down
42 changes: 22 additions & 20 deletions cors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestCORSWithMultipleHandlers(t *testing.T) {
}
}

func TestCORSHandlerWithCustomValidator(t *testing.T) {
func TestCORSOriginValidatorWithImplicitStar(t *testing.T) {
r := newRequest("GET", "http://a.example.com")
r.Header.Set("Origin", r.URL.String())
rr := httptest.NewRecorder()
Expand All @@ -327,45 +327,47 @@ func TestCORSHandlerWithCustomValidator(t *testing.T) {
return false
}

// Specially craft a CORS object.
handleFunc := func(h http.Handler) http.Handler {
c := &cors{
allowedMethods: defaultCorsMethods,
allowedHeaders: defaultCorsHeaders,
allowedOrigins: []string{"http://a.example.com"},
h: h,
}
AllowedOriginValidator(originValidator)(c)
return c
}

handleFunc(testHandler).ServeHTTP(rr, r)
CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r)
header := rr.HeaderMap.Get(corsAllowOriginHeader)
if header != r.URL.String() {
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header)
}

}

func TestCORSAllowStar(t *testing.T) {
func TestCORSOriginValidatorWithExplicitStar(t *testing.T) {
r := newRequest("GET", "http://a.example.com")
r.Header.Set("Origin", r.URL.String())
rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

originValidator := func(origin string) bool {
if strings.HasSuffix(origin, ".example.com") {
return true
}
return false
}

CORS(AllowedOriginValidator(originValidator))(testHandler).ServeHTTP(rr, r)
CORS(
AllowedOriginValidator(originValidator),
AllowedOrigins([]string{"*"}),
)(testHandler).ServeHTTP(rr, r)
header := rr.HeaderMap.Get(corsAllowOriginHeader)
// Because * is the default CORS policy (which is safe), we should be
// expect a * returned here as the Access Control Allow Origin header
if header != "*" {
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, r.URL.String(), header)
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header)
}
}

func TestCORSAllowStar(t *testing.T) {
r := newRequest("GET", "http://a.example.com")
r.Header.Set("Origin", r.URL.String())
rr := httptest.NewRecorder()

testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})

CORS()(testHandler).ServeHTTP(rr, r)
header := rr.HeaderMap.Get(corsAllowOriginHeader)
if header != "*" {
t.Fatalf("bad header: expected %s to be %s, got %s.", corsAllowOriginHeader, "*", header)
}
}

0 comments on commit c5874fa

Please sign in to comment.