From 9d3ef0df9333aff2c587005df0cdd263028029f3 Mon Sep 17 00:00:00 2001 From: hackerman <3372410+aeneasr@users.noreply.github.com> Date: Mon, 19 Jun 2023 15:22:25 +0200 Subject: [PATCH] fix: identity list pagination (#3325) Resolves a pesky issue that would skip the last page. --- cmd/identities/list_test.go | 4 +- go.mod | 1 + go.sum | 2 + identity/handler_test.go | 115 +++++++++++++++++- .../sql/identity/persister_identity.go | 2 +- 5 files changed, 119 insertions(+), 5 deletions(-) diff --git a/cmd/identities/list_test.go b/cmd/identities/list_test.go index 288e0b5cd2c1..f26a69e186ce 100644 --- a/cmd/identities/list_test.go +++ b/cmd/identities/list_test.go @@ -44,8 +44,8 @@ func TestListCmd(t *testing.T) { is, ids := makeIdentities(t, reg, 6) defer deleteIdentities(t, is) - stdoutP1 := execNoErr(t, c, "1", "3") - stdoutP2 := execNoErr(t, c, "2", "3") + stdoutP1 := execNoErr(t, c, "0", "3") + stdoutP2 := execNoErr(t, c, "1", "3") for _, id := range ids { // exactly one of page 1 and 2 should contain the id diff --git a/go.mod b/go.mod index 5349d2ed25cc..744933acf89e 100644 --- a/go.mod +++ b/go.mod @@ -90,6 +90,7 @@ require ( github.com/stretchr/testify v1.8.2 github.com/tidwall/gjson v1.14.3 github.com/tidwall/sjson v1.2.5 + github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 github.com/urfave/negroni v1.0.0 github.com/zmb3/spotify/v2 v2.0.0 go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.36.4 diff --git a/go.sum b/go.sum index bf50b00e4e2a..4abedd603a17 100644 --- a/go.sum +++ b/go.sum @@ -1393,6 +1393,8 @@ github.com/tmc/grpc-websocket-proxy v0.0.0-20200427203606-3cfed13b9966/go.mod h1 github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4dN7GR16kFc5fp3d1RIYzJW5onx8Ybykw2YQFA= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tomasen/realip v0.0.0-20180522021738-f0c99a92ddce/go.mod h1:o8v6yHRoik09Xen7gje4m9ERNah1d1PPsVq1VEx9vE4= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80 h1:nrZ3ySNYwJbSpD6ce9duiP+QkD3JuLCcWkdaehUS/3Y= +github.com/tomnomnom/linkheader v0.0.0-20180905144013-02ca5825eb80/go.mod h1:iFyPdL66DjUD96XmzVL3ZntbzcflLnznH0fr99w5VqE= github.com/toqueteos/webbrowser v1.2.0 h1:tVP/gpK69Fx+qMJKsLE7TD8LuGWPnEV71wBN9rrstGQ= github.com/toqueteos/webbrowser v1.2.0/go.mod h1:XWoZq4cyp9WeUeak7w7LXRUQf1F1ATJMir8RTqb4ayM= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= diff --git a/identity/handler_test.go b/identity/handler_test.go index 06b956fc63be..0ecca0a4fb00 100644 --- a/identity/handler_test.go +++ b/identity/handler_test.go @@ -23,6 +23,8 @@ import ( "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + "github.com/tomnomnom/linkheader" + "github.com/ory/kratos/driver/config" "github.com/ory/kratos/hash" "github.com/ory/kratos/identity" @@ -54,7 +56,7 @@ func TestHandler(t *testing.T) { conf.MustSet(ctx, config.ViperKeyPublicBaseURL, mockServerURL.String()) - get := func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result { + getFull := func(t *testing.T, base *httptest.Server, href string, expectCode int) (gjson.Result, *http.Response) { t.Helper() res, err := base.Client().Get(base.URL + href) require.NoError(t, err) @@ -63,7 +65,13 @@ func TestHandler(t *testing.T) { require.NoError(t, res.Body.Close()) require.EqualValues(t, expectCode, res.StatusCode, "%s", body) - return gjson.ParseBytes(body) + return gjson.ParseBytes(body), res + } + + get := func(t *testing.T, base *httptest.Server, href string, expectCode int) gjson.Result { + t.Helper() + res, _ := getFull(t, base, href, expectCode) + return res } remove := func(t *testing.T, base *httptest.Server, href string, expectCode int) { @@ -1390,6 +1398,109 @@ func TestHandler(t *testing.T) { } } }) + + t.Run("case=should paginate all identities", func(t *testing.T) { + // Start new server + conf, reg := internal.NewFastRegistryWithMocks(t) + _, ts := testhelpers.NewKratosServerWithCSRF(t, reg) + mockServerURL := urlx.ParseOrPanic(publicTS.URL) + conf.MustSet(ctx, config.ViperKeyAdminBaseURL, ts.URL) + testhelpers.SetIdentitySchemas(t, conf, map[string]string{ + "default": "file://./stub/identity.schema.json", + "customer": "file://./stub/handler/customer.schema.json", + "multiple_emails": "file://./stub/handler/multiple_emails.schema.json", + "employee": "file://./stub/handler/employee.schema.json", + }) + conf.MustSet(ctx, config.ViperKeyPublicBaseURL, mockServerURL.String()) + + var toCreate []*identity.Identity + count := 500 + for i := 0; i < count; i++ { + i := identity.NewIdentity(config.DefaultIdentityTraitsSchemaID) + i.Traits = identity.Traits(`{"email":"` + x.NewUUID().String() + `@ory.sh"}`) + toCreate = append(toCreate, i) + } + + require.NoError(t, reg.PrivilegedIdentityPool().CreateIdentities(context.Background(), toCreate...)) + + for _, perPage := range []int{10, 50, 100, 500} { + t.Run(fmt.Sprintf("perPage=%d", perPage), func(t *testing.T) { + t.Parallel() + body, res := getFull(t, ts, fmt.Sprintf("/identities?per_page=%d", perPage), http.StatusOK) + assert.Len(t, body.Array(), perPage) + assert.Equal(t, strconv.Itoa(count), res.Header.Get("X-Total-Count")) + }) + } + + t.Run("iterate over next page", func(t *testing.T) { + perPage := 10 + pagePath := fmt.Sprintf("/identities?per_page=%d", perPage) + + run := func(t *testing.T, path string, knownIDs map[string]struct{}) (isLast bool, parsed *url.URL) { + var err error + t.Logf("Requesting %s", path) + body, res := getFull(t, ts, path, http.StatusOK) + for _, link := range linkheader.Parse(res.Header.Get("Link")) { + if link.Rel != "next" { + isLast = true + continue + } + parsed, err = url.Parse(link.URL) + require.NoError(t, err) + isLast = false + break + } + + for _, i := range body.Array() { + assert.NotContains(t, knownIDs, i.Get("id").String()) + knownIDs[i.Get("id").String()] = struct{}{} + } + return isLast, parsed + } + + t.Run("using token pagination", func(t *testing.T) { + knownIDs := make(map[string]struct{}) + var isLast bool + var pages int + path := pagePath + for !isLast { + t.Run(fmt.Sprintf("page=%d", pages), func(t *testing.T) { + var res *url.URL + pages++ + isLast, res = run(t, path, knownIDs) + if isLast { + return + } + path = fmt.Sprintf("/identities?page_size=%s&page_token=%s", res.Query().Get("page_size"), res.Query().Get("page_token")) + }) + } + + assert.Len(t, knownIDs, count) + assert.Equal(t, count/perPage, pages) + }) + + t.Run("using üage pagination", func(t *testing.T) { + knownIDs := make(map[string]struct{}) + var isLast bool + var pages int + path := pagePath + for !isLast { + t.Run(fmt.Sprintf("page=%d", pages), func(t *testing.T) { + var res *url.URL + pages++ + isLast, res = run(t, path, knownIDs) + if isLast { + return + } + path = fmt.Sprintf("/identities?per_page=%s&page=%s", res.Query().Get("per_page"), res.Query().Get("page")) + }) + } + + assert.Len(t, knownIDs, count) + assert.Equal(t, count/perPage, pages) + }) + }) + }) } func validCreateIdentityBody(prefix string, i int) *identity.CreateIdentityBody { diff --git a/persistence/sql/identity/persister_identity.go b/persistence/sql/identity/persister_identity.go index 4bba9e6faf45..d9e7a983756a 100644 --- a/persistence/sql/identity/persister_identity.go +++ b/persistence/sql/identity/persister_identity.go @@ -688,7 +688,7 @@ func (p *IdentityPersister) ListIdentities(ctx context.Context, params identity. Where("ict.name IN (?)", identity.CredentialsTypeWebAuthn, identity.CredentialsTypePassword). Limit(1) } else { - query = query.Paginate(params.Page, params.PerPage) + query = query.Paginate(params.Page+1, params.PerPage) } if err := sqlcon.HandleError(query.All(&is)); err != nil {