Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More OPA conformance #77

Merged
merged 1 commit into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/builtins/objects.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::ast::{Expr, Ref};
use crate::builtins;
use crate::builtins::utils::{ensure_args_count, ensure_object};
use crate::builtins::utils::{ensure_args_count, ensure_array, ensure_object};
use crate::lexer::Span;
use crate::value::Value;

Expand All @@ -21,6 +21,8 @@ pub fn register(m: &mut HashMap<&'static str, builtins::BuiltinFcn>) {
m.insert("object.keys", (keys, 1));
m.insert("object.remove", (remove, 2));
m.insert("object.subset", (subset, 2));
m.insert("object.union", (object_union, 2));
m.insert("object.union_n", (object_union_n, 1));
}

fn json_filter_impl(v: &Value, filter: &Value) -> Value {
Expand Down Expand Up @@ -324,3 +326,59 @@ fn subset(span: &Span, params: &[Ref<Expr>], args: &[Value], _strict: bool) -> R

Ok(Value::Bool(is_subset(&args[0], &args[1])))
}

fn union(obj1: &Value, obj2: &Value) -> Result<Value> {
match (obj1, obj2) {
(Value::Object(m1), Value::Object(m2)) => {
let mut u = obj1.clone();
let um = u.as_object_mut()?;

for (key2, value2) in m2.iter() {
let vm = match m1.get(key2) {
Some(value1) => union(value1, value2)?,
_ => value2.clone(),
};
um.insert(key2.clone(), vm);
}
Ok(u)
}
_ => Ok(obj2.clone()),
}
}

fn object_union(span: &Span, params: &[Ref<Expr>], args: &[Value], _strict: bool) -> Result<Value> {
let name = "object.union";
ensure_args_count(span, name, params, args, 2)?;

let _ = ensure_object(name, &params[0], args[0].clone())?;
let _ = ensure_object(name, &params[1], args[1].clone())?;

union(&args[0], &args[1])
}

fn object_union_n(
span: &Span,
params: &[Ref<Expr>],
args: &[Value],
strict: bool,
) -> Result<Value> {
let name = "object.union_n";
ensure_args_count(span, name, params, args, 1)?;

let arr = ensure_array(name, &params[0], args[0].clone())?;

let mut u = Value::new_object();
for (idx, a) in arr.iter().enumerate() {
if a.as_object().is_err() {
if strict {
bail!(params[0]
.span()
.error(&format!("item at index {idx} is not an object")));
}
return Ok(Value::Undefined);
}
u = union(&u, a)?;
}

Ok(u)
}
27 changes: 15 additions & 12 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,26 @@ impl Engine {
pub fn eval_query(&mut self, query: String, enable_tracing: bool) -> Result<QueryResults> {
self.eval_modules(false)?;

let query_module = {
let source = Source::new(
"<query_module.rego>".to_owned(),
"package __internal_query_module".to_owned(),
);
Ref::new(Parser::new(&source)?.parse()?)
};

// Parse the query.
let query_len = query.len();
let query_source = Source::new("<query.rego>".to_string(), query);
let query_span = Span {
source: query_source.clone(),
line: 1,
col: 1,
start: 0,
end: query_len as u16,
};
let mut parser = Parser::new(&query_source)?;
let query_node = Ref::new(parser.parse_query(query_span, "")?);
let query_node = parser.parse_user_query()?;
let query_schedule = Analyzer::new().analyze_query_snippet(&self.modules, &query_node)?;

let results =
self.interpreter
.eval_user_query(&query_node, &query_schedule, enable_tracing)?;
let results = self.interpreter.eval_user_query(
&query_module,
&query_node,
&query_schedule,
enable_tracing,
)?;
Ok(results)
}
}
11 changes: 10 additions & 1 deletion src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2727,6 +2727,7 @@ impl Interpreter {

pub fn eval_user_query(
&mut self,
module: &Ref<Module>,
query: &Ref<Query>,
schedule: &Schedule,
enable_tracing: bool,
Expand Down Expand Up @@ -2754,7 +2755,7 @@ impl Interpreter {
is_compr: false,
});

let prev_module = self.set_current_module(self.modules.last().cloned())?;
let prev_module = self.set_current_module(Some(module.clone()))?;

// Eval the query.
let query_r = self.eval_query(query);
Expand Down Expand Up @@ -2795,6 +2796,14 @@ impl Interpreter {

self.set_current_module(prev_module)?;

if let Some(r) = results.result.last() {
if r.bindings.is_empty_object()
&& r.expressions.iter().any(|e| e.value == Value::Bool(false))
{
results = QueryResults::default();
}
}

match query_r {
Ok(_) => Ok(results),
Err(e) => Err(e),
Expand Down
28 changes: 24 additions & 4 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,10 +620,21 @@ impl<'source> Parser<'source> {
let op = match self.token_text() {
"+" => ArithOp::Add,
"-" => ArithOp::Sub,
n if n.starts_with('-') && self.tok.0 == TokenKind::Number => ArithOp::Sub,
_ => return Ok(expr),
};
self.next_token()?;
let right = self.parse_mul_div_mod_expr()?;
let right = if self.token_text().len() > 1 {
// Treat the - as a separate token
let mut rhs_span = self.tok.1.clone();
rhs_span.start += 1;
rhs_span.col += 1;

self.next_token()?;
Expr::Number(rhs_span)
} else {
self.next_token()?;
self.parse_mul_div_mod_expr()?
};
span.end = self.end;
expr = Expr::ArithExpr {
span,
Expand Down Expand Up @@ -953,7 +964,7 @@ impl<'source> Parser<'source> {
})
}

pub fn parse_query(&mut self, mut span: Span, end_delim: &str) -> Result<Query> {
fn parse_query(&mut self, mut span: Span, end_delim: &str) -> Result<Query> {
let state = self.clone();
let is_definite_query = matches!(self.token_text(), "some" | "every");

Expand Down Expand Up @@ -1485,7 +1496,7 @@ impl<'source> Parser<'source> {
Ok(Rule::Spec { span, head, bodies })
}

fn parse_package(&mut self) -> Result<Package> {
pub fn parse_package(&mut self) -> Result<Package> {
let mut span = self.tok.1.clone();
self.expect("package", "Missing package declaration.")?;
let name = self.parse_path_ref()?;
Expand Down Expand Up @@ -1609,4 +1620,13 @@ impl<'source> Parser<'source> {
policy,
})
}

pub fn parse_user_query(&mut self) -> Result<Ref<Query>> {
let span = self.tok.1.clone();
let query = Ref::new(self.parse_query(span, "")?);
if self.tok.0 != TokenKind::Eof {
bail!(self.tok.1.error("expecting EOF"));
}
Ok(query)
}
}
21 changes: 21 additions & 0 deletions tests/interpreter/cases/arithmetic/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
cases:
- note: negative-integer-literal-in-arithmetic-expressions
data: {}
modules:
- |
# In the following, the negative integer must be broken into a - and an integer tokens
# when in arighmetic expression contexts
package test
a = 1+1-1
b = 1 +1 -1
c = 1 + 1 - 1
d = -1 -1
query: data.test
want_result:
a: 1
b: 1
c: 1
d: -2

44 changes: 44 additions & 0 deletions tests/interpreter/cases/engine/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
cases:
- note: trailing whitespace in query
data: {}
modules: ["package test"]
query: "1 + 1 -2 "
want_result: 0

- note: trailing chars in query
data: {}
modules: ["package test"]
query: "[1]]"
error: expecting EOF

- note: trailing expressions in query
data: {}
modules: ["package test"]
query: "1 2"
error: expecting EOF

- note: multiple statements in query
data: {}
modules: ["package test"]
query: |
a = [1, 2, 3]
true
y = 1 + 1
want_result:
a: [1, 2, 3]
y: 2

- note: comprehensions in query
data: {}
modules: ["package test"]
query: |
true
[1, 2, 3][_]
want_result:
many!:
- [true, 1]
- [true, 2]
- [true, 3]

25 changes: 20 additions & 5 deletions tests/interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ fn match_values(computed: &Value, expected: &Value) -> Result<()> {

pub fn check_output(computed_results: &[Value], expected_results: &[Value]) -> Result<()> {
if computed_results.len() != expected_results.len() {
dbg!((&computed_results, &expected_results));
bail!(
"the number of computed results ({}) and expected results ({}) is not equal",
computed_results.len(),
Expand All @@ -108,11 +109,25 @@ pub fn check_output(computed_results: &[Value], expected_results: &[Value]) -> R
}

fn push_query_results(query_results: QueryResults, results: &mut Vec<Value>) {
if let Some(query_result) = query_results.result.last() {
if !query_result.bindings.is_empty_object() {
results.push(query_result.bindings.clone());
} else if let Some(v) = query_result.expressions.last() {
results.push(v.value.clone());
if query_results.result.len() == 1 {
if let Some(query_result) = query_results.result.last() {
if !query_result.bindings.is_empty_object() {
results.push(query_result.bindings.clone());
} else {
for e in query_result.expressions.iter() {
results.push(e.value.clone());
}
}
}
} else {
for r in query_results.result.iter() {
if !r.bindings.is_empty_object() {
results.push(r.bindings.clone());
} else {
results.push(Value::from_array(
r.expressions.iter().map(|e| e.value.clone()).collect(),
));
}
}
}
}
Expand Down
5 changes: 5 additions & 0 deletions tests/opa.passing
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ indirectreferences
inputvalues
intersection
invalidkeyerror
jsonbuiltins
jsonfilter
jsonfilteridempotent
jsonremove
Expand All @@ -67,7 +68,11 @@ objectkeys
objectremove
objectremoveidempotent
objectremovenonstringkey
objectunion
objectunionn
partialdocconstants
partialiter
partialobjectdoc
partialsetdoc
planner-ir
rand
Expand Down
6 changes: 5 additions & 1 deletion tests/opa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ fn run_opa_tests(opa_tests_dir: String, folders: &[String]) -> Result<()> {
entry.0 += 1;
}
// TODO: Handle tests that specify both want_result and strict_error
(Err(_), _) if case.want_error.is_some() => {
(Err(_), _)
if case.want_error.is_some()
|| case.strict_error == Some(true)
|| case.want_error_code.is_some() =>
{
// Expected failure.
entry.0 += 1;
}
Expand Down