From 27ea33c0f488a74b1bd1a553e8ba272cade4123c Mon Sep 17 00:00:00 2001 From: ajanthan Date: Thu, 8 Oct 2020 18:59:06 -0700 Subject: [PATCH 01/22] feat: Adding form_post support --- authorize_error.go | 6 +- authorize_error_test.go | 29 ++++ authorize_helper.go | 122 +++++++++++++ authorize_helper_test.go | 13 ++ authorize_response.go | 13 ++ authorize_response_test.go | 4 + authorize_response_writer.go | 1 + authorize_write.go | 64 ++++--- authorize_write_test.go | 19 +++ handler/oauth2/flow_authorize_code_auth.go | 13 +- handler/oauth2/flow_authorize_implicit.go | 18 +- handler/openid/flow_hybrid.go | 24 ++- handler/openid/flow_implicit.go | 14 +- handler/openid/helper.go | 7 +- integration/authorize_form_post_test.go | 189 +++++++++++++++++++++ internal/authorize_response.go | 26 +++ oauth2.go | 6 + 17 files changed, 521 insertions(+), 47 deletions(-) create mode 100644 integration/authorize_form_post_test.go diff --git a/authorize_error.go b/authorize_error.go index 3608ca480..c025e08d1 100644 --- a/authorize_error.go +++ b/authorize_error.go @@ -66,7 +66,11 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest query.Add("state", ar.GetState()) var redirectURIString string - if !(len(ar.GetResponseTypes()) == 0 || ar.GetResponseTypes().ExactOne("code")) && !errors.Is(err, ErrUnsupportedResponseType) { + if ar.GetRequestForm().Get("response_mode") == "form_post" { + rw.Header().Add("Content-Type", "text/html;charset=UTF-8") + WriteAuthorizeFormPostResponse(redirectURI.String(), query, rw) + return + } else if !(len(ar.GetResponseTypes()) == 0 || ar.GetResponseTypes().ExactOne("code")) && !errors.Is(err, ErrUnsupportedResponseType) { redirectURIString = redirectURI.String() + "#" + query.Encode() } else { for key, values := range redirectURI.Query() { diff --git a/authorize_error_test.go b/authorize_error_test.go index 207772b4e..34dc125c2 100644 --- a/authorize_error_test.go +++ b/authorize_error_test.go @@ -88,6 +88,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -106,6 +107,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -124,6 +126,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -142,6 +145,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"foobar"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -160,6 +164,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -178,6 +183,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -196,6 +202,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -214,6 +221,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -233,6 +241,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -252,6 +261,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"id_token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -271,6 +281,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) + req.EXPECT().GetRequestForm().Return(url.Values{}) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -282,6 +293,24 @@ func TestWriteAuthorizeError(t *testing.T) { assert.Equal(t, "no-cache", header.Get("Pragma")) }, }, + { + debug: true, + err: ErrInvalidRequest.WithDebug("with-debug"), + mock: func(rw *MockResponseWriter, req *MockAuthorizeRequester) { + req.EXPECT().IsRedirectURIValid().Return(true) + req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) + req.EXPECT().GetState().Return("foostate") + req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) + req.EXPECT().GetRequestForm().Return(url.Values{"response_mode": {"form_post"}}) + rw.EXPECT().Header().Times(3).Return(header) + rw.EXPECT().Write(gomock.Any()).AnyTimes() + }, + checkHeader: func(t *testing.T, k int) { + assert.Equal(t, "no-store", header.Get("Cache-Control")) + assert.Equal(t, "no-cache", header.Get("Pragma")) + assert.Equal(t, "text/html;charset=UTF-8", header.Get("Content-Type")) + }, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { oauth2 := &Fosite{ diff --git a/authorize_helper.go b/authorize_helper.go index 1e1251459..b21cfb040 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -22,9 +22,16 @@ package fosite import ( + "html/template" + "io" "net/url" "regexp" + "strconv" "strings" + "time" + + "golang.org/x/net/html" + goauth "golang.org/x/oauth2" "github.com/asaskevich/govalidator" "github.com/pkg/errors" @@ -175,3 +182,118 @@ func IsLocalhost(redirectURI *url.URL) bool { hn := redirectURI.Hostname() return strings.HasSuffix(hn, ".localhost") || hn == "127.0.0.1" || hn == "localhost" } + +func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, rw io.Writer) { + t := template.Must(template.New("form_post").Parse(` + + Submit This Form + + +
+ {{ range $key,$value := .Parameters }} + + {{ end }} +
+ +`)) + + _ = t.Execute(rw, struct { + RedirURL string + Parameters url.Values + }{ + RedirURL: redirectURL, + Parameters: parameters, + }) +} +func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, rFC6749Error RFC6749Error, err error) { + + token = goauth.Token{} + rFC6749Error = RFC6749Error{} + + doc, err := html.Parse(resp) + if err != nil { + return "", "", "", token, rFC6749Error, err + } + //doc>html>body + body := findBody(doc.FirstChild.FirstChild) + if body.Data != "body" { + return "", "", "", token, rFC6749Error, errors.New("Malformed html") + } + htmlEvent := body.Attr[0].Key + if htmlEvent != "onload" { + return "", "", "", token, rFC6749Error, errors.New("onload event is missing") + } + onLoadFunc := body.Attr[0].Val + if onLoadFunc != "javascript:document.forms[0].submit()" { + return "", "", "", token, rFC6749Error, errors.New("onload function is missing") + } + form := getNextNoneTextNode(body.FirstChild) + if form.Data != "form" { + return "", "", "", token, rFC6749Error, errors.New("html form is missing") + } + for _, attr := range form.Attr { + if attr.Key == "method" { + if attr.Val != "post" { + return "", "", "", token, rFC6749Error, errors.New("html form post method is missing") + } + } else { + if attr.Val != redirectURL { + return "", "", "", token, rFC6749Error, errors.New("html form post url is wrong") + } + } + } + + for node := getNextNoneTextNode(form.FirstChild); node != nil; node = getNextNoneTextNode(node.NextSibling) { + var k, v string + for _, attr := range node.Attr { + if attr.Key == "name" { + k = attr.Val + } else if attr.Key == "value" { + v = attr.Val + } + + } + switch k { + case "state": + stateFromServer = v + case "code": + authorizationCode = v + case "expires_in": + expires, err := strconv.Atoi(v) + if err != nil { + return "", "", "", token, rFC6749Error, err + } + token.Expiry = time.Now().UTC().Add(time.Duration(expires) * time.Second) + case "access_token": + token.AccessToken = v + case "token_type": + token.TokenType = v + case "refresh_token": + token.RefreshToken = v + case "error": + rFC6749Error.Name = v + case "error_description": + rFC6749Error.Description = v + case "id_token": + iDToken = v + } + } + return +} + +func getNextNoneTextNode(node *html.Node) *html.Node { + nextNode := node.NextSibling + if nextNode != nil && nextNode.Type == html.TextNode { + nextNode = getNextNoneTextNode(node.NextSibling) + } + return nextNode +} +func findBody(node *html.Node) *html.Node { + if node != nil { + if node.Data == "body" { + return node + } + return findBody(node.NextSibling) + } + return nil +} diff --git a/authorize_helper_test.go b/authorize_helper_test.go index 4d3c00946..7159a715e 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -22,6 +22,8 @@ package fosite import ( + "bytes" + "io/ioutil" "net/url" "testing" @@ -253,3 +255,14 @@ func TestIsRedirectURISecure(t *testing.T) { assert.Equal(t, !c.err, IsRedirectURISecure(uu), "case %d", d) } } + +func TestWriteAuthorizeFormPostResponse(t *testing.T) { + var responseBuffer bytes.Buffer + redirectURL := "https://localhost:8080/cb" + parameters := url.Values{"code": {"lshr755nsg39fgur"}, "state": {"924659540232"}} + WriteAuthorizeFormPostResponse(redirectURL, parameters, &responseBuffer) + code, state, _, _, _, err := ParseFormPostResponse(redirectURL, ioutil.NopCloser(bytes.NewReader(responseBuffer.Bytes()))) + assert.NoError(t, err) + assert.Equal(t, parameters.Get("code"), code) + assert.Equal(t, parameters.Get("state"), state) +} diff --git a/authorize_response.go b/authorize_response.go index c7feac3f5..8e053d5cb 100644 --- a/authorize_response.go +++ b/authorize_response.go @@ -31,6 +31,7 @@ type AuthorizeResponse struct { Header http.Header Query url.Values Fragment url.Values + Form url.Values code string } @@ -39,6 +40,7 @@ func NewAuthorizeResponse() *AuthorizeResponse { Header: http.Header{}, Query: url.Values{}, Fragment: url.Values{}, + Form: url.Values{}, } } @@ -62,6 +64,10 @@ func (a *AuthorizeResponse) GetFragment() url.Values { return a.Fragment } +func (a *AuthorizeResponse) GetForm() url.Values { + return a.Form +} + func (a *AuthorizeResponse) AddQuery(key, value string) { if key == "code" { a.code = value @@ -75,3 +81,10 @@ func (a *AuthorizeResponse) AddFragment(key, value string) { } a.Fragment.Add(key, value) } + +func (a *AuthorizeResponse) AddForm(key, value string) { + if key == "code" { + a.code = value + } + a.Form.Add(key, value) +} diff --git a/authorize_response_test.go b/authorize_response_test.go index bd3d95ae0..ebb371dea 100644 --- a/authorize_response_test.go +++ b/authorize_response_test.go @@ -32,13 +32,17 @@ func TestAuthorizeResponse(t *testing.T) { ar.AddFragment("foo", "bar") ar.AddQuery("foo", "baz") ar.AddHeader("foo", "foo") + ar.AddForm("bar", "bar") ar.AddFragment("code", "bar") assert.Equal(t, "bar", ar.GetCode()) ar.AddQuery("code", "baz") assert.Equal(t, "baz", ar.GetCode()) + ar.AddForm("code", "baz") + assert.Equal(t, "baz", ar.GetCode()) assert.Equal(t, "bar", ar.GetFragment().Get("foo")) assert.Equal(t, "baz", ar.GetQuery().Get("foo")) assert.Equal(t, "foo", ar.GetHeader().Get("foo")) + assert.Equal(t, "bar", ar.GetForm().Get("bar")) } diff --git a/authorize_response_writer.go b/authorize_response_writer.go index 208b104c2..b5b032708 100644 --- a/authorize_response_writer.go +++ b/authorize_response_writer.go @@ -34,6 +34,7 @@ func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester Header: http.Header{}, Query: url.Values{}, Fragment: url.Values{}, + Form: url.Values{}, } ar.SetSession(session) diff --git a/authorize_write.go b/authorize_write.go index 9556a35c1..8fa28fb4a 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -32,15 +32,6 @@ var ( ) func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) { - redir := ar.GetRedirectURI() - - // Explicit grants - q := redir.Query() - rq := resp.GetQuery() - for k := range rq { - q.Set(k, rq.Get(k)) - } - redir.RawQuery = q.Encode() // Set custom headers, e.g. "X-MySuperCoolCustomHeader" or "X-DONT-CACHE-ME"... wh := rw.Header() @@ -49,27 +40,46 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ wh.Set(k, rh.Get(k)) } - // Implicit grants - // The endpoint URI MUST NOT include a fragment component. - redir.Fragment = "" + wh.Set("Cache-Control", "no-store") + wh.Set("Pragma", "no-cache") - u := redir.String() + redir := ar.GetRedirectURI() + form := resp.GetForm() - fr := resp.GetFragment() - if len(fr) > 0 { - u = u + "#" + fr.Encode() - } + if len(form) > 0 { + //form_post + rw.Header().Add("Content-Type", "text/html;charset=UTF-8") + WriteAuthorizeFormPostResponse(redir.String(), form, rw) + } else { - u = plusMatch.ReplaceAllString(u, "%20") + // Explicit grants + q := redir.Query() + rq := resp.GetQuery() + for k := range rq { + q.Set(k, rq.Get(k)) + } + redir.RawQuery = q.Encode() - wh.Set("Cache-Control", "no-store") - wh.Set("Pragma", "no-cache") + // Implicit grants + // The endpoint URI MUST NOT include a fragment component. + redir.Fragment = "" + + u := redir.String() + + fr := resp.GetFragment() + if len(fr) > 0 { + u = u + "#" + fr.Encode() + } + + u = plusMatch.ReplaceAllString(u, "%20") + + // https://tools.ietf.org/html/rfc6749#section-4.1.1 + // When a decision is established, the authorization server directs the + // user-agent to the provided client redirection URI using an HTTP + // redirection response, or by other means available to it via the + // user-agent. + wh.Set("Location", u) + rw.WriteHeader(http.StatusFound) + } - // https://tools.ietf.org/html/rfc6749#section-4.1.1 - // When a decision is established, the authorization server directs the - // user-agent to the provided client redirection URI using an HTTP - // redirection response, or by other means available to it via the - // user-agent. - wh.Set("Location", u) - rw.WriteHeader(http.StatusFound) } diff --git a/authorize_write_test.go b/authorize_write_test.go index dd4d12a20..4630d775a 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -53,6 +53,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { resp.EXPECT().GetFragment().Return(url.Values{}) resp.EXPECT().GetHeader().Return(http.Header{}) resp.EXPECT().GetQuery().Return(url.Values{}) + resp.EXPECT().GetForm().Return(url.Values{}) rw.EXPECT().Header().Return(header) rw.EXPECT().WriteHeader(http.StatusFound) @@ -72,6 +73,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}}) resp.EXPECT().GetHeader().Return(http.Header{}) resp.EXPECT().GetQuery().Return(url.Values{}) + resp.EXPECT().GetForm().Return(url.Values{}) rw.EXPECT().Header().Return(header) rw.EXPECT().WriteHeader(http.StatusFound) @@ -91,6 +93,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}}) resp.EXPECT().GetHeader().Return(http.Header{}) resp.EXPECT().GetQuery().Return(url.Values{"bar": {"baz"}}) + resp.EXPECT().GetForm().Return(url.Values{}) rw.EXPECT().Header().Return(header) rw.EXPECT().WriteHeader(http.StatusFound) @@ -110,6 +113,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}, "scope": {"a b"}}) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) resp.EXPECT().GetQuery().Return(url.Values{"bar": {"b+az"}, "scope": {"a b"}}) + resp.EXPECT().GetForm().Return(url.Values{}) rw.EXPECT().Header().Return(header) rw.EXPECT().WriteHeader(http.StatusFound) @@ -130,6 +134,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}, "scope": {"api:*"}}) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) resp.EXPECT().GetQuery().Return(url.Values{"bar": {"b+az"}, "scope": {"api:*"}}) + resp.EXPECT().GetForm().Return(url.Values{}) rw.EXPECT().Header().Return(header) rw.EXPECT().WriteHeader(http.StatusFound) @@ -143,6 +148,20 @@ func TestWriteAuthorizeResponse(t *testing.T) { }, header) }, }, + { + setup: func() { + redir, _ := url.Parse("https://foobar.com/?foo=bar") + ar.EXPECT().GetRedirectURI().Return(redir) + resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) + resp.EXPECT().GetForm().Return(url.Values{"code": {"poz65kqoneu"}, "state": {"qm6dnsrn"}}) + + rw.EXPECT().Header().Return(header).AnyTimes() + rw.EXPECT().Write(gomock.Any()).AnyTimes() + }, + expect: func() { + assert.Equal(t, "text/html;charset=UTF-8", header.Get("Content-Type")) + }, + }, } { t.Logf("Starting test case %d", k) c.setup() diff --git a/handler/oauth2/flow_authorize_code_auth.go b/handler/oauth2/flow_authorize_code_auth.go index 5e5b1a665..dfdf7edc2 100644 --- a/handler/oauth2/flow_authorize_code_auth.go +++ b/handler/oauth2/flow_authorize_code_auth.go @@ -110,10 +110,15 @@ func (c *AuthorizeExplicitGrantHandler) IssueAuthorizeCode(ctx context.Context, if err := c.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, ar.Sanitize(c.GetSanitationWhiteList())); err != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } - - resp.AddQuery("code", code) - resp.AddQuery("state", ar.GetState()) - resp.AddQuery("scope", strings.Join(ar.GetGrantedScopes(), " ")) + if ar.GetRequestForm().Get("response_mode") == "form_post" { + resp.AddForm("code", code) + resp.AddForm("state", ar.GetState()) + resp.AddForm("scope", strings.Join(ar.GetGrantedScopes(), " ")) + } else { + resp.AddQuery("code", code) + resp.AddQuery("state", ar.GetState()) + resp.AddQuery("scope", strings.Join(ar.GetGrantedScopes(), " ")) + } ar.SetResponseTypeHandled("code") return nil } diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index 99e35ede8..c6952096c 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -94,12 +94,20 @@ func (c *AuthorizeImplicitGrantTypeHandler) IssueImplicitAccessToken(ctx context if err := c.AccessTokenStorage.CreateAccessTokenSession(ctx, signature, ar.Sanitize([]string{})); err != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } + if ar.GetRequestForm().Get("response_mode") == "form_post" { + resp.AddForm("access_token", token) + resp.AddForm("expires_in", strconv.FormatInt(int64(getExpiresIn(ar, fosite.AccessToken, c.AccessTokenLifespan, time.Now().UTC())/time.Second), 10)) + resp.AddForm("token_type", "bearer") + resp.AddForm("state", ar.GetState()) + resp.AddForm("scope", strings.Join(ar.GetGrantedScopes(), " ")) + } else { + resp.AddFragment("access_token", token) + resp.AddFragment("expires_in", strconv.FormatInt(int64(getExpiresIn(ar, fosite.AccessToken, c.AccessTokenLifespan, time.Now().UTC())/time.Second), 10)) + resp.AddFragment("token_type", "bearer") + resp.AddFragment("state", ar.GetState()) + resp.AddFragment("scope", strings.Join(ar.GetGrantedScopes(), " ")) + } - resp.AddFragment("access_token", token) - resp.AddFragment("expires_in", strconv.FormatInt(int64(getExpiresIn(ar, fosite.AccessToken, c.AccessTokenLifespan, time.Now().UTC())/time.Second), 10)) - resp.AddFragment("token_type", "bearer") - resp.AddFragment("state", ar.GetState()) - resp.AddFragment("scope", strings.Join(ar.GetGrantedScopes(), " ")) ar.SetResponseTypeHandled("token") return nil diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index b15636657..e42b41e5e 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -109,11 +109,14 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. if err := c.AuthorizeExplicitGrantHandler.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, ar.Sanitize(c.AuthorizeExplicitGrantHandler.GetSanitationWhiteList())); err != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } - - resp.AddFragment("code", code) + if ar.GetRequestForm().Get("response_mode") == "form_post" { + resp.AddForm("code", code) + } else { + resp.AddFragment("code", code) + } ar.SetResponseTypeHandled("code") - hash, err := c.Enigma.Hash(ctx, []byte(resp.GetFragment().Get("code"))) + hash, err := c.Enigma.Hash(ctx, []byte(resp.GetCode())) if err != nil { return err } @@ -133,8 +136,13 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. return errors.WithStack(err) } ar.SetResponseTypeHandled("token") - - hash, err := c.Enigma.Hash(ctx, []byte(resp.GetFragment().Get("access_token"))) + var accessToken string + if ar.GetRequestForm().Get("response_mode") == "form_post" { + accessToken = resp.GetForm().Get("access_token") + } else { + accessToken = resp.GetFragment().Get("access_token") + } + hash, err := c.Enigma.Hash(ctx, []byte(accessToken)) if err != nil { return err } @@ -142,7 +150,11 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. } if resp.GetFragment().Get("state") == "" { - resp.AddFragment("state", ar.GetState()) + if ar.GetRequestForm().Get("response_mode") == "form_post" { + resp.AddForm("state", ar.GetState()) + } else { + resp.AddFragment("state", ar.GetState()) + } } if !ar.GetGrantedScopes().Has("openid") || !ar.GetResponseTypes().Has("id_token") { diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index 3f267b349..01e94b92d 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -91,14 +91,24 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex } ar.SetResponseTypeHandled("token") - hash, err := c.RS256JWTStrategy.Hash(ctx, []byte(resp.GetFragment().Get("access_token"))) + var accessToken string + if ar.GetRequestForm().Get("response_mode") == "form_post" { + accessToken = resp.GetForm().Get("access_token") + } else { + accessToken = resp.GetFragment().Get("access_token") + } + hash, err := c.RS256JWTStrategy.Hash(ctx, []byte(accessToken)) if err != nil { return err } claims.AccessTokenHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.RS256JWTStrategy.GetSigningMethodLength()/2])) } else { - resp.AddFragment("state", ar.GetState()) + if ar.GetRequestForm().Get("response_mode") == "form_post" { + resp.AddForm("state", ar.GetState()) + } else { + resp.AddFragment("state", ar.GetState()) + } } if err := c.IssueImplicitIDToken(ctx, ar, resp); err != nil { diff --git a/handler/openid/helper.go b/handler/openid/helper.go index 3fabfbf2e..8263774e5 100644 --- a/handler/openid/helper.go +++ b/handler/openid/helper.go @@ -64,8 +64,11 @@ func (i *IDTokenHandleHelper) IssueImplicitIDToken(ctx context.Context, ar fosit if err != nil { return err } - - resp.AddFragment("id_token", token) + if ar.GetRequestForm().Get("response_mode") == "form_post" { + resp.AddForm("id_token", token) + } else { + resp.AddFragment("id_token", token) + } return nil } diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go new file mode 100644 index 000000000..e151b9221 --- /dev/null +++ b/integration/authorize_form_post_test.go @@ -0,0 +1,189 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "fmt" + "net/http" + "strings" + "testing" + + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/internal" + "github.com/ory/fosite/token/jwt" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + goauth "golang.org/x/oauth2" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/handler/oauth2" +) + +func TestAuthorizeFormPostImplicitFlow(t *testing.T) { + for _, strategy := range []oauth2.AccessTokenStrategy{ + hmacStrategy, + } { + runTestAuthorizeFormPostImplicitGrant(t, strategy) + } +} + +func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { + session := &defaultSession{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "peter", + }, + Headers: &jwt.Headers{}, + }, + } + f := compose.ComposeAllEnabled(new(compose.Config), fositeStore, []byte("some-secret-thats-random-some-secret-thats-random-"), internal.MustRSAKey()) + ts := mockServer(t, f, session) + defer ts.Close() + + oauthClient := newOAuth2Client(ts) + fositeStore.Clients["my-client"].(*fosite.DefaultClient).RedirectURIs[0] = ts.URL + "/callback" + + var state string + for k, c := range []struct { + description string + setup func() + check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) + responseType string + }{ + //{ + // description: "should fail because of audience", + // responseType: []goauth.AuthCodeOption{goauth.SetAuthURLParam("audience", "https://www.ory.sh/not-api")}, + // setup: func() { + // state = "12345678901234567890" + // }, + // authStatusCode: http.StatusNotAcceptable, + //}, + //{ + // description: "should fail because of scope", + // responseType: []goauth.AuthCodeOption{}, + // setup: func() { + // oauthClient.Scopes = []string{"not-exist"} + // state = "12345678901234567890" + // }, + // authStatusCode: http.StatusNotAcceptable, + //}, + { + description: "implicit grant test with form_post", + responseType: "token", + setup: func() { + state = "12345678901234567890" + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, token.TokenType) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.Expiry) + }, + }, + { + description: "explicit grant test with form_post", + responseType: "code", + setup: func() { + state = "12345678901234567890" + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + }, + }, + { + description: "oidc grant test with form_post", + responseType: "token%20code", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + assert.NotEmpty(t, token.TokenType) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.Expiry) + }, + }, + { + description: "hybrid grant test with form_post", + responseType: "token%20id_token%20code", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + assert.NotEmpty(t, iDToken) + assert.NotEmpty(t, token.TokenType) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.Expiry) + }, + }, + { + description: "hybrid grant test with form_post", + responseType: "id_token%20code", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + assert.NotEmpty(t, iDToken) + }, + }, + { + description: "error message test for form_post response", + responseType: "foo", + setup: func() { + state = "12345678901234567890" + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, err.Name) + assert.NotEmpty(t, err.Description) + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) { + c.setup() + authURL := strings.Replace(oauthClient.AuthCodeURL(state, goauth.SetAuthURLParam("response_mode", "form_post"), goauth.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1) + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return errors.New("Dont follow redirects") + }, + } + resp, err := client.Get(authURL) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + code, state, token, iDToken, errResp, err := fosite.ParseFormPostResponse(fositeStore.Clients["my-client"].GetRedirectURIs()[0], resp.Body) + require.NoError(t, err) + c.check(t, state, code, iDToken, token, errResp) + }) + } +} diff --git a/internal/authorize_response.go b/internal/authorize_response.go index 6cde26f24..92e4c2ec0 100644 --- a/internal/authorize_response.go +++ b/internal/authorize_response.go @@ -35,6 +35,18 @@ func (m *MockAuthorizeResponder) EXPECT() *MockAuthorizeResponderMockRecorder { return m.recorder } +// AddForm mocks base method +func (m *MockAuthorizeResponder) AddForm(arg0, arg1 string) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "AddForm", arg0, arg1) +} + +// AddForm indicates an expected call of AddForm +func (mr *MockAuthorizeResponderMockRecorder) AddForm(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddForm", reflect.TypeOf((*MockAuthorizeResponder)(nil).AddForm), arg0, arg1) +} + // AddFragment mocks base method func (m *MockAuthorizeResponder) AddFragment(arg0, arg1 string) { m.ctrl.T.Helper() @@ -85,6 +97,20 @@ func (mr *MockAuthorizeResponderMockRecorder) GetCode() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCode", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetCode)) } +// GetForm mocks base method +func (m *MockAuthorizeResponder) GetForm() url.Values { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetForm") + ret0, _ := ret[0].(url.Values) + return ret0 +} + +// GetForm indicates an expected call of GetForm +func (mr *MockAuthorizeResponderMockRecorder) GetForm() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForm", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetForm)) +} + // GetFragment mocks base method func (m *MockAuthorizeResponder) GetFragment() url.Values { m.ctrl.T.Helper() diff --git a/oauth2.go b/oauth2.go index 01e42009e..b0aabcee5 100644 --- a/oauth2.go +++ b/oauth2.go @@ -321,4 +321,10 @@ type AuthorizeResponder interface { // AddHeader adds a key value pair to the response's url fragment AddFragment(key, value string) + + // GetForm returns form with parameters + GetForm() (form url.Values) + + // AddForm adds a key value pair to the form post request + AddForm(key, value string) } From 54a9d50658d5cb11e7a5288e25cbaac03476031d Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Tue, 13 Oct 2020 22:56:27 -0700 Subject: [PATCH 02/22] refactor: introducing responseMode enum, setting responseMode during the initial handler and avoiding duplicate code in setting query,fragment and form parameter in authorize response --- authorize_helper.go | 28 +++---- authorize_request.go | 27 ++++++- authorize_request_handler.go | 21 +++++ authorize_response.go | 44 +++-------- authorize_response_test.go | 17 ++--- authorize_response_writer.go | 6 +- authorize_write.go | 40 +++++----- authorize_write_test.go | 75 +++++++++++------- handler/oauth2/flow_authorize_code_auth.go | 15 ++-- .../oauth2/flow_authorize_code_auth_test.go | 6 +- handler/oauth2/flow_authorize_implicit.go | 19 ++--- .../oauth2/flow_authorize_implicit_test.go | 10 +-- handler/openid/flow_hybrid.go | 28 +++---- handler/openid/flow_hybrid_test.go | 6 +- handler/openid/flow_implicit.go | 16 ++-- handler/openid/flow_implicit_test.go | 24 +++--- handler/openid/helper.go | 6 +- handler/openid/helper_test.go | 2 +- handler/pkce/handler_test.go | 2 +- internal/authorize_request.go | 27 ++++++- internal/authorize_response.go | 76 +++---------------- oauth2.go | 26 +++---- 22 files changed, 244 insertions(+), 277 deletions(-) diff --git a/authorize_helper.go b/authorize_helper.go index b21cfb040..7272495a0 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -37,6 +37,19 @@ import ( "github.com/pkg/errors" ) +var formPostTemplate = template.Must(template.New("form_post").Parse(` + + Submit This Form + + +
+ {{ range $key,$value := .Parameters }} + + {{ end }} +
+ +`)) + // MatchRedirectURIWithClientRedirectURIs if the given uri is a registered redirect uri. Does not perform // uri validation. // @@ -184,20 +197,7 @@ func IsLocalhost(redirectURI *url.URL) bool { } func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, rw io.Writer) { - t := template.Must(template.New("form_post").Parse(` - - Submit This Form - - -
- {{ range $key,$value := .Parameters }} - - {{ end }} -
- -`)) - - _ = t.Execute(rw, struct { + _ = formPostTemplate.Execute(rw, struct { RedirURL string Parameters url.Values }{ diff --git a/authorize_request.go b/authorize_request.go index dade3a3a4..74f2af67c 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -25,12 +25,22 @@ import ( "net/url" ) +type ResponseModeType string + +const ( + ResponseModeNone = ResponseModeType("") + ResponseModePost = ResponseModeType("form_post") + ResponseModeQuery = ResponseModeType("query") + ResponseModeFragment = ResponseModeType("fragment") +) + // AuthorizeRequest is an implementation of AuthorizeRequester type AuthorizeRequest struct { - ResponseTypes Arguments `json:"responseTypes" gorethink:"responseTypes"` - RedirectURI *url.URL `json:"redirectUri" gorethink:"redirectUri"` - State string `json:"state" gorethink:"state"` - HandledResponseTypes Arguments `json:"handledResponseTypes" gorethink:"handledResponseTypes"` + ResponseTypes Arguments `json:"responseTypes" gorethink:"responseTypes"` + RedirectURI *url.URL `json:"redirectUri" gorethink:"redirectUri"` + State string `json:"state" gorethink:"state"` + HandledResponseTypes Arguments `json:"handledResponseTypes" gorethink:"handledResponseTypes"` + ResponseMode ResponseModeType `json:"ResponseMode" gorethink:"ResponseMode"` Request } @@ -41,6 +51,7 @@ func NewAuthorizeRequest() *AuthorizeRequest { RedirectURI: &url.URL{}, HandledResponseTypes: Arguments{}, Request: *NewRequest(), + ResponseMode: ResponseModeQuery, } } @@ -86,3 +97,11 @@ func (d *AuthorizeRequest) DidHandleAllResponseTypes() bool { return len(d.ResponseTypes) > 0 } + +func (d *AuthorizeRequest) GetResponseMode() ResponseModeType { + return d.ResponseMode +} + +func (d *AuthorizeRequest) SetResponseMode(responseMode ResponseModeType) { + d.ResponseMode = responseMode +} diff --git a/authorize_request_handler.go b/authorize_request_handler.go index c00db2aea..e40bd37d9 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -211,6 +211,23 @@ func (f *Fosite) validateResponseTypes(r *http.Request, request *AuthorizeReques request.ResponseTypes = responseTypes return nil } +func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest) error { + responseMode := r.Form.Get("response_mode") + + switch responseMode { + case string(ResponseModeNone): + request.ResponseMode = ResponseModeNone + case string(ResponseModeFragment): + request.ResponseMode = ResponseModeFragment + case string(ResponseModeQuery): + request.ResponseMode = ResponseModeQuery + case string(ResponseModePost): + request.ResponseMode = ResponseModePost + default: + return errors.WithStack(ErrUnsupportedResponseType.WithHintf("Request with unsupported response_mode \"%s\".", responseMode)) + } + return nil +} func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (AuthorizeRequester, error) { request := &AuthorizeRequest{ @@ -259,6 +276,10 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth return request, err } + if err := f.validateResponseMode(r, request); err != nil { + return request, err + } + // rfc6819 4.4.1.8. Threat: CSRF Attack against redirect-uri // The "state" parameter should be used to link the authorization // request with the redirect URI used to deliver the access token (Section 5.3.5). diff --git a/authorize_response.go b/authorize_response.go index 8e053d5cb..60edb83e7 100644 --- a/authorize_response.go +++ b/authorize_response.go @@ -28,19 +28,15 @@ import ( // AuthorizeResponse is an implementation of AuthorizeResponder type AuthorizeResponse struct { - Header http.Header - Query url.Values - Fragment url.Values - Form url.Values - code string + Header http.Header + Parameters url.Values + code string } func NewAuthorizeResponse() *AuthorizeResponse { return &AuthorizeResponse{ - Header: http.Header{}, - Query: url.Values{}, - Fragment: url.Values{}, - Form: url.Values{}, + Header: http.Header{}, + Parameters: url.Values{}, } } @@ -56,35 +52,13 @@ func (a *AuthorizeResponse) AddHeader(key, value string) { a.Header.Add(key, value) } -func (a *AuthorizeResponse) GetQuery() url.Values { - return a.Query +func (a *AuthorizeResponse) GetParameters() url.Values { + return a.Parameters } -func (a *AuthorizeResponse) GetFragment() url.Values { - return a.Fragment -} - -func (a *AuthorizeResponse) GetForm() url.Values { - return a.Form -} - -func (a *AuthorizeResponse) AddQuery(key, value string) { - if key == "code" { - a.code = value - } - a.Query.Add(key, value) -} - -func (a *AuthorizeResponse) AddFragment(key, value string) { - if key == "code" { - a.code = value - } - a.Fragment.Add(key, value) -} - -func (a *AuthorizeResponse) AddForm(key, value string) { +func (a *AuthorizeResponse) AddParameter(key, value string) { if key == "code" { a.code = value } - a.Form.Add(key, value) + a.Parameters.Add(key, value) } diff --git a/authorize_response_test.go b/authorize_response_test.go index ebb371dea..c3a684e7f 100644 --- a/authorize_response_test.go +++ b/authorize_response_test.go @@ -29,20 +29,15 @@ import ( func TestAuthorizeResponse(t *testing.T) { ar := NewAuthorizeResponse() - ar.AddFragment("foo", "bar") - ar.AddQuery("foo", "baz") + ar.AddParameter("foo", "bar") + ar.AddParameter("bar", "bar") + ar.AddHeader("foo", "foo") - ar.AddForm("bar", "bar") - ar.AddFragment("code", "bar") + ar.AddParameter("code", "bar") assert.Equal(t, "bar", ar.GetCode()) - ar.AddQuery("code", "baz") - assert.Equal(t, "baz", ar.GetCode()) - ar.AddForm("code", "baz") - assert.Equal(t, "baz", ar.GetCode()) - assert.Equal(t, "bar", ar.GetFragment().Get("foo")) - assert.Equal(t, "baz", ar.GetQuery().Get("foo")) + assert.Equal(t, "bar", ar.GetParameters().Get("foo")) assert.Equal(t, "foo", ar.GetHeader().Get("foo")) - assert.Equal(t, "bar", ar.GetForm().Get("bar")) + assert.Equal(t, "bar", ar.GetParameters().Get("bar")) } diff --git a/authorize_response_writer.go b/authorize_response_writer.go index b5b032708..5d9163cd9 100644 --- a/authorize_response_writer.go +++ b/authorize_response_writer.go @@ -31,10 +31,8 @@ import ( func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester, session Session) (AuthorizeResponder, error) { var resp = &AuthorizeResponse{ - Header: http.Header{}, - Query: url.Values{}, - Fragment: url.Values{}, - Form: url.Values{}, + Header: http.Header{}, + Parameters: url.Values{}, } ar.SetSession(session) diff --git a/authorize_write.go b/authorize_write.go index 8fa28fb4a..54ff1fc2b 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -44,42 +44,40 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ wh.Set("Pragma", "no-cache") redir := ar.GetRedirectURI() - form := resp.GetForm() - if len(form) > 0 { + switch ar.GetResponseMode() { + case ResponseModePost: //form_post rw.Header().Add("Content-Type", "text/html;charset=UTF-8") - WriteAuthorizeFormPostResponse(redir.String(), form, rw) - } else { - + WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), rw) + case ResponseModeQuery, ResponseModeNone: // Explicit grants q := redir.Query() - rq := resp.GetQuery() + rq := resp.GetParameters() for k := range rq { q.Set(k, rq.Get(k)) } redir.RawQuery = q.Encode() - + sendRedirect(redir.String(), rw) + case ResponseModeFragment: // Implicit grants // The endpoint URI MUST NOT include a fragment component. redir.Fragment = "" u := redir.String() - - fr := resp.GetFragment() - if len(fr) > 0 { - u = u + "#" + fr.Encode() - } - + fr := resp.GetParameters() + u = u + "#" + fr.Encode() u = plusMatch.ReplaceAllString(u, "%20") - - // https://tools.ietf.org/html/rfc6749#section-4.1.1 - // When a decision is established, the authorization server directs the - // user-agent to the provided client redirection URI using an HTTP - // redirection response, or by other means available to it via the - // user-agent. - wh.Set("Location", u) - rw.WriteHeader(http.StatusFound) + sendRedirect(u, rw) } +} +// https://tools.ietf.org/html/rfc6749#section-4.1.1 +// When a decision is established, the authorization server directs the +// user-agent to the provided client redirection URI using an HTTP +// redirection response, or by other means available to it via the +// user-agent. +func sendRedirect(url string, rw http.ResponseWriter) { + rw.Header().Set("Location", url) + rw.WriteHeader(http.StatusFound) } diff --git a/authorize_write_test.go b/authorize_write_test.go index 4630d775a..9c5602ed7 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -50,12 +50,11 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - resp.EXPECT().GetFragment().Return(url.Values{}) + ar.EXPECT().GetResponseMode().Return(ResponseModeNone) + resp.EXPECT().GetParameters().Return(url.Values{}) resp.EXPECT().GetHeader().Return(http.Header{}) - resp.EXPECT().GetQuery().Return(url.Values{}) - resp.EXPECT().GetForm().Return(url.Values{}) - rw.EXPECT().Header().Return(header) + rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusFound) }, expect: func() { @@ -70,12 +69,11 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}}) + ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) + resp.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}}) resp.EXPECT().GetHeader().Return(http.Header{}) - resp.EXPECT().GetQuery().Return(url.Values{}) - resp.EXPECT().GetForm().Return(url.Values{}) - rw.EXPECT().Header().Return(header) + rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusFound) }, expect: func() { @@ -90,38 +88,37 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}}) + ar.EXPECT().GetResponseMode().Return(ResponseModeQuery) + resp.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}}) resp.EXPECT().GetHeader().Return(http.Header{}) - resp.EXPECT().GetQuery().Return(url.Values{"bar": {"baz"}}) - resp.EXPECT().GetForm().Return(url.Values{}) - rw.EXPECT().Header().Return(header) + rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusFound) }, expect: func() { - assert.Equal(t, http.Header{ - "Location": []string{"https://foobar.com/?bar=baz&foo=bar#bar=baz"}, - "Cache-Control": []string{"no-store"}, - "Pragma": []string{"no-cache"}, - }, header) + expectedUrl, _ := url.Parse("https://foobar.com/?foo=bar&bar=baz") + actualUrl, err := url.Parse(header.Get("Location")) + assert.Nil(t, err) + assert.Equal(t, expectedUrl.Query(), actualUrl.Query()) + assert.Equal(t, "no-cache", header.Get("Pragma")) + assert.Equal(t, "no-store", header.Get("Cache-Control")) }, }, { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}, "scope": {"a b"}}) + ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) + resp.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az"}, "scope": {"a b"}}) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) - resp.EXPECT().GetQuery().Return(url.Values{"bar": {"b+az"}, "scope": {"a b"}}) - resp.EXPECT().GetForm().Return(url.Values{}) - rw.EXPECT().Header().Return(header) + rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusFound) }, expect: func() { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, - "Location": {"https://foobar.com/?bar=b%2Baz&foo=bar&scope=a%20b#bar=baz&scope=a%20b"}, + "Location": {"https://foobar.com/?foo=bar#bar=b%2Baz&scope=a%20b"}, "Cache-Control": []string{"no-store"}, "Pragma": []string{"no-cache"}, }, header) @@ -131,18 +128,39 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - resp.EXPECT().GetFragment().Return(url.Values{"bar": {"baz"}, "scope": {"api:*"}}) + ar.EXPECT().GetResponseMode().Return(ResponseModeQuery) + resp.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az"}, "scope": {"a b"}}) + resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) + + rw.EXPECT().Header().Return(header).Times(2) + rw.EXPECT().WriteHeader(http.StatusFound) + }, + expect: func() { + expectedUrl, err := url.Parse("https://foobar.com/?foo=bar&bar=b%2Baz&scope=a+b") + assert.Nil(t, err) + actualUrl, err := url.Parse(header.Get("Location")) + assert.Nil(t, err) + assert.Equal(t, expectedUrl.Query(), actualUrl.Query()) + assert.Equal(t, "no-cache", header.Get("Pragma")) + assert.Equal(t, "no-store", header.Get("Cache-Control")) + assert.Equal(t, "baz", header.Get("X-Bar")) + }, + }, + { + setup: func() { + redir, _ := url.Parse("https://foobar.com/?foo=bar") + ar.EXPECT().GetRedirectURI().Return(redir) + ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) + resp.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}, "scope": {"api:*"}}) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) - resp.EXPECT().GetQuery().Return(url.Values{"bar": {"b+az"}, "scope": {"api:*"}}) - resp.EXPECT().GetForm().Return(url.Values{}) - rw.EXPECT().Header().Return(header) + rw.EXPECT().Header().Return(header).Times(2) rw.EXPECT().WriteHeader(http.StatusFound) }, expect: func() { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, - "Location": {"https://foobar.com/?bar=b%2Baz&foo=bar&scope=api%3A%2A#bar=baz&scope=api%3A%2A"}, + "Location": {"https://foobar.com/?foo=bar#bar=baz&scope=api%3A%2A"}, "Cache-Control": []string{"no-store"}, "Pragma": []string{"no-cache"}, }, header) @@ -152,8 +170,9 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) + ar.EXPECT().GetResponseMode().Return(ResponseModePost) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) - resp.EXPECT().GetForm().Return(url.Values{"code": {"poz65kqoneu"}, "state": {"qm6dnsrn"}}) + resp.EXPECT().GetParameters().Return(url.Values{"code": {"poz65kqoneu"}, "state": {"qm6dnsrn"}}) rw.EXPECT().Header().Return(header).AnyTimes() rw.EXPECT().Write(gomock.Any()).AnyTimes() diff --git a/handler/oauth2/flow_authorize_code_auth.go b/handler/oauth2/flow_authorize_code_auth.go index dfdf7edc2..4f0113595 100644 --- a/handler/oauth2/flow_authorize_code_auth.go +++ b/handler/oauth2/flow_authorize_code_auth.go @@ -110,15 +110,14 @@ func (c *AuthorizeExplicitGrantHandler) IssueAuthorizeCode(ctx context.Context, if err := c.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, ar.Sanitize(c.GetSanitationWhiteList())); err != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } - if ar.GetRequestForm().Get("response_mode") == "form_post" { - resp.AddForm("code", code) - resp.AddForm("state", ar.GetState()) - resp.AddForm("scope", strings.Join(ar.GetGrantedScopes(), " ")) - } else { - resp.AddQuery("code", code) - resp.AddQuery("state", ar.GetState()) - resp.AddQuery("scope", strings.Join(ar.GetGrantedScopes(), " ")) + + resp.AddParameter("code", code) + resp.AddParameter("state", ar.GetState()) + resp.AddParameter("scope", strings.Join(ar.GetGrantedScopes(), " ")) + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeQuery) } + ar.SetResponseTypeHandled("code") return nil } diff --git a/handler/oauth2/flow_authorize_code_auth_test.go b/handler/oauth2/flow_authorize_code_auth_test.go index 198208137..0f20224a9 100644 --- a/handler/oauth2/flow_authorize_code_auth_test.go +++ b/handler/oauth2/flow_authorize_code_auth_test.go @@ -122,11 +122,11 @@ func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) { }, description: "should pass", expect: func(t *testing.T, areq *fosite.AuthorizeRequest, aresp *fosite.AuthorizeResponse) { - code := aresp.GetQuery().Get("code") + code := aresp.GetParameters().Get("code") assert.NotEmpty(t, code) - assert.Equal(t, strings.Join(areq.GrantedScope, " "), aresp.GetQuery().Get("scope")) - assert.Equal(t, areq.State, aresp.GetQuery().Get("state")) + assert.Equal(t, strings.Join(areq.GrantedScope, " "), aresp.GetParameters().Get("scope")) + assert.Equal(t, areq.State, aresp.GetParameters().Get("state")) }, }, } { diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index c6952096c..19c440e11 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -94,18 +94,13 @@ func (c *AuthorizeImplicitGrantTypeHandler) IssueImplicitAccessToken(ctx context if err := c.AccessTokenStorage.CreateAccessTokenSession(ctx, signature, ar.Sanitize([]string{})); err != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } - if ar.GetRequestForm().Get("response_mode") == "form_post" { - resp.AddForm("access_token", token) - resp.AddForm("expires_in", strconv.FormatInt(int64(getExpiresIn(ar, fosite.AccessToken, c.AccessTokenLifespan, time.Now().UTC())/time.Second), 10)) - resp.AddForm("token_type", "bearer") - resp.AddForm("state", ar.GetState()) - resp.AddForm("scope", strings.Join(ar.GetGrantedScopes(), " ")) - } else { - resp.AddFragment("access_token", token) - resp.AddFragment("expires_in", strconv.FormatInt(int64(getExpiresIn(ar, fosite.AccessToken, c.AccessTokenLifespan, time.Now().UTC())/time.Second), 10)) - resp.AddFragment("token_type", "bearer") - resp.AddFragment("state", ar.GetState()) - resp.AddFragment("scope", strings.Join(ar.GetGrantedScopes(), " ")) + resp.AddParameter("access_token", token) + resp.AddParameter("expires_in", strconv.FormatInt(int64(getExpiresIn(ar, fosite.AccessToken, c.AccessTokenLifespan, time.Now().UTC())/time.Second), 10)) + resp.AddParameter("token_type", "bearer") + resp.AddParameter("state", ar.GetState()) + resp.AddParameter("scope", strings.Join(ar.GetGrantedScopes(), " ")) + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeFragment) } ar.SetResponseTypeHandled("token") diff --git a/handler/oauth2/flow_authorize_implicit_test.go b/handler/oauth2/flow_authorize_implicit_test.go index d34c59202..3110565a0 100644 --- a/handler/oauth2/flow_authorize_implicit_test.go +++ b/handler/oauth2/flow_authorize_implicit_test.go @@ -118,11 +118,11 @@ func TestAuthorizeImplicit_EndpointHandler(t *testing.T) { store.EXPECT().CreateAccessTokenSession(nil, "ats", gomock.Eq(areq.Sanitize([]string{}))).AnyTimes().Return(nil) - aresp.EXPECT().AddFragment("access_token", "access.ats") - aresp.EXPECT().AddFragment("expires_in", gomock.Any()) - aresp.EXPECT().AddFragment("token_type", "bearer") - aresp.EXPECT().AddFragment("state", "state") - aresp.EXPECT().AddFragment("scope", "scope") + aresp.EXPECT().AddParameter("access_token", "access.ats") + aresp.EXPECT().AddParameter("expires_in", gomock.Any()) + aresp.EXPECT().AddParameter("token_type", "bearer") + aresp.EXPECT().AddParameter("state", "state") + aresp.EXPECT().AddParameter("scope", "scope") }, expectErr: nil, }, diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index e42b41e5e..18512fdd6 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -109,14 +109,15 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. if err := c.AuthorizeExplicitGrantHandler.CoreStorage.CreateAuthorizeCodeSession(ctx, signature, ar.Sanitize(c.AuthorizeExplicitGrantHandler.GetSanitationWhiteList())); err != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } - if ar.GetRequestForm().Get("response_mode") == "form_post" { - resp.AddForm("code", code) - } else { - resp.AddFragment("code", code) + + resp.AddParameter("code", code) + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeFragment) } + ar.SetResponseTypeHandled("code") - hash, err := c.Enigma.Hash(ctx, []byte(resp.GetCode())) + hash, err := c.Enigma.Hash(ctx, []byte(resp.GetParameters().Get("code"))) if err != nil { return err } @@ -136,25 +137,16 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. return errors.WithStack(err) } ar.SetResponseTypeHandled("token") - var accessToken string - if ar.GetRequestForm().Get("response_mode") == "form_post" { - accessToken = resp.GetForm().Get("access_token") - } else { - accessToken = resp.GetFragment().Get("access_token") - } - hash, err := c.Enigma.Hash(ctx, []byte(accessToken)) + + hash, err := c.Enigma.Hash(ctx, []byte(resp.GetParameters().Get("access_token"))) if err != nil { return err } claims.AccessTokenHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.Enigma.GetSigningMethodLength()/2])) } - if resp.GetFragment().Get("state") == "" { - if ar.GetRequestForm().Get("response_mode") == "form_post" { - resp.AddForm("state", ar.GetState()) - } else { - resp.AddFragment("state", ar.GetState()) - } + if resp.GetParameters().Get("state") == "" { + resp.AddParameter("state", ar.GetState()) } if !ar.GetGrantedScopes().Has("openid") || !ar.GetResponseTypes().Has("id_token") { diff --git a/handler/openid/flow_hybrid_test.go b/handler/openid/flow_hybrid_test.go index 7d0ebeea8..734716bfc 100644 --- a/handler/openid/flow_hybrid_test.go +++ b/handler/openid/flow_hybrid_test.go @@ -249,9 +249,9 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { return makeOpenIDConnectHybridHandler(fosite.MinParameterEntropy) }, check: func() { - assert.NotEmpty(t, aresp.GetFragment().Get("id_token")) - assert.NotEmpty(t, aresp.GetFragment().Get("code")) - assert.NotEmpty(t, aresp.GetFragment().Get("access_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("code")) + assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(fosite.AuthorizeCode)) }, }, diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index 01e94b92d..45cf096bd 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -91,23 +91,17 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex } ar.SetResponseTypeHandled("token") - var accessToken string - if ar.GetRequestForm().Get("response_mode") == "form_post" { - accessToken = resp.GetForm().Get("access_token") - } else { - accessToken = resp.GetFragment().Get("access_token") - } - hash, err := c.RS256JWTStrategy.Hash(ctx, []byte(accessToken)) + hash, err := c.RS256JWTStrategy.Hash(ctx, []byte(resp.GetParameters().Get("access_token"))) if err != nil { return err } claims.AccessTokenHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.RS256JWTStrategy.GetSigningMethodLength()/2])) } else { - if ar.GetRequestForm().Get("response_mode") == "form_post" { - resp.AddForm("state", ar.GetState()) - } else { - resp.AddFragment("state", ar.GetState()) + + resp.AddParameter("state", ar.GetState()) + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeFragment) } } diff --git a/handler/openid/flow_implicit_test.go b/handler/openid/flow_implicit_test.go index 890dca9b2..03a00b094 100644 --- a/handler/openid/flow_implicit_test.go +++ b/handler/openid/flow_implicit_test.go @@ -204,9 +204,9 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { return makeOpenIDConnectImplicitHandler(fosite.MinParameterEntropy) }, check: func() { - assert.NotEmpty(t, aresp.GetFragment().Get("id_token")) - assert.NotEmpty(t, aresp.GetFragment().Get("state")) - assert.Empty(t, aresp.GetFragment().Get("access_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("state")) + assert.Empty(t, aresp.GetParameters().Get("access_token")) }, }, { @@ -216,9 +216,9 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { return makeOpenIDConnectImplicitHandler(fosite.MinParameterEntropy) }, check: func() { - assert.NotEmpty(t, aresp.GetFragment().Get("id_token")) - assert.NotEmpty(t, aresp.GetFragment().Get("state")) - assert.NotEmpty(t, aresp.GetFragment().Get("access_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("state")) + assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) }, }, { @@ -229,9 +229,9 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { return makeOpenIDConnectImplicitHandler(fosite.MinParameterEntropy) }, check: func() { - assert.NotEmpty(t, aresp.GetFragment().Get("id_token")) - assert.NotEmpty(t, aresp.GetFragment().Get("state")) - assert.NotEmpty(t, aresp.GetFragment().Get("access_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("state")) + assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) }, }, { @@ -241,9 +241,9 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { return makeOpenIDConnectImplicitHandler(4) }, check: func() { - assert.NotEmpty(t, aresp.GetFragment().Get("id_token")) - assert.NotEmpty(t, aresp.GetFragment().Get("state")) - assert.NotEmpty(t, aresp.GetFragment().Get("access_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("state")) + assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) }, }, } { diff --git a/handler/openid/helper.go b/handler/openid/helper.go index 8263774e5..321198e71 100644 --- a/handler/openid/helper.go +++ b/handler/openid/helper.go @@ -64,11 +64,7 @@ func (i *IDTokenHandleHelper) IssueImplicitIDToken(ctx context.Context, ar fosit if err != nil { return err } - if ar.GetRequestForm().Get("response_mode") == "form_post" { - resp.AddForm("id_token", token) - } else { - resp.AddFragment("id_token", token) - } + resp.AddParameter("id_token", token) return nil } diff --git a/handler/openid/helper_test.go b/handler/openid/helper_test.go index ed2bdb159..035924c62 100644 --- a/handler/openid/helper_test.go +++ b/handler/openid/helper_test.go @@ -117,7 +117,7 @@ func TestIssueImplicitToken(t *testing.T) { Subject: "peter", }, Headers: &jwt.Headers{}}) - resp.EXPECT().AddFragment("id_token", gomock.Any()) + resp.EXPECT().AddParameter("id_token", gomock.Any()) h := &IDTokenHandleHelper{IDTokenStrategy: strat} err := h.IssueImplicitIDToken(nil, ar, resp) assert.NoError(t, err) diff --git a/handler/pkce/handler_test.go b/handler/pkce/handler_test.go index 27d86388a..2bc7c0656 100644 --- a/handler/pkce/handler_test.go +++ b/handler/pkce/handler_test.go @@ -62,7 +62,7 @@ func TestPKCEHandleAuthorizeEndpointRequest(t *testing.T) { c := &fosite.DefaultClient{} r.Client = c - w.AddQuery("code", "foo") + w.AddParameter("code", "foo") r.Form.Add("code_challenge", "challenge") r.Form.Add("code_challenge_method", "plain") diff --git a/internal/authorize_request.go b/internal/authorize_request.go index ee88913a6..7e620679e 100644 --- a/internal/authorize_request.go +++ b/internal/authorize_request.go @@ -10,7 +10,6 @@ import ( time "time" gomock "github.com/golang/mock/gomock" - fosite "github.com/ory/fosite" ) @@ -189,6 +188,20 @@ func (mr *MockAuthorizeRequesterMockRecorder) GetRequestedScopes() *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRequestedScopes", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetRequestedScopes)) } +// GetResponseMode mocks base method +func (m *MockAuthorizeRequester) GetResponseMode() fosite.ResponseModeType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetResponseMode") + ret0, _ := ret[0].(fosite.ResponseModeType) + return ret0 +} + +// GetResponseMode indicates an expected call of GetResponseMode +func (mr *MockAuthorizeRequesterMockRecorder) GetResponseMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetResponseMode)) +} + // GetResponseTypes mocks base method func (m *MockAuthorizeRequester) GetResponseTypes() fosite.Arguments { m.ctrl.T.Helper() @@ -331,6 +344,18 @@ func (mr *MockAuthorizeRequesterMockRecorder) SetRequestedScopes(arg0 interface{ return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetRequestedScopes", reflect.TypeOf((*MockAuthorizeRequester)(nil).SetRequestedScopes), arg0) } +// SetResponseMode mocks base method +func (m *MockAuthorizeRequester) SetResponseMode(arg0 fosite.ResponseModeType) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetResponseMode", arg0) +} + +// SetResponseMode indicates an expected call of SetResponseMode +func (mr *MockAuthorizeRequesterMockRecorder) SetResponseMode(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).SetResponseMode), arg0) +} + // SetResponseTypeHandled mocks base method func (m *MockAuthorizeRequester) SetResponseTypeHandled(arg0 string) { m.ctrl.T.Helper() diff --git a/internal/authorize_response.go b/internal/authorize_response.go index 92e4c2ec0..09bfe76ce 100644 --- a/internal/authorize_response.go +++ b/internal/authorize_response.go @@ -35,30 +35,6 @@ func (m *MockAuthorizeResponder) EXPECT() *MockAuthorizeResponderMockRecorder { return m.recorder } -// AddForm mocks base method -func (m *MockAuthorizeResponder) AddForm(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddForm", arg0, arg1) -} - -// AddForm indicates an expected call of AddForm -func (mr *MockAuthorizeResponderMockRecorder) AddForm(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddForm", reflect.TypeOf((*MockAuthorizeResponder)(nil).AddForm), arg0, arg1) -} - -// AddFragment mocks base method -func (m *MockAuthorizeResponder) AddFragment(arg0, arg1 string) { - m.ctrl.T.Helper() - m.ctrl.Call(m, "AddFragment", arg0, arg1) -} - -// AddFragment indicates an expected call of AddFragment -func (mr *MockAuthorizeResponderMockRecorder) AddFragment(arg0, arg1 interface{}) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFragment", reflect.TypeOf((*MockAuthorizeResponder)(nil).AddFragment), arg0, arg1) -} - // AddHeader mocks base method func (m *MockAuthorizeResponder) AddHeader(arg0, arg1 string) { m.ctrl.T.Helper() @@ -71,16 +47,16 @@ func (mr *MockAuthorizeResponderMockRecorder) AddHeader(arg0, arg1 interface{}) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddHeader", reflect.TypeOf((*MockAuthorizeResponder)(nil).AddHeader), arg0, arg1) } -// AddQuery mocks base method -func (m *MockAuthorizeResponder) AddQuery(arg0, arg1 string) { +// AddParameter mocks base method +func (m *MockAuthorizeResponder) AddParameter(arg0, arg1 string) { m.ctrl.T.Helper() - m.ctrl.Call(m, "AddQuery", arg0, arg1) + m.ctrl.Call(m, "AddParameter", arg0, arg1) } -// AddQuery indicates an expected call of AddQuery -func (mr *MockAuthorizeResponderMockRecorder) AddQuery(arg0, arg1 interface{}) *gomock.Call { +// AddParameter indicates an expected call of AddParameter +func (mr *MockAuthorizeResponderMockRecorder) AddParameter(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddQuery", reflect.TypeOf((*MockAuthorizeResponder)(nil).AddQuery), arg0, arg1) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddParameter", reflect.TypeOf((*MockAuthorizeResponder)(nil).AddParameter), arg0, arg1) } // GetCode mocks base method @@ -97,34 +73,6 @@ func (mr *MockAuthorizeResponderMockRecorder) GetCode() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetCode", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetCode)) } -// GetForm mocks base method -func (m *MockAuthorizeResponder) GetForm() url.Values { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetForm") - ret0, _ := ret[0].(url.Values) - return ret0 -} - -// GetForm indicates an expected call of GetForm -func (mr *MockAuthorizeResponderMockRecorder) GetForm() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetForm", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetForm)) -} - -// GetFragment mocks base method -func (m *MockAuthorizeResponder) GetFragment() url.Values { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetFragment") - ret0, _ := ret[0].(url.Values) - return ret0 -} - -// GetFragment indicates an expected call of GetFragment -func (mr *MockAuthorizeResponderMockRecorder) GetFragment() *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetFragment", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetFragment)) -} - // GetHeader mocks base method func (m *MockAuthorizeResponder) GetHeader() http.Header { m.ctrl.T.Helper() @@ -139,16 +87,16 @@ func (mr *MockAuthorizeResponderMockRecorder) GetHeader() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetHeader", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetHeader)) } -// GetQuery mocks base method -func (m *MockAuthorizeResponder) GetQuery() url.Values { +// GetParameters mocks base method +func (m *MockAuthorizeResponder) GetParameters() url.Values { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetQuery") + ret := m.ctrl.Call(m, "GetParameters") ret0, _ := ret[0].(url.Values) return ret0 } -// GetQuery indicates an expected call of GetQuery -func (mr *MockAuthorizeResponderMockRecorder) GetQuery() *gomock.Call { +// GetParameters indicates an expected call of GetParameters +func (mr *MockAuthorizeResponderMockRecorder) GetParameters() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetQuery", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetQuery)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetParameters", reflect.TypeOf((*MockAuthorizeResponder)(nil).GetParameters)) } diff --git a/oauth2.go b/oauth2.go index b0aabcee5..05eafb168 100644 --- a/oauth2.go +++ b/oauth2.go @@ -268,6 +268,12 @@ type AuthorizeRequester interface { // GetState returns the request's state. GetState() (state string) + //GetResponseMode returns response_mode of the authorization request + GetResponseMode() ResponseModeType + + //SetResponseMode sets response mode of the authorization request + SetResponseMode(responseMode ResponseModeType) + Requester } @@ -310,21 +316,9 @@ type AuthorizeResponder interface { // AddHeader adds an header key value pair to the response AddHeader(key, value string) - // GetQuery returns the response's query - GetQuery() (query url.Values) - - // AddQuery adds an url query key value pair to the response - AddQuery(key, value string) - - // GetHeader returns the response's url fragments - GetFragment() (fragment url.Values) - - // AddHeader adds a key value pair to the response's url fragment - AddFragment(key, value string) - - // GetForm returns form with parameters - GetForm() (form url.Values) + // GetParameters returns the response's parameters + GetParameters() (query url.Values) - // AddForm adds a key value pair to the form post request - AddForm(key, value string) + // AddParameter adds key value pair to the response + AddParameter(key, value string) } From ec3e000dc0efd5b41a3c6868983873cbcf88caf5 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Tue, 20 Oct 2020 12:44:30 -0700 Subject: [PATCH 03/22] refactor: Reusing response mode enum in error handling --- authorize_error.go | 2 +- authorize_error_test.go | 24 ++++++++++++------------ authorize_request_handler.go | 8 ++++---- integration/authorize_form_post_test.go | 17 ----------------- 4 files changed, 17 insertions(+), 34 deletions(-) diff --git a/authorize_error.go b/authorize_error.go index c025e08d1..872a57ff3 100644 --- a/authorize_error.go +++ b/authorize_error.go @@ -66,7 +66,7 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest query.Add("state", ar.GetState()) var redirectURIString string - if ar.GetRequestForm().Get("response_mode") == "form_post" { + if ar.GetResponseMode() == ResponseModePost { rw.Header().Add("Content-Type", "text/html;charset=UTF-8") WriteAuthorizeFormPostResponse(redirectURI.String(), query, rw) return diff --git a/authorize_error_test.go b/authorize_error_test.go index 34dc125c2..1473ec883 100644 --- a/authorize_error_test.go +++ b/authorize_error_test.go @@ -88,7 +88,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -107,7 +107,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -126,7 +126,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -145,7 +145,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"foobar"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -164,7 +164,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -183,7 +183,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -202,7 +202,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -221,7 +221,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -241,7 +241,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -261,7 +261,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"id_token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -281,7 +281,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetRequestForm().Return(url.Values{}) + req.EXPECT().GetResponseMode().Return(ResponseModeNone) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -301,7 +301,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetRequestForm().Return(url.Values{"response_mode": {"form_post"}}) + req.EXPECT().GetResponseMode().Return(ResponseModePost) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().Write(gomock.Any()).AnyTimes() }, diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 0835eaf96..91430bc5b 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -245,6 +245,10 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth state := request.Form.Get("state") request.State = state + if err := f.validateResponseMode(r, request); err != nil { + return request, err + } + client, err := f.Store.GetClient(ctx, request.GetRequestForm().Get("client_id")) if err != nil { return request, errors.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithCause(err).WithDebug(err.Error())) @@ -275,10 +279,6 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth return request, err } - if err := f.validateResponseMode(r, request); err != nil { - return request, err - } - // rfc6819 4.4.1.8. Threat: CSRF Attack against redirect-uri // The "state" parameter should be used to link the authorization // request with the redirect URI used to deliver the access token (Section 5.3.5). diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index e151b9221..5efeec2a8 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -72,23 +72,6 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) responseType string }{ - //{ - // description: "should fail because of audience", - // responseType: []goauth.AuthCodeOption{goauth.SetAuthURLParam("audience", "https://www.ory.sh/not-api")}, - // setup: func() { - // state = "12345678901234567890" - // }, - // authStatusCode: http.StatusNotAcceptable, - //}, - //{ - // description: "should fail because of scope", - // responseType: []goauth.AuthCodeOption{}, - // setup: func() { - // oauthClient.Scopes = []string{"not-exist"} - // state = "12345678901234567890" - // }, - // authStatusCode: http.StatusNotAcceptable, - //}, { description: "implicit grant test with form_post", responseType: "token", From 0ebd04d0f7187727e21ce6eb16261c29bf4629c4 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Tue, 20 Oct 2020 21:08:08 -0700 Subject: [PATCH 04/22] fix: bulding authrize error message according to responseMode --- authorize_error.go | 4 +- authorize_error_test.go | 24 ++++---- authorize_request.go | 2 +- handler/oauth2/flow_authorize_code_auth.go | 6 +- .../oauth2/flow_authorize_code_auth_test.go | 29 +++++++++ handler/oauth2/flow_authorize_implicit.go | 7 +-- .../oauth2/flow_authorize_implicit_test.go | 59 +++++++++++++++---- handler/openid/flow_hybrid.go | 8 +-- handler/openid/flow_hybrid_test.go | 13 ++++ handler/openid/flow_implicit.go | 7 ++- handler/openid/flow_implicit_test.go | 15 +++++ 11 files changed, 133 insertions(+), 41 deletions(-) diff --git a/authorize_error.go b/authorize_error.go index 872a57ff3..cc7a328fe 100644 --- a/authorize_error.go +++ b/authorize_error.go @@ -25,8 +25,6 @@ import ( "encoding/json" "fmt" "net/http" - - "github.com/pkg/errors" ) func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequester, err error) { @@ -70,7 +68,7 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest rw.Header().Add("Content-Type", "text/html;charset=UTF-8") WriteAuthorizeFormPostResponse(redirectURI.String(), query, rw) return - } else if !(len(ar.GetResponseTypes()) == 0 || ar.GetResponseTypes().ExactOne("code")) && !errors.Is(err, ErrUnsupportedResponseType) { + } else if ar.GetResponseMode() == ResponseModeFragment { redirectURIString = redirectURI.String() + "#" + query.Encode() } else { for key, values := range redirectURI.Query() { diff --git a/authorize_error_test.go b/authorize_error_test.go index 1473ec883..b8446e826 100644 --- a/authorize_error_test.go +++ b/authorize_error_test.go @@ -88,7 +88,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -107,7 +107,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeNone).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -126,7 +126,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -145,7 +145,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"foobar"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -164,7 +164,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -183,7 +183,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -202,7 +202,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -221,7 +221,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -241,7 +241,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code", "token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -261,7 +261,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"id_token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -281,7 +281,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone) + req.EXPECT().GetResponseMode().Return(ResponseModeFragment).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, @@ -301,7 +301,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetResponseMode().Return(ResponseModePost) + req.EXPECT().GetResponseMode().Return(ResponseModePost).Times(1) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().Write(gomock.Any()).AnyTimes() }, diff --git a/authorize_request.go b/authorize_request.go index 74f2af67c..b41e2f5f6 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -51,7 +51,7 @@ func NewAuthorizeRequest() *AuthorizeRequest { RedirectURI: &url.URL{}, HandledResponseTypes: Arguments{}, Request: *NewRequest(), - ResponseMode: ResponseModeQuery, + ResponseMode: ResponseModeNone, } } diff --git a/handler/oauth2/flow_authorize_code_auth.go b/handler/oauth2/flow_authorize_code_auth.go index 4f0113595..55f2f626c 100644 --- a/handler/oauth2/flow_authorize_code_auth.go +++ b/handler/oauth2/flow_authorize_code_auth.go @@ -77,6 +77,9 @@ func (c *AuthorizeExplicitGrantHandler) HandleAuthorizeEndpointRequest(ctx conte return nil } + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeQuery) + } // Disabled because this is already handled at the authorize_request_handler // if !ar.GetClient().GetResponseTypes().Has("code") { // return errors.WithStack(fosite.ErrInvalidGrant) @@ -114,9 +117,6 @@ func (c *AuthorizeExplicitGrantHandler) IssueAuthorizeCode(ctx context.Context, resp.AddParameter("code", code) resp.AddParameter("state", ar.GetState()) resp.AddParameter("scope", strings.Join(ar.GetGrantedScopes(), " ")) - if ar.GetResponseMode() == fosite.ResponseModeNone { - ar.SetResponseMode(fosite.ResponseModeQuery) - } ar.SetResponseTypeHandled("code") return nil diff --git a/handler/oauth2/flow_authorize_code_auth_test.go b/handler/oauth2/flow_authorize_code_auth_test.go index 0f20224a9..63d9916ee 100644 --- a/handler/oauth2/flow_authorize_code_auth_test.go +++ b/handler/oauth2/flow_authorize_code_auth_test.go @@ -129,6 +129,35 @@ func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) { assert.Equal(t, areq.State, aresp.GetParameters().Get("state")) }, }, + { + areq: &fosite.AuthorizeRequest{ + ResponseTypes: fosite.Arguments{"code"}, + Request: fosite.Request{ + Client: &fosite.DefaultClient{ + ResponseTypes: fosite.Arguments{"code"}, + RedirectURIs: []string{"https://asdf.de/cb"}, + Audience: []string{"https://www.ory.sh/api"}, + }, + RequestedAudience: []string{"https://www.ory.sh/api"}, + GrantedScope: fosite.Arguments{"a", "b"}, + Session: &fosite.DefaultSession{ + ExpiresAt: map[fosite.TokenType]time.Time{fosite.AccessToken: time.Now().UTC().Add(time.Hour)}, + }, + RequestedAt: time.Now().UTC(), + }, + State: "superstate", + RedirectURI: parseUrl("https://asdf.de/cb"), + }, + description: "Default responseMode check", + expect: func(t *testing.T, areq *fosite.AuthorizeRequest, aresp *fosite.AuthorizeResponse) { + code := aresp.GetParameters().Get("code") + assert.NotEmpty(t, code) + + assert.Equal(t, strings.Join(areq.GrantedScope, " "), aresp.GetParameters().Get("scope")) + assert.Equal(t, areq.State, aresp.GetParameters().Get("state")) + assert.Equal(t, fosite.ResponseModeQuery, areq.GetResponseMode()) + }, + }, } { t.Run("case="+c.description, func(t *testing.T) { aresp := fosite.NewAuthorizeResponse() diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index 19c440e11..cd32fe0cf 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -52,7 +52,9 @@ func (c *AuthorizeImplicitGrantTypeHandler) HandleAuthorizeEndpointRequest(ctx c if !ar.GetResponseTypes().ExactOne("token") { return nil } - + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeFragment) + } // Disabled because this is already handled at the authorize_request_handler // if !ar.GetClient().GetResponseTypes().Has("token") { // return errors.WithStack(fosite.ErrInvalidGrant.WithDebug("The client is not allowed to use response type token")) @@ -99,9 +101,6 @@ func (c *AuthorizeImplicitGrantTypeHandler) IssueImplicitAccessToken(ctx context resp.AddParameter("token_type", "bearer") resp.AddParameter("state", ar.GetState()) resp.AddParameter("scope", strings.Join(ar.GetGrantedScopes(), " ")) - if ar.GetResponseMode() == fosite.ResponseModeNone { - ar.SetResponseMode(fosite.ResponseModeFragment) - } ar.SetResponseTypeHandled("token") diff --git a/handler/oauth2/flow_authorize_implicit_test.go b/handler/oauth2/flow_authorize_implicit_test.go index 3110565a0..6740bc7be 100644 --- a/handler/oauth2/flow_authorize_implicit_test.go +++ b/handler/oauth2/flow_authorize_implicit_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/golang/mock/gomock" "github.com/pkg/errors" "github.com/stretchr/testify/require" @@ -36,21 +38,12 @@ import ( func TestAuthorizeImplicit_EndpointHandler(t *testing.T) { ctrl := gomock.NewController(t) - store := internal.NewMockAccessTokenStorage(ctrl) - chgen := internal.NewMockAccessTokenStrategy(ctrl) - aresp := internal.NewMockAuthorizeResponder(ctrl) defer ctrl.Finish() areq := fosite.NewAuthorizeRequest() areq.Session = new(fosite.DefaultSession) + h, store, chgen, aresp := makeAuthorizeImplicitGrantTypeHandler(ctrl) - h := AuthorizeImplicitGrantTypeHandler{ - AccessTokenStorage: store, - AccessTokenStrategy: chgen, - AccessTokenLifespan: time.Hour, - ScopeStrategy: fosite.HierarchicScopeStrategy, - AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, - } for k, c := range []struct { description string setup func() @@ -138,3 +131,49 @@ func TestAuthorizeImplicit_EndpointHandler(t *testing.T) { }) } } +func makeAuthorizeImplicitGrantTypeHandler(ctrl *gomock.Controller) (AuthorizeImplicitGrantTypeHandler, + *internal.MockAccessTokenStorage, *internal.MockAccessTokenStrategy, *internal.MockAuthorizeResponder) { + store := internal.NewMockAccessTokenStorage(ctrl) + chgen := internal.NewMockAccessTokenStrategy(ctrl) + aresp := internal.NewMockAuthorizeResponder(ctrl) + + h := AuthorizeImplicitGrantTypeHandler{ + AccessTokenStorage: store, + AccessTokenStrategy: chgen, + AccessTokenLifespan: time.Hour, + ScopeStrategy: fosite.HierarchicScopeStrategy, + AudienceMatchingStrategy: fosite.DefaultAudienceMatchingStrategy, + } + + return h, store, chgen, aresp +} + +func TestDefaultResponseMode_AuthorizeImplicit_EndpointHandler(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + areq := fosite.NewAuthorizeRequest() + areq.Session = new(fosite.DefaultSession) + h, store, chgen, aresp := makeAuthorizeImplicitGrantTypeHandler(ctrl) + + areq.State = "state" + areq.GrantedScope = fosite.Arguments{"scope"} + areq.ResponseTypes = fosite.Arguments{"token"} + areq.Client = &fosite.DefaultClient{ + GrantTypes: fosite.Arguments{"implicit"}, + ResponseTypes: fosite.Arguments{"token"}, + } + + store.EXPECT().CreateAccessTokenSession(nil, "ats", gomock.Eq(areq.Sanitize([]string{}))).AnyTimes().Return(nil) + + aresp.EXPECT().AddParameter("access_token", "access.ats") + aresp.EXPECT().AddParameter("expires_in", gomock.Any()) + aresp.EXPECT().AddParameter("token_type", "bearer") + aresp.EXPECT().AddParameter("state", "state") + aresp.EXPECT().AddParameter("scope", "scope") + chgen.EXPECT().GenerateAccessToken(nil, areq).AnyTimes().Return("access.ats", "ats", nil) + + err := h.HandleAuthorizeEndpointRequest(nil, areq, aresp) + assert.NoError(t, err) + assert.Equal(t, fosite.ResponseModeFragment, areq.GetResponseMode()) +} diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index 18512fdd6..92c95aaeb 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -54,7 +54,9 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. if !(ar.GetResponseTypes().Matches("token", "id_token", "code") || ar.GetResponseTypes().Matches("token", "code") || ar.GetResponseTypes().Matches("id_token", "code")) { return nil } - + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeFragment) + } // Disabled because this is already handled at the authorize_request_handler //if ar.GetResponseTypes().Matches("token") && !ar.GetClient().GetResponseTypes().Has("token") { // return errors.WithStack(fosite.ErrInvalidGrant.WithDebug("The client is not allowed to use the token response type")) @@ -111,10 +113,6 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. } resp.AddParameter("code", code) - if ar.GetResponseMode() == fosite.ResponseModeNone { - ar.SetResponseMode(fosite.ResponseModeFragment) - } - ar.SetResponseTypeHandled("code") hash, err := c.Enigma.Hash(ctx, []byte(resp.GetParameters().Get("code"))) diff --git a/handler/openid/flow_hybrid_test.go b/handler/openid/flow_hybrid_test.go index 734716bfc..de78f5293 100644 --- a/handler/openid/flow_hybrid_test.go +++ b/handler/openid/flow_hybrid_test.go @@ -255,6 +255,19 @@ func TestHybrid_HandleAuthorizeEndpointRequest(t *testing.T) { assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(fosite.AuthorizeCode)) }, }, + { + description: "Default responseMode check", + setup: func() OpenIDConnectHybridHandler { + return makeOpenIDConnectHybridHandler(fosite.MinParameterEntropy) + }, + check: func() { + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("code")) + assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) + assert.Equal(t, fosite.ResponseModeFragment, areq.GetResponseMode()) + assert.Equal(t, time.Now().Add(time.Hour).UTC().Round(time.Second), areq.GetSession().GetExpiresAt(fosite.AuthorizeCode)) + }, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { h := c.setup() diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index 45cf096bd..d59f72a66 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -51,6 +51,10 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex return nil } + if ar.GetResponseMode() == fosite.ResponseModeNone { + ar.SetResponseMode(fosite.ResponseModeFragment) + } + if !ar.GetClient().GetGrantTypes().Has("implicit") { return errors.WithStack(fosite.ErrInvalidGrant.WithHint("The OAuth 2.0 Client is not allowed to use the authorization grant \"implicit\".")) } @@ -100,9 +104,6 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex } else { resp.AddParameter("state", ar.GetState()) - if ar.GetResponseMode() == fosite.ResponseModeNone { - ar.SetResponseMode(fosite.ResponseModeFragment) - } } if err := c.IssueImplicitIDToken(ctx, ar, resp); err != nil { diff --git a/handler/openid/flow_implicit_test.go b/handler/openid/flow_implicit_test.go index 03a00b094..b9d870563 100644 --- a/handler/openid/flow_implicit_test.go +++ b/handler/openid/flow_implicit_test.go @@ -246,6 +246,21 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) }, }, + { + description: "default responseMode check", + setup: func() OpenIDConnectImplicitHandler { + areq.Form.Set("nonce", "some-random-foo-nonce-wow") + areq.ResponseTypes = fosite.Arguments{"id_token", "token"} + areq.RequestedScope = fosite.Arguments{"fosite", "openid"} + return makeOpenIDConnectImplicitHandler(fosite.MinParameterEntropy) + }, + check: func() { + assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) + assert.NotEmpty(t, aresp.GetParameters().Get("state")) + assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) + assert.Equal(t, fosite.ResponseModeFragment, areq.GetResponseMode()) + }, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { h := c.setup() From f6dea638966cb2ca2ebd0af91b045ffa289e55db Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Wed, 28 Oct 2020 09:38:16 -0700 Subject: [PATCH 05/22] refactor: Moving ParseFormPostResponse test helper to internal package --- authorize_helper.go | 97 ------------------ authorize_helper_test.go | 77 ++++++++------- integration/authorize_form_post_test.go | 20 ++-- internal/test_helpers.go | 126 ++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 144 deletions(-) create mode 100644 internal/test_helpers.go diff --git a/authorize_helper.go b/authorize_helper.go index 6b6a98ee5..f65a00f31 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -26,12 +26,7 @@ import ( "io" "net/url" "regexp" - "strconv" "strings" - "time" - - "golang.org/x/net/html" - goauth "golang.org/x/oauth2" "github.com/asaskevich/govalidator" "github.com/pkg/errors" @@ -212,95 +207,3 @@ func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, r Parameters: parameters, }) } -func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, rFC6749Error RFC6749Error, err error) { - - token = goauth.Token{} - rFC6749Error = RFC6749Error{} - - doc, err := html.Parse(resp) - if err != nil { - return "", "", "", token, rFC6749Error, err - } - //doc>html>body - body := findBody(doc.FirstChild.FirstChild) - if body.Data != "body" { - return "", "", "", token, rFC6749Error, errors.New("Malformed html") - } - htmlEvent := body.Attr[0].Key - if htmlEvent != "onload" { - return "", "", "", token, rFC6749Error, errors.New("onload event is missing") - } - onLoadFunc := body.Attr[0].Val - if onLoadFunc != "javascript:document.forms[0].submit()" { - return "", "", "", token, rFC6749Error, errors.New("onload function is missing") - } - form := getNextNoneTextNode(body.FirstChild) - if form.Data != "form" { - return "", "", "", token, rFC6749Error, errors.New("html form is missing") - } - for _, attr := range form.Attr { - if attr.Key == "method" { - if attr.Val != "post" { - return "", "", "", token, rFC6749Error, errors.New("html form post method is missing") - } - } else { - if attr.Val != redirectURL { - return "", "", "", token, rFC6749Error, errors.New("html form post url is wrong") - } - } - } - - for node := getNextNoneTextNode(form.FirstChild); node != nil; node = getNextNoneTextNode(node.NextSibling) { - var k, v string - for _, attr := range node.Attr { - if attr.Key == "name" { - k = attr.Val - } else if attr.Key == "value" { - v = attr.Val - } - - } - switch k { - case "state": - stateFromServer = v - case "code": - authorizationCode = v - case "expires_in": - expires, err := strconv.Atoi(v) - if err != nil { - return "", "", "", token, rFC6749Error, err - } - token.Expiry = time.Now().UTC().Add(time.Duration(expires) * time.Second) - case "access_token": - token.AccessToken = v - case "token_type": - token.TokenType = v - case "refresh_token": - token.RefreshToken = v - case "error": - rFC6749Error.Name = v - case "error_description": - rFC6749Error.Description = v - case "id_token": - iDToken = v - } - } - return -} - -func getNextNoneTextNode(node *html.Node) *html.Node { - nextNode := node.NextSibling - if nextNode != nil && nextNode.Type == html.TextNode { - nextNode = getNextNoneTextNode(node.NextSibling) - } - return nextNode -} -func findBody(node *html.Node) *html.Node { - if node != nil { - if node.Data == "body" { - return node - } - return findBody(node.NextSibling) - } - return nil -} diff --git a/authorize_helper_test.go b/authorize_helper_test.go index a36e4e0f9..27078d1b6 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -19,7 +19,7 @@ * */ -package fosite +package fosite_test import ( "bytes" @@ -27,6 +27,9 @@ import ( "net/url" "testing" + "github.com/ory/fosite" + "github.com/ory/fosite/internal" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -45,7 +48,7 @@ func TestIsLocalhost(t *testing.T) { {expect: true, rawurl: "https://test.localhost"}, } { u, _ := url.Parse(c.rawurl) - assert.Equal(t, c.expect, IsLocalhost(u), "case %d", k) + assert.Equal(t, c.expect, fosite.IsLocalhost(u), "case %d", k) } } @@ -62,172 +65,172 @@ func TestIsLocalhost(t *testing.T) { // of pre-registered redirect URIs (see Section 5.2.3.5). func TestDoesClientWhiteListRedirect(t *testing.T) { for k, c := range []struct { - client Client + client fosite.Client url string isError bool expected string }{ { - client: &DefaultClient{RedirectURIs: []string{""}}, + client: &fosite.DefaultClient{RedirectURIs: []string{""}}, url: "https://foo.com/cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"wta://auth"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"wta://auth"}}, url: "wta://auth", expected: "wta://auth", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"wta:///auth"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"wta:///auth"}}, url: "wta:///auth", expected: "wta:///auth", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"wta://foo/auth"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"wta://foo/auth"}}, url: "wta://foo/auth", expected: "wta://foo/auth", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, url: "https://foo.com/cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, url: "", isError: false, expected: "https://bar.com/cb", }, { - client: &DefaultClient{RedirectURIs: []string{""}}, + client: &fosite.DefaultClient{RedirectURIs: []string{""}}, url: "", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, url: "https://bar.com/cb", isError: false, expected: "https://bar.com/cb", }, { - client: &DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"https://bar.com/cb"}}, url: "https://bar.com/cb123", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://[::1]"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://[::1]"}}, url: "http://[::1]:1024", expected: "http://[::1]:1024", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://[::1]"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://[::1]"}}, url: "http://[::1]:1024/cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://[::1]/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://[::1]/cb"}}, url: "http://[::1]:1024/cb", expected: "http://[::1]:1024/cb", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://[::1]"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://[::1]"}}, url: "http://foo.bar/bar", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, url: "http://127.0.0.1:1024", expected: "http://127.0.0.1:1024", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1/cb"}}, url: "http://127.0.0.1:64000/cb", expected: "http://127.0.0.1:64000/cb", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, url: "http://127.0.0.1:64000/cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, url: "http://127.0.0.1", expected: "http://127.0.0.1", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1/Cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1/Cb"}}, url: "http://127.0.0.1:8080/Cb", expected: "http://127.0.0.1:8080/Cb", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, url: "http://foo.bar/bar", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1"}}, url: ":/invalid.uri)bar", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb"}}, url: "http://127.0.0.1:8080/Cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb"}}, url: "http://127.0.0.1:8080/cb?foo=bar", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb?foo=bar"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb?foo=bar"}}, url: "http://127.0.0.1:8080/cb?foo=bar", expected: "http://127.0.0.1:8080/cb?foo=bar", isError: false, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb?foo=bar"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb?foo=bar"}}, url: "http://127.0.0.1:8080/cb?baz=bar&foo=bar", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb?foo=bar&baz=bar"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb?foo=bar&baz=bar"}}, url: "http://127.0.0.1:8080/cb?baz=bar&foo=bar", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"https://www.ory.sh/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"https://www.ory.sh/cb"}}, url: "http://127.0.0.1:8080/cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"http://127.0.0.1:8080/cb"}}, url: "https://www.ory.sh/cb", isError: true, }, { - client: &DefaultClient{RedirectURIs: []string{"web+application://callback"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"web+application://callback"}}, url: "web+application://callback", isError: false, expected: "web+application://callback", }, { - client: &DefaultClient{RedirectURIs: []string{"https://google.com/?foo=bar%20foo+baz"}}, + client: &fosite.DefaultClient{RedirectURIs: []string{"https://google.com/?foo=bar%20foo+baz"}}, url: "https://google.com/?foo=bar%20foo+baz", isError: false, expected: "https://google.com/?foo=bar%20foo+baz", }, } { - redir, err := MatchRedirectURIWithClientRedirectURIs(c.url, c.client) + redir, err := fosite.MatchRedirectURIWithClientRedirectURIs(c.url, c.client) assert.Equal(t, c.isError, err != nil, "%d: %+v", k, c) if err == nil { require.NotNil(t, redir, "%d", k) @@ -255,7 +258,7 @@ func TestIsRedirectURISecure(t *testing.T) { } { uu, err := url.Parse(c.u) require.NoError(t, err) - assert.Equal(t, !c.err, IsRedirectURISecure(uu), "case %d", d) + assert.Equal(t, !c.err, fosite.IsRedirectURISecure(uu), "case %d", d) } } @@ -263,8 +266,8 @@ func TestWriteAuthorizeFormPostResponse(t *testing.T) { var responseBuffer bytes.Buffer redirectURL := "https://localhost:8080/cb" parameters := url.Values{"code": {"lshr755nsg39fgur"}, "state": {"924659540232"}} - WriteAuthorizeFormPostResponse(redirectURL, parameters, &responseBuffer) - code, state, _, _, _, err := ParseFormPostResponse(redirectURL, ioutil.NopCloser(bytes.NewReader(responseBuffer.Bytes()))) + fosite.WriteAuthorizeFormPostResponse(redirectURL, parameters, &responseBuffer) + code, state, _, _, _, err := internal.ParseFormPostResponse(redirectURL, ioutil.NopCloser(bytes.NewReader(responseBuffer.Bytes()))) assert.NoError(t, err) assert.Equal(t, parameters.Get("code"), code) assert.Equal(t, parameters.Get("state"), state) @@ -288,6 +291,6 @@ func TestIsRedirectURISecureStrict(t *testing.T) { } { uu, err := url.Parse(c.u) require.NoError(t, err) - assert.Equal(t, !c.err, IsRedirectURISecureStrict(uu), "case %d", d) + assert.Equal(t, !c.err, fosite.IsRedirectURISecureStrict(uu), "case %d", d) } } diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index 5efeec2a8..6ffc6c27c 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -69,7 +69,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { for k, c := range []struct { description string setup func() - check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) + check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) responseType string }{ { @@ -78,7 +78,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" }, - check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, token.TokenType) assert.NotEmpty(t, token.AccessToken) @@ -91,7 +91,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" }, - check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, code) }, @@ -103,7 +103,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} }, - check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, code) assert.NotEmpty(t, token.TokenType) @@ -118,7 +118,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} }, - check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, code) assert.NotEmpty(t, iDToken) @@ -134,7 +134,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} }, - check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, code) assert.NotEmpty(t, iDToken) @@ -146,10 +146,10 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" }, - check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err fosite.RFC6749Error) { + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) - assert.NotEmpty(t, err.Name) - assert.NotEmpty(t, err.Description) + assert.NotEmpty(t, err["Name"]) + assert.NotEmpty(t, err["Description"]) }, }, } { @@ -164,7 +164,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { resp, err := client.Get(authURL) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - code, state, token, iDToken, errResp, err := fosite.ParseFormPostResponse(fositeStore.Clients["my-client"].GetRedirectURIs()[0], resp.Body) + code, state, token, iDToken, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["my-client"].GetRedirectURIs()[0], resp.Body) require.NoError(t, err) c.check(t, state, code, iDToken, token, errResp) }) diff --git a/internal/test_helpers.go b/internal/test_helpers.go new file mode 100644 index 000000000..f1d1757b5 --- /dev/null +++ b/internal/test_helpers.go @@ -0,0 +1,126 @@ +/* + * Copyright © 2015-2020 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2020 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package internal + +import ( + "errors" + + "io" + "strconv" + "time" + + "golang.org/x/net/html" + goauth "golang.org/x/oauth2" +) + +func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, rFC6749Error map[string]string, err error) { + + token = goauth.Token{} + rFC6749Error = map[string]string{} + + doc, err := html.Parse(resp) + if err != nil { + return "", "", "", token, rFC6749Error, err + } + //doc>html>body + body := findBody(doc.FirstChild.FirstChild) + if body.Data != "body" { + return "", "", "", token, rFC6749Error, errors.New("Malformed html") + } + htmlEvent := body.Attr[0].Key + if htmlEvent != "onload" { + return "", "", "", token, rFC6749Error, errors.New("onload event is missing") + } + onLoadFunc := body.Attr[0].Val + if onLoadFunc != "javascript:document.forms[0].submit()" { + return "", "", "", token, rFC6749Error, errors.New("onload function is missing") + } + form := getNextNoneTextNode(body.FirstChild) + if form.Data != "form" { + return "", "", "", token, rFC6749Error, errors.New("html form is missing") + } + for _, attr := range form.Attr { + if attr.Key == "method" { + if attr.Val != "post" { + return "", "", "", token, rFC6749Error, errors.New("html form post method is missing") + } + } else { + if attr.Val != redirectURL { + return "", "", "", token, rFC6749Error, errors.New("html form post url is wrong") + } + } + } + + for node := getNextNoneTextNode(form.FirstChild); node != nil; node = getNextNoneTextNode(node.NextSibling) { + var k, v string + for _, attr := range node.Attr { + if attr.Key == "name" { + k = attr.Val + } else if attr.Key == "value" { + v = attr.Val + } + + } + switch k { + case "state": + stateFromServer = v + case "code": + authorizationCode = v + case "expires_in": + expires, err := strconv.Atoi(v) + if err != nil { + return "", "", "", token, rFC6749Error, err + } + token.Expiry = time.Now().UTC().Add(time.Duration(expires) * time.Second) + case "access_token": + token.AccessToken = v + case "token_type": + token.TokenType = v + case "refresh_token": + token.RefreshToken = v + case "error": + rFC6749Error["Name"] = v + case "error_description": + rFC6749Error["Description"] = v + case "id_token": + iDToken = v + } + } + return +} + +func getNextNoneTextNode(node *html.Node) *html.Node { + nextNode := node.NextSibling + if nextNode != nil && nextNode.Type == html.TextNode { + nextNode = getNextNoneTextNode(node.NextSibling) + } + return nextNode +} +func findBody(node *html.Node) *html.Node { + if node != nil { + if node.Data == "body" { + return node + } + return findBody(node.NextSibling) + } + return nil +} From ded62bed66e766c5980cb81c8c319df72c4d5518 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Wed, 28 Oct 2020 17:17:31 -0700 Subject: [PATCH 06/22] refactor: Renaming ResponseModeNone --- authorize_error_test.go | 2 +- authorize_request.go | 4 ++-- authorize_request_handler.go | 4 ++-- authorize_write.go | 2 +- authorize_write_test.go | 2 +- handler/oauth2/flow_authorize_code_auth.go | 2 +- handler/oauth2/flow_authorize_implicit.go | 2 +- handler/openid/flow_hybrid.go | 2 +- handler/openid/flow_implicit.go | 2 +- 9 files changed, 11 insertions(+), 11 deletions(-) diff --git a/authorize_error_test.go b/authorize_error_test.go index b8446e826..ae16a1631 100644 --- a/authorize_error_test.go +++ b/authorize_error_test.go @@ -107,7 +107,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[0])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"code"})) - req.EXPECT().GetResponseMode().Return(ResponseModeNone).Times(2) + req.EXPECT().GetResponseMode().Return(ResponseModeDefault).Times(2) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().WriteHeader(http.StatusFound) }, diff --git a/authorize_request.go b/authorize_request.go index b41e2f5f6..7319cc7dd 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -28,7 +28,7 @@ import ( type ResponseModeType string const ( - ResponseModeNone = ResponseModeType("") + ResponseModeDefault = ResponseModeType("") ResponseModePost = ResponseModeType("form_post") ResponseModeQuery = ResponseModeType("query") ResponseModeFragment = ResponseModeType("fragment") @@ -51,7 +51,7 @@ func NewAuthorizeRequest() *AuthorizeRequest { RedirectURI: &url.URL{}, HandledResponseTypes: Arguments{}, Request: *NewRequest(), - ResponseMode: ResponseModeNone, + ResponseMode: ResponseModeDefault, } } diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 91430bc5b..b0e1f5a2a 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -214,8 +214,8 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest responseMode := r.Form.Get("response_mode") switch responseMode { - case string(ResponseModeNone): - request.ResponseMode = ResponseModeNone + case string(ResponseModeDefault): + request.ResponseMode = ResponseModeDefault case string(ResponseModeFragment): request.ResponseMode = ResponseModeFragment case string(ResponseModeQuery): diff --git a/authorize_write.go b/authorize_write.go index 54ff1fc2b..4e755ee3f 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -50,7 +50,7 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ //form_post rw.Header().Add("Content-Type", "text/html;charset=UTF-8") WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), rw) - case ResponseModeQuery, ResponseModeNone: + case ResponseModeQuery, ResponseModeDefault: // Explicit grants q := redir.Query() rq := resp.GetParameters() diff --git a/authorize_write_test.go b/authorize_write_test.go index 9c5602ed7..2e90b3ac8 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -50,7 +50,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModeNone) + ar.EXPECT().GetResponseMode().Return(ResponseModeDefault) resp.EXPECT().GetParameters().Return(url.Values{}) resp.EXPECT().GetHeader().Return(http.Header{}) diff --git a/handler/oauth2/flow_authorize_code_auth.go b/handler/oauth2/flow_authorize_code_auth.go index 55f2f626c..76e5538db 100644 --- a/handler/oauth2/flow_authorize_code_auth.go +++ b/handler/oauth2/flow_authorize_code_auth.go @@ -77,7 +77,7 @@ func (c *AuthorizeExplicitGrantHandler) HandleAuthorizeEndpointRequest(ctx conte return nil } - if ar.GetResponseMode() == fosite.ResponseModeNone { + if ar.GetResponseMode() == fosite.ResponseModeDefault { ar.SetResponseMode(fosite.ResponseModeQuery) } // Disabled because this is already handled at the authorize_request_handler diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index cd32fe0cf..6f287079e 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -52,7 +52,7 @@ func (c *AuthorizeImplicitGrantTypeHandler) HandleAuthorizeEndpointRequest(ctx c if !ar.GetResponseTypes().ExactOne("token") { return nil } - if ar.GetResponseMode() == fosite.ResponseModeNone { + if ar.GetResponseMode() == fosite.ResponseModeDefault { ar.SetResponseMode(fosite.ResponseModeFragment) } // Disabled because this is already handled at the authorize_request_handler diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index 92c95aaeb..a6b40b24b 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -54,7 +54,7 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. if !(ar.GetResponseTypes().Matches("token", "id_token", "code") || ar.GetResponseTypes().Matches("token", "code") || ar.GetResponseTypes().Matches("id_token", "code")) { return nil } - if ar.GetResponseMode() == fosite.ResponseModeNone { + if ar.GetResponseMode() == fosite.ResponseModeDefault { ar.SetResponseMode(fosite.ResponseModeFragment) } // Disabled because this is already handled at the authorize_request_handler diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index d59f72a66..79708aa9c 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -51,7 +51,7 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex return nil } - if ar.GetResponseMode() == fosite.ResponseModeNone { + if ar.GetResponseMode() == fosite.ResponseModeDefault { ar.SetResponseMode(fosite.ResponseModeFragment) } From c3c5ea7ccfbd44bcdc339967ef0d6f705faea608 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Wed, 28 Oct 2020 22:43:08 -0700 Subject: [PATCH 07/22] refactor: intoducing function to do the fragment encoding --- authorize_helper.go | 17 +++++++++++++++++ authorize_helper_test.go | 35 +++++++++++++++++++++++++++++++++++ authorize_write.go | 14 ++------------ authorize_write_test.go | 6 +++--- 4 files changed, 57 insertions(+), 15 deletions(-) diff --git a/authorize_helper.go b/authorize_helper.go index f65a00f31..8e6a45a87 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -22,6 +22,7 @@ package fosite import ( + "fmt" "html/template" "io" "net/url" @@ -207,3 +208,19 @@ func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, r Parameters: parameters, }) } +func URLSetFragment(source *url.URL, fragment url.Values) { + var f string + for k, v := range fragment { + for _, vv := range v { + if len(f) != 0 { + f += fmt.Sprintf("&%s=%s", k, vv) + } else { + f += fmt.Sprintf("%s=%s", k, vv) + } + } + } + //f=fragment.Encode() + //f=plusMatch.ReplaceAllString(f," ") + //f=encodedPlusMatch.ReplaceAllString(f,"+") + source.Fragment = f +} diff --git a/authorize_helper_test.go b/authorize_helper_test.go index 27078d1b6..c8a056a49 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -25,6 +25,7 @@ import ( "bytes" "io/ioutil" "net/url" + "strings" "testing" "github.com/ory/fosite" @@ -294,3 +295,37 @@ func TestIsRedirectURISecureStrict(t *testing.T) { assert.Equal(t, !c.err, fosite.IsRedirectURISecureStrict(uu), "case %d", d) } } + +func TestURLSetFragment(t *testing.T) { + for d, c := range []struct { + u string + a string + f url.Values + }{ + {u: "http://google.com", a: "http://google.com#code=567060896", f: url.Values{"code": []string{"567060896"}}}, + {u: "http://google.com", a: "http://google.com#code=567060896&scope=read", f: url.Values{"code": []string{"567060896"}, "scope": []string{"read"}}}, + {u: "http://google.com", a: "http://google.com#code=567060896&scope=read%20mail", f: url.Values{"code": []string{"567060896j"}, "scope": []string{"read mail"}}}, + {u: "http://google.com", a: "http://google.com#code=567060896&scope=read+write", f: url.Values{"code": []string{"567060896"}, "scope": []string{"read+write"}}}, + {u: "http://google.com", a: "http://google.com#code=567060896&scope=api:*", f: url.Values{"code": []string{"567060896"}, "scope": []string{"api:*"}}}, + {u: "https://google.com?foo=bar", a: "https://google.com?foo=bar#code=567060896", f: url.Values{"code": []string{"567060896"}}}, + {u: "http://localhost?foo=bar&baz=foo", a: "http://localhost?foo=bar&baz=foo#code=567060896", f: url.Values{"code": []string{"567060896"}}}, + } { + uu, err := url.Parse(c.u) + require.NoError(t, err) + fosite.URLSetFragment(uu, c.f) + tURL, err := url.Parse(uu.String()) + require.NoError(t, err) + r := ParseURLFragment(tURL.Fragment) + assert.Equal(t, c.f.Get("code"), r.Get("code"), "case %d", d) + assert.Equal(t, c.f.Get("scope"), r.Get("scope"), "case %d", d) + } +} +func ParseURLFragment(fragment string) url.Values { + r := url.Values{} + kvs := strings.Split(fragment, "&") + for _, kv := range kvs { + kva := strings.Split(kv, "=") + r.Add(kva[0], kva[1]) + } + return r +} diff --git a/authorize_write.go b/authorize_write.go index 4e755ee3f..208bc3cde 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -23,12 +23,6 @@ package fosite import ( "net/http" - "regexp" -) - -var ( - // scopeMatch = regexp.MustCompile("scope=[^\\&]+.*$") - plusMatch = regexp.MustCompile("\\+") ) func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) { @@ -63,12 +57,8 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ // Implicit grants // The endpoint URI MUST NOT include a fragment component. redir.Fragment = "" - - u := redir.String() - fr := resp.GetParameters() - u = u + "#" + fr.Encode() - u = plusMatch.ReplaceAllString(u, "%20") - sendRedirect(u, rw) + URLSetFragment(redir, resp.GetParameters()) + sendRedirect(redir.String(), rw) } } diff --git a/authorize_write_test.go b/authorize_write_test.go index 2e90b3ac8..415f3b23a 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -109,7 +109,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az"}, "scope": {"a b"}}) + resp.EXPECT().GetParameters().Return(url.Values{"bar": {"b+az ab"}}) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) rw.EXPECT().Header().Return(header).Times(2) @@ -118,7 +118,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { expect: func() { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, - "Location": {"https://foobar.com/?foo=bar#bar=b%2Baz&scope=a%20b"}, + "Location": {"https://foobar.com/?foo=bar#bar=b+az%20ab"}, "Cache-Control": []string{"no-store"}, "Pragma": []string{"no-cache"}, }, header) @@ -160,7 +160,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { expect: func() { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, - "Location": {"https://foobar.com/?foo=bar#bar=baz&scope=api%3A%2A"}, + "Location": {"https://foobar.com/?foo=bar#bar=baz&scope=api:*"}, "Cache-Control": []string{"no-store"}, "Pragma": []string{"no-cache"}, }, header) From e92ec55d3f08dc45b46ca4ad3804c7827eaa41ae Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Thu, 29 Oct 2020 00:04:34 -0700 Subject: [PATCH 08/22] refactor: Making formPostHTLML Template configurable --- authorize_error.go | 2 +- authorize_helper.go | 21 ++++++++---- authorize_helper_test.go | 43 ++++++++++++++++++++----- authorize_write.go | 2 +- authorize_write_test.go | 4 +-- fosite.go | 4 +++ integration/authorize_form_post_test.go | 2 +- internal/test_helpers.go | 22 +++++++------ 8 files changed, 71 insertions(+), 29 deletions(-) diff --git a/authorize_error.go b/authorize_error.go index cc7a328fe..e4fc51fc5 100644 --- a/authorize_error.go +++ b/authorize_error.go @@ -66,7 +66,7 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest var redirectURIString string if ar.GetResponseMode() == ResponseModePost { rw.Header().Add("Content-Type", "text/html;charset=UTF-8") - WriteAuthorizeFormPostResponse(redirectURI.String(), query, rw) + WriteAuthorizeFormPostResponse(redirectURI.String(), query, GetPostFormHTMLTemplate(*f), rw) return } else if ar.GetResponseMode() == ResponseModeFragment { redirectURIString = redirectURI.String() + "#" + query.Encode() diff --git a/authorize_helper.go b/authorize_helper.go index 8e6a45a87..7c0c143f2 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -33,14 +33,16 @@ import ( "github.com/pkg/errors" ) -var formPostTemplate = template.Must(template.New("form_post").Parse(` +var FormPostDefaultTemplate = template.Must(template.New("form_post").Parse(` Submit This Form
{{ range $key,$value := .Parameters }} - + {{ range $parameter:= $value}} + + {{end}} {{ end }}
@@ -199,8 +201,8 @@ func IsLocalhost(redirectURI *url.URL) bool { return strings.HasSuffix(hn, ".localhost") || hn == "127.0.0.1" || hn == "::1" || hn == "localhost" } -func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, rw io.Writer) { - _ = formPostTemplate.Execute(rw, struct { +func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, template *template.Template, rw io.Writer) { + _ = template.Execute(rw, struct { RedirURL string Parameters url.Values }{ @@ -219,8 +221,13 @@ func URLSetFragment(source *url.URL, fragment url.Values) { } } } - //f=fragment.Encode() - //f=plusMatch.ReplaceAllString(f," ") - //f=encodedPlusMatch.ReplaceAllString(f,"+") source.Fragment = f } + +func GetPostFormHTMLTemplate(f Fosite) *template.Template { + formPostHTMLTemplate := f.FormPostHTMLTemplate + if formPostHTMLTemplate == nil { + formPostHTMLTemplate = FormPostDefaultTemplate + } + return formPostHTMLTemplate +} diff --git a/authorize_helper_test.go b/authorize_helper_test.go index c8a056a49..480b9ca7e 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -264,14 +264,41 @@ func TestIsRedirectURISecure(t *testing.T) { } func TestWriteAuthorizeFormPostResponse(t *testing.T) { - var responseBuffer bytes.Buffer - redirectURL := "https://localhost:8080/cb" - parameters := url.Values{"code": {"lshr755nsg39fgur"}, "state": {"924659540232"}} - fosite.WriteAuthorizeFormPostResponse(redirectURL, parameters, &responseBuffer) - code, state, _, _, _, err := internal.ParseFormPostResponse(redirectURL, ioutil.NopCloser(bytes.NewReader(responseBuffer.Bytes()))) - assert.NoError(t, err) - assert.Equal(t, parameters.Get("code"), code) - assert.Equal(t, parameters.Get("state"), state) + for d, c := range []struct { + parameters url.Values + check func(code string, state string, customParams url.Values, d int) + }{ + { + parameters: url.Values{"code": {"lshr755nsg39fgur"}, "state": {"924659540232"}}, + check: func(code string, state string, customParams url.Values, d int) { + assert.Equal(t, "lshr755nsg39fgur", code, "case %d", d) + assert.Equal(t, "924659540232", state, "case %d", d) + }, + }, + { + parameters: url.Values{"code": {"1234"}, "custom": {"test2", "test3"}}, + check: func(code string, state string, customParams url.Values, d int) { + assert.Equal(t, "1234", code, "case %d", d) + assert.Equal(t, []string{"test2", "test3"}, customParams["custom"], "case %d", d) + }, + }, + { + parameters: url.Values{"code": {"1234"}, "custom": {"Bold"}}, + check: func(code string, state string, customParams url.Values, d int) { + assert.Equal(t, "1234", code, "case %d", d) + assert.Equal(t, "Bold", customParams.Get("custom"), "case %d", d) + }, + }, + } { + var responseBuffer bytes.Buffer + redirectURL := "https://localhost:8080/cb" + //parameters := + fosite.WriteAuthorizeFormPostResponse(redirectURL, c.parameters, fosite.FormPostDefaultTemplate, &responseBuffer) + code, state, _, _, customParams, _, err := internal.ParseFormPostResponse(redirectURL, ioutil.NopCloser(bytes.NewReader(responseBuffer.Bytes()))) + assert.NoError(t, err, "case %d", d) + c.check(code, state, customParams, d) + + } } func TestIsRedirectURISecureStrict(t *testing.T) { diff --git a/authorize_write.go b/authorize_write.go index 208bc3cde..9d1a0f281 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -43,7 +43,7 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ case ResponseModePost: //form_post rw.Header().Add("Content-Type", "text/html;charset=UTF-8") - WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), rw) + WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), GetPostFormHTMLTemplate(*f), rw) case ResponseModeQuery, ResponseModeDefault: // Explicit grants q := redir.Query() diff --git a/authorize_write_test.go b/authorize_write_test.go index 415f3b23a..de37b83f4 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -151,7 +151,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) ar.EXPECT().GetResponseMode().Return(ResponseModeFragment) - resp.EXPECT().GetParameters().Return(url.Values{"bar": {"baz"}, "scope": {"api:*"}}) + resp.EXPECT().GetParameters().Return(url.Values{"scope": {"api:*"}}) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) rw.EXPECT().Header().Return(header).Times(2) @@ -160,7 +160,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { expect: func() { assert.Equal(t, http.Header{ "X-Bar": {"baz"}, - "Location": {"https://foobar.com/?foo=bar#bar=baz&scope=api:*"}, + "Location": {"https://foobar.com/?foo=bar#scope=api:*"}, "Cache-Control": []string{"no-store"}, "Pragma": []string{"no-cache"}, }, header) diff --git a/fosite.go b/fosite.go index e3120a18c..0324cbe27 100644 --- a/fosite.go +++ b/fosite.go @@ -22,6 +22,7 @@ package fosite import ( + "html/template" "net/http" "reflect" ) @@ -105,6 +106,9 @@ type Fosite struct { // MinParameterEntropy controls the minimum size of state and nonce parameters. Defaults to fosite.MinParameterEntropy. MinParameterEntropy int + + //FormPostHTMLTemplate sets html template for rendering the authorization response when the request has response_mode=form_post. Defaults to fosite.FormPostDefaultTemplate + FormPostHTMLTemplate *template.Template } const MinParameterEntropy = 8 diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index 6ffc6c27c..cae789f5b 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -164,7 +164,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { resp, err := client.Get(authURL) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - code, state, token, iDToken, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["my-client"].GetRedirectURIs()[0], resp.Body) + code, state, token, iDToken, _, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["my-client"].GetRedirectURIs()[0], resp.Body) require.NoError(t, err) c.check(t, state, code, iDToken, token, errResp) }) diff --git a/internal/test_helpers.go b/internal/test_helpers.go index f1d1757b5..42f10c20f 100644 --- a/internal/test_helpers.go +++ b/internal/test_helpers.go @@ -23,6 +23,7 @@ package internal import ( "errors" + "net/url" "io" "strconv" @@ -32,40 +33,41 @@ import ( goauth "golang.org/x/oauth2" ) -func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, rFC6749Error map[string]string, err error) { +func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, customParameters url.Values, rFC6749Error map[string]string, err error) { token = goauth.Token{} rFC6749Error = map[string]string{} + customParameters = url.Values{} doc, err := html.Parse(resp) if err != nil { - return "", "", "", token, rFC6749Error, err + return "", "", "", token, customParameters, rFC6749Error, err } //doc>html>body body := findBody(doc.FirstChild.FirstChild) if body.Data != "body" { - return "", "", "", token, rFC6749Error, errors.New("Malformed html") + return "", "", "", token, customParameters, rFC6749Error, errors.New("Malformed html") } htmlEvent := body.Attr[0].Key if htmlEvent != "onload" { - return "", "", "", token, rFC6749Error, errors.New("onload event is missing") + return "", "", "", token, customParameters, rFC6749Error, errors.New("onload event is missing") } onLoadFunc := body.Attr[0].Val if onLoadFunc != "javascript:document.forms[0].submit()" { - return "", "", "", token, rFC6749Error, errors.New("onload function is missing") + return "", "", "", token, customParameters, rFC6749Error, errors.New("onload function is missing") } form := getNextNoneTextNode(body.FirstChild) if form.Data != "form" { - return "", "", "", token, rFC6749Error, errors.New("html form is missing") + return "", "", "", token, customParameters, rFC6749Error, errors.New("html form is missing") } for _, attr := range form.Attr { if attr.Key == "method" { if attr.Val != "post" { - return "", "", "", token, rFC6749Error, errors.New("html form post method is missing") + return "", "", "", token, customParameters, rFC6749Error, errors.New("html form post method is missing") } } else { if attr.Val != redirectURL { - return "", "", "", token, rFC6749Error, errors.New("html form post url is wrong") + return "", "", "", token, customParameters, rFC6749Error, errors.New("html form post url is wrong") } } } @@ -88,7 +90,7 @@ func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizatio case "expires_in": expires, err := strconv.Atoi(v) if err != nil { - return "", "", "", token, rFC6749Error, err + return "", "", "", token, customParameters, rFC6749Error, err } token.Expiry = time.Now().UTC().Add(time.Duration(expires) * time.Second) case "access_token": @@ -103,6 +105,8 @@ func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizatio rFC6749Error["Description"] = v case "id_token": iDToken = v + default: + customParameters.Add(k, v) } } return From 56234b4e70d43262c41ec9093e38ed3c120f816f Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Thu, 29 Oct 2020 17:02:16 +0100 Subject: [PATCH 09/22] Apply suggestions from code review --- authorize_write.go | 2 ++ fosite.go | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/authorize_write.go b/authorize_write.go index 9d1a0f281..97fbab50c 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -53,12 +53,14 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ } redir.RawQuery = q.Encode() sendRedirect(redir.String(), rw) + return case ResponseModeFragment: // Implicit grants // The endpoint URI MUST NOT include a fragment component. redir.Fragment = "" URLSetFragment(redir, resp.GetParameters()) sendRedirect(redir.String(), rw) + return } } diff --git a/fosite.go b/fosite.go index 0324cbe27..db37577ee 100644 --- a/fosite.go +++ b/fosite.go @@ -107,7 +107,7 @@ type Fosite struct { // MinParameterEntropy controls the minimum size of state and nonce parameters. Defaults to fosite.MinParameterEntropy. MinParameterEntropy int - //FormPostHTMLTemplate sets html template for rendering the authorization response when the request has response_mode=form_post. Defaults to fosite.FormPostDefaultTemplate + // FormPostHTMLTemplate sets html template for rendering the authorization response when the request has response_mode=form_post. Defaults to fosite.FormPostDefaultTemplate FormPostHTMLTemplate *template.Template } From d7c0743d03f99eac1a4b8dc1dbfd139bd992d7b3 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Fri, 30 Oct 2020 08:30:46 -0700 Subject: [PATCH 10/22] Adding special characters test case --- authorize_helper_test.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/authorize_helper_test.go b/authorize_helper_test.go index 480b9ca7e..4612e0d68 100644 --- a/authorize_helper_test.go +++ b/authorize_helper_test.go @@ -275,6 +275,13 @@ func TestWriteAuthorizeFormPostResponse(t *testing.T) { assert.Equal(t, "924659540232", state, "case %d", d) }, }, + { + parameters: url.Values{"code": {"lshr75*ns-39f+ur"}, "state": {"9a:* <&)"}}, + check: func(code string, state string, customParams url.Values, d int) { + assert.Equal(t, "lshr75*ns-39f+ur", code, "case %d", d) + assert.Equal(t, "9a:* <&)", state, "case %d", d) + }, + }, { parameters: url.Values{"code": {"1234"}, "custom": {"test2", "test3"}}, check: func(code string, state string, customParams url.Values, d int) { From b8db770846f60fd3d37429e26e1645d5bb22a1b7 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Fri, 30 Oct 2020 16:24:39 -0700 Subject: [PATCH 11/22] Adding ability to client to specify which response modes it allows --- authorize_request_handler.go | 29 ++++++- authorize_request_handler_test.go | 101 ++++++++++++++++++++++++ client.go | 16 ++++ client_test.go | 5 ++ errors.go | 6 ++ integration/authorize_form_post_test.go | 9 ++- 6 files changed, 162 insertions(+), 4 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 19f21a50c..e94b2346e 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -210,7 +210,7 @@ func (f *Fosite) validateResponseTypes(r *http.Request, request *AuthorizeReques request.ResponseTypes = responseTypes return nil } -func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest) error { +func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) error { responseMode := r.Form.Get("response_mode") switch responseMode { @@ -223,7 +223,26 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest case string(ResponseModePost): request.ResponseMode = ResponseModePost default: - return errors.WithStack(ErrUnsupportedResponseType.WithHintf("Request with unsupported response_mode \"%s\".", responseMode)) + return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", responseMode)) + } + return nil +} +func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest) error { + if request.ResponseMode != ResponseModeDefault { + var found bool + responseModeClient, ok := request.GetClient().(ResponseModeClient) + if !ok { + return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("The request has response_mode \"%s\". set but registered OAuth 2.0 client doesn't support response_mode", r.Form.Get("response_mode"))) + } + for _, t := range responseModeClient.GetResponseMode() { + if request.ResponseMode == t { + found = true + break + } + } + if !found { + return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("The client is not allowed to request response_mode \"%s\".", r.Form.Get("response_mode"))) + } } return nil } @@ -245,7 +264,7 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth state := request.Form.Get("state") request.State = state - if err := f.validateResponseMode(r, request); err != nil { + if err := f.ParseResponseMode(r, request); err != nil { return request, err } @@ -255,6 +274,10 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth } request.Client = client + if err := f.validateResponseMode(r, request); err != nil { + return request, err + } + if err := f.authorizeRequestParametersFromOpenIDConnectRequest(request); err != nil { return request, err } diff --git a/authorize_request_handler_test.go b/authorize_request_handler_test.go index 12e0abe4c..da94f1b75 100644 --- a/authorize_request_handler_test.go +++ b/authorize_request_handler_test.go @@ -367,6 +367,107 @@ func TestNewAuthorizeRequest(t *testing.T) { }, }, }, + /* fails because unknown response_mode*/ + { + desc: "should fail because unknown response_mode", + conf: &Fosite{Store: store, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"unknown"}, + }, + mock: func() { + //store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"},ResponseTypes: []string{"code token"}}, nil) + }, + expectedError: ErrUnsupportedResponseMode, + }, + /* fails because response_mode is requested but the OAuth 2.0 client doesn't support response mode */ + { + desc: "should fail because response_mode is requested but the OAuth 2.0 client doesn't support response mode", + conf: &Fosite{Store: store, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}}, nil) + }, + expectedError: ErrUnsupportedResponseMode, + }, + /* fails because requested response mode is not allowed */ + { + desc: "should fail because requested response mode is not allowed", + conf: &Fosite{Store: store, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + }, + ResponseMode: []ResponseModeType{ResponseModeQuery}, + }, nil) + }, + expectedError: ErrUnsupportedResponseMode, + }, + /* success with response mode */ + { + desc: "success with response mode", + conf: &Fosite{Store: store, ScopeStrategy: ExactScopeStrategy, AudienceMatchingStrategy: DefaultAudienceMatchingStrategy}, + query: url.Values{ + "redirect_uri": {"https://foo.bar/cb"}, + "client_id": {"1234"}, + "response_type": {"code token"}, + "state": {"strong-state"}, + "scope": {"foo bar"}, + "response_mode": {"form_post"}, + "audience": {"https://cloud.ory.sh/api https://www.ory.sh/api"}, + }, + mock: func() { + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + ResponseMode: []ResponseModeType{ResponseModePost}, + }, nil) + }, + expect: &AuthorizeRequest{ + RedirectURI: redir, + ResponseTypes: []string{"code", "token"}, + State: "strong-state", + Request: Request{ + Client: &DefaultResponseModeClient{ + DefaultClient: &DefaultClient{ + RedirectURIs: []string{"https://foo.bar/cb"}, + Scopes: []string{"foo", "bar"}, + ResponseTypes: []string{"code token"}, + Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + ResponseMode: []ResponseModeType{ResponseModePost}, + }, + RequestedScope: []string{"foo", "bar"}, + RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, + }, + }, + }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { c.mock() diff --git a/client.go b/client.go index 9a4ee32e1..9f3f18c8d 100644 --- a/client.go +++ b/client.go @@ -80,6 +80,13 @@ type OpenIDConnectClient interface { GetTokenEndpointAuthSigningAlgorithm() string } +// ResponseModeClient represents a client capable of handling response_mode +type ResponseModeClient interface { + + // GetResponseMode returns the response modes that client is allowed to send + GetResponseMode() []ResponseModeType +} + // DefaultClient is a simple default implementation of the Client interface. type DefaultClient struct { ID string `json:"id"` @@ -102,6 +109,11 @@ type DefaultOpenIDConnectClient struct { TokenEndpointAuthSigningAlgorithm string `json:"token_endpoint_auth_signing_alg"` } +type DefaultResponseModeClient struct { + *DefaultClient + ResponseMode []ResponseModeType `json:"response_mode"` +} + func (c *DefaultClient) GetID() string { return c.ID } @@ -177,3 +189,7 @@ func (c *DefaultOpenIDConnectClient) GetTokenEndpointAuthMethod() string { func (c *DefaultOpenIDConnectClient) GetRequestURIs() []string { return c.RequestURIs } + +func (c *DefaultResponseModeClient) GetResponseMode() []ResponseModeType { + return c.ResponseMode +} diff --git a/client_test.go b/client_test.go index e88d70dfa..3c00cd4e0 100644 --- a/client_test.go +++ b/client_test.go @@ -49,3 +49,8 @@ func TestDefaultClient(t *testing.T) { assert.Equal(t, "code", sc.GetResponseTypes()[0]) assert.Equal(t, "authorization_code", sc.GetGrantTypes()[0]) } + +func TestDefaultResponseModeClient_GetResponseMode(t *testing.T) { + rc := &DefaultResponseModeClient{ResponseMode: []ResponseModeType{ResponseModeFragment}} + assert.Equal(t, []ResponseModeType{ResponseModeFragment}, rc.GetResponseMode()) +} diff --git a/errors.go b/errors.go index 96fbf2826..f28bffc7e 100644 --- a/errors.go +++ b/errors.go @@ -71,6 +71,11 @@ var ( Description: "The authorization server does not support obtaining a token using this method", Code: http.StatusBadRequest, } + ErrUnsupportedResponseMode = &RFC6749Error{ + Name: errUnsupportedResponseModeName, + Description: "The authorization server does not support obtaining response using this response mode", + Code: http.StatusBadRequest, + } ErrInvalidScope = &RFC6749Error{ Name: errInvalidScopeName, Description: "The requested scope is invalid, unknown, or malformed", @@ -222,6 +227,7 @@ const ( errUnauthorizedClientName = "unauthorized_client" errAccessDeniedName = "access_denied" errUnsupportedResponseTypeName = "unsupported_response_type" + errUnsupportedResponseModeName = "unsupported_response_mode" errInvalidScopeName = "invalid_scope" errServerErrorName = "server_error" errTemporarilyUnavailableName = "temporarily_unavailable" diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index cae789f5b..229088db4 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -63,7 +63,14 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { defer ts.Close() oauthClient := newOAuth2Client(ts) - fositeStore.Clients["my-client"].(*fosite.DefaultClient).RedirectURIs[0] = ts.URL + "/callback" + defaultClient := fositeStore.Clients["my-client"].(*fosite.DefaultClient) + defaultClient.RedirectURIs[0] = ts.URL + "/callback" + responseModeClient := &fosite.DefaultResponseModeClient{ + DefaultClient: defaultClient, + ResponseMode: []fosite.ResponseModeType{fosite.ResponseModePost}, + } + fositeStore.Clients["response-mode-client"] = responseModeClient + oauthClient.ClientID = "response-mode-client" var state string for k, c := range []struct { From 9ee157e55dee48de1a0885c7b72ea2f4244f26a9 Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Sun, 1 Nov 2020 10:22:22 -0800 Subject: [PATCH 12/22] Validating response_mode for insecure mode --- authorize_request.go | 12 ++++++-- authorize_response_writer.go | 4 +++ authorize_response_writer_test.go | 17 +++++++++++ handler/oauth2/flow_authorize_code_auth.go | 5 ++-- .../oauth2/flow_authorize_code_auth_test.go | 28 ------------------- handler/oauth2/flow_authorize_implicit.go | 6 ++-- handler/openid/flow_hybrid.go | 6 ++-- handler/openid/flow_implicit.go | 4 +-- integration/authorize_form_post_test.go | 26 +++++++++++++---- internal/authorize_request.go | 26 +++++++++++++++++ oauth2.go | 9 ++++-- 11 files changed, 92 insertions(+), 51 deletions(-) diff --git a/authorize_request.go b/authorize_request.go index 7319cc7dd..91a62a657 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -41,6 +41,7 @@ type AuthorizeRequest struct { State string `json:"state" gorethink:"state"` HandledResponseTypes Arguments `json:"handledResponseTypes" gorethink:"handledResponseTypes"` ResponseMode ResponseModeType `json:"ResponseMode" gorethink:"ResponseMode"` + DefaultResponseMode ResponseModeType `json:"DefaultResponseMode" gorethink:"DefaultResponseMode"` Request } @@ -102,6 +103,13 @@ func (d *AuthorizeRequest) GetResponseMode() ResponseModeType { return d.ResponseMode } -func (d *AuthorizeRequest) SetResponseMode(responseMode ResponseModeType) { - d.ResponseMode = responseMode +func (d *AuthorizeRequest) SetDefaultResponseMode(defaultResponseMode ResponseModeType) { + if d.ResponseMode == ResponseModeDefault { + d.ResponseMode = defaultResponseMode + } + d.DefaultResponseMode = defaultResponseMode +} + +func (d *AuthorizeRequest) GetDefaultResponseMode() ResponseModeType { + return d.DefaultResponseMode } diff --git a/authorize_response_writer.go b/authorize_response_writer.go index 5d9163cd9..a34334085 100644 --- a/authorize_response_writer.go +++ b/authorize_response_writer.go @@ -46,5 +46,9 @@ func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester return nil, errors.WithStack(ErrUnsupportedResponseType) } + if ar.GetDefaultResponseMode() == ResponseModeFragment && ar.GetResponseMode() == ResponseModeQuery { + return nil, ErrUnsupportedResponseMode.WithHintf("Insecure response_mode \"%s\" for the response_type \"%s\".", ar.GetResponseMode(), ar.GetResponseTypes()) + } + return resp, nil } diff --git a/authorize_response_writer_test.go b/authorize_response_writer_test.go index 12961c47a..d9a550fa6 100644 --- a/authorize_response_writer_test.go +++ b/authorize_response_writer_test.go @@ -64,6 +64,8 @@ func TestNewAuthorizeResponse(t *testing.T) { mock: func() { handlers[0].EXPECT().HandleAuthorizeEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) ar.EXPECT().DidHandleAllResponseTypes().Return(true) + ar.EXPECT().GetDefaultResponseMode().Return(ResponseModeFragment) + ar.EXPECT().GetResponseMode().Return(ResponseModeDefault) }, isErr: false, }, @@ -73,6 +75,8 @@ func TestNewAuthorizeResponse(t *testing.T) { handlers[0].EXPECT().HandleAuthorizeEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) handlers[0].EXPECT().HandleAuthorizeEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) ar.EXPECT().DidHandleAllResponseTypes().Return(true) + ar.EXPECT().GetDefaultResponseMode().Return(ResponseModeFragment) + ar.EXPECT().GetResponseMode().Return(ResponseModeDefault) }, isErr: false, }, @@ -85,6 +89,19 @@ func TestNewAuthorizeResponse(t *testing.T) { isErr: true, expectErr: fooErr, }, + { + mock: func() { + oauth2 = duo + handlers[0].EXPECT().HandleAuthorizeEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + handlers[0].EXPECT().HandleAuthorizeEndpointRequest(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil) + ar.EXPECT().DidHandleAllResponseTypes().Return(true) + ar.EXPECT().GetDefaultResponseMode().Return(ResponseModeFragment) + ar.EXPECT().GetResponseMode().Return(ResponseModeQuery).Times(2) + ar.EXPECT().GetResponseTypes().Return([]string{"token", "code"}) + }, + isErr: true, + expectErr: ErrUnsupportedResponseMode.WithHintf("Insecure response_mode \"%s\" for the response_type \"%s\".", ResponseModeQuery, []string{"token", "code"}), + }, } { c.mock() responder, err := oauth2.NewAuthorizeResponse(ctx, ar, new(DefaultSession)) diff --git a/handler/oauth2/flow_authorize_code_auth.go b/handler/oauth2/flow_authorize_code_auth.go index 76e5538db..b3976068d 100644 --- a/handler/oauth2/flow_authorize_code_auth.go +++ b/handler/oauth2/flow_authorize_code_auth.go @@ -77,9 +77,8 @@ func (c *AuthorizeExplicitGrantHandler) HandleAuthorizeEndpointRequest(ctx conte return nil } - if ar.GetResponseMode() == fosite.ResponseModeDefault { - ar.SetResponseMode(fosite.ResponseModeQuery) - } + ar.SetDefaultResponseMode(fosite.ResponseModeQuery) + // Disabled because this is already handled at the authorize_request_handler // if !ar.GetClient().GetResponseTypes().Has("code") { // return errors.WithStack(fosite.ErrInvalidGrant) diff --git a/handler/oauth2/flow_authorize_code_auth_test.go b/handler/oauth2/flow_authorize_code_auth_test.go index 63d9916ee..5c9c20620 100644 --- a/handler/oauth2/flow_authorize_code_auth_test.go +++ b/handler/oauth2/flow_authorize_code_auth_test.go @@ -125,34 +125,6 @@ func TestAuthorizeCode_HandleAuthorizeEndpointRequest(t *testing.T) { code := aresp.GetParameters().Get("code") assert.NotEmpty(t, code) - assert.Equal(t, strings.Join(areq.GrantedScope, " "), aresp.GetParameters().Get("scope")) - assert.Equal(t, areq.State, aresp.GetParameters().Get("state")) - }, - }, - { - areq: &fosite.AuthorizeRequest{ - ResponseTypes: fosite.Arguments{"code"}, - Request: fosite.Request{ - Client: &fosite.DefaultClient{ - ResponseTypes: fosite.Arguments{"code"}, - RedirectURIs: []string{"https://asdf.de/cb"}, - Audience: []string{"https://www.ory.sh/api"}, - }, - RequestedAudience: []string{"https://www.ory.sh/api"}, - GrantedScope: fosite.Arguments{"a", "b"}, - Session: &fosite.DefaultSession{ - ExpiresAt: map[fosite.TokenType]time.Time{fosite.AccessToken: time.Now().UTC().Add(time.Hour)}, - }, - RequestedAt: time.Now().UTC(), - }, - State: "superstate", - RedirectURI: parseUrl("https://asdf.de/cb"), - }, - description: "Default responseMode check", - expect: func(t *testing.T, areq *fosite.AuthorizeRequest, aresp *fosite.AuthorizeResponse) { - code := aresp.GetParameters().Get("code") - assert.NotEmpty(t, code) - assert.Equal(t, strings.Join(areq.GrantedScope, " "), aresp.GetParameters().Get("scope")) assert.Equal(t, areq.State, aresp.GetParameters().Get("state")) assert.Equal(t, fosite.ResponseModeQuery, areq.GetResponseMode()) diff --git a/handler/oauth2/flow_authorize_implicit.go b/handler/oauth2/flow_authorize_implicit.go index 6f287079e..7a9313047 100644 --- a/handler/oauth2/flow_authorize_implicit.go +++ b/handler/oauth2/flow_authorize_implicit.go @@ -52,9 +52,9 @@ func (c *AuthorizeImplicitGrantTypeHandler) HandleAuthorizeEndpointRequest(ctx c if !ar.GetResponseTypes().ExactOne("token") { return nil } - if ar.GetResponseMode() == fosite.ResponseModeDefault { - ar.SetResponseMode(fosite.ResponseModeFragment) - } + + ar.SetDefaultResponseMode(fosite.ResponseModeFragment) + // Disabled because this is already handled at the authorize_request_handler // if !ar.GetClient().GetResponseTypes().Has("token") { // return errors.WithStack(fosite.ErrInvalidGrant.WithDebug("The client is not allowed to use response type token")) diff --git a/handler/openid/flow_hybrid.go b/handler/openid/flow_hybrid.go index a6b40b24b..4a08bff41 100644 --- a/handler/openid/flow_hybrid.go +++ b/handler/openid/flow_hybrid.go @@ -54,9 +54,9 @@ func (c *OpenIDConnectHybridHandler) HandleAuthorizeEndpointRequest(ctx context. if !(ar.GetResponseTypes().Matches("token", "id_token", "code") || ar.GetResponseTypes().Matches("token", "code") || ar.GetResponseTypes().Matches("id_token", "code")) { return nil } - if ar.GetResponseMode() == fosite.ResponseModeDefault { - ar.SetResponseMode(fosite.ResponseModeFragment) - } + + ar.SetDefaultResponseMode(fosite.ResponseModeFragment) + // Disabled because this is already handled at the authorize_request_handler //if ar.GetResponseTypes().Matches("token") && !ar.GetClient().GetResponseTypes().Has("token") { // return errors.WithStack(fosite.ErrInvalidGrant.WithDebug("The client is not allowed to use the token response type")) diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index 79708aa9c..2c4c36d6f 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -51,9 +51,7 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex return nil } - if ar.GetResponseMode() == fosite.ResponseModeDefault { - ar.SetResponseMode(fosite.ResponseModeFragment) - } + ar.SetDefaultResponseMode(fosite.ResponseModeFragment) if !ar.GetClient().GetGrantTypes().Has("implicit") { return errors.WithStack(fosite.ErrInvalidGrant.WithHint("The OAuth 2.0 Client is not allowed to use the authorization grant \"implicit\".")) diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index 229088db4..6ee74542e 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -80,20 +80,34 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { responseType string }{ { - description: "implicit grant test with form_post", - responseType: "token", + description: "implicit grant #1 test with form_post", + responseType: "id_token%20token", setup: func() { state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, token.TokenType) assert.NotEmpty(t, token.AccessToken) assert.NotEmpty(t, token.Expiry) + assert.NotEmpty(t, iDToken) + }, + }, + { + description: "implicit grant #2 test with form_post", + responseType: "id_token", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, iDToken) }, }, { - description: "explicit grant test with form_post", + description: "Authorization code grant test with form_post", responseType: "code", setup: func() { state = "12345678901234567890" @@ -104,7 +118,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { }, }, { - description: "oidc grant test with form_post", + description: "Hybrid #1 grant test with form_post", responseType: "token%20code", setup: func() { state = "12345678901234567890" @@ -119,7 +133,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { }, }, { - description: "hybrid grant test with form_post", + description: "Hybrid #2 grant test with form_post", responseType: "token%20id_token%20code", setup: func() { state = "12345678901234567890" @@ -135,7 +149,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { }, }, { - description: "hybrid grant test with form_post", + description: "Hybrid #3 grant test with form_post", responseType: "id_token%20code", setup: func() { state = "12345678901234567890" diff --git a/internal/authorize_request.go b/internal/authorize_request.go index 7e620679e..a2b6407a0 100644 --- a/internal/authorize_request.go +++ b/internal/authorize_request.go @@ -76,6 +76,20 @@ func (mr *MockAuthorizeRequesterMockRecorder) GetClient() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClient", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetClient)) } +// GetDefaultResponseMode mocks base method +func (m *MockAuthorizeRequester) GetDefaultResponseMode() fosite.ResponseModeType { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetDefaultResponseMode") + ret0, _ := ret[0].(fosite.ResponseModeType) + return ret0 +} + +// GetDefaultResponseMode indicates an expected call of GetDefaultResponseMode +func (mr *MockAuthorizeRequesterMockRecorder) GetDefaultResponseMode() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetDefaultResponseMode)) +} + // GetGrantedAudience mocks base method func (m *MockAuthorizeRequester) GetGrantedAudience() fosite.Arguments { m.ctrl.T.Helper() @@ -308,6 +322,18 @@ func (mr *MockAuthorizeRequesterMockRecorder) Sanitize(arg0 interface{}) *gomock return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sanitize", reflect.TypeOf((*MockAuthorizeRequester)(nil).Sanitize), arg0) } +// SetDefaultResponseMode mocks base method +func (m *MockAuthorizeRequester) SetDefaultResponseMode(arg0 fosite.ResponseModeType) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "SetDefaultResponseMode", arg0) +} + +// SetDefaultResponseMode indicates an expected call of SetDefaultResponseMode +func (mr *MockAuthorizeRequesterMockRecorder) SetDefaultResponseMode(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetDefaultResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).SetDefaultResponseMode), arg0) +} + // SetID mocks base method func (m *MockAuthorizeRequester) SetID(arg0 string) { m.ctrl.T.Helper() diff --git a/oauth2.go b/oauth2.go index 05eafb168..66316d80c 100644 --- a/oauth2.go +++ b/oauth2.go @@ -268,11 +268,14 @@ type AuthorizeRequester interface { // GetState returns the request's state. GetState() (state string) - //GetResponseMode returns response_mode of the authorization request + // GetResponseMode returns response_mode of the authorization request GetResponseMode() ResponseModeType - //SetResponseMode sets response mode of the authorization request - SetResponseMode(responseMode ResponseModeType) + // SetDefaultResponseMode sets default response mode for a response type in a flow + SetDefaultResponseMode(responseMode ResponseModeType) + + // GetDefaultResponseMode gets default response mode for a response type in a flow + GetDefaultResponseMode() ResponseModeType Requester } From 9ed9df8fd03a0eae0d6e05385a6f417efab1876d Mon Sep 17 00:00:00 2001 From: Ajanthan Balachandran Date: Sun, 1 Nov 2020 15:19:29 -0800 Subject: [PATCH 13/22] Adding test cases for none default response modes --- authorize_error.go | 2 +- authorize_error_test.go | 2 +- authorize_request.go | 2 +- authorize_request_handler.go | 12 +- authorize_request_handler_test.go | 4 +- authorize_write.go | 2 +- authorize_write_test.go | 2 +- handler/openid/flow_implicit_test.go | 16 +- integration/authorize_form_post_test.go | 10 +- integration/authorize_response_mode_test.go | 257 ++++++++++++++++++++ internal/test_helpers.go | 2 + 11 files changed, 278 insertions(+), 33 deletions(-) create mode 100644 integration/authorize_response_mode_test.go diff --git a/authorize_error.go b/authorize_error.go index e4fc51fc5..819222ce2 100644 --- a/authorize_error.go +++ b/authorize_error.go @@ -64,7 +64,7 @@ func (f *Fosite) WriteAuthorizeError(rw http.ResponseWriter, ar AuthorizeRequest query.Add("state", ar.GetState()) var redirectURIString string - if ar.GetResponseMode() == ResponseModePost { + if ar.GetResponseMode() == ResponseModeFormPost { rw.Header().Add("Content-Type", "text/html;charset=UTF-8") WriteAuthorizeFormPostResponse(redirectURI.String(), query, GetPostFormHTMLTemplate(*f), rw) return diff --git a/authorize_error_test.go b/authorize_error_test.go index ae16a1631..6cf20171e 100644 --- a/authorize_error_test.go +++ b/authorize_error_test.go @@ -301,7 +301,7 @@ func TestWriteAuthorizeError(t *testing.T) { req.EXPECT().GetRedirectURI().Return(copyUrl(purls[1])) req.EXPECT().GetState().Return("foostate") req.EXPECT().GetResponseTypes().MaxTimes(2).Return(Arguments([]string{"token"})) - req.EXPECT().GetResponseMode().Return(ResponseModePost).Times(1) + req.EXPECT().GetResponseMode().Return(ResponseModeFormPost).Times(1) rw.EXPECT().Header().Times(3).Return(header) rw.EXPECT().Write(gomock.Any()).AnyTimes() }, diff --git a/authorize_request.go b/authorize_request.go index 91a62a657..791f0dc38 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -29,7 +29,7 @@ type ResponseModeType string const ( ResponseModeDefault = ResponseModeType("") - ResponseModePost = ResponseModeType("form_post") + ResponseModeFormPost = ResponseModeType("form_post") ResponseModeQuery = ResponseModeType("query") ResponseModeFragment = ResponseModeType("fragment") ) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index e94b2346e..8beea3603 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -220,8 +220,8 @@ func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) e request.ResponseMode = ResponseModeFragment case string(ResponseModeQuery): request.ResponseMode = ResponseModeQuery - case string(ResponseModePost): - request.ResponseMode = ResponseModePost + case string(ResponseModeFormPost): + request.ResponseMode = ResponseModeFormPost default: return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", responseMode)) } @@ -274,10 +274,6 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth } request.Client = client - if err := f.validateResponseMode(r, request); err != nil { - return request, err - } - if err := f.authorizeRequestParametersFromOpenIDConnectRequest(request); err != nil { return request, err } @@ -302,6 +298,10 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth return request, err } + if err := f.validateResponseMode(r, request); err != nil { + return request, err + } + // rfc6819 4.4.1.8. Threat: CSRF Attack against redirect-uri // The "state" parameter should be used to link the authorization // request with the redirect URI used to deliver the access token (Section 5.3.5). diff --git a/authorize_request_handler_test.go b/authorize_request_handler_test.go index da94f1b75..06eea0cce 100644 --- a/authorize_request_handler_test.go +++ b/authorize_request_handler_test.go @@ -446,7 +446,7 @@ func TestNewAuthorizeRequest(t *testing.T) { ResponseTypes: []string{"code token"}, Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, }, - ResponseMode: []ResponseModeType{ResponseModePost}, + ResponseMode: []ResponseModeType{ResponseModeFormPost}, }, nil) }, expect: &AuthorizeRequest{ @@ -461,7 +461,7 @@ func TestNewAuthorizeRequest(t *testing.T) { ResponseTypes: []string{"code token"}, Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, }, - ResponseMode: []ResponseModeType{ResponseModePost}, + ResponseMode: []ResponseModeType{ResponseModeFormPost}, }, RequestedScope: []string{"foo", "bar"}, RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, diff --git a/authorize_write.go b/authorize_write.go index 97fbab50c..b9c7149fc 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -40,7 +40,7 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ redir := ar.GetRedirectURI() switch ar.GetResponseMode() { - case ResponseModePost: + case ResponseModeFormPost: //form_post rw.Header().Add("Content-Type", "text/html;charset=UTF-8") WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), GetPostFormHTMLTemplate(*f), rw) diff --git a/authorize_write_test.go b/authorize_write_test.go index de37b83f4..3d92d3dec 100644 --- a/authorize_write_test.go +++ b/authorize_write_test.go @@ -170,7 +170,7 @@ func TestWriteAuthorizeResponse(t *testing.T) { setup: func() { redir, _ := url.Parse("https://foobar.com/?foo=bar") ar.EXPECT().GetRedirectURI().Return(redir) - ar.EXPECT().GetResponseMode().Return(ResponseModePost) + ar.EXPECT().GetResponseMode().Return(ResponseModeFormPost) resp.EXPECT().GetHeader().Return(http.Header{"X-Bar": {"baz"}}) resp.EXPECT().GetParameters().Return(url.Values{"code": {"poz65kqoneu"}, "state": {"qm6dnsrn"}}) diff --git a/handler/openid/flow_implicit_test.go b/handler/openid/flow_implicit_test.go index b9d870563..4a88d4ae0 100644 --- a/handler/openid/flow_implicit_test.go +++ b/handler/openid/flow_implicit_test.go @@ -232,6 +232,7 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) assert.NotEmpty(t, aresp.GetParameters().Get("state")) assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) + assert.Equal(t, fosite.ResponseModeFragment, areq.GetResponseMode()) }, }, { @@ -246,21 +247,6 @@ func TestImplicit_HandleAuthorizeEndpointRequest(t *testing.T) { assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) }, }, - { - description: "default responseMode check", - setup: func() OpenIDConnectImplicitHandler { - areq.Form.Set("nonce", "some-random-foo-nonce-wow") - areq.ResponseTypes = fosite.Arguments{"id_token", "token"} - areq.RequestedScope = fosite.Arguments{"fosite", "openid"} - return makeOpenIDConnectImplicitHandler(fosite.MinParameterEntropy) - }, - check: func() { - assert.NotEmpty(t, aresp.GetParameters().Get("id_token")) - assert.NotEmpty(t, aresp.GetParameters().Get("state")) - assert.NotEmpty(t, aresp.GetParameters().Get("access_token")) - assert.Equal(t, fosite.ResponseModeFragment, areq.GetResponseMode()) - }, - }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { h := c.setup() diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index 6ee74542e..87122f41e 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -41,15 +41,15 @@ import ( "github.com/ory/fosite/handler/oauth2" ) -func TestAuthorizeFormPostImplicitFlow(t *testing.T) { +func TestAuthorizeFormPostResponseMode(t *testing.T) { for _, strategy := range []oauth2.AccessTokenStrategy{ hmacStrategy, } { - runTestAuthorizeFormPostImplicitGrant(t, strategy) + runTestAuthorizeFormPostResponseMode(t, strategy) } } -func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { +func runTestAuthorizeFormPostResponseMode(t *testing.T, strategy interface{}) { session := &defaultSession{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ @@ -67,7 +67,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { defaultClient.RedirectURIs[0] = ts.URL + "/callback" responseModeClient := &fosite.DefaultResponseModeClient{ DefaultClient: defaultClient, - ResponseMode: []fosite.ResponseModeType{fosite.ResponseModePost}, + ResponseMode: []fosite.ResponseModeType{fosite.ResponseModeFormPost}, } fositeStore.Clients["response-mode-client"] = responseModeClient oauthClient.ClientID = "response-mode-client" @@ -185,7 +185,7 @@ func runTestAuthorizeFormPostImplicitGrant(t *testing.T, strategy interface{}) { resp, err := client.Get(authURL) require.NoError(t, err) require.Equal(t, http.StatusOK, resp.StatusCode) - code, state, token, iDToken, _, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["my-client"].GetRedirectURIs()[0], resp.Body) + code, state, token, iDToken, _, errResp, err := internal.ParseFormPostResponse(fositeStore.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body) require.NoError(t, err) c.check(t, state, code, iDToken, token, errResp) }) diff --git a/integration/authorize_response_mode_test.go b/integration/authorize_response_mode_test.go new file mode 100644 index 000000000..6a6aa16fd --- /dev/null +++ b/integration/authorize_response_mode_test.go @@ -0,0 +1,257 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +package integration_test + +import ( + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/internal" + "github.com/ory/fosite/token/jwt" + "github.com/pkg/errors" + + "github.com/stretchr/testify/require" + goauth "golang.org/x/oauth2" + + "github.com/ory/fosite" + "github.com/ory/fosite/compose" + "github.com/ory/fosite/handler/oauth2" +) + +func TestAuthorizeResponseMode(t *testing.T) { + for _, strategy := range []oauth2.AccessTokenStrategy{ + hmacStrategy, + } { + runTestAuthorizeResponseMode(t, strategy) + } +} + +func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { + session := &defaultSession{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "peter", + }, + Headers: &jwt.Headers{}, + }, + } + f := compose.ComposeAllEnabled(new(compose.Config), fositeStore, []byte("some-secret-thats-random-some-secret-thats-random-"), internal.MustRSAKey()) + ts := mockServer(t, f, session) + defer ts.Close() + + oauthClient := newOAuth2Client(ts) + defaultClient := fositeStore.Clients["my-client"].(*fosite.DefaultClient) + defaultClient.RedirectURIs[0] = ts.URL + "/callback" + responseModeClient := &fosite.DefaultResponseModeClient{ + DefaultClient: defaultClient, + ResponseMode: []fosite.ResponseModeType{}, + } + fositeStore.Clients["response-mode-client"] = responseModeClient + oauthClient.ClientID = "response-mode-client" + + var state string + for k, c := range []struct { + description string + setup func() + check func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) + responseType string + responseMode string + }{ + { + description: "Should give err because implicit grant with response mode query", + responseType: "id_token%20token", + responseMode: "query", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeQuery} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.NotEmpty(t, err["Name"]) + assert.NotEmpty(t, err["Description"]) + assert.Equal(t, "Insecure response_mode \"query\" for the response_type \"[id_token token]\".", err["Hint"]) + + }, + }, + { + description: "Should pass implicit grant with response mode form_post", + responseType: "id_token%20token", + responseMode: "form_post", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFormPost} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, token.TokenType) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.Expiry) + assert.NotEmpty(t, iDToken) + + }, + }, + { + description: "Should fail because response mode form_post is not allowed by the client", + responseType: "id_token%20token", + responseMode: "form_post", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeQuery} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.NotEmpty(t, err["Name"]) + assert.NotEmpty(t, err["Description"]) + assert.Equal(t, "The client is not allowed to request response_mode \"form_post\".", err["Hint"]) + }, + }, + { + description: "Should pass Authorization code grant test with response mode fragment", + responseType: "code", + responseMode: "fragment", + setup: func() { + state = "12345678901234567890" + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFragment} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + }, + }, + { + description: "Should pass Authorization code grant test with response mode form_post", + responseType: "code", + responseMode: "form_post", + setup: func() { + state = "12345678901234567890" + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFormPost} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + }, + }, + { + description: "Should fail Hybrid grant test with query", + responseType: "token%20code", + responseMode: "query", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeQuery} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + //assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, err["Name"]) + assert.NotEmpty(t, err["Description"]) + assert.Equal(t, "Insecure response_mode \"query\" for the response_type \"[token code]\".", err["Hint"]) + }, + }, + { + description: "Should pass Hybrid grant test with form_post", + responseType: "token%20code", + responseMode: "form_post", + setup: func() { + state = "12345678901234567890" + oauthClient.Scopes = []string{"openid"} + responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFormPost} + }, + check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { + assert.EqualValues(t, state, stateFromServer) + assert.NotEmpty(t, code) + assert.NotEmpty(t, token.TokenType) + assert.NotEmpty(t, token.AccessToken) + assert.NotEmpty(t, token.Expiry) + }, + }, + } { + t.Run(fmt.Sprintf("case=%d/description=%s", k, c.description), func(t *testing.T) { + c.setup() + authURL := strings.Replace(oauthClient.AuthCodeURL(state, goauth.SetAuthURLParam("response_mode", c.responseMode), goauth.SetAuthURLParam("nonce", "111111111")), "response_type=code", "response_type="+c.responseType, -1) + var callbackURL *url.URL + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + callbackURL = req.URL + return errors.New("Dont follow redirects") + }, + } + + var code, state, iDToken string + var token goauth.Token + var errResp map[string]string + + resp, err := client.Get(authURL) + if callbackURL != nil { + if fosite.ResponseModeType(c.responseMode) == fosite.ResponseModeFragment { + require.Error(t, err) + //fragment + fragment, err := url.ParseQuery(callbackURL.Fragment) + require.NoError(t, err) + code, state, iDToken, token, errResp = getParameters(t, fragment) + } else if fosite.ResponseModeType(c.responseMode) == fosite.ResponseModeQuery { + require.Error(t, err) + //query + query, err := url.ParseQuery(callbackURL.RawQuery) + require.NoError(t, err) + code, state, iDToken, token, errResp = getParameters(t, query) + } + } + if fosite.ResponseModeType(c.responseMode) == fosite.ResponseModeFormPost && resp.Body != nil { + //form_post + code, state, iDToken, token, _, errResp, err = internal.ParseFormPostResponse(fositeStore.Clients["response-mode-client"].GetRedirectURIs()[0], resp.Body) + } + c.check(t, state, code, token, iDToken, errResp) + }) + } +} +func getParameters(t *testing.T, param url.Values) (code, state, iDToken string, token goauth.Token, errResp map[string]string) { + errResp = make(map[string]string) + if param.Get("error") != "" { + errResp["Name"] = param.Get("error") + errResp["Description"] = param.Get("error_description") + errResp["Hint"] = param.Get("error_hint") + } else { + code = param.Get("code") + state = param.Get("state") + iDToken = param.Get("id_token") + token = goauth.Token{ + AccessToken: param.Get("access_token"), + TokenType: param.Get("token_type"), + RefreshToken: param.Get("refresh_token"), + } + if param.Get("expires_in") != "" { + expires, err := strconv.Atoi(param.Get("expires_in")) + require.NoError(t, err) + token.Expiry = time.Now().UTC().Add(time.Duration(expires) * time.Second) + } + } + return +} diff --git a/internal/test_helpers.go b/internal/test_helpers.go index 42f10c20f..91c75dcaa 100644 --- a/internal/test_helpers.go +++ b/internal/test_helpers.go @@ -101,6 +101,8 @@ func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizatio token.RefreshToken = v case "error": rFC6749Error["Name"] = v + case "error_hint": + rFC6749Error["Hint"] = v case "error_description": rFC6749Error["Description"] = v case "id_token": From eee54737fb2b7aa2f0e86987ee68407f17dd04aa Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 14:05:40 +0100 Subject: [PATCH 14/22] fix: respect request object --- authorize_request_handler.go | 16 +++++++++++----- authorize_request_handler_oidc_request_test.go | 10 +++++----- authorize_request_handler_test.go | 11 +++++++---- go.mod | 1 + 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index fdd1f24f3..02e048db4 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -274,26 +274,32 @@ func (f *Fosite) NewAuthorizeRequest(ctx context.Context, r *http.Request) (Auth if err := r.ParseMultipartForm(1 << 20); err != nil && err != http.ErrNotMultipart { return request, errors.WithStack(ErrInvalidRequest.WithHint("Unable to parse HTTP body, make sure to send a properly formatted form request body.").WithCause(err).WithDebug(err.Error())) } - request.Form = r.Form // Save state to the request to be returned in error conditions (https://github.com/ory/hydra/issues/1642) request.State = request.Form.Get("state") - if err := f.ParseResponseMode(r, request); err != nil { - return request, err - } - client, err := f.Store.GetClient(ctx, request.GetRequestForm().Get("client_id")) if err != nil { return request, errors.WithStack(ErrInvalidClient.WithHint("The requested OAuth 2.0 Client does not exist.").WithCause(err).WithDebug(err.Error())) } request.Client = client + // Now that the base fields (state and client) are populated, we extract all the information + // from the request object or request object uri, if one is set. + // + // All other parse methods should come afterwards so that we ensure that the data is taken + // from the request_object if set. if err := f.authorizeRequestParametersFromOpenIDConnectRequest(request); err != nil { return request, err } + // The request context is now fully available and we can start processing the individual + // fields. + if err := f.ParseResponseMode(r, request); err != nil { + return request, err + } + if err := f.validateAuthorizeRedirectURI(r, request); err != nil { return request, err } diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 26e129a5d..1ecc8c5bb 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -64,7 +64,6 @@ func mustGenerateNoneAssertion(t *testing.T, claims jwt.MapClaims) string { } func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { - key, err := rsa.GenerateKey(rand.Reader, 1024) if err != nil { panic(err) @@ -79,7 +78,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { }, } - validRequestObject := mustGenerateAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz"}, key, "kid-foo") + validRequestObject := mustGenerateAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz", "response_type": "token", "response_mode": "post_form"}, key, "kid-foo") validRequestObjectWithoutKid := mustGenerateAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz"}, key, "") validNoneRequestObject := mustGenerateNoneAssertion(t, jwt.MapClaims{"scope": "foo", "foo": "bar", "baz": "baz", "state": "some-state"}) @@ -167,9 +166,10 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { }, { d: "should pass and set request parameters properly", - form: url.Values{"scope": {"openid"}, "request": {validRequestObject}}, + form: url.Values{"scope": {"openid"}, "response_type": {"code"}, "response_mode": {"none"}, "request": {validRequestObject}}, client: &DefaultOpenIDConnectClient{JSONWebKeys: jwks, RequestObjectSigningAlgorithm: "RS256"}, - expectForm: url.Values{"scope": {"foo openid"}, "request": {validRequestObject}, "foo": {"bar"}, "baz": {"baz"}}, + // The values from form are overwritten by the request object. + expectForm: url.Values{"response_type": {"token"}, "response_mode": {"post_form"}, "scope": {"foo openid"}, "request": {validRequestObject}, "foo": {"bar"}, "baz": {"baz"}}, }, { d: "should pass even if kid is unset", @@ -188,7 +188,7 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { d: "should pass and set request_uri parameters properly and also fetch jwk from remote", form: url.Values{"scope": {"openid"}, "request_uri": {reqTS.URL}}, client: &DefaultOpenIDConnectClient{JSONWebKeysURI: reqJWK.URL, RequestObjectSigningAlgorithm: "RS256", RequestURIs: []string{reqTS.URL}}, - expectForm: url.Values{"scope": {"foo openid"}, "request_uri": {reqTS.URL}, "foo": {"bar"}, "baz": {"baz"}}, + expectForm: url.Values{"response_type": {"token"}, "response_mode": {"post_form"}, "scope": {"foo openid"}, "request_uri": {reqTS.URL}, "foo": {"bar"}, "baz": {"baz"}}, }, { d: "should pass when request object uses algorithm none", diff --git a/authorize_request_handler_test.go b/authorize_request_handler_test.go index 06eea0cce..28623c21e 100644 --- a/authorize_request_handler_test.go +++ b/authorize_request_handler_test.go @@ -44,9 +44,7 @@ import ( // If a Response Type contains one of more space characters (%20), it is compared as a space-delimited list of // values in which the order of values does not matter. func TestNewAuthorizeRequest(t *testing.T) { - ctrl := gomock.NewController(t) - store := NewMockStorage(ctrl) - defer ctrl.Finish() + var store *MockStorage redir, _ := url.Parse("https://foo.bar/cb") specialCharRedir, _ := url.Parse("web+application://callback") @@ -380,7 +378,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "response_mode": {"unknown"}, }, mock: func() { - //store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"},ResponseTypes: []string{"code token"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"},ResponseTypes: []string{"code token"}}, nil) }, expectedError: ErrUnsupportedResponseMode, }, @@ -470,6 +468,10 @@ func TestNewAuthorizeRequest(t *testing.T) { }, } { t.Run(fmt.Sprintf("case=%d", k), func(t *testing.T) { + ctrl := gomock.NewController(t) + store = NewMockStorage(ctrl) + defer ctrl.Finish() + c.mock() if c.r == nil { c.r = &http.Request{Header: http.Header{}} @@ -478,6 +480,7 @@ func TestNewAuthorizeRequest(t *testing.T) { } } + c.conf.Store = store ar, err := c.conf.NewAuthorizeRequest(context.Background(), c.r) if c.expectedError != nil { assert.EqualError(t, err, c.expectedError.Error()) diff --git a/go.mod b/go.mod index 08aebe542..38d8ae1be 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/spf13/afero v1.3.2 // indirect github.com/stretchr/testify v1.5.1 golang.org/x/crypto v0.0.0-20200709230013-948cd5f35899 + golang.org/x/net v0.0.0-20200625001655-4c5254603344 golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d golang.org/x/sys v0.0.0-20200720211630-cb9d2d5c5666 // indirect golang.org/x/text v0.3.3 // indirect From 69b2208aaca096f239fa4ac98564e8f60a3b1b86 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 15:07:19 +0100 Subject: [PATCH 15/22] u --- authorize_helper.go | 1 + authorize_request_handler.go | 39 ++++++++++++--------- authorize_response_writer.go | 2 +- authorize_response_writer_test.go | 2 +- integration/authorize_response_mode_test.go | 6 ++-- 5 files changed, 27 insertions(+), 23 deletions(-) diff --git a/authorize_helper.go b/authorize_helper.go index d5382ad4a..79b84a371 100644 --- a/authorize_helper.go +++ b/authorize_helper.go @@ -210,6 +210,7 @@ func WriteAuthorizeFormPostResponse(redirectURL string, parameters url.Values, t Parameters: parameters, }) } + func URLSetFragment(source *url.URL, fragment url.Values) { var f string for k, v := range fragment { diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 02e048db4..6d26a023e 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -228,9 +228,7 @@ func (f *Fosite) validateResponseTypes(r *http.Request, request *AuthorizeReques return nil } func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) error { - responseMode := r.Form.Get("response_mode") - - switch responseMode { + switch responseMode := r.Form.Get("response_mode"); responseMode { case string(ResponseModeDefault): request.ResponseMode = ResponseModeDefault case string(ResponseModeFragment): @@ -242,25 +240,32 @@ func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) e default: return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("Request with unsupported response_mode \"%s\".", responseMode)) } + return nil } + func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest) error { - if request.ResponseMode != ResponseModeDefault { - var found bool - responseModeClient, ok := request.GetClient().(ResponseModeClient) - if !ok { - return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("The request has response_mode \"%s\". set but registered OAuth 2.0 client doesn't support response_mode", r.Form.Get("response_mode"))) - } - for _, t := range responseModeClient.GetResponseMode() { - if request.ResponseMode == t { - found = true - break - } - } - if !found { - return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("The client is not allowed to request response_mode \"%s\".", r.Form.Get("response_mode"))) + if request.ResponseMode == ResponseModeDefault { + return nil + } + + responseModeClient, ok := request.GetClient().(ResponseModeClient) + if !ok { + return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("The request has response_mode \"%s\". set but registered OAuth 2.0 client doesn't support response_mode", r.Form.Get("response_mode"))) + } + + var found bool + for _, t := range responseModeClient.GetResponseMode() { + if request.ResponseMode == t { + found = true + break } } + + if !found { + return errors.WithStack(ErrUnsupportedResponseMode.WithHintf("The client is not allowed to request response_mode \"%s\".", r.Form.Get("response_mode"))) + } + return nil } diff --git a/authorize_response_writer.go b/authorize_response_writer.go index a34334085..5de01b53c 100644 --- a/authorize_response_writer.go +++ b/authorize_response_writer.go @@ -47,7 +47,7 @@ func (f *Fosite) NewAuthorizeResponse(ctx context.Context, ar AuthorizeRequester } if ar.GetDefaultResponseMode() == ResponseModeFragment && ar.GetResponseMode() == ResponseModeQuery { - return nil, ErrUnsupportedResponseMode.WithHintf("Insecure response_mode \"%s\" for the response_type \"%s\".", ar.GetResponseMode(), ar.GetResponseTypes()) + return nil, ErrUnsupportedResponseMode.WithHintf("Insecure response_mode '%s' for the response_type '%s'.", ar.GetResponseMode(), ar.GetResponseTypes()) } return resp, nil diff --git a/authorize_response_writer_test.go b/authorize_response_writer_test.go index d9a550fa6..650e3fa29 100644 --- a/authorize_response_writer_test.go +++ b/authorize_response_writer_test.go @@ -100,7 +100,7 @@ func TestNewAuthorizeResponse(t *testing.T) { ar.EXPECT().GetResponseTypes().Return([]string{"token", "code"}) }, isErr: true, - expectErr: ErrUnsupportedResponseMode.WithHintf("Insecure response_mode \"%s\" for the response_type \"%s\".", ResponseModeQuery, []string{"token", "code"}), + expectErr: ErrUnsupportedResponseMode.WithHintf("Insecure response_mode '%s' for the response_type '%s'.", ResponseModeQuery, []string{"token", "code"}), }, } { c.mock() diff --git a/integration/authorize_response_mode_test.go b/integration/authorize_response_mode_test.go index 6a6aa16fd..4592565d4 100644 --- a/integration/authorize_response_mode_test.go +++ b/integration/authorize_response_mode_test.go @@ -96,8 +96,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.NotEmpty(t, err["Name"]) assert.NotEmpty(t, err["Description"]) - assert.Equal(t, "Insecure response_mode \"query\" for the response_type \"[id_token token]\".", err["Hint"]) - + assert.Equal(t, "Insecure response_mode 'query' for the response_type '[id_token token]'.", err["Hint"]) }, }, { @@ -115,7 +114,6 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { assert.NotEmpty(t, token.AccessToken) assert.NotEmpty(t, token.Expiry) assert.NotEmpty(t, iDToken) - }, }, { @@ -172,7 +170,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { //assert.EqualValues(t, state, stateFromServer) assert.NotEmpty(t, err["Name"]) assert.NotEmpty(t, err["Description"]) - assert.Equal(t, "Insecure response_mode \"query\" for the response_type \"[token code]\".", err["Hint"]) + assert.Equal(t, "Insecure response_mode 'query' for the response_type '[token code]'.", err["Hint"]) }, }, { From 4b29d153ab4f9fb0f16dde87c9c3c736cfe9b432 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 15:08:04 +0100 Subject: [PATCH 16/22] u --- authorize_request_handler.go | 1 + 1 file changed, 1 insertion(+) diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 6d26a023e..fb89bfd6e 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -227,6 +227,7 @@ func (f *Fosite) validateResponseTypes(r *http.Request, request *AuthorizeReques request.ResponseTypes = responseTypes return nil } + func (f *Fosite) ParseResponseMode(r *http.Request, request *AuthorizeRequest) error { switch responseMode := r.Form.Get("response_mode"); responseMode { case string(ResponseModeDefault): From fefcdebdf335569e7c10878e131563ccaa90d919 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 15:21:55 +0100 Subject: [PATCH 17/22] u --- authorize_request.go | 2 +- authorize_request_handler.go | 2 +- ...orize_request_handler_oidc_request_test.go | 6 ++-- authorize_request_handler_test.go | 8 ++--- authorize_write.go | 3 +- client.go | 9 +++--- client_test.go | 4 +-- errors.go | 2 +- handler/openid/flow_implicit.go | 1 - integration/authorize_form_post_test.go | 11 +------ integration/authorize_response_mode_test.go | 31 +++++++------------ internal/authorize_request.go | 5 +-- internal/test_helpers.go | 11 ++++++- 13 files changed, 43 insertions(+), 52 deletions(-) diff --git a/authorize_request.go b/authorize_request.go index 791f0dc38..3c047f211 100644 --- a/authorize_request.go +++ b/authorize_request.go @@ -40,7 +40,7 @@ type AuthorizeRequest struct { RedirectURI *url.URL `json:"redirectUri" gorethink:"redirectUri"` State string `json:"state" gorethink:"state"` HandledResponseTypes Arguments `json:"handledResponseTypes" gorethink:"handledResponseTypes"` - ResponseMode ResponseModeType `json:"ResponseMode" gorethink:"ResponseMode"` + ResponseMode ResponseModeType `json:"ResponseModes" gorethink:"ResponseModes"` DefaultResponseMode ResponseModeType `json:"DefaultResponseMode" gorethink:"DefaultResponseMode"` Request diff --git a/authorize_request_handler.go b/authorize_request_handler.go index fb89bfd6e..f57279dc8 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -256,7 +256,7 @@ func (f *Fosite) validateResponseMode(r *http.Request, request *AuthorizeRequest } var found bool - for _, t := range responseModeClient.GetResponseMode() { + for _, t := range responseModeClient.GetResponseModes() { if request.ResponseMode == t { found = true break diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 1ecc8c5bb..2f59670ac 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -165,9 +165,9 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { expectForm: url.Values{"scope": {"openid"}}, }, { - d: "should pass and set request parameters properly", - form: url.Values{"scope": {"openid"}, "response_type": {"code"}, "response_mode": {"none"}, "request": {validRequestObject}}, - client: &DefaultOpenIDConnectClient{JSONWebKeys: jwks, RequestObjectSigningAlgorithm: "RS256"}, + d: "should pass and set request parameters properly", + form: url.Values{"scope": {"openid"}, "response_type": {"code"}, "response_mode": {"none"}, "request": {validRequestObject}}, + client: &DefaultOpenIDConnectClient{JSONWebKeys: jwks, RequestObjectSigningAlgorithm: "RS256"}, // The values from form are overwritten by the request object. expectForm: url.Values{"response_type": {"token"}, "response_mode": {"post_form"}, "scope": {"foo openid"}, "request": {validRequestObject}, "foo": {"bar"}, "baz": {"baz"}}, }, diff --git a/authorize_request_handler_test.go b/authorize_request_handler_test.go index 28623c21e..e34d7b574 100644 --- a/authorize_request_handler_test.go +++ b/authorize_request_handler_test.go @@ -378,7 +378,7 @@ func TestNewAuthorizeRequest(t *testing.T) { "response_mode": {"unknown"}, }, mock: func() { - store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"},ResponseTypes: []string{"code token"}}, nil) + store.EXPECT().GetClient(gomock.Any(), "1234").Return(&DefaultClient{RedirectURIs: []string{"https://foo.bar/cb"}, Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}}, nil) }, expectedError: ErrUnsupportedResponseMode, }, @@ -418,7 +418,7 @@ func TestNewAuthorizeRequest(t *testing.T) { Scopes: []string{"foo", "bar"}, ResponseTypes: []string{"code token"}, }, - ResponseMode: []ResponseModeType{ResponseModeQuery}, + ResponseModes: []ResponseModeType{ResponseModeQuery}, }, nil) }, expectedError: ErrUnsupportedResponseMode, @@ -444,7 +444,7 @@ func TestNewAuthorizeRequest(t *testing.T) { ResponseTypes: []string{"code token"}, Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, }, - ResponseMode: []ResponseModeType{ResponseModeFormPost}, + ResponseModes: []ResponseModeType{ResponseModeFormPost}, }, nil) }, expect: &AuthorizeRequest{ @@ -459,7 +459,7 @@ func TestNewAuthorizeRequest(t *testing.T) { ResponseTypes: []string{"code token"}, Audience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, }, - ResponseMode: []ResponseModeType{ResponseModeFormPost}, + ResponseModes: []ResponseModeType{ResponseModeFormPost}, }, RequestedScope: []string{"foo", "bar"}, RequestedAudience: []string{"https://cloud.ory.sh/api", "https://www.ory.sh/api"}, diff --git a/authorize_write.go b/authorize_write.go index b9c7149fc..cb7526fa7 100644 --- a/authorize_write.go +++ b/authorize_write.go @@ -26,7 +26,6 @@ import ( ) func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequester, resp AuthorizeResponder) { - // Set custom headers, e.g. "X-MySuperCoolCustomHeader" or "X-DONT-CACHE-ME"... wh := rw.Header() rh := resp.GetHeader() @@ -38,12 +37,12 @@ func (f *Fosite) WriteAuthorizeResponse(rw http.ResponseWriter, ar AuthorizeRequ wh.Set("Pragma", "no-cache") redir := ar.GetRedirectURI() - switch ar.GetResponseMode() { case ResponseModeFormPost: //form_post rw.Header().Add("Content-Type", "text/html;charset=UTF-8") WriteAuthorizeFormPostResponse(redir.String(), resp.GetParameters(), GetPostFormHTMLTemplate(*f), rw) + return case ResponseModeQuery, ResponseModeDefault: // Explicit grants q := redir.Query() diff --git a/client.go b/client.go index 9f3f18c8d..c7cf6c5c6 100644 --- a/client.go +++ b/client.go @@ -82,9 +82,8 @@ type OpenIDConnectClient interface { // ResponseModeClient represents a client capable of handling response_mode type ResponseModeClient interface { - // GetResponseMode returns the response modes that client is allowed to send - GetResponseMode() []ResponseModeType + GetResponseModes() []ResponseModeType } // DefaultClient is a simple default implementation of the Client interface. @@ -111,7 +110,7 @@ type DefaultOpenIDConnectClient struct { type DefaultResponseModeClient struct { *DefaultClient - ResponseMode []ResponseModeType `json:"response_mode"` + ResponseModes []ResponseModeType `json:"response_modes"` } func (c *DefaultClient) GetID() string { @@ -190,6 +189,6 @@ func (c *DefaultOpenIDConnectClient) GetRequestURIs() []string { return c.RequestURIs } -func (c *DefaultResponseModeClient) GetResponseMode() []ResponseModeType { - return c.ResponseMode +func (c *DefaultResponseModeClient) GetResponseModes() []ResponseModeType { + return c.ResponseModes } diff --git a/client_test.go b/client_test.go index 3c00cd4e0..baf07d588 100644 --- a/client_test.go +++ b/client_test.go @@ -51,6 +51,6 @@ func TestDefaultClient(t *testing.T) { } func TestDefaultResponseModeClient_GetResponseMode(t *testing.T) { - rc := &DefaultResponseModeClient{ResponseMode: []ResponseModeType{ResponseModeFragment}} - assert.Equal(t, []ResponseModeType{ResponseModeFragment}, rc.GetResponseMode()) + rc := &DefaultResponseModeClient{ResponseModes: []ResponseModeType{ResponseModeFragment}} + assert.Equal(t, []ResponseModeType{ResponseModeFragment}, rc.GetResponseModes()) } diff --git a/errors.go b/errors.go index b503e54b1..4004233e5 100644 --- a/errors.go +++ b/errors.go @@ -74,7 +74,7 @@ var ( } ErrUnsupportedResponseMode = &RFC6749Error{ Name: errUnsupportedResponseModeName, - Description: "The authorization server does not support obtaining response using this response mode", + Description: "The authorization server does not support obtaining a response using this response mode.", Code: http.StatusBadRequest, } ErrInvalidScope = &RFC6749Error{ diff --git a/handler/openid/flow_implicit.go b/handler/openid/flow_implicit.go index b8276aaa8..59968f366 100644 --- a/handler/openid/flow_implicit.go +++ b/handler/openid/flow_implicit.go @@ -100,7 +100,6 @@ func (c *OpenIDConnectImplicitHandler) HandleAuthorizeEndpointRequest(ctx contex claims.AccessTokenHash = base64.RawURLEncoding.EncodeToString([]byte(hash[:c.RS256JWTStrategy.GetSigningMethodLength()/2])) } else { - resp.AddParameter("state", ar.GetState()) } diff --git a/integration/authorize_form_post_test.go b/integration/authorize_form_post_test.go index 87122f41e..2473def6d 100644 --- a/integration/authorize_form_post_test.go +++ b/integration/authorize_form_post_test.go @@ -38,18 +38,9 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/compose" - "github.com/ory/fosite/handler/oauth2" ) func TestAuthorizeFormPostResponseMode(t *testing.T) { - for _, strategy := range []oauth2.AccessTokenStrategy{ - hmacStrategy, - } { - runTestAuthorizeFormPostResponseMode(t, strategy) - } -} - -func runTestAuthorizeFormPostResponseMode(t *testing.T, strategy interface{}) { session := &defaultSession{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ @@ -67,7 +58,7 @@ func runTestAuthorizeFormPostResponseMode(t *testing.T, strategy interface{}) { defaultClient.RedirectURIs[0] = ts.URL + "/callback" responseModeClient := &fosite.DefaultResponseModeClient{ DefaultClient: defaultClient, - ResponseMode: []fosite.ResponseModeType{fosite.ResponseModeFormPost}, + ResponseModes: []fosite.ResponseModeType{fosite.ResponseModeFormPost}, } fositeStore.Clients["response-mode-client"] = responseModeClient oauthClient.ClientID = "response-mode-client" diff --git a/integration/authorize_response_mode_test.go b/integration/authorize_response_mode_test.go index 4592565d4..038344d76 100644 --- a/integration/authorize_response_mode_test.go +++ b/integration/authorize_response_mode_test.go @@ -32,28 +32,20 @@ import ( "github.com/stretchr/testify/assert" + "github.com/pkg/errors" + "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/internal" "github.com/ory/fosite/token/jwt" - "github.com/pkg/errors" "github.com/stretchr/testify/require" goauth "golang.org/x/oauth2" "github.com/ory/fosite" "github.com/ory/fosite/compose" - "github.com/ory/fosite/handler/oauth2" ) -func TestAuthorizeResponseMode(t *testing.T) { - for _, strategy := range []oauth2.AccessTokenStrategy{ - hmacStrategy, - } { - runTestAuthorizeResponseMode(t, strategy) - } -} - -func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { +func TestAuthorizeResponseModes(t *testing.T) { session := &defaultSession{ DefaultSession: &openid.DefaultSession{ Claims: &jwt.IDTokenClaims{ @@ -71,7 +63,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { defaultClient.RedirectURIs[0] = ts.URL + "/callback" responseModeClient := &fosite.DefaultResponseModeClient{ DefaultClient: defaultClient, - ResponseMode: []fosite.ResponseModeType{}, + ResponseModes: []fosite.ResponseModeType{}, } fositeStore.Clients["response-mode-client"] = responseModeClient oauthClient.ClientID = "response-mode-client" @@ -91,7 +83,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeQuery} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeQuery} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.NotEmpty(t, err["Name"]) @@ -106,7 +98,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFormPost} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeFormPost} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) @@ -123,7 +115,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeQuery} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeQuery} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.NotEmpty(t, err["Name"]) @@ -137,7 +129,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { responseMode: "fragment", setup: func() { state = "12345678901234567890" - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFragment} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeFragment} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) @@ -150,7 +142,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { responseMode: "form_post", setup: func() { state = "12345678901234567890" - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFormPost} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeFormPost} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) @@ -164,7 +156,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeQuery} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeQuery} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { //assert.EqualValues(t, state, stateFromServer) @@ -180,7 +172,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { setup: func() { state = "12345678901234567890" oauthClient.Scopes = []string{"openid"} - responseModeClient.ResponseMode = []fosite.ResponseModeType{fosite.ResponseModeFormPost} + responseModeClient.ResponseModes = []fosite.ResponseModeType{fosite.ResponseModeFormPost} }, check: func(t *testing.T, stateFromServer string, code string, token goauth.Token, iDToken string, err map[string]string) { assert.EqualValues(t, state, stateFromServer) @@ -230,6 +222,7 @@ func runTestAuthorizeResponseMode(t *testing.T, strategy interface{}) { }) } } + func getParameters(t *testing.T, param url.Values) (code, state, iDToken string, token goauth.Token, errResp map[string]string) { errResp = make(map[string]string) if param.Get("error") != "" { diff --git a/internal/authorize_request.go b/internal/authorize_request.go index a2b6407a0..91717d790 100644 --- a/internal/authorize_request.go +++ b/internal/authorize_request.go @@ -10,6 +10,7 @@ import ( time "time" gomock "github.com/golang/mock/gomock" + fosite "github.com/ory/fosite" ) @@ -205,7 +206,7 @@ func (mr *MockAuthorizeRequesterMockRecorder) GetRequestedScopes() *gomock.Call // GetResponseMode mocks base method func (m *MockAuthorizeRequester) GetResponseMode() fosite.ResponseModeType { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetResponseMode") + ret := m.ctrl.Call(m, "GetResponseModes") ret0, _ := ret[0].(fosite.ResponseModeType) return ret0 } @@ -213,7 +214,7 @@ func (m *MockAuthorizeRequester) GetResponseMode() fosite.ResponseModeType { // GetResponseMode indicates an expected call of GetResponseMode func (mr *MockAuthorizeRequesterMockRecorder) GetResponseMode() *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseMode", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetResponseMode)) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetResponseModes", reflect.TypeOf((*MockAuthorizeRequester)(nil).GetResponseMode)) } // GetResponseTypes mocks base method diff --git a/internal/test_helpers.go b/internal/test_helpers.go index 91c75dcaa..bd153bc03 100644 --- a/internal/test_helpers.go +++ b/internal/test_helpers.go @@ -34,7 +34,6 @@ import ( ) func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizationCode, stateFromServer, iDToken string, token goauth.Token, customParameters url.Values, rFC6749Error map[string]string, err error) { - token = goauth.Token{} rFC6749Error = map[string]string{} customParameters = url.Values{} @@ -43,23 +42,28 @@ func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizatio if err != nil { return "", "", "", token, customParameters, rFC6749Error, err } + //doc>html>body body := findBody(doc.FirstChild.FirstChild) if body.Data != "body" { return "", "", "", token, customParameters, rFC6749Error, errors.New("Malformed html") } + htmlEvent := body.Attr[0].Key if htmlEvent != "onload" { return "", "", "", token, customParameters, rFC6749Error, errors.New("onload event is missing") } + onLoadFunc := body.Attr[0].Val if onLoadFunc != "javascript:document.forms[0].submit()" { return "", "", "", token, customParameters, rFC6749Error, errors.New("onload function is missing") } + form := getNextNoneTextNode(body.FirstChild) if form.Data != "form" { return "", "", "", token, customParameters, rFC6749Error, errors.New("html form is missing") } + for _, attr := range form.Attr { if attr.Key == "method" { if attr.Val != "post" { @@ -82,6 +86,7 @@ func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizatio } } + switch k { case "state": stateFromServer = v @@ -111,6 +116,7 @@ func ParseFormPostResponse(redirectURL string, resp io.ReadCloser) (authorizatio customParameters.Add(k, v) } } + return } @@ -119,8 +125,10 @@ func getNextNoneTextNode(node *html.Node) *html.Node { if nextNode != nil && nextNode.Type == html.TextNode { nextNode = getNextNoneTextNode(node.NextSibling) } + return nextNode } + func findBody(node *html.Node) *html.Node { if node != nil { if node.Data == "body" { @@ -128,5 +136,6 @@ func findBody(node *html.Node) *html.Node { } return findBody(node.NextSibling) } + return nil } From a8be9f2df5b66aa1e4f0e5789810eedc302428f9 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 17:46:48 +0100 Subject: [PATCH 18/22] u --- handler/oauth2/flow_authorize_code_token.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 968a06a47..e140be424 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -180,13 +180,13 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } else if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - err = rollBackTxnErr + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) } return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } else if refreshSignature != "" { if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - err = rollBackTxnErr + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) } return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } From da08c73d20f3fa1762f2b818fad8b297da3aa2e2 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 17:56:08 +0100 Subject: [PATCH 19/22] u --- handler/oauth2/flow_authorize_code_token.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index e140be424..2c90f998a 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -180,15 +180,15 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } else if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) + return fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr) } - return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) + return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) } else if refreshSignature != "" { if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) } - return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) + return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) } } From 3e6a1bde76502f00b0a1a074cc0c59266e1fa2e8 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Fri, 6 Nov 2020 18:03:09 +0100 Subject: [PATCH 20/22] u --- handler/oauth2/flow_authorize_code_token.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 2c90f998a..065f958cf 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -175,14 +175,14 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex if err := c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - err = rollBackTxnErr + return fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr) } - return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) + return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) } else if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { return fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr) } - return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) + return err } else if refreshSignature != "" { if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { From 6b741ef79f57c8153ed9e77901a7322bab6b0dff Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 9 Nov 2020 12:09:20 +0100 Subject: [PATCH 21/22] u --- client_authentication_test.go | 2 +- errors.go | 24 ++++++++++++--------- handler/oauth2/flow_authorize_code_token.go | 2 +- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/client_authentication_test.go b/client_authentication_test.go index 16b08bbea..bde8a4819 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -502,7 +502,7 @@ func TestAuthenticateClient(t *testing.T) { if errors.As(err, &validationError) { t.Logf("Error is: %s", validationError.Inner) } else if errors.As(err, &rfcError) { - t.Logf("Debug is: %s", rfcError.Debug) + t.Logf("DebugField is: %s", rfcError.DebugField) } } require.NoError(t, err) diff --git a/errors.go b/errors.go index 4004233e5..c078149ab 100644 --- a/errors.go +++ b/errors.go @@ -261,7 +261,7 @@ func ErrorToRFC6749Error(err error) *RFC6749Error { return &RFC6749Error{ Name: errUnknownErrorName, Description: "The error is unrecognizable", - Debug: err.Error(), + DebugField: err.Error(), Code: http.StatusInternalServerError, cause: err, } @@ -272,7 +272,7 @@ type RFC6749Error struct { Description string Hint string Code int - Debug string + DebugField string cause error } @@ -314,9 +314,13 @@ func (e *RFC6749Error) WithHint(hint string) *RFC6749Error { return &err } +func (e *RFC6749Error) Debug() string { + return e.DebugField +} + func (e *RFC6749Error) WithDebug(debug string) *RFC6749Error { err := *e - err.Debug = debug + err.DebugField = debug return &err } @@ -338,7 +342,7 @@ func (e *RFC6749Error) WithCause(cause error) *RFC6749Error { func (e *RFC6749Error) Sanitize() *RFC6749Error { err := *e - err.Debug = "" + err.DebugField = "" return &err } @@ -348,8 +352,8 @@ func (e *RFC6749Error) GetDescription() string { if e.Hint != "" { description += " " + e.Hint } - if e.Debug != "" { - description += " " + e.Debug + if e.DebugField != "" { + description += " " + e.DebugField } return strings.ReplaceAll(description, "\"", "'") } @@ -380,7 +384,7 @@ func (e *RFC6749Error) UnmarshalJSON(b []byte) error { e.Description = data.Verbose e.Hint = data.Hint e.Code = data.Code - e.Debug = data.Debug + e.DebugField = data.Debug return nil } @@ -392,7 +396,7 @@ func (e RFC6749Error) MarshalJSON() ([]byte, error) { Description: e.GetDescription(), Hint: e.Hint, Code: e.Code, - Debug: e.Debug, + Debug: e.DebugField, } return json.Marshal(data) } @@ -404,8 +408,8 @@ func (e *RFC6749Error) ToValues() url.Values { if e.Hint != "" { values.Add("error_hint", e.Hint) } - if e.Debug != "" { - values.Add("error_debug", e.Debug) + if e.DebugField != "" { + values.Add("error_debug", e.DebugField) } return values } diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 065f958cf..9000602e7 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -182,7 +182,7 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { return fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr) } - return err + return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) } else if refreshSignature != "" { if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { From b812014add9bb292df0c127cb6058456cd181905 Mon Sep 17 00:00:00 2001 From: aeneasr <3372410+aeneasr@users.noreply.github.com> Date: Mon, 9 Nov 2020 12:19:15 +0100 Subject: [PATCH 22/22] u --- handler/oauth2/flow_authorize_code_token.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/handler/oauth2/flow_authorize_code_token.go b/handler/oauth2/flow_authorize_code_token.go index 9000602e7..152469f7d 100644 --- a/handler/oauth2/flow_authorize_code_token.go +++ b/handler/oauth2/flow_authorize_code_token.go @@ -175,20 +175,20 @@ func (c *AuthorizeExplicitGrantHandler) PopulateTokenEndpointResponse(ctx contex if err := c.CoreStorage.InvalidateAuthorizeCodeSession(ctx, signature); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - return fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr) + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) } - return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } else if err := c.CoreStorage.CreateAccessTokenSession(ctx, accessSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { - return fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr) + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) } - return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } else if refreshSignature != "" { if err := c.CoreStorage.CreateRefreshTokenSession(ctx, refreshSignature, requester.Sanitize([]string{})); err != nil { if rollBackTxnErr := storage.MaybeRollbackTx(ctx, c.CoreStorage); rollBackTxnErr != nil { return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebugf("error: %s; rollback error: %s", err, rollBackTxnErr)) } - return fosite.ErrServerError.WithCause(err).WithDebug(err.Error()) + return errors.WithStack(fosite.ErrServerError.WithCause(err).WithDebug(err.Error())) } }