Skip to content

Commit

Permalink
PR Feedback: move parallelization factor to config. Use err group
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickduffy95 committed Jul 17, 2024
1 parent 50efbca commit 42639c0
Show file tree
Hide file tree
Showing 29 changed files with 176 additions and 410 deletions.
8 changes: 8 additions & 0 deletions embedx/config.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,14 @@
"description": "The maximum number of tuples that will be accepted by the batch check endpoint.",
"minimum": 1,
"maximum": 65535
},
"batch_check_max_parallelization": {
"type": "integer",
"default": 5,
"title": "Max concurrent checks during batch check",
"description": "The limit for the number of tuples that will be checked concurrently during a batch check.",
"minimum": 1,
"maximum": 65535
}
},
"additionalProperties": false
Expand Down
36 changes: 13 additions & 23 deletions internal/check/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,12 @@ package check

import (
"context"
"sync"

"golang.org/x/sync/semaphore"

"github.com/ory/herodot"
"github.com/ory/x/otelx"
"github.com/pkg/errors"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"

"github.com/ory/keto/x/events"

Expand Down Expand Up @@ -270,30 +268,19 @@ func (e *Engine) astRelationFor(ctx context.Context, r *relationTuple) (*ast.Rel

// BatchCheck makes parallelized check requests for tuples. The check results are returned as slice, where the
// result index matches the tuple index of the incoming tuples array.
//
// parallelizationFactor is the max the number of checks that can happen in parallel.
func (e *Engine) BatchCheck(ctx context.Context,
tuples []*ketoapi.RelationTuple,
maxDepth, parallelizationFactor int) ([]checkgroup.Result, error) {

if parallelizationFactor <= 0 {
return nil, errors.New("invalid parallelization factor")
}
maxDepth int) ([]checkgroup.Result, error) {

wg := &sync.WaitGroup{}
sem := semaphore.NewWeighted(int64(parallelizationFactor))
eg := &errgroup.Group{}
eg.SetLimit(e.d.Config(ctx).BatchCheckParallelizationLimit())

mapper := e.d.ReadOnlyMapper()
results := make([]checkgroup.Result, len(tuples))
for outerIndex, outerTuple := range tuples {
sem.Acquire(context.Background(), 1) // Pass in background context to guarantee this won't return an error
wg.Add(1)
go func(i int, tuple *ketoapi.RelationTuple) {
defer func() {
wg.Done()
sem.Release(1)
}()

for i, tuple := range tuples {
i := i
tuple := tuple
eg.Go(func() error {
internalTuple, err := mapper.FromTuple(ctx, tuple)
if err != nil {
results[i] = checkgroup.Result{
Expand All @@ -303,10 +290,13 @@ func (e *Engine) BatchCheck(ctx context.Context,
} else {
results[i] = e.CheckRelationTuple(ctx, internalTuple[0], maxDepth)
}
}(outerIndex, outerTuple)
return nil
})
}

wg.Wait()
if err := eg.Wait(); err != nil {
return nil, err
}

return results, nil
}
13 changes: 2 additions & 11 deletions internal/check/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ func TestEngine(t *testing.T) {
}

// Batch check with low max depth
results, err := e.BatchCheck(ctx, targetTuples, 2, 5)
results, err := e.BatchCheck(ctx, targetTuples, 2)
require.NoError(t, err)

require.Equal(t, checkgroup.IsMember, results[0].Membership)
Expand All @@ -677,17 +677,8 @@ func TestEngine(t *testing.T) {
require.NoError(t, results[5].Err)

// Check with higher max depth and verify the third tuple is now shown as a member
results, err = e.BatchCheck(ctx, targetTuples, 3, 5)
results, err = e.BatchCheck(ctx, targetTuples, 3)
require.NoError(t, err)
require.Equal(t, checkgroup.IsMember, results[2].Membership)

// Check success with no parallelization
noParallelizationResults, err := e.BatchCheck(ctx, targetTuples, 3, 1)
require.NoError(t, err)
require.Equal(t, results, noParallelizationResults)

// Attempt with an invalid parallelization factor
_, err = e.BatchCheck(ctx, targetTuples, 3, 0)
require.EqualError(t, err, "invalid parallelization factor")
})
}
29 changes: 2 additions & 27 deletions internal/check/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"io"
"net/http"
"net/url"
"strconv"

"github.com/julienschmidt/httprouter"

Expand Down Expand Up @@ -56,9 +55,6 @@ const (
RouteBase = "/relation-tuples/check"
OpenAPIRouteBase = RouteBase + "/openapi"
BatchRoute = "/relation-tuples/batch/check"

parallelizationFactorQueryParam = "parallelization-factor"
defaultBatchCheckParallelizationFactor = 5
)

func (h *Handler) RegisterReadRoutes(r *x.ReadRouter) {
Expand Down Expand Up @@ -364,12 +360,6 @@ type batchCheckPermission struct {
// in: query
MaxDepth int `json:"max-depth"`

// ParallelizationFactor is the maximum number of check requests
// that can happen concurrently. Optional. Defaults to 5.
//
// in: query
ParallelizationFactor int `json:"parallelization-factor"`

// in: body
Body batchCheckPermissionBody
}
Expand Down Expand Up @@ -427,14 +417,6 @@ func (h *Handler) doBatchCheck(ctx context.Context, body io.Reader, query url.Va
if err != nil {
return nil, err
}
parallelizationFactor := defaultBatchCheckParallelizationFactor
if query.Get(parallelizationFactorQueryParam) != "" {
parallelizationFactor, err = strconv.Atoi(query.Get(parallelizationFactorQueryParam))
if err != nil || parallelizationFactor <= 0 {
return nil, herodot.ErrBadRequest.WithError("parallelization factor must be a positive integer")
}
}
h.d.Writer()
var request batchCheckPermissionBody
if err := json.NewDecoder(body).Decode(&request); err != nil {
return nil, errors.WithStack(herodot.ErrBadRequest.WithErrorf("could not unmarshal json: %s", err.Error()))
Expand All @@ -445,7 +427,7 @@ func (h *Handler) doBatchCheck(ctx context.Context, body io.Reader, query url.Va
h.d.Config(ctx).BatchCheckMaxBatchSize())
}

results, err := h.d.PermissionEngine().BatchCheck(ctx, request.Tuples, maxDepth, parallelizationFactor)
results, err := h.d.PermissionEngine().BatchCheck(ctx, request.Tuples, maxDepth)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -477,14 +459,7 @@ func (h *Handler) BatchCheck(ctx context.Context, req *rts.BatchCheckRequest) (*
ketoTuples[i] = (&ketoapi.RelationTuple{}).FromProto(tuple)
}

parallelizationFactor := defaultBatchCheckParallelizationFactor
if req.ParallelizationFactor != nil {
if *req.ParallelizationFactor <= 0 {
return nil, status.Error(codes.InvalidArgument, "parallelization factor must be a positive integer")
}
parallelizationFactor = int(*req.ParallelizationFactor)
}
results, err := h.d.PermissionEngine().BatchCheck(ctx, ketoTuples, int(req.MaxDepth), parallelizationFactor)
results, err := h.d.PermissionEngine().BatchCheck(ctx, ketoTuples, int(req.MaxDepth))
if err != nil {
return nil, err
}
Expand Down
65 changes: 9 additions & 56 deletions internal/check/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,30 +194,8 @@ func TestBatchCheckRESTHandler(t *testing.T) {
assert.Contains(t, string(body), "invalid syntax")
})

t.Run("case=returns bad request on non-int parallelization factor", func(t *testing.T) {
resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5", "abc"),
"application/json", nil)
require.NoError(t, err)

assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(body), "parallelization factor must be a positive integer")
})

t.Run("case=returns bad request on negative parallelization factor", func(t *testing.T) {
resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5", "-1"),
"application/json", nil)
require.NoError(t, err)

assert.Equal(t, http.StatusBadRequest, resp.StatusCode)
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Contains(t, string(body), "parallelization factor must be a positive integer")
})

t.Run("case=returns bad request on invalid request body", func(t *testing.T) {
resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5", "5"),
resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5"),
"application/json", strings.NewReader("not-json"))
require.NoError(t, err)

Expand All @@ -241,7 +219,7 @@ func TestBatchCheckRESTHandler(t *testing.T) {
bodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err)

resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5", "5"),
resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5"),
"application/json", bytes.NewReader(bodyBytes))
require.NoError(t, err)

Expand Down Expand Up @@ -283,7 +261,7 @@ func TestBatchCheckRESTHandler(t *testing.T) {
bodyBytes, err := json.Marshal(reqBody)
require.NoError(t, err)

resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5", "5"),
resp, err := ts.Client().Post(buildBatchURL(ts.URL, "5"),
"application/json", bytes.NewReader(bodyBytes))
require.NoError(t, err)
assert.Equal(t, http.StatusOK, resp.StatusCode)
Expand Down Expand Up @@ -321,9 +299,9 @@ func TestBatchCheckRESTHandler(t *testing.T) {
})
}

func buildBatchURL(baseURL, maxDepth, parallelizationFactor string) string {
return fmt.Sprintf("%s%s?max-depth=%s&parallelization-factor=%s",
baseURL, check.BatchRoute, maxDepth, parallelizationFactor)
func buildBatchURL(baseURL, maxDepth string) string {
return fmt.Sprintf("%s%s?max-depth=%s",
baseURL, check.BatchRoute, maxDepth)
}

func TestBatchCheckGRPCHandler(t *testing.T) {
Expand Down Expand Up @@ -371,39 +349,15 @@ func TestBatchCheckGRPCHandler(t *testing.T) {
}
}
_, err := checkClient.BatchCheck(ctx, &rts.BatchCheckRequest{
Tuples: tuples,
MaxDepth: 5,
ParallelizationFactor: nil,
Tuples: tuples,
MaxDepth: 5,
})
statusErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.InvalidArgument, statusErr.Code())
require.Equal(t, "batch exceeds max size of 10", statusErr.Message())
})

t.Run("case=returns bad request when batch too large", func(t *testing.T) {
_, err := checkClient.BatchCheck(ctx, &rts.BatchCheckRequest{
Tuples: []*rts.RelationTuple{
{
Namespace: "n",
Object: "o",
Relation: "r",
Subject: &rts.Subject{
Ref: &rts.Subject_Id{
Id: "s",
},
},
},
},
MaxDepth: 5,
ParallelizationFactor: pointerx.Ptr[int32](0),
})
statusErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, codes.InvalidArgument, statusErr.Code())
require.Equal(t, "parallelization factor must be a positive integer", statusErr.Message())
})

t.Run("case=batch check", func(t *testing.T) {
rt := &ketoapi.RelationTuple{
Namespace: nspaces[0].Name,
Expand Down Expand Up @@ -446,8 +400,7 @@ func TestBatchCheckGRPCHandler(t *testing.T) {
},
},
},
MaxDepth: 5,
ParallelizationFactor: pointerx.Ptr[int32](5),
MaxDepth: 5,
})
require.NoError(t, err)
require.Len(t, resp.Results, 3)
Expand Down
7 changes: 6 additions & 1 deletion internal/driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ const (
KeyLimitMaxReadDepth = "limit.max_read_depth"
KeyLimitMaxReadWidth = "limit.max_read_width"

KeyBatchCheckMaxBatchSize = "limit.max_batch_check_size"
KeyBatchCheckMaxBatchSize = "limit.max_batch_check_size"
KeyBatchCheckParallelizationLimit = "limit.batch_check_max_parallelization"

KeyReadAPIHost = "serve." + string(EndpointRead) + ".host"
KeyReadAPIPort = "serve." + string(EndpointRead) + ".port"
Expand Down Expand Up @@ -190,6 +191,10 @@ func (k *Config) BatchCheckMaxBatchSize() int {
return k.p.Int(KeyBatchCheckMaxBatchSize)
}

func (k *Config) BatchCheckParallelizationLimit() int {
return k.p.Int(KeyBatchCheckParallelizationLimit)
}

func (k *Config) CORS(iface string) (cors.Options, bool) {
switch iface {
case "read", "write", "metrics":
Expand Down
4 changes: 2 additions & 2 deletions internal/e2e/cli_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,12 @@ func (g *cliClient) check(t require.TestingT, r *ketoapi.RelationTuple) bool {
}

func (g *cliClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple,
parallelizationFactor *int, expected herodot.DefaultError) {
expected herodot.DefaultError) {
if t, ok := t.(*testing.T); ok {
t.Skip("not implemented for the CLI")
}
}
func (g *cliClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple, parallelizationFactor *int) []checkResponse {
func (g *cliClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) []checkResponse {
if t, ok := t.(*testing.T); ok {
t.Skip("not implemented for the CLI")
}
Expand Down
4 changes: 2 additions & 2 deletions internal/e2e/full_suit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ type (
queryTuple(t require.TestingT, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter) *ketoapi.GetResponse
queryTupleErr(t require.TestingT, expected herodot.DefaultError, q *ketoapi.RelationQuery, opts ...x.PaginationOptionSetter)
check(t require.TestingT, r *ketoapi.RelationTuple) bool
batchCheck(t require.TestingT, r []*ketoapi.RelationTuple, parallelizationFactor *int) []checkResponse
batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, parallelizationFactor *int, expected herodot.DefaultError)
batchCheck(t require.TestingT, r []*ketoapi.RelationTuple) []checkResponse
batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple, expected herodot.DefaultError)
expand(t require.TestingT, r *ketoapi.SubjectSet, depth int) *ketoapi.Tree[*ketoapi.RelationTuple]
oplCheckSyntax(t require.TestingT, content []byte) []*ketoapi.ParseError
waitUntilLive(t require.TestingT)
Expand Down
16 changes: 5 additions & 11 deletions internal/e2e/grpc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import (
"encoding/json"
"time"

"github.com/ory/x/pointerx"

"github.com/ory/keto/ketoapi"
opl "github.com/ory/keto/proto/ory/keto/opl/v1alpha1"

Expand Down Expand Up @@ -158,18 +156,18 @@ type checkResponse struct {
}

func (g *grpcClient) batchCheckErr(t require.TestingT, requestTuples []*ketoapi.RelationTuple,
parallelizationFactor *int, expected herodot.DefaultError) {
expected herodot.DefaultError) {

_, err := g.doBatchCheck(t, requestTuples, parallelizationFactor)
_, err := g.doBatchCheck(t, requestTuples)
require.Error(t, err)
s, ok := status.FromError(err)
require.True(t, ok)
assert.Equal(t, expected.GRPCCodeField, s.Code(), "%+v", err)
assert.Contains(t, s.Message(), expected.Reason())
}

func (g *grpcClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple, parallelizationFactor *int) []checkResponse {
resp, err := g.doBatchCheck(t, requestTuples, parallelizationFactor)
func (g *grpcClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) []checkResponse {
resp, err := g.doBatchCheck(t, requestTuples)
require.NoError(t, err)

checkResponses := make([]checkResponse, len(resp.Results))
Expand All @@ -183,8 +181,7 @@ func (g *grpcClient) batchCheck(t require.TestingT, requestTuples []*ketoapi.Rel
return checkResponses
}

func (g *grpcClient) doBatchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple,
parallelizationFactor *int) (*rts.BatchCheckResponse, error) {
func (g *grpcClient) doBatchCheck(t require.TestingT, requestTuples []*ketoapi.RelationTuple) (*rts.BatchCheckResponse, error) {

c := rts.NewCheckServiceClient(g.readConn(t))

Expand All @@ -207,9 +204,6 @@ func (g *grpcClient) doBatchCheck(t require.TestingT, requestTuples []*ketoapi.R
req := &rts.BatchCheckRequest{
Tuples: tuples,
}
if parallelizationFactor != nil {
req.ParallelizationFactor = pointerx.Ptr(int32(*parallelizationFactor))
}
return c.BatchCheck(g.ctx, req)
}

Expand Down
Loading

0 comments on commit 42639c0

Please sign in to comment.