Skip to content

Commit

Permalink
Enable policy files greater than 64KB in size
Browse files Browse the repository at this point in the history
fixes #214

Signed-off-by: Anand Krishnamoorthi <[email protected]>
  • Loading branch information
anakrish committed Apr 26, 2024
1 parent 7fde338 commit 466dbc3
Show file tree
Hide file tree
Showing 8 changed files with 2,711 additions and 69 deletions.
14 changes: 7 additions & 7 deletions src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ impl Engine {
/// ```
///
pub fn add_policy(&mut self, path: String, rego: String) -> Result<()> {
let source = Source::new(path, rego);
let source = Source::from_contents(path, rego)?;
let mut parser = Parser::new(&source)?;
self.modules.push(Ref::new(parser.parse()?));
// if policies change, interpreter needs to be prepared again
Expand Down Expand Up @@ -294,15 +294,15 @@ impl Engine {

self.interpreter.create_rule_prefixes()?;
let query_module = {
let source = Source::new(
let source = Source::from_contents(
"<query_module.rego>".to_owned(),
"package __internal_query_module".to_owned(),
);
)?;
Ref::new(Parser::new(&source)?.parse()?)
};

// Parse the query.
let query_source = Source::new("<query.rego>".to_string(), query);
let query_source = Source::from_contents("<query.rego>".to_string(), query)?;
let mut parser = Parser::new(&query_source)?;
let query_node = parser.parse_user_query()?;
if query_node.span.text() == "data" {
Expand Down Expand Up @@ -411,15 +411,15 @@ impl Engine {
self.eval_modules(enable_tracing)?;

let query_module = {
let source = Source::new(
let source = Source::from_contents(
"<query_module.rego>".to_owned(),
"package __internal_query_module".to_owned(),
);
)?;
Ref::new(Parser::new(&source)?.parse()?)
};

// Parse the query.
let query_source = Source::new("<query.rego>".to_string(), query);
let query_source = Source::from_contents("<query.rego>".to_string(), query)?;
let mut parser = Parser::new(&query_source)?;
let query_node = parser.parse_user_query()?;
let query_schedule = Analyzer::new().analyze_query_snippet(&self.modules, &query_node)?;
Expand Down
103 changes: 52 additions & 51 deletions src/lexer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use anyhow::{anyhow, bail, Result};
struct SourceInternal {
pub file: String,
pub contents: String,
pub lines: Vec<(u16, u16)>,
pub lines: Vec<(u32, u32)>,
}

#[derive(Clone)]
Expand Down Expand Up @@ -61,8 +61,8 @@ impl Debug for Source {
#[derive(Clone)]
pub struct SourceStr {
source: Source,
start: u16,
end: u16,
start: u32,
end: u32,
}

impl Debug for SourceStr {
Expand All @@ -78,7 +78,7 @@ impl std::fmt::Display for SourceStr {
}

impl SourceStr {
pub fn new(source: Source, start: u16, end: u16) -> Self {
pub fn new(source: Source, start: u32, end: u32) -> Self {
Self { source, start, end }
}

Expand Down Expand Up @@ -116,39 +116,43 @@ impl std::cmp::Ord for SourceStr {
}

impl Source {
pub fn new(file: String, contents: String) -> Source {
pub fn from_contents(file: String, contents: String) -> Result<Source> {
let max_size = u32::MAX as usize - 2; // Account for rows, cols possibly starting at 1, EOF etc.
if contents.len() > max_size {
bail!("{file} exceeds maximum allowed policy file size {max_size}");
}
let mut lines = vec![];
let mut prev_ch = ' ';
let mut prev_pos = 0u16;
let mut start = 0u16;
let mut prev_pos = 0u32;
let mut start = 0u32;
for (i, ch) in contents.char_indices() {
if ch == '\n' {
let end = match prev_ch {
'\r' => prev_pos,
_ => i as u16,
_ => i as u32,
};
lines.push((start, end));
start = i as u16 + 1;
start = i as u32 + 1;
}
prev_ch = ch;
prev_pos = i as u16;
prev_pos = i as u32;
}

if (start as usize) < contents.len() {
lines.push((start, contents.len() as u16));
lines.push((start, contents.len() as u32));
} else if contents.is_empty() {
lines.push((0, 0));
} else {
let s = (contents.len() - 1) as u16;
let s = (contents.len() - 1) as u32;
lines.push((s, s));
}
Self {
Ok(Self {
src: Rc::new(SourceInternal {
file,
contents,
lines,
}),
}
})
}

pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Source> {
Expand All @@ -157,10 +161,7 @@ impl Source {
Err(e) => bail!("Failed to read {}. {e}", path.as_ref().display()),
};
// TODO: retain path instead of converting to string
Ok(Self::new(
path.as_ref().to_string_lossy().to_string(),
contents,
))
Self::from_contents(path.as_ref().to_string_lossy().to_string(), contents)
}

pub fn file(&self) -> &String {
Expand All @@ -169,7 +170,7 @@ impl Source {
pub fn contents(&self) -> &String {
&self.src.contents
}
pub fn line(&self, idx: u16) -> &str {
pub fn line(&self, idx: u32) -> &str {
let idx = idx as usize;
if idx < self.src.lines.len() {
let (start, end) = self.src.lines[idx];
Expand All @@ -179,7 +180,7 @@ impl Source {
}
}

pub fn message(&self, line: u16, col: u16, kind: &str, msg: &str) -> String {
pub fn message(&self, line: u32, col: u32, kind: &str, msg: &str) -> String {
if line as usize > self.src.lines.len() {
return format!("{}: invalid line {} specified", self.src.file, line);
}
Expand All @@ -206,18 +207,18 @@ impl Source {
)
}

pub fn error(&self, line: u16, col: u16, msg: &str) -> anyhow::Error {
pub fn error(&self, line: u32, col: u32, msg: &str) -> anyhow::Error {
anyhow!(self.message(line, col, "error", msg))
}
}

#[derive(Clone)]
pub struct Span {
pub source: Source,
pub line: u16,
pub col: u16,
pub start: u16,
pub end: u16,
pub line: u32,
pub col: u32,
pub start: u32,
pub end: u32,
}

impl Span {
Expand Down Expand Up @@ -272,8 +273,8 @@ pub struct Token(pub TokenKind, pub Span);
pub struct Lexer<'source> {
source: Source,
iter: Peekable<CharIndices<'source>>,
line: u16,
col: u16,
line: u32,
col: u32,
}

impl<'source> Lexer<'source> {
Expand Down Expand Up @@ -312,15 +313,15 @@ impl<'source> Lexer<'source> {
}
}
let end = self.peek().0;
self.col += (end - start) as u16;
self.col += (end - start) as u32;
Ok(Token(
TokenKind::Ident,
Span {
source: self.source.clone(),
line: self.line,
col,
start: start as u16,
end: end as u16,
start: start as u32,
end: end as u32,
},
))
}
Expand Down Expand Up @@ -363,7 +364,7 @@ impl<'source> Lexer<'source> {
}

let end = self.peek().0;
self.col += (end - start) as u16;
self.col += (end - start) as u32;

// Check for invalid number.Valid number cannot be followed by
// these characters:
Expand Down Expand Up @@ -403,8 +404,8 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line: self.line,
col,
start: start as u16,
end: end as u16,
start: start as u32,
end: end as u32,
},
))
}
Expand Down Expand Up @@ -440,8 +441,8 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line,
col,
start: start as u16,
end: end as u16 - 1,
start: start as u32,
end: end as u32 - 1,
},
))
}
Expand All @@ -453,7 +454,7 @@ impl<'source> Lexer<'source> {
let (start, _) = self.peek();
loop {
let (offset, ch) = self.peek();
let col = self.col + (offset - start) as u16;
let col = self.col + (offset - start) as u32;
match ch {
'"' | '#' | '\x00' => {
break;
Expand All @@ -468,7 +469,7 @@ impl<'source> Lexer<'source> {
'u' => {
for _i in 0..4 {
let (offset, ch) = self.peek();
let col = self.col + (offset - start) as u16;
let col = self.col + (offset - start) as u32;
if !ch.is_ascii_hexdigit() {
return Err(self.source.error(
line,
Expand All @@ -484,7 +485,7 @@ impl<'source> Lexer<'source> {
}
_ => {
// check for valid json chars
let col = self.col + (offset - start) as u16;
let col = self.col + (offset - start) as u32;
if !('\u{0020}'..='\u{10FFFF}').contains(&ch) {
return Err(self.source.error(line, col, "invalid character in string"));
}
Expand All @@ -499,7 +500,7 @@ impl<'source> Lexer<'source> {

self.iter.next();
let end = self.peek().0;
self.col += (end - start) as u16;
self.col += (end - start) as u32;

// Ensure that the string is parsable in Rust.
match serde_json::from_str::<String>(&self.source.contents()[start - 1..end]) {
Expand All @@ -522,8 +523,8 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line,
col: col + 1,
start: start as u16,
end: end as u16 - 1,
start: start as u32,
end: end as u32 - 1,
},
))
}
Expand Down Expand Up @@ -593,14 +594,14 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line: self.line,
col,
start: start as u16,
end: start as u16 + 1,
start: start as u32,
end: start as u32 + 1,
}))
}
':' => {
self.col += 1;
self.iter.next();
let mut end = start as u16 + 1;
let mut end = start as u32 + 1;
if self.peek().1 == '=' {
self.col += 1;
self.iter.next();
Expand All @@ -610,7 +611,7 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line: self.line,
col,
start: start as u16,
start: start as u32,
end
}))
}
Expand All @@ -626,8 +627,8 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line: self.line,
col,
start: start as u16,
end: self.peek().0 as u16,
start: start as u32,
end: self.peek().0 as u32,
}))
}
'!' if self.peekahead(1).1 == '=' => {
Expand All @@ -638,8 +639,8 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line: self.line,
col,
start: start as u16,
end: self.peek().0 as u16,
start: start as u32,
end: self.peek().0 as u32,
}))
}
'"' => self.read_string(),
Expand All @@ -648,8 +649,8 @@ impl<'source> Lexer<'source> {
source: self.source.clone(),
line:self.line,
col,
start: start as u16,
end: start as u16
start: start as u32,
end: start as u32
})),
_ if chr.is_ascii_digit() => self.read_number(),
_ if chr.is_ascii_alphabetic() || chr == '_' => {
Expand Down
4 changes: 2 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ use std::rc::Rc;
#[derive(Debug, Clone, Serialize, Eq, PartialEq)]
pub struct Location {
/// Line number. Starts at 1.
pub row: u16,
pub row: u32,
/// Column number. Starts at 1.
pub col: u16,
pub col: u32,
}

/// An expression in a Rego query.
Expand Down
6 changes: 3 additions & 3 deletions src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ pub struct Parser<'source> {
source: Source,
lexer: Lexer<'source>,
tok: Token,
line: u16,
end: u16,
line: u32,
end: u32,
future_keywords: BTreeMap<String, Span>,
rego_v1: bool,
}
Expand Down Expand Up @@ -753,7 +753,7 @@ impl<'source> Parser<'source> {

fn parse_membership_tail(
&mut self,
start: u16,
start: u32,
mut expr1: Expr,
mut expr2: Option<Expr>,
) -> Result<Expr> {
Expand Down
5 changes: 4 additions & 1 deletion src/tests/scheduler/analyzer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ fn analyze_file(regos: &[String], expected_scopes: &[Scope]) -> Result<()> {
let mut sources = vec![];
let mut modules = vec![];
for (idx, _) in regos.iter().enumerate() {
sources.push(Source::new(format!("rego_{idx}"), regos[idx].clone()));
sources.push(Source::from_contents(
format!("rego_{idx}"),
regos[idx].clone(),
)?);
}

for source in &sources {
Expand Down
Loading

0 comments on commit 466dbc3

Please sign in to comment.