Skip to content

Commit

Permalink
Move validation of redirect to separate function, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Антон Костенко authored and stephen committed Oct 29, 2018
1 parent b094daf commit 9691907
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 29 deletions.
59 changes: 30 additions & 29 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,20 +465,10 @@ func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error)
if err != nil {
return
}
redirect = req.Form.Get("rd")
if p.allowedURL != "" {
matched, err := regexp.MatchString(p.allowedURL, redirect)
if err != nil {
log.Printf("error parsing regexp %s", err)
return redirect, err
}
if !matched {
redirect = "/"
}
} else {
if redirect == "" || !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
redirect = "/"
}
redirect, err = p.getValidatedRedirect(req.Form.Get("rd"))
if err != nil {
log.Printf("failed to validate redirect %s", err)
return redirect, err
}
return
}
Expand Down Expand Up @@ -550,6 +540,26 @@ func (p *OAuthProxy) SignIn(rw http.ResponseWriter, req *http.Request) {
}
}

func (p *OAuthProxy) getValidatedRedirect(redirect string) (string, error) {
fallbackRedirect := "/"
// We using 2 types of validation - basic checks or based on allowedURLs
switch {
case redirect == "":
return fallbackRedirect, nil
case p.allowedURL != "":
matched, err := regexp.MatchString(p.allowedURL, redirect)
if err != nil {
return fallbackRedirect, err
}
if !matched {
return fallbackRedirect, err
}
case !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//"):
return fallbackRedirect, nil
}
return redirect, nil
}

func (p *OAuthProxy) SignOut(rw http.ResponseWriter, req *http.Request) {
p.ClearSessionCookie(rw, req)
http.Redirect(rw, req, "/", 302)
Expand Down Expand Up @@ -599,7 +609,12 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return
}
nonce := s[0]
redirect := s[1]
redirect, err := p.getValidatedRedirect(s[1])
if err != nil {
log.Printf("failed to validate redirect %s", err)
p.ErrorPage(rw, 500, "Internal Error", "Internal Error")
return
}
c, err := req.Cookie(p.CSRFCookieName)
if err != nil {
p.ErrorPage(rw, 403, "Permission Denied", err.Error())
Expand All @@ -611,20 +626,6 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, 403, "Permission Denied", "csrf failed")
return
}
if p.allowedURL != "" {
matched, err := regexp.MatchString(p.allowedURL, redirect)
if err != nil {
log.Printf("error parsing regexp %s", err)
return
}
if !matched {
redirect = "/"
}
} else {
if !strings.HasPrefix(redirect, "/") || strings.HasPrefix(redirect, "//") {
redirect = "/"
}
}
// set cookie, or deny
if p.Validator(session.Email) && p.provider.ValidateGroup(session.Email) {
log.Printf("%s authentication complete %s", remoteAddr, session)
Expand Down
46 changes: 46 additions & 0 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"testing"
"time"

"github.com/bitly/oauth2_proxy/providers"
"github.com/bmizerany/assert"
"github.com/mbland/hmacauth"
"github.com/samsarahq/oauth2_proxy/providers"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -131,6 +133,50 @@ func (tp *TestProvider) ValidateSessionState(session *providers.SessionState) bo
return tp.ValidToken
}

func TestGetValidatedRedirect(t *testing.T) {
opts := NewOptions()
opts.ClientID = "bazquux"
opts.ClientSecret = "foobar"
opts.CookieSecret = "xyzzyplugh"

opts.AllowedURL = ".+\\.internals\\.example\\.com"
opts.Validate()

proxy := NewOAuthProxy(opts, func(string) bool { return true })
fallbackRedirect := "/"

noRD, err := proxy.getValidatedRedirect("")
assert.Equal(t, nil, err)
assert.Equal(t, fallbackRedirect, noRD)

singleSlash, err := proxy.getValidatedRedirect("/redirect")
assert.Equal(t, nil, err)
assert.Equal(t, singleSlash, singleSlash)

doubleSlash, err := proxy.getValidatedRedirect("//redirect")
assert.Equal(t, nil, err)
assert.Equal(t, fallbackRedirect, doubleSlash)

validHttp, err := proxy.getValidatedRedirect("http://internals.example.com/redirect")
assert.Equal(t, nil, err)
assert.Equal(t, validHttp, validHttp)

validHttps, err := proxy.getValidatedRedirect("https://internals.example.com/redirect")
assert.Equal(t, nil, err)
assert.Equal(t, validHttps, validHttps)

invalidHttp, err := proxy.getValidatedRedirect("http://internals.corporate.com/redirect")
assert.Equal(t, nil, err)
assert.Equal(t, fallbackRedirect, invalidHttp)

// Test for incorrect regexp
opts.AllowedURL = "*"
opts.Validate()
proxy = NewOAuthProxy(opts, func(string) bool { return true })
_, hasErr := proxy.getValidatedRedirect("http://internals.corporate.com/redirect")
assert.NotEqual(t, nil, hasErr)
}

func TestBasicAuthPassword(t *testing.T) {
provider_server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Printf("%#v", r)
Expand Down

0 comments on commit 9691907

Please sign in to comment.