diff --git a/repl/repl.go b/repl/repl.go index 524418483d..ef736825ac 100644 --- a/repl/repl.go +++ b/repl/repl.go @@ -23,7 +23,6 @@ import ( "github.com/open-policy-agent/opa/topdown" "github.com/open-policy-agent/opa/topdown/explain" "github.com/open-policy-agent/opa/version" - "github.com/peterh/liner" ) @@ -1081,7 +1080,7 @@ func singleValue(body ast.Body) bool { } func dumpStorage(store *storage.Storage, txn storage.Transaction, w io.Writer) error { - data, err := store.Read(txn, ast.Ref{ast.DefaultRootDocument}) + data, err := store.Read(txn, storage.Path{}) if err != nil { return err } @@ -1105,8 +1104,13 @@ func mangleEvent(store *storage.Storage, txn storage.Transaction, event *topdown var err error event.Locals.Iter(func(k, v ast.Value) bool { if r, ok := v.(ast.Ref); ok { + var path storage.Path + path, err = storage.NewPathForRef(r) + if err != nil { + return true + } var doc interface{} - doc, err = store.Read(txn, r) + doc, err = store.Read(txn, path) if err != nil { return true } diff --git a/runtime/runtime.go b/runtime/runtime.go index cd6fda985c..e2ea60502d 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -122,8 +122,7 @@ func (rt *Runtime) init(params *Params) error { defer store.Close(txn) - ref := ast.Ref{ast.DefaultRootDocument} - if err := store.Write(txn, storage.AddOp, ref, loaded.Documents); err != nil { + if err := store.Write(txn, storage.AddOp, storage.Path{}, loaded.Documents); err != nil { return errors.Wrapf(err, "storage error") } @@ -240,8 +239,7 @@ func (rt *Runtime) processWatcherUpdate(paths []string) error { defer rt.Store.Close(txn) - ref := ast.Ref{ast.DefaultRootDocument} - if err := rt.Store.Write(txn, storage.AddOp, ref, loaded.Documents); err != nil { + if err := rt.Store.Write(txn, storage.AddOp, storage.Path{}, loaded.Documents); err != nil { return err } diff --git a/runtime/runtime_test.go b/runtime/runtime_test.go index ee8014dd54..5831823067 100644 --- a/runtime/runtime_test.go +++ b/runtime/runtime_test.go @@ -97,7 +97,7 @@ func TestInit(t *testing.T) { txn := storage.NewTransactionOrDie(rt.Store) - node, err := rt.Store.Read(txn, ast.MustParseRef("data.foo")) + node, err := rt.Store.Read(txn, storage.MustParsePath("/foo")) if util.Compare(node, "bar") != 0 || err != nil { t.Errorf("Expected %v but got %v (err: %v)", "bar", node, err) return diff --git a/server/server.go b/server/server.go index d4026ad314..05d1ffc127 100644 --- a/server/server.go +++ b/server/server.go @@ -59,7 +59,7 @@ const compileQueryErrMsg = "error(s) occurred while compiling query, see Errors" // attempts to modify a virtual document or create a document at a path that // conflicts with an existing document. type WriteConflictError struct { - path ast.Ref + path storage.Path } func (err WriteConflictError) Error() string { @@ -72,6 +72,26 @@ func IsWriteConflict(err error) bool { return ok } +type badRequestError string + +// isBadRequest reqturns true if the error indicates a badly formatted request. +func isBadRequest(err error) bool { + _, ok := err.(badRequestError) + return ok +} + +func (err badRequestError) Error() string { + return string(err) +} + +func badPatchOperationError(op string) badRequestError { + return badRequestError(fmt.Sprintf("bad patch operation: %v", op)) +} + +func badPatchPathError(path string) badRequestError { + return badRequestError(fmt.Sprintf("bad patch path: %v", path)) +} + // patchV1 models a single patch operation against a document. type patchV1 struct { Op string `json:"op"` @@ -246,6 +266,12 @@ func newBindingsV1(locals *ast.ValueMap) (result []*bindingV1) { return result } +type patchImpl struct { + path storage.Path + op storage.PatchOp + value interface{} +} + // Server represents an instance of OPA running in server mode. type Server struct { Handler http.Handler @@ -504,7 +530,6 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { func (s *Server) v1DataPatch(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) - root := stringPathToDataRef(vars["path"]) ops := []patchV1{} if err := json.NewDecoder(r.Body).Decode(&ops); err != nil { @@ -520,32 +545,14 @@ func (s *Server) v1DataPatch(w http.ResponseWriter, r *http.Request) { defer s.store.Close(txn) - for i := range ops { - - var op storage.PatchOp - - // TODO this could be refactored for failure handling - switch ops[i].Op { - case "add": - op = storage.AddOp - case "remove": - op = storage.RemoveOp - case "replace": - op = storage.ReplaceOp - default: - handleErrorf(w, 400, "bad patch operation: %v", ops[i].Op) - return - } - - path := root - path = append(path, stringPathToRef(ops[i].Path)...) - - if err := s.writeConflict(op, path); err != nil { - handleErrorAuto(w, err) - return - } + patches, err := s.prepareV1PatchSlice(vars["path"], ops) + if err != nil { + handleErrorAuto(w, err) + return + } - if err := s.store.Write(txn, op, path, ops[i].Value); err != nil { + for _, patch := range patches { + if err := s.store.Write(txn, patch.op, patch.path, patch.value); err != nil { handleErrorAuto(w, err) return } @@ -571,10 +578,11 @@ func (s *Server) v1DataPut(w http.ResponseWriter, r *http.Request) { defer s.store.Close(txn) - // The path route variable contains the path portion *after* /v1/data so we - // prepend the global root document here. - path := ast.Ref{ast.DefaultRootDocument} - path = append(path, stringPathToRef(vars["path"])...) + path, ok := storage.ParsePath("/" + strings.Trim(vars["path"], "/")) + if !ok { + handleErrorf(w, 400, "bad path format %v", vars["path"]) + return + } _, err = s.store.Read(txn, path) @@ -825,7 +833,7 @@ func (s *Server) setCompiler(compiler *ast.Compiler) { s.compiler = compiler } -func (s *Server) makeDir(txn storage.Transaction, path ast.Ref) error { +func (s *Server) makeDir(txn storage.Transaction, path storage.Path) error { node, err := s.store.Read(txn, path) if err == nil { @@ -850,14 +858,61 @@ func (s *Server) makeDir(txn storage.Transaction, path ast.Ref) error { return s.store.Write(txn, storage.AddOp, path, map[string]interface{}{}) } +func (s *Server) prepareV1PatchSlice(root string, ops []patchV1) (result []patchImpl, err error) { + + root = "/" + strings.Trim(root, "/") + + for _, op := range ops { + impl := patchImpl{ + value: op.Value, + } + + // Map patch operation. + switch op.Op { + case "add": + impl.op = storage.AddOp + case "remove": + impl.op = storage.RemoveOp + case "replace": + impl.op = storage.ReplaceOp + default: + return nil, badPatchOperationError(op.Op) + } + + // Construct patch path. + path := strings.Trim(op.Path, "/") + if len(path) > 0 { + path = root + "/" + path + } else { + path = root + } + + var ok bool + impl.path, ok = storage.ParsePath(path) + if !ok { + return nil, badPatchPathError(op.Path) + } + + if err := s.writeConflict(impl.op, impl.path); err != nil { + return nil, err + } + + result = append(result, impl) + } + + return result, nil +} + // TODO(tsandall): this ought to be enforced by the storage layer. -func (s *Server) writeConflict(op storage.PatchOp, path ast.Ref) error { +func (s *Server) writeConflict(op storage.PatchOp, path storage.Path) error { - if op == storage.AddOp && path[len(path)-1].Value.Equal(ast.String("-")) { + if op == storage.AddOp && len(path) > 0 && path[len(path)-1] == "-" { path = path[:len(path)-1] } - if rs := s.Compiler().GetRulesForVirtualDocument(path); rs != nil { + ref := path.Ref(ast.DefaultRootDocument) + + if rs := s.Compiler().GetRulesForVirtualDocument(ref); rs != nil { return WriteConflictError{path} } @@ -908,6 +963,10 @@ func handleErrorAuto(w http.ResponseWriter, err error) { handleError(w, 404, err) return } + if isBadRequest(curr) { + handleError(w, http.StatusBadRequest, err) + return + } if storage.IsInvalidPatch(curr) { handleError(w, 400, err) return diff --git a/server/server_test.go b/server/server_test.go index 9d9e50b014..f63d0bdd99 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -142,14 +142,14 @@ func TestDataV1(t *testing.T) { tr{"PUT", "/data/a/b", `[1,2,3,4]`, 204, ""}, tr{"PUT", "/data/a/b/c/d", "0", 404, `{ "Code": 404, - "Message": "write conflict: data.a.b" + "Message": "write conflict: /a/b" }`}, }}, {"put virtual write conflict", []tr{ tr{"PUT", "/policies/test", testMod2, 200, ""}, tr{"PUT", "/data/testmod/q/x", "0", 404, `{ "Code": 404, - "Message": "write conflict: data.testmod.q" + "Message": "write conflict: /testmod/q" }`}, }}, {"get virtual", []tr{ @@ -161,7 +161,7 @@ func TestDataV1(t *testing.T) { tr{"PUT", "/policies/test", testMod1, 200, ""}, tr{"PATCH", "/data/testmod/p", `[{"op": "add", "path": "-", "value": 1}]`, 404, `{ "Code": 404, - "Message": "write conflict: data.testmod.p" + "Message": "write conflict: /testmod/p" }`}, }}, {"get with global", []tr{ @@ -665,7 +665,7 @@ func (queryBindingErrStore) ID() string { return "mock" } -func (s *queryBindingErrStore) Read(txn storage.Transaction, ref ast.Ref) (interface{}, error) { +func (s *queryBindingErrStore) Read(txn storage.Transaction, path storage.Path) (interface{}, error) { // At this time, the store will receive two reads: // - The first during evaluation // - The second when the server tries to accumulate the bindings @@ -676,7 +676,7 @@ func (s *queryBindingErrStore) Read(txn storage.Transaction, ref ast.Ref) (inter return "", nil } -func (queryBindingErrStore) Begin(txn storage.Transaction, refs []ast.Ref) error { +func (queryBindingErrStore) Begin(txn storage.Transaction, params storage.TransactionParams) error { return nil } @@ -689,7 +689,7 @@ func TestQueryBindingIterationError(t *testing.T) { store := storage.New(storage.InMemoryConfig()) mock := &queryBindingErrStore{} - if err := store.Mount(mock, ast.MustParseRef("data.foo.bar")); err != nil { + if err := store.Mount(mock, storage.MustParsePath("/foo/bar")); err != nil { panic(err) } diff --git a/storage/datastore.go b/storage/datastore.go index 2fa2785ae1..dc409fe363 100644 --- a/storage/datastore.go +++ b/storage/datastore.go @@ -9,22 +9,20 @@ import ( "fmt" "io" - "github.com/open-policy-agent/opa/ast" + "strconv" ) -// DataStore is the backend containing rule references and data. +// DataStore is a simple in-memory data store that implements the storage.Store interface. type DataStore struct { - mountPath ast.Ref - data map[string]interface{} - triggers map[string]TriggerConfig + data map[string]interface{} + triggers map[string]TriggerConfig } // NewDataStore returns an empty DataStore. func NewDataStore() *DataStore { return &DataStore{ - data: map[string]interface{}{}, - triggers: map[string]TriggerConfig{}, - mountPath: ast.Ref{ast.DefaultRootDocument}, + data: map[string]interface{}{}, + triggers: map[string]TriggerConfig{}, } } @@ -33,7 +31,7 @@ func NewDataStore() *DataStore { func NewDataStoreFromJSONObject(data map[string]interface{}) *DataStore { ds := NewDataStore() for k, v := range data { - if err := ds.patch(AddOp, []interface{}{k}, v); err != nil { + if err := ds.patch(AddOp, Path{k}, v); err != nil { panic(err) } } @@ -51,19 +49,13 @@ func NewDataStoreFromReader(r io.Reader) *DataStore { return NewDataStoreFromJSONObject(data) } -// SetMountPath updates the data store's mount path. This is the path the data -// store expects all references to be prefixed with. -func (ds *DataStore) SetMountPath(ref ast.Ref) { - ds.mountPath = ref -} - // ID returns a unique identifier for the in-memory store. func (ds *DataStore) ID() string { return "org.openpolicyagent/in-memory" } // Begin is called when a new transaction is started. -func (ds *DataStore) Begin(txn Transaction, refs []ast.Ref) error { +func (ds *DataStore) Begin(txn Transaction, params TransactionParams) error { // TODO(tsandall): return nil } @@ -85,64 +77,20 @@ func (ds *DataStore) Unregister(id string) { } // Read fetches a value from the in-memory store. -func (ds *DataStore) Read(txn Transaction, path ast.Ref) (interface{}, error) { - return ds.getRef(path) +func (ds *DataStore) Read(txn Transaction, path Path) (interface{}, error) { + return get(ds.data, path) } // Write modifies a document referred to by path. -func (ds *DataStore) Write(txn Transaction, op PatchOp, path ast.Ref, value interface{}) error { - p, err := path.Underlying() - if err != nil { - return err - } - // TODO(tsandall): Patch() assumes that paths in writes are relative to - // "data" so drop the head here. - return ds.patch(op, p[1:], value) +func (ds *DataStore) Write(txn Transaction, op PatchOp, path Path, value interface{}) error { + return ds.patch(op, path, value) } func (ds *DataStore) String() string { return fmt.Sprintf("%v", ds.data) } -func (ds *DataStore) get(path []interface{}) (interface{}, error) { - return get(ds.data, path) -} - -func (ds *DataStore) getRef(ref ast.Ref) (interface{}, error) { - - ref = ref[len(ds.mountPath):] - path := make([]interface{}, len(ref)) - - for i, x := range ref { - switch v := x.Value.(type) { - case ast.Ref: - n, err := ds.getRef(v) - if err != nil { - return nil, err - } - path[i] = n - case ast.String: - path[i] = string(v) - case ast.Number: - path[i] = float64(v) - case ast.Boolean: - path[i] = bool(v) - case ast.Null: - path[i] = nil - default: - return nil, fmt.Errorf("illegal reference element: %v", x) - } - } - return ds.get(path) -} - -func (ds *DataStore) mustPatch(op PatchOp, path []interface{}, value interface{}) { - if err := ds.patch(op, path, value); err != nil { - panic(err) - } -} - -func (ds *DataStore) patch(op PatchOp, path []interface{}, value interface{}) error { +func (ds *DataStore) patch(op PatchOp, path Path, value interface{}) error { if len(path) == 0 { if op == AddOp || op == ReplaceOp { @@ -155,15 +103,11 @@ func (ds *DataStore) patch(op PatchOp, path []interface{}, value interface{}) er return invalidPatchErr(rootCannotBeRemovedMsg) } - _, isString := path[0].(string) - if !isString { - return notFoundError(path, stringHeadMsg) - } - for _, t := range ds.triggers { if t.Before != nil { // TODO(tsandall): use correct transaction. - if err := t.Before(invalidTXN, op, path, value); err != nil { + // TODO(tsandall): fix path + if err := t.Before(invalidTXN, op, nil, value); err != nil { return err } } @@ -187,7 +131,8 @@ func (ds *DataStore) patch(op PatchOp, path []interface{}, value interface{}) er for _, t := range ds.triggers { if t.After != nil { // TODO(tsandall): use correct transaction. - if err := t.After(invalidTXN, op, path, value); err != nil { + // TODO(tsandall): fix path + if err := t.After(invalidTXN, op, nil, value); err != nil { return err } } @@ -196,19 +141,16 @@ func (ds *DataStore) patch(op PatchOp, path []interface{}, value interface{}) er return nil } -func add(data map[string]interface{}, path []interface{}, value interface{}) error { +func add(data map[string]interface{}, path Path, value interface{}) error { // Special case for adding a new root. if len(path) == 1 { - return addRoot(data, path[0].(string), value) + return addRoot(data, path[0], value) } // Special case for appending to an array. - switch v := path[len(path)-1].(type) { - case string: - if v == "-" { - return addAppend(data, path[:len(path)-1], value) - } + if path[len(path)-1] == "-" { + return addAppend(data, path[:len(path)-1], value) } node, err := get(data, path[:len(path)-1]) @@ -222,12 +164,12 @@ func add(data map[string]interface{}, path []interface{}, value interface{}) err case []interface{}: return addInsertArray(data, path, node, value) default: - return notFoundError(path, nonCollectionMsg(path[len(path)-2])) + return notFoundError(path, doesNotExistMsg) } } -func addAppend(data map[string]interface{}, path []interface{}, value interface{}) error { +func addAppend(data map[string]interface{}, path Path, value interface{}) error { var parent interface{} = data @@ -244,29 +186,31 @@ func addAppend(data map[string]interface{}, path []interface{}, value interface{ return err } - a, ok := n.([]interface{}) + node, ok := n.([]interface{}) if !ok { - return notFoundError(path, nonArrayMsg(path[len(path)-1])) + return notFoundError(path, doesNotExistMsg) } - a = append(a, value) + node = append(node, value) e := path[len(path)-1] switch parent := parent.(type) { case []interface{}: - i := int(e.(float64)) - parent[i] = a + i, err := strconv.ParseInt(e, 10, 64) + if err != nil { + return notFoundError(path, "array index must be integer") + } + parent[i] = node case map[string]interface{}: - k := e.(string) - parent[k] = a + parent[e] = node default: - panic(fmt.Sprintf("illegal value: %v %v", parent, path)) // "node" exists, therefore this is not reachable. + panic("illegal value") // node exists, therefore parent must be collection. } return nil } -func addInsertArray(data map[string]interface{}, path []interface{}, node []interface{}, value interface{}) error { +func addInsertArray(data map[string]interface{}, path Path, node []interface{}, value interface{}) error { i, err := checkArrayIndex(path, node, path[len(path)-1]) if err != nil { @@ -286,29 +230,22 @@ func addInsertArray(data map[string]interface{}, path []interface{}, node []inte switch parent := parent.(type) { case map[string]interface{}: - k := e.(string) - parent[k] = node + parent[e] = node case []interface{}: - i = int(e.(float64)) + i, err := strconv.ParseInt(e, 10, 64) + if err != nil { + return notFoundError(path, "array index must be integer") + } parent[i] = node default: - panic(fmt.Sprintf("illegal value: %v %v", parent, path)) // "node" exists, therefore this is not reachable. + panic("illegal value") // node exists, therefore parent must be collection. } return nil } -func addInsertObject(data map[string]interface{}, path []interface{}, node map[string]interface{}, value interface{}) error { - - var k string - - switch last := path[len(path)-1].(type) { - case string: - k = last - default: - return notFoundError(path, objectKeyTypeMsg(last)) - } - +func addInsertObject(data map[string]interface{}, path Path, node map[string]interface{}, value interface{}) error { + k := path[len(path)-1] node[k] = value return nil } @@ -318,16 +255,12 @@ func addRoot(data map[string]interface{}, key string, value interface{}) error { return nil } -func get(data map[string]interface{}, path []interface{}) (interface{}, error) { +func get(data map[string]interface{}, path Path) (interface{}, error) { if len(path) == 0 { return data, nil } - head, ok := path[0].(string) - if !ok { - return nil, notFoundError(path, stringHeadMsg) - } - + head := path[0] node, ok := data[head] if !ok { return nil, notFoundError(path, doesNotExistMsg) @@ -352,14 +285,14 @@ func get(data map[string]interface{}, path []interface{}) (interface{}, error) { node = n[idx] default: - return nil, notFoundError(path, nonCollectionMsg(v)) + return nil, notFoundError(path, doesNotExistMsg) } } return node, nil } -func mustGet(data map[string]interface{}, path []interface{}) interface{} { +func mustGet(data map[string]interface{}, path Path) interface{} { r, err := get(data, path) if err != nil { panic(err) @@ -367,7 +300,7 @@ func mustGet(data map[string]interface{}, path []interface{}) interface{} { return r } -func remove(data map[string]interface{}, path []interface{}) error { +func remove(data map[string]interface{}, path Path) error { if _, err := get(data, path); err != nil { return err @@ -375,7 +308,7 @@ func remove(data map[string]interface{}, path []interface{}) error { // Special case for removing a root. if len(path) == 1 { - return removeRoot(data, path[0].(string)) + return removeRoot(data, path[0]) } node := mustGet(data, path[:len(path)-1]) @@ -386,11 +319,11 @@ func remove(data map[string]interface{}, path []interface{}) error { case map[string]interface{}: return removeObject(data, path, node) default: - return notFoundError(path, nonCollectionMsg(path[len(path)-2])) + return notFoundError(path, doesNotExistMsg) } } -func removeArray(data map[string]interface{}, path []interface{}, node []interface{}) error { +func removeArray(data map[string]interface{}, path Path, node []interface{}) error { i, err := checkArrayIndex(path, node, path[len(path)-1]) if err != nil { @@ -408,10 +341,12 @@ func removeArray(data map[string]interface{}, path []interface{}, node []interfa switch parent := parent.(type) { case map[string]interface{}: - k := e.(string) - parent[k] = node + parent[e] = node case []interface{}: - i = int(e.(float64)) + i, err := strconv.ParseInt(e, 10, 64) + if err != nil { + return notFoundError(path, "array index must be integer") + } parent[i] = node default: panic(fmt.Sprintf("illegal value: %v %v", parent, path)) // "node" exists, therefore this is not reachable. @@ -420,7 +355,7 @@ func removeArray(data map[string]interface{}, path []interface{}, node []interfa return nil } -func removeObject(data map[string]interface{}, path []interface{}, node map[string]interface{}) error { +func removeObject(data map[string]interface{}, path Path, node map[string]interface{}) error { k, err := checkObjectKey(path, node, path[len(path)-1]) if err != nil { return err @@ -435,7 +370,7 @@ func removeRoot(data map[string]interface{}, root string) error { return nil } -func replace(data map[string]interface{}, path []interface{}, value interface{}) error { +func replace(data map[string]interface{}, path Path, value interface{}) error { if _, err := get(data, path); err != nil { return err @@ -453,12 +388,12 @@ func replace(data map[string]interface{}, path []interface{}, value interface{}) case []interface{}: return replaceArray(data, path, node, value) default: - return notFoundError(path, nonCollectionMsg(path[len(path)-2])) + return notFoundError(path, doesNotExistMsg) } } -func replaceObject(data map[string]interface{}, path []interface{}, node map[string]interface{}, value interface{}) error { +func replaceObject(data map[string]interface{}, path Path, node map[string]interface{}, value interface{}) error { k, err := checkObjectKey(path, node, path[len(path)-1]) if err != nil { return err @@ -468,13 +403,13 @@ func replaceObject(data map[string]interface{}, path []interface{}, node map[str return nil } -func replaceRoot(data map[string]interface{}, path []interface{}, value interface{}) error { - root := path[0].(string) +func replaceRoot(data map[string]interface{}, path Path, value interface{}) error { + root := path[0] data[root] = value return nil } -func replaceArray(data map[string]interface{}, path []interface{}, node []interface{}, value interface{}) error { +func replaceArray(data map[string]interface{}, path Path, node []interface{}, value interface{}) error { i, err := checkArrayIndex(path, node, path[len(path)-1]) if err != nil { return err @@ -484,27 +419,19 @@ func replaceArray(data map[string]interface{}, path []interface{}, node []interf return nil } -func checkObjectKey(path []interface{}, node map[string]interface{}, v interface{}) (string, error) { - k, ok := v.(string) - if !ok { - return "", notFoundError(path, objectKeyTypeMsg(v)) - } - _, ok = node[string(k)] - if !ok { +func checkObjectKey(path Path, node map[string]interface{}, v string) (string, error) { + if _, ok := node[v]; !ok { return "", notFoundError(path, doesNotExistMsg) } - return string(k), nil + return v, nil } -func checkArrayIndex(path []interface{}, node []interface{}, v interface{}) (int, error) { - f, isFloat := v.(float64) - if !isFloat { - return 0, notFoundError(path, arrayIndexTypeMsg(v)) - } - i := int(f) - if float64(i) != f { - return 0, notFoundError(path, arrayIndexTypeMsg(v)) +func checkArrayIndex(path Path, node []interface{}, v string) (int, error) { + i64, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return 0, notFoundError(path, "array index must be integer") } + i := int(i64) if i >= len(node) { return 0, notFoundError(path, outOfRangeMsg) } else if i < 0 { diff --git a/storage/datastore_test.go b/storage/datastore_test.go index 1c4a76bd0a..829a44b68b 100644 --- a/storage/datastore_test.go +++ b/storage/datastore_test.go @@ -9,55 +9,48 @@ import ( "fmt" "reflect" "testing" - - "github.com/open-policy-agent/opa/ast" ) -func TestStorageGet(t *testing.T) { +func TestDataStoreGet(t *testing.T) { data := loadSmallTestData() var tests = []struct { - ref string + path string expected interface{} }{ - {"a[0]", float64(1)}, - {"a[3]", float64(4)}, - {"b.v1", "hello"}, - {"b.v2", "goodbye"}, - {"c[0].x[1]", false}, - {"c[0].y[0]", nil}, - {"c[0].y[1]", 3.14159}, - {"d.e[1]", "baz"}, - {"d.e", []interface{}{"bar", "baz"}}, - {"c[0].z", map[string]interface{}{"p": true, "q": false}}, - {"d[100]", notFoundError(path("d[100]"), objectKeyTypeMsg(float64(100)))}, - {"dead.beef", notFoundError(path("dead.beef"), doesNotExistMsg)}, - {"a.str", notFoundError(path("a.str"), arrayIndexTypeMsg("str"))}, - {"a[100]", notFoundError(path("a[100]"), outOfRangeMsg)}, - {"a[-1]", notFoundError(path("a[-1]"), outOfRangeMsg)}, - {"b.vdeadbeef", notFoundError(path("b.vdeadbeef"), doesNotExistMsg)}, + {"/a/0", float64(1)}, + {"/a/3", float64(4)}, + {"/b/v1", "hello"}, + {"/b/v2", "goodbye"}, + {"/c/0/x/1", false}, + {"/c/0/y/0", nil}, + {"/c/0/y/1", 3.14159}, + {"/d/e/1", "baz"}, + {"/d/e", []interface{}{"bar", "baz"}}, + {"/c/0/z", map[string]interface{}{"p": true, "q": false}}, + {"/d/100", notFoundError(MustParsePath("/d/100"), doesNotExistMsg)}, + {"/dead/beef", notFoundError(MustParsePath("/dead/beef"), doesNotExistMsg)}, + {"/a/str", notFoundError(MustParsePath("/a/str"), arrayIndexTypeMsg)}, + {"/a/100", notFoundError(MustParsePath("/a/100"), outOfRangeMsg)}, + {"/a/-1", notFoundError(MustParsePath("/a/-1"), outOfRangeMsg)}, + {"/b/vdeadbeef", notFoundError(MustParsePath("/b/vdeadbeef"), doesNotExistMsg)}, } ds := NewDataStoreFromJSONObject(data) for idx, tc := range tests { - ref := ast.MustParseRef(tc.ref) - path, err := ref.Underlying() - if err != nil { - panic(err) - } - result, err := ds.get(path) + result, err := ds.Read(nil, MustParsePath(tc.path)) switch e := tc.expected.(type) { case error: if err == nil { - t.Errorf("Test case %d: expected error for %v but got %v", idx+1, ref, result) + t.Errorf("Test case %d: expected error for %v but got %v", idx+1, tc.path, result) } else if !reflect.DeepEqual(err, tc.expected) { - t.Errorf("Test case %d: unexpected error for %v: %v, expected: %v", idx+1, ref, err, e) + t.Errorf("Test case %d: unexpected error for %v: %v, expected: %v", idx+1, tc.path, err, e) } default: if err != nil { - t.Errorf("Test case %d: expected success for %v but got %v", idx+1, ref, err) + t.Errorf("Test case %d: expected success for %v but got %v", idx+1, tc.path, err) } if !reflect.DeepEqual(result, tc.expected) { t.Errorf("Test case %d: expected %f but got %f", idx+1, tc.expected, result) @@ -67,60 +60,58 @@ func TestStorageGet(t *testing.T) { } -func TestStoragePatch(t *testing.T) { +func TestDataStorePatch(t *testing.T) { tests := []struct { note string op string - path interface{} + path string value string expected error - getPath interface{} + getPath string getExpected interface{} }{ - {"add root", "add", path([]interface{}{}), `{"a": [1]}`, nil, path([]interface{}{}), `{"a": [1]}`}, - {"add", "add", path("newroot"), `{"a": [[1]]}`, nil, path("newroot"), `{"a": [[1]]}`}, - {"add arr", "add", path("a[1]"), `"x"`, nil, path("a"), `[1,"x",2,3,4]`}, - {"add arr/arr", "add", path("h[1][2]"), `"x"`, nil, path("h"), `[[1,2,3], [2,3,"x",4]]`}, - {"add obj/arr", "add", path("d.e[1]"), `"x"`, nil, path("d"), `{"e": ["bar", "x", "baz"]}`}, - {"add obj", "add", path("b.vNew"), `"x"`, nil, path("b"), `{"v1": "hello", "v2": "goodbye", "vNew": "x"}`}, - {"add obj (existing)", "add", path("b.v2"), `"x"`, nil, path("b"), `{"v1": "hello", "v2": "x"}`}, - - {"append arr", "add", path(`a["-"]`), `"x"`, nil, path("a"), `[1,2,3,4,"x"]`}, - {"append obj/arr", "add", path(`c[0].x["-"]`), `"x"`, nil, path("c[0].x"), `[true,false,"foo","x"]`}, - {"append arr/arr", "add", path(`h[0]["-"]`), `"x"`, nil, path(`h[0][3]`), `"x"`}, - - {"remove", "remove", path("a"), "", nil, path("a"), notFoundError(path("a"), doesNotExistMsg)}, - {"remove arr", "remove", path("a[1]"), "", nil, path("a"), "[1,3,4]"}, - {"remove obj/arr", "remove", path("c[0].x[1]"), "", nil, path("c[0].x"), `[true,"foo"]`}, - {"remove arr/arr", "remove", path("h[0][1]"), "", nil, path("h[0]"), "[1,3]"}, - {"remove obj", "remove", path("b.v2"), "", nil, path("b"), `{"v1": "hello"}`}, - - {"replace root", "replace", path([]interface{}{}), `{"a": [1]}`, nil, path([]interface{}{}), `{"a": [1]}`}, - {"replace", "replace", path("a"), "1", nil, path("a"), "1"}, - {"replace obj", "replace", path("b.v1"), "1", nil, path("b"), `{"v1": 1, "v2": "goodbye"}`}, - {"replace array", "replace", path("a[1]"), "999", nil, path("a"), "[1,999,3,4]"}, - - {"err: bad root type", "add", []interface{}{}, "[1,2,3]", invalidPatchErr(rootMustBeObjectMsg), nil, nil}, - {"err: remove root", "remove", []interface{}{}, "", invalidPatchErr(rootCannotBeRemovedMsg), nil, nil}, - {"err: non-string head", "add", []interface{}{float64(1)}, "", notFoundError([]interface{}{float64(1)}, stringHeadMsg), nil, nil}, - {"err: add arr (non-integer)", "add", path("a.foo"), "1", notFoundError(path("a.foo"), arrayIndexTypeMsg("xxx")), nil, nil}, - {"err: add arr (non-integer)", "add", path("a[3.14]"), "1", notFoundError(path("a[3.14]"), arrayIndexTypeMsg(3.14)), nil, nil}, - {"err: add arr (out of range)", "add", path("a[5]"), "1", notFoundError(path("a[5]"), outOfRangeMsg), nil, nil}, - {"err: add arr (out of range)", "add", path("a[-1]"), "1", notFoundError(path("a[-1]"), outOfRangeMsg), nil, nil}, - {"err: add arr (missing root)", "add", path("dead.beef[0]"), "1", notFoundError(path("dead.beef"), doesNotExistMsg), nil, nil}, - {"err: add obj (non-string)", "add", path("b[100]"), "1", notFoundError(path("b[100]"), objectKeyTypeMsg(float64(100))), nil, nil}, - {"err: add non-coll", "add", path("a[1][2]"), "1", notFoundError(path("a[1][2]"), nonCollectionMsg(float64(1))), nil, nil}, - {"err: append (missing)", "add", path(`dead.beef["-"]`), "1", notFoundError(path("dead"), doesNotExistMsg), nil, nil}, - {"err: append obj/arr", "add", path(`c[0].deadbeef["-"]`), `"x"`, notFoundError(path("c[0].deadbeef"), doesNotExistMsg), nil, nil}, - {"err: append arr/arr (out of range)", "add", path(`h[9999]["-"]`), `"x"`, notFoundError(path("h[9999]"), outOfRangeMsg), nil, nil}, - {"err: append append+add", "add", path(`a["-"].b["-"]`), `"x"`, notFoundError(path(`a["-"]`), arrayIndexTypeMsg("-")), nil, nil}, - {"err: append arr/arr (non-array)", "add", path(`b.v1["-"]`), "1", notFoundError(path("b.v1"), nonArrayMsg("v1")), nil, nil}, - {"err: remove missing", "remove", path("dead.beef[0]"), "", notFoundError(path("dead.beef[0]"), doesNotExistMsg), nil, nil}, - {"err: remove obj (non string)", "remove", path("b[100]"), "", notFoundError(path("b[100]"), objectKeyTypeMsg(float64(100))), nil, nil}, - {"err: remove obj (missing)", "remove", path("b.deadbeef"), "", notFoundError(path("b.deadbeef"), doesNotExistMsg), nil, nil}, - {"err: replace root (missing)", "replace", path("deadbeef"), "1", notFoundError(path("deadbeef"), doesNotExistMsg), nil, nil}, - {"err: replace missing", "replace", "dead.beef[1]", "1", notFoundError(path("dead.beef[1]"), doesNotExistMsg), nil, nil}, + {"add root", "add", "/", `{"a": [1]}`, nil, "/", `{"a": [1]}`}, + {"add", "add", "/newroot", `{"a": [[1]]}`, nil, "/newroot", `{"a": [[1]]}`}, + {"add arr", "add", "/a/1", `"x"`, nil, "/a", `[1,"x",2,3,4]`}, + {"add arr/arr", "add", "/h/1/2", `"x"`, nil, "/h", `[[1,2,3], [2,3,"x",4]]`}, + {"add obj/arr", "add", "/d/e/1", `"x"`, nil, "/d", `{"e": ["bar", "x", "baz"]}`}, + {"add obj", "add", "/b/vNew", `"x"`, nil, "/b", `{"v1": "hello", "v2": "goodbye", "vNew": "x"}`}, + {"add obj (existing)", "add", "/b/v2", `"x"`, nil, "/b", `{"v1": "hello", "v2": "x"}`}, + + {"append arr", "add", "/a/-", `"x"`, nil, "/a", `[1,2,3,4,"x"]`}, + {"append obj/arr", "add", `/c/0/x/-`, `"x"`, nil, "/c/0/x", `[true,false,"foo","x"]`}, + {"append arr/arr", "add", `/h/0/-`, `"x"`, nil, `/h/0/3`, `"x"`}, + + {"remove", "remove", "/a", "", nil, "/a", notFoundError(MustParsePath("/a"), doesNotExistMsg)}, + {"remove arr", "remove", "/a/1", "", nil, "/a", "[1,3,4]"}, + {"remove obj/arr", "remove", "/c/0/x/1", "", nil, "/c/0/x", `[true,"foo"]`}, + {"remove arr/arr", "remove", "/h/0/1", "", nil, "/h/0", "[1,3]"}, + {"remove obj", "remove", "/b/v2", "", nil, "/b", `{"v1": "hello"}`}, + + {"replace root", "replace", "/", `{"a": [1]}`, nil, "/", `{"a": [1]}`}, + {"replace", "replace", "/a", "1", nil, "/a", "1"}, + {"replace obj", "replace", "/b/v1", "1", nil, "/b", `{"v1": 1, "v2": "goodbye"}`}, + {"replace array", "replace", "/a/1", "999", nil, "/a", "[1,999,3,4]"}, + + {"err: bad root type", "add", "/", "[1,2,3]", invalidPatchErr(rootMustBeObjectMsg), "", nil}, + {"err: remove root", "remove", "/", "", invalidPatchErr(rootCannotBeRemovedMsg), "", nil}, + {"err: add arr (non-integer)", "add", "/a/foo", "1", notFoundError(MustParsePath("/a/foo"), arrayIndexTypeMsg), "", nil}, + {"err: add arr (non-integer)", "add", "/a/3.14", "1", notFoundError(MustParsePath("/a/3.14"), arrayIndexTypeMsg), "", nil}, + {"err: add arr (out of range)", "add", "/a/5", "1", notFoundError(MustParsePath("/a/5"), outOfRangeMsg), "", nil}, + {"err: add arr (out of range)", "add", "/a/-1", "1", notFoundError(MustParsePath("/a/-1"), outOfRangeMsg), "", nil}, + {"err: add arr (missing root)", "add", "/dead/beef/0", "1", notFoundError(MustParsePath("/dead/beef"), doesNotExistMsg), "", nil}, + {"err: add non-coll", "add", "/a/1/2", "1", notFoundError(MustParsePath("/a/1/2"), doesNotExistMsg), "", nil}, + {"err: append (missing)", "add", `/dead/beef/-`, "1", notFoundError(MustParsePath("/dead"), doesNotExistMsg), "", nil}, + {"err: append obj/arr", "add", `/c/0/deadbeef/-`, `"x"`, notFoundError(MustParsePath("/c/0/deadbeef"), doesNotExistMsg), "", nil}, + {"err: append arr/arr (out of range)", "add", `/h/9999/-`, `"x"`, notFoundError(MustParsePath("/h/9999"), outOfRangeMsg), "", nil}, + {"err: append append+add", "add", `/a/-/b/-`, `"x"`, notFoundError(MustParsePath(`/a/-`), arrayIndexTypeMsg), "", nil}, + {"err: append arr/arr (non-array)", "add", `/b/v1/-`, "1", notFoundError(MustParsePath("/b/v1"), doesNotExistMsg), "", nil}, + {"err: remove missing", "remove", "/dead/beef/0", "", notFoundError(MustParsePath("/dead/beef/0"), doesNotExistMsg), "", nil}, + {"err: remove obj (non string)", "remove", "/b/100", "", notFoundError(MustParsePath("/b/100"), doesNotExistMsg), "", nil}, + {"err: remove obj (missing)", "remove", "/b/deadbeef", "", notFoundError(MustParsePath("/b/deadbeef"), doesNotExistMsg), "", nil}, + {"err: replace root (missing)", "replace", "/deadbeef", "1", notFoundError(MustParsePath("/deadbeef"), doesNotExistMsg), "", nil}, + {"err: replace missing", "replace", "/dead/beef/1", "1", notFoundError(MustParsePath("/dead/beef/1"), doesNotExistMsg), "", nil}, } for i, tc := range tests { @@ -142,7 +133,7 @@ func TestStoragePatch(t *testing.T) { panic(fmt.Sprintf("illegal value: %v", tc.op)) } - err := ds.patch(op, path(tc.path), value) + err := ds.patch(op, MustParsePath(tc.path), value) if tc.expected == nil { if err != nil { @@ -160,12 +151,12 @@ func TestStoragePatch(t *testing.T) { } } - if tc.getPath == nil { + if tc.getPath == "" { continue } // Perform get and verify result - result, err := ds.get(path(tc.getPath)) + result, err := ds.Read(nil, MustParsePath(tc.getPath)) switch expected := tc.getExpected.(type) { case error: if err == nil { @@ -258,22 +249,3 @@ func loadSmallTestData() map[string]interface{} { } return data } - -func path(input interface{}) []interface{} { - switch input := input.(type) { - case []interface{}: - return input - case string: - switch v := ast.MustParseTerm(input).Value.(type) { - case ast.Var: - return []interface{}{string(v)} - case ast.Ref: - path, err := v.Underlying() - if err != nil { - panic(err) - } - return path - } - } - panic(fmt.Sprintf("illegal value: %v", input)) -} diff --git a/storage/errors.go b/storage/errors.go index 7fdd96def2..2d63ecb099 100644 --- a/storage/errors.go +++ b/storage/errors.go @@ -75,27 +75,11 @@ func IsInvalidPatch(err error) bool { return false } +var doesNotExistMsg = "document does not exist" var rootMustBeObjectMsg = "root must be object" var rootCannotBeRemovedMsg = "root cannot be removed" -var doesNotExistMsg = "document does not exist" var outOfRangeMsg = "array index out of range" -var stringHeadMsg = "path must begin with string" - -func arrayIndexTypeMsg(v interface{}) string { - return fmt.Sprintf("array index must be integer, not %T", v) -} - -func objectKeyTypeMsg(v interface{}) string { - return fmt.Sprintf("object key must be string, not %v (%T)", v, v) -} - -func nonCollectionMsg(v interface{}) string { - return fmt.Sprintf("path refers to non-collection document with element %v", v) -} - -func nonArrayMsg(v interface{}) string { - return fmt.Sprintf("path refers to non-array document with element %v", v) -} +var arrayIndexTypeMsg = "array index must be integer" func indexNotFoundError() *Error { return &Error{ @@ -136,7 +120,7 @@ func mountConflictError() *Error { } } -func notFoundError(path []interface{}, f string, a ...interface{}) *Error { +func notFoundError(path Path, f string, a ...interface{}) *Error { msg := fmt.Sprintf("bad path: %v", path) if len(f) > 0 { msg += ", " + fmt.Sprintf(f, a...) @@ -144,14 +128,6 @@ func notFoundError(path []interface{}, f string, a ...interface{}) *Error { return notFoundErrorf(msg) } -func notFoundErrorf(f string, a ...interface{}) *Error { - msg := fmt.Sprintf(f, a...) - return &Error{ - Code: NotFoundErr, - Message: msg, - } -} - func notFoundRefError(ref ast.Ref, f string, a ...interface{}) *Error { msg := fmt.Sprintf("bad path: %v", ref) if len(f) > 0 { @@ -160,6 +136,14 @@ func notFoundRefError(ref ast.Ref, f string, a ...interface{}) *Error { return notFoundErrorf(msg) } +func notFoundErrorf(f string, a ...interface{}) *Error { + msg := fmt.Sprintf(f, a...) + return &Error{ + Code: NotFoundErr, + Message: msg, + } +} + func triggersNotSupportedError() *Error { return &Error{ Code: TriggersNotSupportedErr, diff --git a/storage/example_test.go b/storage/example_test.go index 4dc7e85626..265d2e9615 100644 --- a/storage/example_test.go +++ b/storage/example_test.go @@ -11,7 +11,6 @@ import ( "os" "path/filepath" - "github.com/open-policy-agent/opa/ast" "github.com/open-policy-agent/opa/storage" ) @@ -50,8 +49,8 @@ func ExampleStorage_Read() { defer store.Close(txn) // Read values out of storage. - v1, err1 := store.Read(txn, ast.MustParseRef("data.users[1].likes[1]")) - v2, err2 := store.Read(txn, ast.MustParseRef("data.users[0].age")) + v1, err1 := store.Read(txn, storage.MustParsePath("/users/1/likes/1")) + v2, err2 := store.Read(txn, storage.MustParsePath("/users/0/age")) // Inspect the return values. fmt.Println("v1:", v1) @@ -64,7 +63,7 @@ func ExampleStorage_Read() { // v1: wine // err1: // v2: - // err2: storage error (code: 1): bad path: [users 0 age], document does not exist + // err2: storage error (code: 1): bad path: /users/0/age, document does not exist // err2 is not found: true } @@ -115,9 +114,9 @@ func ExampleStorage_Write() { defer store.Close(txn) // Write values into storage and read result. - err0 := store.Write(txn, storage.AddOp, ast.MustParseRef("data.users[0].location"), patch) - v1, err1 := store.Read(txn, ast.MustParseRef("data.users[0].location.latitude")) - err2 := store.Write(txn, storage.ReplaceOp, ast.MustParseRef("data.users[1].color"), "red") + err0 := store.Write(txn, storage.AddOp, storage.MustParsePath("/users/0/location"), patch) + v1, err1 := store.Read(txn, storage.MustParsePath("/users/0/location/latitude")) + err2 := store.Write(txn, storage.ReplaceOp, storage.MustParsePath("/users/1/color"), "red") // Inspect the return values. fmt.Println("err0:", err0) @@ -131,7 +130,7 @@ func ExampleStorage_Write() { // err0: // v1: -62.338889 // err1: - // err2: storage error (code: 1): bad path: [users 1 color], document does not exist + // err2: storage error (code: 1): bad path: /users/1/color, document does not exist } diff --git a/storage/index.go b/storage/index.go index 6fe7e3e034..352e2964b1 100644 --- a/storage/index.go +++ b/storage/index.go @@ -118,7 +118,7 @@ func (ind *indices) String() string { return "{" + strings.Join(buf, ", ") + "}" } -func (ind *indices) dropAll(Transaction, PatchOp, []interface{}, interface{}) error { +func (ind *indices) dropAll(Transaction, PatchOp, Path, interface{}) error { ind.table = map[int]*indicesNode{} return nil } @@ -304,9 +304,13 @@ func hash(v interface{}) int { panic(fmt.Sprintf("illegal argument: %v (%T)", v, v)) } -func iterStorage(store Store, txn Transaction, ref ast.Ref, path ast.Ref, bindings *ast.ValueMap, iter func(*ast.ValueMap, interface{})) error { +func iterStorage(store Store, txn Transaction, nonGround, ground ast.Ref, bindings *ast.ValueMap, iter func(*ast.ValueMap, interface{})) error { - if len(ref) == 0 { + if len(nonGround) == 0 { + path, err := NewPathForRef(ground) + if err != nil { + return err + } node, err := store.Read(txn, path) if err != nil { if IsNotFound(err) { @@ -314,19 +318,23 @@ func iterStorage(store Store, txn Transaction, ref ast.Ref, path ast.Ref, bindin } return err } - iter(bindings, node) return nil } - head := ref[0] - tail := ref[1:] + head := nonGround[0] + tail := nonGround[1:] headVar, isVar := head.Value.(ast.Var) - if !isVar || len(path) == 0 { - path = append(path, head) - return iterStorage(store, txn, tail, path, bindings, iter) + if !isVar || len(ground) == 0 { + ground = append(ground, head) + return iterStorage(store, txn, tail, ground, bindings, iter) + } + + path, err := NewPathForRef(ground) + if err != nil { + return err } node, err := store.Read(txn, path) @@ -340,25 +348,25 @@ func iterStorage(store Store, txn Transaction, ref ast.Ref, path ast.Ref, bindin switch node := node.(type) { case map[string]interface{}: for key := range node { - path = append(path, ast.StringTerm(key)) + ground = append(ground, ast.StringTerm(key)) cpy := bindings.Copy() cpy.Put(headVar, ast.String(key)) - err := iterStorage(store, txn, tail, path, cpy, iter) + err := iterStorage(store, txn, tail, ground, cpy, iter) if err != nil { return err } - path = path[:len(path)-1] + ground = ground[:len(ground)-1] } case []interface{}: for i := range node { - path = append(path, ast.NumberTerm(float64(i))) + ground = append(ground, ast.NumberTerm(float64(i))) cpy := bindings.Copy() cpy.Put(headVar, ast.Number(float64(i))) - err := iterStorage(store, txn, tail, path, cpy, iter) + err := iterStorage(store, txn, tail, ground, cpy, iter) if err != nil { return err } - path = path[:len(path)-1] + ground = ground[:len(ground)-1] } } diff --git a/storage/interface.go b/storage/interface.go index 47c2ba39c8..03d05cbe75 100644 --- a/storage/interface.go +++ b/storage/interface.go @@ -4,8 +4,6 @@ package storage -import "github.com/open-policy-agent/opa/ast" - // Store defines the interface for the storage layer's backend. Users can // implement their own stores and mount them into the storage layer to provide // the policy engine access to external data sources. @@ -19,21 +17,40 @@ type Store interface { // Begin is called to indicate that a new transaction has started. The store // can use the call to initialize any resources that may be required for the - // transaction. The caller will provide refs hinting the paths that may be - // read during the transaction. - Begin(txn Transaction, refs []ast.Ref) error + // transaction. + Begin(txn Transaction, params TransactionParams) error - // Read is called to fetch a document referred to by ref. - Read(txn Transaction, ref ast.Ref) (interface{}, error) + // Read is called to fetch a document referred to by path. + Read(txn Transaction, path Path) (interface{}, error) - // Write is called to modify a document referred to by ref. - Write(txn Transaction, op PatchOp, ref ast.Ref, value interface{}) error + // Write is called to modify a document referred to by path. + Write(txn Transaction, op PatchOp, path Path, value interface{}) error // Close indicates a transaction has finished. The store can use the call to // release any resources temporarily allocated for the transaction. Close(txn Transaction) } +// TransactionParams describes a new transaction. +type TransactionParams struct { + + // Paths represents a set of document paths that may be read during the + // transaction. The paths may be provided by the caller to hint to the + // storage layer that certain documents could be pre-loaded. + Paths []Path +} + +// NewTransactionParams returns a new TransactionParams object. +func NewTransactionParams() TransactionParams { + return TransactionParams{} +} + +// WithPaths returns a new TransactionParams object with the paths set. +func (params TransactionParams) WithPaths(paths []Path) TransactionParams { + params.Paths = paths + return params +} + // PatchOp is the enumeration of supposed modifications. type PatchOp int @@ -48,6 +65,6 @@ const ( // interface which may be used if the backend does not support writes. type WritesNotSupported struct{} -func (WritesNotSupported) Write(txn Transaction, op PatchOp, ref ast.Ref, value interface{}) error { +func (WritesNotSupported) Write(txn Transaction, op PatchOp, path Path, value interface{}) error { return writesNotSupportedError() } diff --git a/storage/storage.go b/storage/storage.go index 783efe6485..ec77352acd 100644 --- a/storage/storage.go +++ b/storage/storage.go @@ -56,8 +56,7 @@ type Storage struct { } type mount struct { - path ast.Ref - strpath []string + path Path backend Store } @@ -122,7 +121,7 @@ func (s *Storage) DeletePolicy(txn Transaction, id string) error { // Mount adds a store into the storage layer at the given path. If the path // conflicts with an existing mount, an error is returned. -func (s *Storage) Mount(backend Store, path ast.Ref) error { +func (s *Storage) Mount(backend Store, path Path) error { s.mtx.Lock() defer s.mtx.Unlock() @@ -132,29 +131,19 @@ func (s *Storage) Mount(backend Store, path ast.Ref) error { return mountConflictError() } } - spath := make([]string, len(path)) - for i, x := range path { - switch v := x.Value.(type) { - case ast.String: - spath[i] = string(v) - case ast.Var: - spath[i] = string(v) - default: - return internalError("bad mount path: %v", path) - } - } + m := &mount{ path: path, - strpath: spath, backend: backend, } + s.mounts = append(s.mounts, m) return nil } // Unmount removes a store from the storage layer. If the path does not locate // an existing mount, an error is returned. -func (s *Storage) Unmount(path ast.Ref) error { +func (s *Storage) Unmount(path Path) error { s.mtx.Lock() defer s.mtx.Unlock() @@ -165,23 +154,19 @@ func (s *Storage) Unmount(path ast.Ref) error { return nil } } - return notFoundRefError(path, "unmount") + return notFoundError(path, "unmount") } // Read fetches the value in storage referred to by path. The path may refer to // multiple stores in which case the storage layer will fetch the values from // each store and then stitch together the result. -func (s *Storage) Read(txn Transaction, path ast.Ref) (interface{}, error) { +func (s *Storage) Read(txn Transaction, path Path) (interface{}, error) { type hole struct { path []string doc interface{} } - if !path.IsGround() { - return nil, internalError("non-ground reference: %v", path) - } - holes := []hole{} for _, mount := range s.mounts { @@ -191,7 +176,7 @@ func (s *Storage) Read(txn Transaction, path ast.Ref) (interface{}, error) { if err := s.lazyActivate(mount.backend, txn, nil); err != nil { return nil, err } - return mount.backend.Read(txn, path) + return mount.backend.Read(txn, path[len(mount.path):]) } // Check if read is over this mount (and possibly others) @@ -199,11 +184,11 @@ func (s *Storage) Read(txn Transaction, path ast.Ref) (interface{}, error) { if err := s.lazyActivate(mount.backend, txn, nil); err != nil { return nil, err } - node, err := mount.backend.Read(txn, mount.path) + node, err := mount.backend.Read(txn, Path{}) if err != nil { return nil, err } - prefix := mount.strpath[len(path):] + prefix := mount.path[len(path):] holes = append(holes, hole{prefix, node}) } } @@ -241,23 +226,28 @@ func (s *Storage) Read(txn Transaction, path ast.Ref) (interface{}, error) { } // Write updates a value in storage. -func (s *Storage) Write(txn Transaction, op PatchOp, ref ast.Ref, value interface{}) error { +func (s *Storage) Write(txn Transaction, op PatchOp, path Path, value interface{}) error { + if err := s.lazyActivate(s.builtin, txn, nil); err != nil { return err } - return s.builtin.Write(txn, op, ref, value) + + return s.builtin.Write(txn, op, path, value) } -// NewTransaction returns a new transcation that can be used to perform reads -// and writes against a consistent snapshot of the storage layer. The caller can -// provide a slice of references that may be read during the transaction. -func (s *Storage) NewTransaction(refs ...ast.Ref) (Transaction, error) { +// NewTransaction returns a new Transaction with default parameters. +func (s *Storage) NewTransaction() (Transaction, error) { + return s.NewTransactionWithParams(TransactionParams{}) +} + +// NewTransactionWithParams returns a new Transaction. +func (s *Storage) NewTransactionWithParams(params TransactionParams) (Transaction, error) { s.mtx.Lock() s.txn++ txn := s.txn - if err := s.notifyStoresBegin(txn, refs); err != nil { + if err := s.notifyStoresBegin(txn, params.Paths); err != nil { return nil, err } @@ -274,20 +264,17 @@ func (s *Storage) Close(txn Transaction) { // reference over the snapshot identified by the transaction. func (s *Storage) BuildIndex(txn Transaction, ref ast.Ref) error { + path, err := NewPathForRef(ref.GroundPrefix()) + if err != nil { + return indexingNotSupportedError() + } + // TODO(tsandall): for now we prevent indexing against stores other than the // built-in. This will be revisited in the future. To determine the // reference touches an external store, we collect the ground portion of // the reference and see if it matches any mounts. - ground := ast.Ref{ref[0]} - - for _, x := range ref[1:] { - if x.IsGround() { - ground = append(ground, x) - } - } - for _, mount := range s.mounts { - if ground.HasPrefix(mount.path) { + if path.HasPrefix(mount.path) || mount.path.HasPrefix(path) { return indexingNotSupportedError() } } @@ -325,14 +312,15 @@ func (s *Storage) getStoreByID(id string) Store { return nil } -func (s *Storage) lazyActivate(store Store, txn Transaction, refs []ast.Ref) error { +func (s *Storage) lazyActivate(store Store, txn Transaction, paths []Path) error { id := store.ID() if _, ok := s.active[id]; ok { return nil } - if err := store.Begin(txn, refs); err != nil { + params := TransactionParams{} + if err := store.Begin(txn, params); err != nil { return err } @@ -340,7 +328,7 @@ func (s *Storage) lazyActivate(store Store, txn Transaction, refs []ast.Ref) err return nil } -func (s *Storage) notifyStoresBegin(txn Transaction, refs []ast.Ref) error { +func (s *Storage) notifyStoresBegin(txn Transaction, paths []Path) error { builtinID := s.builtin.ID() @@ -349,15 +337,18 @@ func (s *Storage) notifyStoresBegin(txn Transaction, refs []ast.Ref) error { // closed, the set is consulted to determine which stores to notify. s.active = map[string]struct{}{} - mounts := map[string]ast.Ref{} + mounts := map[string]Path{} for _, mount := range s.mounts { mounts[mount.backend.ID()] = mount.path } - grouped := groupRefsByStore(builtinID, mounts, refs) + grouped := groupPathsByStore(builtinID, mounts, paths) - for id, refs := range grouped { - if err := s.getStoreByID(id).Begin(txn, refs); err != nil { + for id, groupedPaths := range grouped { + params := TransactionParams{ + Paths: groupedPaths, + } + if err := s.getStoreByID(id).Begin(txn, params); err != nil { return err } s.active[id] = struct{}{} @@ -403,73 +394,37 @@ func GetPolicy(store *Storage, id string) (*ast.Module, []byte, error) { return store.GetPolicy(txn, id) } -// ReadOrDie is a helper function to read the path from storage. If the read -// fails for any reason, this function will panic. This function should only be -// used for tests. -func ReadOrDie(store *Storage, path ast.Ref) interface{} { - txn, err := store.NewTransaction() - if err != nil { - panic(err) - } - defer store.Close(txn) - node, err := store.Read(txn, path) - if err != nil { - panic(err) - } - return node -} - // NewTransactionOrDie is a helper function to create a new transaction. If the // storage layer cannot create a new transaction, this function will panic. This // function should only be used for tests. -func NewTransactionOrDie(store *Storage, refs ...ast.Ref) Transaction { - txn, err := store.NewTransaction(refs...) +func NewTransactionOrDie(store *Storage) Transaction { + txn, err := store.NewTransaction() if err != nil { panic(err) } return txn } -func groupRefsByStore(builtinID string, mounts map[string]ast.Ref, refs []ast.Ref) map[string][]ast.Ref { +func groupPathsByStore(builtinID string, mounts map[string]Path, paths []Path) map[string][]Path { - r := map[string][]ast.Ref{} + r := map[string][]Path{} - for _, ref := range refs { - prefix := ref.GroundPrefix() + for _, path := range paths { sole := false - - // TODO(tsandall): if number of mounts is large this will be costly; - // consider replacing with a trie. - for id, path := range mounts { - - if prefix.HasPrefix(path) { - // This store is solely responsible for the ref. - r[id] = append(r[id], ref) + for id, mountPath := range mounts { + if path.HasPrefix(mountPath) { + r[id] = append(r[id], path[len(mountPath):]) sole = true break } - - if path.HasPrefix(prefix) { - // This store is partially responsible for the ref. If the ref - // is shorter than the mount path, then the entire content of - // the mounted store may be read. Otherwise, replace prefix of - // ref with mount path as the references passed to the store are - // always prefixed with the mount path of the store. - if len(ref) <= len(path) { - r[id] = append(r[id], path) - } else { - tmp := make(ast.Ref, len(ref)) - copy(tmp, path) - copy(tmp[len(path):], ref[len(path):]) - r[id] = append(r[id], tmp) - } + if mountPath.HasPrefix(path) { + r[id] = append(r[id], Path{}) } } - if !sole { // Read may span multiple stores, so by definition, built-in store // will be read. - r[builtinID] = append(r[builtinID], ref) + r[builtinID] = append(r[builtinID], path) } } diff --git a/storage/storage_test.go b/storage/storage_test.go index e91fa883ac..85b17fe86e 100644 --- a/storage/storage_test.go +++ b/storage/storage_test.go @@ -12,21 +12,6 @@ import ( "github.com/open-policy-agent/opa/ast" ) -func TestStorageReadNonGroundRef(t *testing.T) { - store := New(InMemoryConfig()) - txn := NewTransactionOrDie(store) - defer store.Close(txn) - ref := ast.MustParseRef("data.foo[i]") - _, e := store.Read(txn, ref) - err, ok := e.(*Error) - if !ok { - t.Fatalf("Expected storage error but got: %v", err) - } - if err.Code != InternalErr { - t.Fatalf("Expected internal error but got: %v", err) - } -} - func TestStorageReadPlugin(t *testing.T) { mem1 := NewDataStoreFromReader(strings.NewReader(` @@ -44,8 +29,7 @@ func TestStorageReadPlugin(t *testing.T) { } `)) - mountPath := ast.MustParseRef("data.foo.bar.qux") - mem2.SetMountPath(mountPath) + mountPath := MustParsePath("/foo/bar/qux") store := New(Config{ Builtin: mem1, }) @@ -63,21 +47,19 @@ func TestStorageReadPlugin(t *testing.T) { path string expected string }{ - {"plugin", "data.foo.bar.qux.corge[1]", "6"}, - {"multiple", "data.foo.bar", `{"baz": [1,2,3,4], "qux": {"corge": [5,6,7,8]}}`}, + {"plugin", "/foo/bar/qux/corge/1", "6"}, + {"multiple", "/foo/bar", `{"baz": [1,2,3,4], "qux": {"corge": [5,6,7,8]}}`}, } for i, tc := range tests { - result, err := store.Read(txn, ast.MustParseRef(tc.path)) - if err != nil { - t.Errorf("Test #%d (%v): Unexpected read error: %v", i+1, tc.note, err) - } - expected := loadExpectedResult(tc.expected) + result, err := store.Read(txn, MustParsePath(tc.path)) - if !reflect.DeepEqual(result, expected) { - t.Fatalf("Test #%d (%v): Expected %v from built-in store but got: %v", i+1, tc.note, expected, result) + if err != nil { + t.Errorf("Test #%d (%v): Unexpected read error: %v", i+1, tc.note, err) + } else if !reflect.DeepEqual(result, expected) { + t.Errorf("Test #%d (%v): Expected %v from built-in store but got: %v", i+1, tc.note, expected, result) } } @@ -89,7 +71,7 @@ func TestStorageIndexingBasicUpdate(t *testing.T) { refA := ast.MustParseRef("data.a[i]") refB := ast.MustParseRef("data.b[x]") store, ds := newStorageWithIndices(refA, refB) - ds.mustPatch(AddOp, path(`a["-"]`), float64(100)) + ds.Write(nil, AddOp, MustParsePath("/a/-"), float64(100)) if store.IndexExists(refA) { t.Errorf("Expected index to be removed after patch") @@ -111,12 +93,14 @@ func TestStorageTransactionManagement(t *testing.T) { mock := mockStore{} - if err := store.Mount(mock, ast.MustParseRef("data.foo.bar.qux")); err != nil { + mountPath := MustParsePath("/foo/bar/qux") + if err := store.Mount(mock, mountPath); err != nil { t.Fatalf("Unexpected mount error: %v", err) } - txn, err := store.NewTransaction(ast.MustParseRef("data.foo.bar.qux.corge[x]")) - + params := NewTransactionParams(). + WithPaths([]Path{Path{"foo", "bar", "qux", "corge"}}) + txn, err := store.NewTransactionWithParams(params) if err != nil { t.Fatalf("Unexpected error: %v", err) } @@ -143,11 +127,11 @@ func (mockStore) ID() string { return "mock-store" } -func (mockStore) Read(txn Transaction, ref ast.Ref) (interface{}, error) { +func (mockStore) Read(txn Transaction, path Path) (interface{}, error) { return nil, nil } -func (mockStore) Begin(txn Transaction, refs []ast.Ref) error { +func (mockStore) Begin(txn Transaction, params TransactionParams) error { return nil } @@ -155,52 +139,71 @@ func (mockStore) Close(txn Transaction) { } -func TestGroupStoresByRef(t *testing.T) { +func TestGroupPathsByStore(t *testing.T) { - mounts := map[string]ast.Ref{ - "mount-1": ast.MustParseRef("data.foo.bar.qux"), - "mount-2": ast.MustParseRef("data.foo.baz"), - "mount-3": ast.MustParseRef("data.corge"), - } + root := MustParsePath("/") + foo := MustParsePath("/foo") + fooBarQux := MustParsePath("/foo/bar/qux") + fooBarQuxGrault := MustParsePath("/foo/bar/qux/grault") + fooBaz := MustParsePath("/foo/baz") + corge := MustParsePath("/corge") + grault := MustParsePath("/grault") - result := groupRefsByStore("built-in", mounts, []ast.Ref{ - ast.MustParseRef("data[x]"), - ast.MustParseRef("data.foo.bar.qux.grault"), - ast.MustParseRef("data.foo[x][y][z]"), - }) + mounts := map[string]Path{ + "mount-1": fooBarQux, + "mount-2": fooBaz, + "mount-3": corge, + } - expected := map[string][]ast.Ref{ - "built-in": []ast.Ref{ - ast.MustParseRef("data[x]"), - ast.MustParseRef("data.foo[x][y][z]"), + result := groupPathsByStore("built-in", mounts, []Path{root}) + expected := map[string][]Path{ + "built-in": { + root, }, - "mount-1": []ast.Ref{ - ast.MustParseRef("data.foo.bar.qux"), - ast.MustParseRef("data.foo.bar.qux.grault"), - ast.MustParseRef("data.foo.bar.qux[z]"), + "mount-1": { + root, }, - "mount-2": []ast.Ref{ - ast.MustParseRef("data.foo.baz"), - ast.MustParseRef("data.foo.baz[y][z]"), + "mount-2": { + root, }, - "mount-3": []ast.Ref{ - ast.MustParseRef("data.corge"), + "mount-3": { + root, }, } - if len(result) != len(expected) { - t.Fatalf("Expected %v but got: %v", expected, result) + if !reflect.DeepEqual(expected, result) { + t.Errorf("Expected:\n%v\n\nGot:\n%v", expected, result) } - for id := range result { - if len(result[id]) != len(expected[id]) { - t.Fatalf("Expected %v but got: %v", expected[id], result[id]) - } - for i := range result[id] { - if !result[id][i].Equal(expected[id][i]) { - t.Fatalf("Expected %v but got: %v", expected[id], result[id]) - } - } + result = groupPathsByStore("built-in", mounts, []Path{foo}) + expected = map[string][]Path{ + "built-in": { + foo, + }, + "mount-1": { + root, + }, + "mount-2": { + root, + }, + } + + if !reflect.DeepEqual(expected, result) { + t.Errorf("Expected:\n%v\n\nGot:\n%v", expected, result) + } + + result = groupPathsByStore("built-in", mounts, []Path{fooBarQuxGrault, corge}) + expected = map[string][]Path{ + "mount-1": { + grault, + }, + "mount-3": { + root, + }, + } + + if !reflect.DeepEqual(expected, result) { + t.Errorf("Expected:\n%v\n\nGot:\n%v", expected, result) } } diff --git a/storage/trigger.go b/storage/trigger.go index 1a12a2e5b5..ccf95918d6 100644 --- a/storage/trigger.go +++ b/storage/trigger.go @@ -6,7 +6,7 @@ package storage // TriggerCallback defines the interface that callers can implement to handle // changes in the stores. -type TriggerCallback func(txn Transaction, op PatchOp, path []interface{}, value interface{}) error +type TriggerCallback func(txn Transaction, op PatchOp, path Path, value interface{}) error // TriggerConfig contains the trigger registration configuration. type TriggerConfig struct { diff --git a/test/scheduler/scheduler_bench_test.go b/test/scheduler/scheduler_bench_test.go index 6ae3d49a3b..a56281d0f2 100644 --- a/test/scheduler/scheduler_bench_test.go +++ b/test/scheduler/scheduler_bench_test.go @@ -98,7 +98,7 @@ func setupNodes(store *storage.Storage, txn storage.Transaction, n int) { if err != nil { panic(err) } - if err := store.Write(txn, storage.AddOp, ast.MustParseRef("data.nodes"), map[string]interface{}{}); err != nil { + if err := store.Write(txn, storage.AddOp, storage.MustParsePath("/nodes"), map[string]interface{}{}); err != nil { panic(err) } for i := 0; i < n; i++ { @@ -106,8 +106,8 @@ func setupNodes(store *storage.Storage, txn storage.Transaction, n int) { Name: fmt.Sprintf("node%v", i), } v := runTemplate(tmpl, input) - ref := ast.MustParseRef(fmt.Sprintf("data.nodes.%v", input.Name)) - if err := store.Write(txn, storage.AddOp, ref, v); err != nil { + path := storage.MustParsePath(fmt.Sprintf("/nodes/%v", input.Name)) + if err := store.Write(txn, storage.AddOp, path, v); err != nil { panic(err) } } @@ -118,8 +118,8 @@ func setupRCs(store *storage.Storage, txn storage.Transaction, n int) { if err != nil { panic(err) } - ref := ast.MustParseRef("data.replicationcontrollers") - if err := store.Write(txn, storage.AddOp, ref, map[string]interface{}{}); err != nil { + path := storage.MustParsePath("/replicationcontrollers") + if err := store.Write(txn, storage.AddOp, path, map[string]interface{}{}); err != nil { panic(err) } for i := 0; i < n; i++ { @@ -127,8 +127,8 @@ func setupRCs(store *storage.Storage, txn storage.Transaction, n int) { Name: fmt.Sprintf("rc%v", i), } v := runTemplate(tmpl, input) - ref = ast.MustParseRef(fmt.Sprintf("data.replicationcontrollers.%v", input.Name)) - if err := store.Write(txn, storage.AddOp, ref, v); err != nil { + path = storage.MustParsePath(fmt.Sprintf("/replicationcontrollers/%v", input.Name)) + if err := store.Write(txn, storage.AddOp, path, v); err != nil { panic(err) } } @@ -139,8 +139,8 @@ func setupPods(store *storage.Storage, txn storage.Transaction, n int, numNodes if err != nil { panic(err) } - ref := ast.MustParseRef("data.pods") - if err := store.Write(txn, storage.AddOp, ref, map[string]interface{}{}); err != nil { + path := storage.MustParsePath("/pods") + if err := store.Write(txn, storage.AddOp, path, map[string]interface{}{}); err != nil { panic(err) } for i := 0; i < n; i++ { @@ -149,8 +149,8 @@ func setupPods(store *storage.Storage, txn storage.Transaction, n int, numNodes NodeName: fmt.Sprintf("node%v", i%numNodes), } v := runTemplate(tmpl, input) - ref = ast.MustParseRef(fmt.Sprintf("data.pods.%v", input.Name)) - if err := store.Write(txn, storage.AddOp, ref, v); err != nil { + path = storage.MustParsePath(fmt.Sprintf("/pods/%v", input.Name)) + if err := store.Write(txn, storage.AddOp, path, v); err != nil { panic(err) } } diff --git a/topdown/eq.go b/topdown/eq.go index 8920fd7779..de29340fce 100644 --- a/topdown/eq.go +++ b/topdown/eq.go @@ -100,8 +100,7 @@ func evalEqUnifyArray(ctx *Context, a ast.Array, b ast.Value, prev *Undo, iter I func evalEqUnifyArrayRef(ctx *Context, a ast.Array, b ast.Ref, prev *Undo, iter Iterator) (*Undo, error) { - // TODO(tsandall): should not be accessing txn here? - r, err := ctx.Store.Read(ctx.txn, b) + r, err := ctx.Resolve(b) if err != nil { return prev, err } @@ -180,8 +179,7 @@ func evalEqUnifyObject(ctx *Context, a ast.Object, b ast.Value, prev *Undo, iter func evalEqUnifyObjectRef(ctx *Context, a ast.Object, b ast.Ref, prev *Undo, iter Iterator) (*Undo, error) { - // TODO(tsandall): should not be accessing txn here? - r, err := ctx.Store.Read(ctx.txn, b) + r, err := ctx.Resolve(b) if err != nil { return prev, err diff --git a/topdown/topdown.go b/topdown/topdown.go index 0d1461f969..48461e9e64 100644 --- a/topdown/topdown.go +++ b/topdown/topdown.go @@ -25,9 +25,6 @@ type Context struct { Store *storage.Storage Tracer Tracer - // TODO(tsandall): make the transaction public and lazily create one in - // Eval(). This way callers do not have to provide the transaction unless - // they want to run evaluation multiple times against the same snapshot. txn storage.Transaction cache *contextcache qid uint64 @@ -140,7 +137,30 @@ func (ctx *Context) Current() *ast.Expr { // Resolve returns the native Go value referred to by the ref. func (ctx *Context) Resolve(ref ast.Ref) (interface{}, error) { - return ctx.Store.Read(ctx.txn, ref) + + if ref.IsNested() { + cpy := make(ast.Ref, len(ref)) + for i := range ref { + switch v := ref[i].Value.(type) { + case ast.Ref: + r, err := lookupValue(ctx, v) + if err != nil { + return nil, err + } + cpy[i] = ast.NewTerm(r) + default: + cpy[i] = ref[i] + } + } + ref = cpy + } + + path, err := storage.NewPathForRef(ref) + if err != nil { + return nil, err + } + + return ctx.Store.Read(ctx.txn, path) } // Step returns a new context to evaluate the next expression. @@ -692,7 +712,11 @@ type resolver struct { } func (r resolver) Resolve(ref ast.Ref) (interface{}, error) { - return r.store.Read(r.txn, ref) + path, err := storage.NewPathForRef(ref) + if err != nil { + return nil, err + } + return r.store.Read(r.txn, path) } // ResolveRefs returns the AST value obtained by resolving references to base @@ -1010,12 +1034,11 @@ func evalRef(ctx *Context, ref, path ast.Ref, iter Iterator) error { func evalRefRec(ctx *Context, ref ast.Ref, iter Iterator) error { // Obtain ground prefix of the reference. - var plugged ast.Ref var prefix ast.Ref + switch v := PlugValue(ref, ctx.Binding).(type) { case ast.Ref: - plugged = v - prefix = plugged.GroundPrefix() + prefix = v.GroundPrefix() default: // Fast-path? TODO test case. return iter(ctx) @@ -1045,7 +1068,7 @@ func evalRefRec(ctx *Context, ref ast.Ref, iter Iterator) error { // result before continuing. func evalRefRecGround(ctx *Context, ref, prefix ast.Ref, iter Iterator) error { - doc, readErr := ctx.Store.Read(ctx.txn, prefix) + doc, readErr := ctx.Resolve(prefix) if readErr != nil { if !storage.IsNotFound(readErr) { return readErr @@ -1170,7 +1193,7 @@ func evalRefRecNonGround(ctx *Context, ref, prefix ast.Ref, iter Iterator) error variable := ref[len(prefix)].Value - doc, err := ctx.Store.Read(ctx.txn, prefix) + doc, err := ctx.Resolve(prefix) if err != nil { if !storage.IsNotFound(err) { return err @@ -2079,19 +2102,8 @@ func indexingAllowed(ref ast.Ref, term *ast.Term) bool { return true } -func lookupExists(ctx *Context, ref ast.Ref) (bool, error) { - _, err := ctx.Store.Read(ctx.txn, ref) - if err != nil { - if storage.IsNotFound(err) { - return false, nil - } - return false, err - } - return true, nil -} - func lookupValue(ctx *Context, ref ast.Ref) (ast.Value, error) { - r, err := ctx.Store.Read(ctx.txn, ref) + r, err := ctx.Resolve(ref) if err != nil { return nil, err } diff --git a/topdown/topdown_test.go b/topdown/topdown_test.go index 09e33cdc2c..dda760c105 100644 --- a/topdown/topdown_test.go +++ b/topdown/topdown_test.go @@ -59,7 +59,7 @@ func TestEvalRef(t *testing.T) { store := storage.New(storage.InMemoryWithJSONConfig(loadSmallTestData())) - txn := storage.NewTransactionOrDie(store, nil) + txn := storage.NewTransactionOrDie(store) defer store.Close(txn) ctx := NewContext(nil, compiler, store, txn) @@ -151,7 +151,7 @@ func TestEvalTerms(t *testing.T) { store := storage.New(storage.InMemoryWithJSONConfig(loadSmallTestData())) - txn := storage.NewTransactionOrDie(store, nil) + txn := storage.NewTransactionOrDie(store) defer store.Close(txn) for _, tc := range tests { @@ -1245,8 +1245,7 @@ func TestTopDownStoragePlugin(t *testing.T) { store := storage.New(storage.InMemoryWithJSONConfig(loadSmallTestData())) plugin := storage.NewDataStoreFromReader(strings.NewReader(`{"b": [1,3,5,6]}`)) - mountPath := ast.MustParseRef("data.plugin") - plugin.SetMountPath(mountPath) + mountPath, _ := storage.ParsePath("/plugin") if err := store.Mount(plugin, mountPath); err != nil { t.Fatalf("Unexpected mount error: %v", err)