From e20043b69d23c66b22ce0c349cd6ee9c2a7caf87 Mon Sep 17 00:00:00 2001 From: Aeneas Rekkas Date: Mon, 26 Oct 2015 11:20:08 +0100 Subject: [PATCH 1/2] tests: increased coverage --- account/handler/handler.go | 25 +++++---- account/handler/handler_test.go | 77 ++++++++++++++++++++++----- context/auth.go | 12 +---- handler/middleware/middleware_test.go | 2 +- jwt/jwt_test.go | 39 ++++++++++++-- 5 files changed, 116 insertions(+), 39 deletions(-) diff --git a/account/handler/handler.go b/account/handler/handler.go index b4ba20aa05b..f1676e232a1 100644 --- a/account/handler/handler.go +++ b/account/handler/handler.go @@ -2,7 +2,7 @@ package handler import ( "encoding/json" - valid "github.com/asaskevich/govalidator" + "github.com/asaskevich/govalidator" "github.com/gorilla/mux" . "github.com/ory-am/hydra/account" hydcon "github.com/ory-am/hydra/context" @@ -25,29 +25,32 @@ func (h *Handler) SetRoutes(r *mux.Router, extractor func(h hydcon.ContextHandle r.Handle("/users", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Create(extractor)}).Methods("POST") r.Handle("/users/{id}", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Get(extractor)}).Methods("GET") r.Handle("/users/{id}", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Delete(extractor)}).Methods("DELETE") + + //r.Handle("/login", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Login(extractor)}).Methods("GET") + //r.Handle("/logout", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Logout(extractor)}).Methods("GET") } func (h *Handler) Create(extractor func(h hydcon.ContextHandler) hydcon.ContextHandler) hydcon.ContextHandler { return extractor(h.m.IsAuthenticated(h.m.IsAuthorized(hydcon.ContextHandlerFunc( func(ctx context.Context, rw http.ResponseWriter, req *http.Request) { - type payload struct { - Email string `json:"email",valid:"email,required"` - Password string `json:"password",valid:"required"` - Data string `json:"data",valid:"json"` + type Payload struct { + Email string `valid:"email,required" json:"email" ` + Password string `valid:"length(6|254),required" json:"password"` + Data string `valid:"optional,json", json:"data"` } - var p payload + var p Payload decoder := json.NewDecoder(req.Body) if err := decoder.Decode(&p); err != nil { http.Error(rw, err.Error(), http.StatusBadRequest) return } - result, err := valid.ValidateStruct(p) - if err != nil { - http.Error(rw, err.Error(), http.StatusBadRequest) - return - } else if !result { + if v, err := govalidator.ValidateStruct(p); !v { + if err != nil { + http.Error(rw, err.Error(), http.StatusBadRequest) + return + } http.Error(rw, "Payload did not validate.", http.StatusBadRequest) return } diff --git a/account/handler/handler_test.go b/account/handler/handler_test.go index 0cc17200488..c8e6dea0a6e 100644 --- a/account/handler/handler_test.go +++ b/account/handler/handler_test.go @@ -13,6 +13,7 @@ import ( "github.com/ory-am/hydra/handler/middleware" "github.com/ory-am/hydra/hash" "github.com/ory-am/ladon/policy" + "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/net/context" @@ -38,7 +39,12 @@ func TestMain(m *testing.M) { if err != nil { log.Fatalf("Could not connect to database: %s", err) } - defer c.KillRemove() + defer func() { + err := c.KillRemove() + if err != nil { + panic(err.Error()) + } + }() s = hydra.New(&hash.BCrypt{10}, db) if err := s.CreateSchemas(); err != nil { @@ -51,10 +57,10 @@ func TestMain(m *testing.M) { } type payload struct { - ID string `json:"id"` - Email string `json:"email"` - Password string `json:"password"` - Data string `json:"data"` + ID string `json:"id,omitempty"` + Email string `json:"email,omitempty"` + Password string `json:"password,omitempty"` + Data string `json:"data,omitempty"` } type test struct { @@ -111,7 +117,52 @@ var cases = []*test{ []policy.Policy{ &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, }, - &payload{Email: "1@bar.com", Password: "secret", Data: "{}"}, + &payload{Email: uuid.New() + "@foobar.com", Data: "{}"}, + http.StatusBadRequest, 0, 0, + }, + &test{ + &account.DefaultAccount{ID: "peter"}, + &jwt.Token{Valid: true}, + []policy.Policy{ + &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, + }, + &payload{Email: uuid.New() + "@foobar.com", Password: "123", Data: "{}"}, + http.StatusBadRequest, 0, 0, + }, + &test{ + &account.DefaultAccount{ID: "peter"}, + &jwt.Token{Valid: true}, + []policy.Policy{ + &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, + }, + &payload{Email: "notemail", Password: "secret", Data: "{}"}, + http.StatusBadRequest, 0, 0, + }, + &test{ + &account.DefaultAccount{ID: "peter"}, + &jwt.Token{Valid: true}, + []policy.Policy{ + &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, + }, + &payload{Email: uuid.New() + "@bar.com", Password: "", Data: "{}"}, + http.StatusBadRequest, 0, 0, + }, + &test{ + &account.DefaultAccount{ID: "peter"}, + &jwt.Token{Valid: true}, + []policy.Policy{ + &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, + }, + &payload{Email: uuid.New() + "@bar.com", Password: "secret", Data: "not json"}, + http.StatusBadRequest, 0, 0, + }, + &test{ + &account.DefaultAccount{ID: "peter"}, + &jwt.Token{Valid: true}, + []policy.Policy{ + &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, + }, + &payload{Email: uuid.New() + "@bar.com", Password: "secret", Data: "{}"}, http.StatusOK, http.StatusForbidden, http.StatusForbidden, }, &test{ @@ -121,7 +172,7 @@ var cases = []*test{ &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users"}, []string{"create"}}, &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{".*"}, []string{"get"}}, }, - &payload{Email: "2@bar.com", Password: "secret", Data: "{}"}, + &payload{Email: uuid.New() + "@bar.com", Password: "secret", Data: "{}"}, http.StatusOK, http.StatusOK, http.StatusForbidden, }, &test{ @@ -132,7 +183,7 @@ var cases = []*test{ &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{".*"}, []string{"get"}}, &policy.DefaultPolicy{"", "", []string{"peter"}, policy.AllowAccess, []string{"/users/.*"}, []string{"delete"}}, }, - &payload{Email: "3@bar.com", Password: "secret", Data: "{}"}, + &payload{Email: uuid.New() + "@bar.com", Password: "secret", Data: "{}"}, http.StatusOK, http.StatusOK, http.StatusAccepted, }, } @@ -153,9 +204,11 @@ func TestCreateGetDelete(t *testing.T) { req := request(testCase) res := httptest.NewRecorder() router.ServeHTTP(res, req) - assert.Equal(t, code, res.Code, "Case %d, %s", k, name) + assert.Equal(t, code, res.Code, `Case %d, %s: %s`, k, name, res.Body.Bytes()) if http.StatusOK == res.Code || http.StatusAccepted == res.Code { finish(testCase, res) + } else if res.Code == http.StatusNotFound { + log.Printf("404 case %d: %s", k, testCase.createData.ID) } } @@ -170,7 +223,7 @@ func TestCreateGetDelete(t *testing.T) { }, func(c *test, res *httptest.ResponseRecorder) { code = res.Code result := res.Body.Bytes() - log.Printf("POST ok /users: %s", result) + log.Printf("POST case %d /users: %s", k, result) require.Nil(t, json.Unmarshal(result, &p)) assert.Equal(t, c.createData.Email, p.Email) assert.Equal(t, c.createData.Data, p.Data) @@ -187,7 +240,7 @@ func TestCreateGetDelete(t *testing.T) { }, func(c *test, res *httptest.ResponseRecorder) { code = res.Code result := res.Body.Bytes() - log.Printf("GET ok /users/%s: %s", p.ID, result) + log.Printf("GET case %d /users/%s: %s", k, p.ID, result) require.Nil(t, json.Unmarshal(result, &p)) assert.Equal(t, c.createData.Email, p.Email) assert.Equal(t, c.createData.Data, p.Data) @@ -203,7 +256,7 @@ func TestCreateGetDelete(t *testing.T) { return req }, func(c *test, res *httptest.ResponseRecorder) { code = res.Code - log.Printf("DELETE ok /users/%s", p.ID) + log.Printf("DELETE case %d /users/%s", k, p.ID) }) if code != http.StatusAccepted { diff --git a/context/auth.go b/context/auth.go index 60ad4db4710..a9cfd2ff79f 100644 --- a/context/auth.go +++ b/context/auth.go @@ -101,20 +101,10 @@ func PoliciesFromContext(ctx context.Context) ([]policy.Policy, error) { func IsAuthenticatedFromContext(ctx context.Context) bool { a, b := ctx.Value(authKey).(*authorization) - return b && a.token.Valid + return (b && a.token != nil && a.token.Valid) } func NewContextFromAuthValues(ctx context.Context, subject account.Account, token *jwt.Token, policies []policy.Policy) context.Context { - if subject == nil { - subject = &account.DefaultAccount{} - } - if token == nil { - token = &jwt.Token{} - } - if policies == nil { - policies = []policy.Policy{} - } - return context.WithValue(ctx, authKey, &authorization{subject, token, policies}) } diff --git a/handler/middleware/middleware_test.go b/handler/middleware/middleware_test.go index 0a1b00455ae..2521732070b 100644 --- a/handler/middleware/middleware_test.go +++ b/handler/middleware/middleware_test.go @@ -81,7 +81,7 @@ var cases = []test{ &jwt.Token{Valid: true}, []policy.Policy{}, "", "", - true, false, + false, false, }, test{ &account.DefaultAccount{ID: "max"}, diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 9de5fbf77e6..adb0da6737b 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -13,6 +13,10 @@ func TestLoadCertificate(t *testing.T) { assert.Nil(t, err) assert.Equal(t, c[1], string(out)) } + _, err := LoadCertificate("") + assert.NotNil(t, err) + _, err = LoadCertificate("foobar") + assert.NotNil(t, err) } func TestSignRejectsAlgAndTypHeader(t *testing.T) { @@ -34,15 +38,33 @@ func TestSignAndVerify(t *testing.T) { header map[string]interface{} claims map[string]interface{} valid bool + signOk bool } cases := []test{ + test{ + []byte(""), + []byte(TestCertificates[1][1]), + map[string]interface{}{"foo": "bar"}, + map[string]interface{}{"nbf": time.Now().Add(time.Hour).Unix()}, + false, + false, + }, + test{ + []byte(TestCertificates[0][1]), + []byte(""), + map[string]interface{}{"foo": "bar"}, + map[string]interface{}{"nbf": time.Now().Add(time.Hour).Unix()}, + false, + true, + }, test{ []byte(TestCertificates[0][1]), []byte(TestCertificates[1][1]), map[string]interface{}{"foo": "bar"}, map[string]interface{}{"nbf": time.Now().Add(time.Hour).Unix()}, false, + true, }, test{ []byte(TestCertificates[0][1]), @@ -50,6 +72,7 @@ func TestSignAndVerify(t *testing.T) { map[string]interface{}{"foo": "bar"}, map[string]interface{}{"exp": time.Now().Add(-time.Hour).Unix()}, false, + true, }, test{ []byte(TestCertificates[0][1]), @@ -60,6 +83,7 @@ func TestSignAndVerify(t *testing.T) { "exp": time.Now().Add(time.Hour).Unix(), }, true, + true, }, test{ []byte(TestCertificates[0][1]), @@ -69,6 +93,7 @@ func TestSignAndVerify(t *testing.T) { "nbf": time.Now().Add(-time.Hour).Unix(), }, true, + true, }, test{ []byte(TestCertificates[0][1]), @@ -78,6 +103,7 @@ func TestSignAndVerify(t *testing.T) { "exp": time.Now().Add(time.Hour).Unix(), }, true, + true, }, test{ []byte(TestCertificates[0][1]), @@ -85,19 +111,24 @@ func TestSignAndVerify(t *testing.T) { map[string]interface{}{"foo": "bar"}, map[string]interface{}{}, true, + true, }, } for i, c := range cases { j := New(c.private, c.public) data, err := j.SignToken(c.claims, c.header) - require.Nil(t, err, "Case %d", i) + if c.signOk { + require.Nil(t, err, "Case %d", i) + } else { + require.NotNil(t, err, "Case %d", i) + } tok, err := j.VerifyToken([]byte(data)) if c.valid { - require.Nil(t, err) - require.Equal(t, c.valid, tok.Valid) + require.Nil(t, err, "Case %d", i) + require.Equal(t, c.valid, tok.Valid, "Case %d", i) } else { - require.NotNil(t, err) + require.NotNil(t, err, "Case %d", i) } } } From 9338754fa1433d3b1acc1b0c3206c1d671385af8 Mon Sep 17 00:00:00 2001 From: Aeneas Rekkas Date: Tue, 27 Oct 2015 00:16:29 +0100 Subject: [PATCH 2/2] oauth2: created provider handler --- account/handler/handler.go | 3 - account/postgres/store.go | 5 +- account/postgres/store_test.go | 2 + hash/bcrypt_test.go | 2 +- jwt/jwt.go | 50 ++++++-- oauth/provider/handler.go | 173 ++++++++++++++++++++++++++++ oauth/provider/handler_test.go | 204 +++++++++++++++++++++++++++++++++ 7 files changed, 423 insertions(+), 16 deletions(-) create mode 100644 oauth/provider/handler.go create mode 100644 oauth/provider/handler_test.go diff --git a/account/handler/handler.go b/account/handler/handler.go index f1676e232a1..ab972ec40eb 100644 --- a/account/handler/handler.go +++ b/account/handler/handler.go @@ -25,9 +25,6 @@ func (h *Handler) SetRoutes(r *mux.Router, extractor func(h hydcon.ContextHandle r.Handle("/users", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Create(extractor)}).Methods("POST") r.Handle("/users/{id}", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Get(extractor)}).Methods("GET") r.Handle("/users/{id}", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Delete(extractor)}).Methods("DELETE") - - //r.Handle("/login", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Login(extractor)}).Methods("GET") - //r.Handle("/logout", &hydcon.ContextAdapter{Ctx: context.Background(), Handler: h.Logout(extractor)}).Methods("GET") } func (h *Handler) Create(extractor func(h hydcon.ContextHandler) hydcon.ContextHandler) hydcon.ContextHandler { diff --git a/account/postgres/store.go b/account/postgres/store.go index 8c43fdf7975..5db06a593e0 100644 --- a/account/postgres/store.go +++ b/account/postgres/store.go @@ -2,6 +2,7 @@ package postgres import ( "database/sql" + "errors" "github.com/ory-am/hydra/account" "github.com/ory-am/hydra/hash" "log" @@ -14,6 +15,8 @@ const accountSchema = `CREATE TABLE account ( data json )` +var ErrNotFound = errors.New("Not found") + type Store struct { hasher hash.Hasher db *sql.DB @@ -103,7 +106,7 @@ func (s *Store) Delete(id string) (err error) { func (s *Store) Authenticate(email, password string) (account.Account, error) { var a account.DefaultAccount // Query account - row := s.db.QueryRow("SELECT id, email, password, data FROM account WHERE email=$1 LIMIT 1", email) + row := s.db.QueryRow("SELECT id, email, password, data FROM account WHERE email=$1", email) // Hydrate struct with data if err := row.Scan(&a.ID, &a.Email, &a.Password, &a.Data); err != nil { diff --git a/account/postgres/store_test.go b/account/postgres/store_test.go index 9c3b7b1a6ff..815b4361f28 100644 --- a/account/postgres/store_test.go +++ b/account/postgres/store_test.go @@ -136,6 +136,8 @@ func TestAuthenticate(t *testing.T) { assert.NotNil(t, err) _, err = store.Authenticate("doesnotexist@foo", "secret") assert.NotNil(t, err) + _, err = store.Authenticate("", "") + assert.NotNil(t, err) result, err := store.Authenticate("5@bar", "secret") assert.Nil(t, err) diff --git a/hash/bcrypt_test.go b/hash/bcrypt_test.go index 0d09684a43b..0fd598e8f8c 100644 --- a/hash/bcrypt_test.go +++ b/hash/bcrypt_test.go @@ -1,7 +1,7 @@ package hash import ( - "code.google.com/p/go-uuid/uuid" + "github.com/pborman/uuid" "github.com/stretchr/testify/assert" "testing" ) diff --git a/jwt/jwt.go b/jwt/jwt.go index 92dc012525d..4e3676cf351 100644 --- a/jwt/jwt.go +++ b/jwt/jwt.go @@ -3,6 +3,7 @@ package jwt import ( "errors" "fmt" + "github.com/RangelReale/osin" "github.com/dgrijalva/jwt-go" "io" "io/ioutil" @@ -115,22 +116,49 @@ func (j *JWT) SignToken(claims map[string]interface{}, header map[string]interfa return "", errors.New("You may not override the typ header key.") } - c := func(a, b map[string]interface{}) map[string]interface{} { - for k, w := range b { - if _, ok := a[k]; !ok { - continue - } - a[k] = w - } - return a - } - token := jwt.New(jwt.SigningMethodRS256) token.Claims = claims - token.Header = c(token.Header, header) + token.Header = merge(token.Header, header) ecdsaKey, err := jwt.ParseRSAPrivateKeyFromPEM(j.privateKey) if err != nil { return "", err } return token.SignedString(ecdsaKey) } + +func merge(a, b map[string]interface{}) map[string]interface{} { + for k, w := range b { + if _, ok := a[k]; !ok { + continue + } + a[k] = w + } + return a +} + +type Map struct { + Data map[string]interface{} +} + +func (j *JWT) GenerateAccessToken(data *osin.AccessData, generateRefresh bool) (accessToken string, refreshToken string, err error) { + claims := map[string]interface{}{"cid": data.Client.GetId(), "exp": data.ExpireAt().Unix()} + extra, ok := data.UserData.(*Map) + if !ok { + extra = &Map{Data: map[string]interface{}{}} + } + + claims = merge(claims, extra.Data) + accessToken, err = j.SignToken(claims, map[string]interface{}{}) + if err != nil { + return "", "", err + } else if !generateRefresh { + return + } + + claims = merge(map[string]interface{}{"at": accessToken}, claims) + refreshToken, err = j.SignToken(claims, map[string]interface{}{}) + if err != nil { + return "", "", err + } + return +} diff --git a/oauth/provider/handler.go b/oauth/provider/handler.go new file mode 100644 index 00000000000..df3b53433e8 --- /dev/null +++ b/oauth/provider/handler.go @@ -0,0 +1,173 @@ +package provider + +import ( + "fmt" + "github.com/RangelReale/osin" + "github.com/gorilla/mux" + "github.com/ory-am/hydra/account" + "github.com/ory-am/hydra/jwt" + "github.com/ory-am/ladon/guard" + "github.com/ory-am/ladon/policy" + "github.com/ory-am/osin-storage/storage" + "log" + "net/http" +) + +func configureOsin() *osin.ServerConfig { + conf := osin.NewServerConfig() + conf.AllowedAuthorizeTypes = osin.AllowedAuthorizeType{ + osin.CODE, + osin.TOKEN, + } + conf.AllowedAccessTypes = osin.AllowedAccessType{ + osin.AUTHORIZATION_CODE, + osin.REFRESH_TOKEN, + osin.PASSWORD, + osin.CLIENT_CREDENTIALS, + // TODO osin.ASSERTION, + } + //conf.AllowGetAccessRequest = true + //conf.AllowClientSecretInParams = true + return conf +} + +type Handler struct { + s storage.Storage + conf *osin.ServerConfig + server *osin.Server + account account.Storage + policy policy.Storer + guard guard.Guarder +} + +func NewHandler(s storage.Storage, j *jwt.JWT, account account.Storage, policy policy.Storer, guard guard.Guarder) *Handler { + conf := configureOsin() + server := osin.NewServer(conf, s) + server.AccessTokenGen = j + return &Handler{ + s: s, + conf: conf, + server: server, + account: account, + policy: policy, + guard: guard, + } +} + +func (h *Handler) SetRoutes(r *mux.Router) { + r.HandleFunc("/oauth2/auth", h.AuthorizeHandler(func(w http.ResponseWriter, r *http.Request) (string, string, error) { + if err := r.ParseForm(); err != nil { + return "", "", err + } + return r.FormValue("username"), r.FormValue("password"), nil + })).Headers("Content-Type", "application/x-www-form-urlencoded") + r.HandleFunc("/oauth2/token", h.TokenHandler) + r.HandleFunc("/oauth2/info", h.TokenHandler) +} + +func (h *Handler) InfoHandler(w http.ResponseWriter, r *http.Request) { + resp := h.server.NewResponse() + defer resp.Close() + + if ir := h.server.HandleInfoRequest(resp, r); ir != nil { + h.server.FinishInfoRequest(resp, r, ir) + } + osin.OutputJSON(resp, w, r) +} + +func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) { + resp := h.server.NewResponse() + r.ParseForm() + defer resp.Close() + if ar := h.server.HandleAccessRequest(resp, r); ar != nil { + switch ar.Type { + case osin.AUTHORIZATION_CODE: + ar.Authorized = true + case osin.REFRESH_TOKEN: + ar.Authorized = true + case osin.PASSWORD: + // TODO if !ar.Client.isAllowedToAuthenticateUser + // TODO ... check session or redirect to trusted client + // TODO ... return + // TODO } + + _, err := h.authenticate(w, ar.Username, ar.Password) + ar.Authorized = err == nil + case osin.CLIENT_CREDENTIALS: + ar.Authorized = true + // TODO ASSERTION federation workflow http://leastprivilege.com/2013/12/23/advanced-oauth2-assertion-flow-why/ + // TODO case osin.ASSERTION: + // TODO if ar.AssertionType == "urn:osin.example.complete" && ar.Assertion == "osin.data" { + // TODO ar.Authorized = true + // TODO } + } + h.server.FinishAccessRequest(resp, r, ar) + } + if resp.IsError { + log.Printf("Error in /oauth2/token: %s, %d, %s", resp.ErrorId, resp.ErrorStatusCode, resp.InternalError) + resp.StatusCode = http.StatusUnauthorized + } + osin.OutputJSON(resp, w, r) +} + +func (h *Handler) AuthorizeHandler(decoder func(w http.ResponseWriter, r *http.Request) (string, string, error)) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + resp := h.server.NewResponse() + defer resp.Close() + if ar := h.server.HandleAuthorizeRequest(resp, r); ar != nil { + + // TODO if !ar.Client.isAllowedToAuthenticateUser + // TODO ... check session or redirect to trusted client + // TODO ... return + // TODO } + + username, password, err := decoder(w, r) + if err != nil { + http.Error(w, fmt.Sprintf("Could parse payload: %s", err), http.StatusUnauthorized) + return + } + + acc, err := h.authenticate(w, username, password) + if err != nil { + log.Printf(`Authentication denied for "%s" using password "%s"`, username, password) + return + } + + ar.UserData = &jwt.Map{ + Data: map[string]interface{}{"subject": acc.GetID()}, + } + ar.Authorized = true + h.server.FinishAuthorizeRequest(resp, r, ar) + } + if resp.IsError { + log.Printf("Error in /oauth2/auth: %s, %d, %s", resp.ErrorId, resp.ErrorStatusCode, resp.InternalError) + resp.StatusCode = http.StatusUnauthorized + } + osin.OutputJSON(resp, w, r) + } +} + +func (h *Handler) authenticate(w http.ResponseWriter, email, password string) (account.Account, error) { + acc, err := h.account.Authenticate(email, password) + if err != nil { + http.Error(w, "Could not authenticate.", http.StatusUnauthorized) + return nil, err + } + + policies, err := h.policy.FindPoliciesForSubject(acc.GetID()) + if err != nil { + http.Error(w, fmt.Sprintf("Could not fetch policies: %s", err.Error()), http.StatusInternalServerError) + return nil, err + } + + if granted, err := h.guard.IsGranted("/oauth2/authorize", "authorize", acc.GetID(), policies); !granted { + err = fmt.Errorf(`Subject "%s" is not allowed to authorize.`, acc.GetID()) + http.Error(w, err.Error(), http.StatusUnauthorized) + return nil, err + } else if err != nil { + http.Error(w, fmt.Sprintf(`Authorization failed for Subject "%s": %s`, acc.GetID(), err.Error()), http.StatusInternalServerError) + return nil, err + } + + return acc, nil +} diff --git a/oauth/provider/handler_test.go b/oauth/provider/handler_test.go new file mode 100644 index 00000000000..71bb40516d0 --- /dev/null +++ b/oauth/provider/handler_test.go @@ -0,0 +1,204 @@ +package provider_test + +import ( + "bytes" + "database/sql" + "github.com/gorilla/mux" + "github.com/ory-am/dockertest" + acpg "github.com/ory-am/hydra/account/postgres" + "github.com/ory-am/hydra/hash" + "github.com/ory-am/hydra/jwt" + . "github.com/ory-am/hydra/oauth/provider" + "github.com/ory-am/ladon/guard" + "github.com/ory-am/ladon/policy" + ppg "github.com/ory-am/ladon/policy/postgres" + opg "github.com/ory-am/osin-storage/storage/postgres" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/oauth2" + "log" + "net/http" + "net/http/httptest" + "net/url" + "os" + "testing" + "time" +) + +var db *sql.DB +var accountStore *acpg.Store +var policyStore *ppg.Store +var handler *Handler +var gd = new(guard.Guard) +var osinStore *opg.Storage + +func TestMain(m *testing.M) { + var err error + var c dockertest.ContainerID + c, db, err = dockertest.OpenPostgreSQLContainerConnection(15, time.Second) + if err != nil { + log.Fatalf("Could not connect to database: %s", err) + } + defer c.KillRemove() + + accountStore = acpg.New(&hash.BCrypt{10}, db) + policyStore = ppg.New(db) + osinStore = opg.New(db) + if err := policyStore.CreateSchemas(); err != nil { + log.Fatalf("Could not set up schemas: %v", err) + } + if err := accountStore.CreateSchemas(); err != nil { + log.Fatalf("Could not set up schemas: %v", err) + } + if err := osinStore.CreateSchemas(); err != nil { + log.Fatalf("Could not set up schemas: %v", err) + } + + j := jwt.New([]byte(jwt.TestCertificates[0][1]), []byte(jwt.TestCertificates[1][1])) + handler = NewHandler(osinStore, j, accountStore, policyStore, gd) + + if _, err := osinStore.CreateClient("1", "secret", "/callback"); err != nil { + log.Fatalf("Could create client: %s", err) + } + if _, err := accountStore.Create("2", "2@bar.com", "secret", "{}"); err != nil { + log.Fatalf("Could create account: %s", err) + } + if _, err := policyStore.Create("3", "", policy.AllowAccess, []string{}, []string{"authorize"}, []string{"/oauth2/authorize"}); err != nil { + log.Fatalf("Could create client: %s", err) + } + os.Exit(m.Run()) +} + +func TestAuthorize(t *testing.T) { + oauthConfigs := []*oauth2.Config{ + &oauth2.Config{ + ClientID: "1", + ClientSecret: "secret", + Scopes: []string{}, + RedirectURL: "/callback", + Endpoint: oauth2.Endpoint{ + AuthURL: "/oauth2/auth", + TokenURL: "/oauth2/token", + }, + }, + &oauth2.Config{ + ClientID: "1", + ClientSecret: "wrongsecret", + Scopes: []string{}, + RedirectURL: "/callback", + Endpoint: oauth2.Endpoint{ + AuthURL: "/oauth2/auth", + TokenURL: "/oauth2/token", + }, + }, + &oauth2.Config{ + ClientID: "notexistent", + ClientSecret: "random", + Scopes: []string{}, + RedirectURL: "/callback", + Endpoint: oauth2.Endpoint{ + AuthURL: "/oauth2/auth", + TokenURL: "/oauth2/token", + }, + }, + } + + type userData struct { + Username string `json:"username"` + Password string `json:"password"` + } + + type test struct { + code int + state string + config *oauth2.Config + userData *userData + pass bool + } + + cases := []*test{ + &test{ + state: "foobar", + config: oauthConfigs[0], + code: http.StatusFound, + userData: &userData{"2@bar.com", "secret"}, + pass: true, + }, + &test{ + state: "foobar", + config: oauthConfigs[0], + code: http.StatusUnauthorized, + userData: &userData{"nonexistent@bar.com", "secret"}, + pass: false, + }, + &test{ + state: "foobar", + config: oauthConfigs[0], + code: http.StatusUnauthorized, + userData: &userData{"2@bar.com", "wrong secret"}, + pass: false, + }, + &test{ + state: "foobar", + config: oauthConfigs[1], + // Ok because oauth2/auth does not check client secret, only oauth2/token does. + code: http.StatusFound, + userData: &userData{"2@bar.com", "secret"}, + pass: false, + }, + &test{ + state: "foobar", + config: oauthConfigs[2], + code: http.StatusUnauthorized, + userData: &userData{"2@bar.com", "secret"}, + pass: false, + }, + } + + for k, c := range cases { + loc := "" + func() { + router := mux.NewRouter() + handler.SetRoutes(router) + + authURL := c.config.AuthCodeURL(c.state) + log.Printf("Acquired auth code url: %s", authURL) + + post := url.Values{} + post.Set("username", c.userData.Username) + post.Add("password", c.userData.Password) + req, _ := http.NewRequest("POST", authURL, bytes.NewBufferString(post.Encode())) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + res := httptest.NewRecorder() + router.ServeHTTP(res, req) + assert.Equal(t, c.code, res.Code, `Case %d, %s: %s`, k, res.Body.Bytes()) + + log.Printf("Result was: %s %s", res.Body.String(), res.Header().Get("Location")) + loc = res.Header().Get("Location") + }() + + if loc == "" { + continue + } + + func() { + router := mux.NewRouter() + handler.SetRoutes(router) + ts := httptest.NewServer(router) + defer ts.Close() + u, err := url.Parse(loc) + require.Nil(t, err) + log.Printf("Exchanging token: %s", ts.URL+"/oauth2/auth") + c.config.Endpoint = oauth2.Endpoint{AuthURL: ts.URL + "/oauth2/auth", TokenURL: ts.URL + "/oauth2/token"} + tok, err := c.config.Exchange(oauth2.NoContext, u.Query().Get("code")) + if !c.pass { + assert.NotNil(t, err, "Case %d", k) + return + } + + assert.Nil(t, err, "Case %d: %v", k) + assert.NotNil(t, tok) + }() + } +}