diff --git a/cmd/keys_create.go b/cmd/keys_create.go index c0759ddd9c7..8b9fb8f269c 100644 --- a/cmd/keys_create.go +++ b/cmd/keys_create.go @@ -27,6 +27,6 @@ var keysCreateCmd = &cobra.Command{ func init() { keysCmd.AddCommand(keysCreateCmd) - keysCreateCmd.Flags().StringP("alg", "a", "", "REQUIRED name that identifies the algorithm intended for use with the key. Supports: RS256, ES521, HS256") + keysCreateCmd.Flags().StringP("alg", "a", "", "REQUIRED name that identifies the algorithm intended for use with the key. Supports: RS256, ES256, ES521, HS256") } diff --git a/docs/api.swagger.json b/docs/api.swagger.json index b80bf8bd0eb..a3d9b031162 100644 --- a/docs/api.swagger.json +++ b/docs/api.swagger.json @@ -1487,8 +1487,8 @@ "tags": [ "warden" ], - "summary": "Find groups by member", - "operationId": "findGroupsByMember", + "summary": "List groups", + "operationId": "listGroups", "security": [ { "oauth2": [ @@ -1504,11 +1504,27 @@ "name": "member", "in": "query", "required": true + }, + { + "type": "integer", + "format": "int64", + "x-go-name": "Offset", + "description": "The offset from where to start looking if member isn't specified.", + "name": "offset", + "in": "query" + }, + { + "type": "integer", + "format": "int64", + "x-go-name": "Limit", + "description": "The maximum amount of policies returned if member isn't specified.", + "name": "limit", + "in": "query" } ], "responses": { "200": { - "$ref": "#/responses/findGroupsByMemberResponse" + "$ref": "#/responses/listGroupsResponse" }, "401": { "$ref": "#/responses/genericError" @@ -2167,7 +2183,7 @@ ], "properties": { "alg": { - "description": "The algorithm to be used for creating the key. Supports \"RS256\", \"ES521\" and \"HS256\"", + "description": "The algorithm to be used for creating the key. Supports \"RS256\", \"ES256\", \"ES521\" and \"HS256\"", "type": "string", "x-go-name": "Algorithm" }, @@ -3015,15 +3031,6 @@ "emptyResponse": { "description": "An empty response" }, - "findGroupsByMemberResponse": { - "description": "A list of groups the member is belonging to", - "schema": { - "type": "array", - "items": { - "$ref": "#/definitions/group" - } - } - }, "genericError": { "description": "The standard error format", "schema": { @@ -3088,6 +3095,15 @@ "$ref": "#/definitions/oAuth2TokenIntrospection" } }, + "listGroupsResponse": { + "description": "A list of groups the member is belonging to", + "schema": { + "type": "array", + "items": { + "$ref": "#/definitions/group" + } + } + }, "oAuth2ClientList": { "description": "A list of clients.", "schema": { diff --git a/jwk/generator_hs256.go b/jwk/generator_hs256.go index 65fb8386338..e0d8138254d 100644 --- a/jwk/generator_hs256.go +++ b/jwk/generator_hs256.go @@ -19,6 +19,10 @@ func (g *HS256Generator) Generate(id string) (*jose.JSONWebKeySet, error) { return nil, errors.WithStack(err) } + if id == "" { + id = "shared" + } + var sliceKey = key[:] return &jose.JSONWebKeySet{ diff --git a/jwk/handler.go b/jwk/handler.go index fb5c525eeb8..53aba115d49 100644 --- a/jwk/handler.go +++ b/jwk/handler.go @@ -28,6 +28,7 @@ func (h *Handler) GetGenerators() map[string]KeyGenerator { if h.Generators == nil || len(h.Generators) == 0 { h.Generators = map[string]KeyGenerator{ "RS256": &RS256Generator{}, + "ES256": &ECDSA256Generator{}, "ES521": &ECDSA521Generator{}, "HS256": &HS256Generator{}, "HS512": &HS512Generator{}, @@ -52,7 +53,7 @@ func (h *Handler) SetRoutes(r *httprouter.Router) { // swagger:model jsonWebKeySetGeneratorRequest type createRequest struct { - // The algorithm to be used for creating the key. Supports "RS256", "ES521", "HS512", and "HS256" + // The algorithm to be used for creating the key. Supports "RS256", "ES256", "ES521", "HS256" and "HS512" // required: true // in: body Algorithm string `json:"alg"` diff --git a/jwk/handler_test.go b/jwk/handler_test.go index 3d1f2ec7e6f..6a52df3970f 100644 --- a/jwk/handler_test.go +++ b/jwk/handler_test.go @@ -38,7 +38,7 @@ func init() { }, ) router := httprouter.New() - IDKS, _ = testGenerator.Generate("") + IDKS, _ = testGenerators["RS256"].Generate("") h := Handler{ Manager: &MemoryManager{}, diff --git a/jwk/manager_test.go b/jwk/manager_test.go index 3f4e3ec3844..978a72375ea 100644 --- a/jwk/manager_test.go +++ b/jwk/manager_test.go @@ -15,7 +15,7 @@ var managers = map[string]Manager{ "memory": new(MemoryManager), } -var testGenerator = &RS256Generator{} +var testGenerators = (&Handler{}).GetGenerators() var encryptionKey, _ = RandomBytes(32) @@ -53,22 +53,37 @@ func connectToMySQL() { } func TestManagerKey(t *testing.T) { - ks, _ := testGenerator.Generate("") - - for name, m := range managers { - t.Run(fmt.Sprintf("case=%s", name), func(t *testing.T) { - TestHelperManagerKey(m, ks)(t) - }) + for algo, testGenerator := range testGenerators { + if algo == "HS256" { + // this is a symmetrical algorithm + continue + } + + ks, err := testGenerator.Generate("") + if err != nil { + t.Fatal(err) + } + + for name, m := range managers { + t.Run(fmt.Sprintf("case=%s/%s", algo, name), func(t *testing.T) { + TestHelperManagerKey(m, algo, ks)(t) + }) + } } } func TestManagerKeySet(t *testing.T) { - ks, _ := testGenerator.Generate("") - ks.Key("private") - - for name, m := range managers { - t.Run(fmt.Sprintf("case=%s", name), func(t *testing.T) { - TestHelperManagerKeySet(m, ks)(t) - }) + for algo, testGenerator := range testGenerators { + ks, err := testGenerator.Generate("") + if err != nil { + t.Fatal(err) + } + ks.Key("private") + + for name, m := range managers { + t.Run(fmt.Sprintf("case=%s/%s", algo, name), func(t *testing.T) { + TestHelperManagerKeySet(m, algo, ks)(t) + }) + } } } diff --git a/jwk/manager_test_helpers.go b/jwk/manager_test_helpers.go index b8bc4526488..51beee4c490 100644 --- a/jwk/manager_test_helpers.go +++ b/jwk/manager_test_helpers.go @@ -19,57 +19,57 @@ func RandomBytes(n int) ([]byte, error) { return bytes, nil } -func TestHelperManagerKey(m Manager, keys *jose.JSONWebKeySet) func(t *testing.T) { +func TestHelperManagerKey(m Manager, name string, keys *jose.JSONWebKeySet) func(t *testing.T) { pub := keys.Key("public") priv := keys.Key("private") return func(t *testing.T) { - _, err := m.GetKey("faz", "baz") + _, err := m.GetKey(name+"faz", "baz") assert.NotNil(t, err) - err = m.AddKey("faz", First(priv)) + err = m.AddKey(name+"faz", First(priv)) assert.Nil(t, err) - got, err := m.GetKey("faz", "private") + got, err := m.GetKey(name+"faz", "private") assert.Nil(t, err) assert.Equal(t, priv, got.Keys) - err = m.AddKey("faz", First(pub)) + err = m.AddKey(name+"faz", First(pub)) assert.Nil(t, err) - got, err = m.GetKey("faz", "private") + got, err = m.GetKey(name+"faz", "private") assert.Nil(t, err) assert.Equal(t, priv, got.Keys) - got, err = m.GetKey("faz", "public") + got, err = m.GetKey(name+"faz", "public") assert.Nil(t, err) assert.Equal(t, pub, got.Keys) - err = m.DeleteKey("faz", "public") + err = m.DeleteKey(name+"faz", "public") assert.Nil(t, err) - _, err = m.GetKey("faz", "public") + _, err = m.GetKey(name+"faz", "public") assert.NotNil(t, err) } } -func TestHelperManagerKeySet(m Manager, keys *jose.JSONWebKeySet) func(t *testing.T) { +func TestHelperManagerKeySet(m Manager, name string, keys *jose.JSONWebKeySet) func(t *testing.T) { return func(t *testing.T) { - _, err := m.GetKeySet("foo") + _, err := m.GetKeySet(name + "foo") require.Error(t, err) - err = m.AddKeySet("bar", keys) + err = m.AddKeySet(name+"bar", keys) assert.Nil(t, err) - got, err := m.GetKeySet("bar") + got, err := m.GetKeySet(name + "bar") assert.Nil(t, err) assert.Equal(t, keys.Key("public"), got.Key("public")) assert.Equal(t, keys.Key("private"), got.Key("private")) - err = m.DeleteKeySet("bar") + err = m.DeleteKeySet(name + "bar") assert.Nil(t, err) - _, err = m.GetKeySet("bar") + _, err = m.GetKeySet(name + "bar") assert.NotNil(t, err) } } diff --git a/warden/group/doc.go b/warden/group/doc.go index 6758ed73a66..493a119258a 100644 --- a/warden/group/doc.go +++ b/warden/group/doc.go @@ -4,19 +4,27 @@ package group // A list of groups the member is belonging to -// swagger:response findGroupsByMemberResponse -type swaggerFindGroupsByMemberResponse struct { +// swagger:response listGroupsResponse +type swaggerListGroupsResponse struct { // in: body // type: array Body []Group } -// swagger:parameters findGroupsByMember -type swaggerFindGroupsByMemberParameters struct { +// swagger:parameters listGroups +type swaggerListGroupsParameters struct { // The id of the member to look up. // in: query // required: true Member string `json:"member"` + + // The offset from where to start looking if member isn't specified. + // in: query + Offset int `json:"offset"` + + // The maximum amount of policies returned if member isn't specified. + // in: query + Limit int `json:"limit"` } // swagger:parameters createGroup diff --git a/warden/group/handler.go b/warden/group/handler.go index d2f2e6d7d92..e57194c78b4 100644 --- a/warden/group/handler.go +++ b/warden/group/handler.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "github.com/julienschmidt/httprouter" "github.com/ory/herodot" @@ -41,9 +42,9 @@ func (h *Handler) SetRoutes(r *httprouter.Router) { r.DELETE(GroupsHandlerPath+"/:id/members", h.RemoveGroupMembers) } -// swagger:route GET /warden/groups warden findGroupsByMember +// swagger:route GET /warden/groups warden listGroups // -// Find groups by member +// List groups // // The subject making the request needs to be assigned to a policy containing: // @@ -67,7 +68,7 @@ func (h *Handler) SetRoutes(r *httprouter.Router) { // oauth2: hydra.groups // // Responses: -// 200: findGroupsByMemberResponse +// 200: listGroupsResponse // 401: genericError // 403: genericError // 500: genericError @@ -75,15 +76,42 @@ func (h *Handler) FindGroupNames(w http.ResponseWriter, r *http.Request, _ httpr var ctx = r.Context() var member = r.URL.Query().Get("member") - if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), &firewall.TokenAccessRequest{ + accessReq := &firewall.TokenAccessRequest{ Resource: GroupsResource, Action: "list", - }, Scope); err != nil { + } + + if _, err := h.W.TokenAllowed(ctx, h.W.TokenFromRequest(r), accessReq, Scope); err != nil { h.H.WriteError(w, r, err) return } - groups, err := h.Manager.FindGroupsByMember(member) + if member != "" { + groups, err := h.Manager.FindGroupsByMember(member) + + if err != nil { + h.H.WriteError(w, r, err) + return + } + + h.H.Write(w, r, groups) + return + } + + limit, err := intFromQuery(r, "limit", 500) + if err != nil { + h.H.WriteError(w, r, errors.WithStack(err)) + return + } + + offset, err := intFromQuery(r, "offset", 0) + if err != nil { + h.H.WriteError(w, r, errors.WithStack(err)) + return + } + + groups, err := h.Manager.ListGroups(limit, offset) + if err != nil { h.H.WriteError(w, r, err) return @@ -92,6 +120,15 @@ func (h *Handler) FindGroupNames(w http.ResponseWriter, r *http.Request, _ httpr h.H.Write(w, r, groups) } +func intFromQuery(r *http.Request, key string, def int64) (int64, error) { + val := r.URL.Query().Get(key) + if val == "" { + return def, nil + } + + return strconv.ParseInt(val, 10, 64) +} + // swagger:route POST /warden/groups warden createGroup // // Create a group diff --git a/warden/group/manager.go b/warden/group/manager.go index 593dfd042f6..a489a573376 100644 --- a/warden/group/manager.go +++ b/warden/group/manager.go @@ -19,5 +19,6 @@ type Manager interface { AddGroupMembers(group string, members []string) error RemoveGroupMembers(group string, members []string) error + ListGroups(limit, offset int64) ([]Group, error) FindGroupsByMember(subject string) ([]Group, error) } diff --git a/warden/group/manager_memory.go b/warden/group/manager_memory.go index d5bae4b2c83..b50fa1d4bd9 100644 --- a/warden/group/manager_memory.go +++ b/warden/group/manager_memory.go @@ -1,6 +1,7 @@ package group import ( + "sort" "sync" "github.com/ory/hydra/pkg" @@ -19,6 +20,8 @@ type MemoryManager struct { sync.RWMutex } +var _ Manager = (*MemoryManager)(nil) + func (m *MemoryManager) CreateGroup(g *Group) error { if g.ID == "" { g.ID = uuid.New() @@ -77,6 +80,40 @@ func (m *MemoryManager) RemoveGroupMembers(group string, subjects []string) erro return m.CreateGroup(g) } +func (m *MemoryManager) ListGroups(limit, offset int64) ([]Group, error) { + if limit <= 0 { + limit = 500 + } + + if offset < 0 { + offset = 0 + } + + if offset >= int64(len(m.Groups)) { + return []Group{}, nil + } + + ids := []string{} + for id := range m.Groups { + ids = append(ids, id) + } + + sort.Strings(ids) + + res := make([]Group, len(ids)) + for i, id := range ids { + res[i] = m.Groups[id] + } + + res = res[offset:] + + if limit < int64(len(res)) { + res = res[:limit] + } + + return res, nil +} + func (m *MemoryManager) FindGroupsByMember(subject string) ([]Group, error) { if m.Groups == nil { m.Groups = map[string]Group{} diff --git a/warden/group/manager_sql.go b/warden/group/manager_sql.go index 86907ddf3ad..675c9cd17d1 100644 --- a/warden/group/manager_sql.go +++ b/warden/group/manager_sql.go @@ -34,6 +34,8 @@ type SQLManager struct { DB *sqlx.DB } +var _ Manager = (*SQLManager)(nil) + func (m *SQLManager) CreateSchemas() (int, error) { migrate.SetTable("hydra_groups_migration") n, err := migrate.Exec(m.DB.DB, m.DB.DriverName(), migrations, migrate.Up) @@ -128,14 +130,7 @@ func (m *SQLManager) RemoveGroupMembers(group string, subjects []string) error { return nil } -func (m *SQLManager) FindGroupsByMember(subject string) ([]Group, error) { - var ids []string - if err := m.DB.Select(&ids, m.DB.Rebind("SELECT group_id from hydra_warden_group_member WHERE member = ? GROUP BY group_id"), subject); err == sql.ErrNoRows { - return nil, errors.WithStack(pkg.ErrNotFound) - } else if err != nil { - return nil, errors.WithStack(err) - } - +func (m *SQLManager) idsToGroups(ids []string) ([]Group, error) { var groups = make([]Group, len(ids)) for k, id := range ids { group, err := m.GetGroup(id) @@ -148,3 +143,31 @@ func (m *SQLManager) FindGroupsByMember(subject string) ([]Group, error) { return groups, nil } + +func (m *SQLManager) ListGroups(limit, offset int64) ([]Group, error) { + if limit <= 0 { + limit = 500 + } + + if offset < 0 { + offset = 0 + } + + var ids []string + if err := m.DB.Select(&ids, m.DB.Rebind("SELECT id from hydra_warden_group ORDER BY id LIMIT ? OFFSET ?"), limit, offset); err != nil { + return nil, errors.WithStack(err) + } + + return m.idsToGroups(ids) +} + +func (m *SQLManager) FindGroupsByMember(subject string) ([]Group, error) { + var ids []string + if err := m.DB.Select(&ids, m.DB.Rebind("SELECT group_id from hydra_warden_group_member WHERE member = ? GROUP BY group_id"), subject); err == sql.ErrNoRows { + return nil, errors.WithStack(pkg.ErrNotFound) + } else if err != nil { + return nil, errors.WithStack(err) + } + + return m.idsToGroups(ids) +} diff --git a/warden/group/manager_test.go b/warden/group/manager_test.go index 01af2da2d99..aaab24959e3 100644 --- a/warden/group/manager_test.go +++ b/warden/group/manager_test.go @@ -52,6 +52,8 @@ func connectToPG() { } func TestManagers(t *testing.T) { + t.Parallel() + for k, m := range clientManagers { t.Run(fmt.Sprintf("case=%s", k), TestHelperManagers(m)) } diff --git a/warden/group/manager_test_helper.go b/warden/group/manager_test_helper.go index 98d116876cc..0be40102cc9 100644 --- a/warden/group/manager_test_helper.go +++ b/warden/group/manager_test_helper.go @@ -9,7 +9,27 @@ import ( func TestHelperManagers(m Manager) func(t *testing.T) { return func(t *testing.T) { - _, err := m.GetGroup("4321") + ds, err := m.ListGroups(0, 0) + assert.NoError(t, err) + assert.Empty(t, ds) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(-1, 0) + assert.NoError(t, err) + assert.Empty(t, ds) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(0, -1) + assert.NoError(t, err) + assert.Empty(t, ds) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(-1, -1) + assert.NoError(t, err) + assert.Empty(t, ds) + assert.NotNil(t, ds) + + _, err = m.GetGroup("4321") assert.NotNil(t, err) c := &Group{ @@ -17,33 +37,94 @@ func TestHelperManagers(m Manager) func(t *testing.T) { Members: []string{"bar", "foo"}, } assert.NoError(t, m.CreateGroup(c)) + ds, err = m.ListGroups(0, 0) + require.NoError(t, err) + assert.Len(t, ds, 1) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(0, 1) + require.NoError(t, err) + assert.Len(t, ds, 0) + assert.NotNil(t, ds) + assert.NoError(t, m.CreateGroup(&Group{ ID: "2", Members: []string{"foo"}, })) + ds, err = m.ListGroups(0, 0) + require.NoError(t, err) + assert.Len(t, ds, 2) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(0, 1) + require.NoError(t, err) + assert.Len(t, ds, 1) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(0, 2) + require.NoError(t, err) + assert.Len(t, ds, 0) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(1, 0) + require.NoError(t, err) + assert.Len(t, ds, 1) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(2, 0) + require.NoError(t, err) + assert.Len(t, ds, 2) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(1, 1) + require.NoError(t, err) + assert.Len(t, ds, 1) + assert.NotNil(t, ds) + + ds, err = m.ListGroups(0, 2) + require.NoError(t, err) + assert.Len(t, ds, 0) + assert.NotNil(t, ds) + + assert.NoError(t, m.CreateGroup(&Group{ + ID: "3", + Members: []string{"bar"}, + })) + ds, err = m.ListGroups(0, 0) + require.NoError(t, err) + assert.Len(t, ds, 3) + assert.NotNil(t, ds) d, err := m.GetGroup("1") require.NoError(t, err) assert.EqualValues(t, c.Members, d.Members) assert.EqualValues(t, c.ID, d.ID) - ds, err := m.FindGroupsByMember("foo") + ds, err = m.FindGroupsByMember("foo") require.NoError(t, err) assert.Len(t, ds, 2) + assert.NotNil(t, ds) assert.NoError(t, m.AddGroupMembers("1", []string{"baz"})) ds, err = m.FindGroupsByMember("baz") require.NoError(t, err) assert.Len(t, ds, 1) + assert.NotNil(t, ds) assert.NoError(t, m.RemoveGroupMembers("1", []string{"baz"})) ds, err = m.FindGroupsByMember("baz") require.NoError(t, err) assert.Len(t, ds, 0) + assert.NotNil(t, ds) assert.NoError(t, m.DeleteGroup("1")) _, err = m.GetGroup("1") require.NotNil(t, err) + + ds, err = m.ListGroups(0, 0) + require.NoError(t, err) + assert.Len(t, ds, 2) + assert.NotNil(t, ds) } }