Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hint ite statements #266

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions examples/hint.no
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ hint fn div(lhs: Field, rhs: Field) -> Field {
}

hint fn ite(lhs: Field, rhs: Field) -> Field {
return if lhs != rhs { lhs } else { rhs };
let mut var = 0;
if lhs != rhs {
return lhs;
} else {
return rhs;
}
}

hint fn exp(const EXP: Field, val: Field) -> Field {
Expand Down Expand Up @@ -55,9 +60,9 @@ fn main(pub public_input: Field, private_input: Field) -> Field {
assert_eq(public_input, 2);
assert_eq(private_input, 2);

let xx = unsafe add_mul_2(public_input, private_input);
let xx = unsafe add_mul_2(public_input, private_input);
let yy = unsafe mul(public_input, private_input);
assert_eq(xx, yy * 2);
assert_eq(xx, yy * 2);

let zz = unsafe div(xx, public_input);
assert_eq(zz, yy);
Expand All @@ -83,4 +88,4 @@ fn main(pub public_input: Field, private_input: Field) -> Field {
assert(oo[2]);

return xx;
}
}
22 changes: 19 additions & 3 deletions examples/if_else.no
Original file line number Diff line number Diff line change
@@ -1,6 +1,22 @@
hint fn ite(xx: Field , arr: [Field;LEN]) -> Field {
let mut var = 0;
if xx == 10 {
var = xx;
for idx in 0..LEN {
var = var + arr[idx];
}
return xx * var;
} else {
return xx * xx;
}
}


fn main(pub xx: Field) {
let plus = xx + 1;
let mut plus = xx + 1;
let cond = xx == 1;
let yy = if cond { plus } else { xx };
assert_eq(yy, 2);
let ans = unsafe ite(plus, [1,2,3]);
log(ans);
assert_eq(ans, 4);
assert_eq(plus, 2);
}
28 changes: 28 additions & 0 deletions src/circuit_writer/ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,34 @@ impl<B: Backend> IRWriter<B> {
return Ok(Some(var));
}
StmtKind::Comment(_) => (),
StmtKind::Ite {
condition,
then_branch,
else_branch,
} => {
fn_env.nest();
let cond_var = self
.compute_expr(fn_env, condition)?
.unwrap()
.value(self, fn_env)
.cvars[0]
.clone();
let ret = self.compile_block(fn_env, then_branch)?.unwrap().cvars[0].clone();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why then_branch block isn't handled the same way as the else_branch block?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the fn_env.nest() and fn_env.pop() be called for each compile_block?

let mut else_ret = None;
match else_branch {
Some(stmts) => {
//start another block
fn_env.nest();
else_ret = self.compile_block(fn_env, stmts)?;
fn_env.pop();
}
None => (),
}
let else_ret_ = else_ret.unwrap().cvars[0].clone();
let ite_ir = term![Op::Ite; cond_var, ret, else_ret_];
let res = Var::new_cvar(ite_ir, condition.span);
return Ok(Some(VarOrRef::Var(res)));
}
}

Ok(None)
Expand Down
5 changes: 5 additions & 0 deletions src/circuit_writer/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,11 @@ impl<B: Backend> CircuitWriter<B> {
return Ok(Some(var));
}
StmtKind::Comment(_) => (),
StmtKind::Ite {
condition,
then_branch,
else_branch,
} => unreachable!("Ite should be only hint function"),
}

Ok(None)
Expand Down
6 changes: 6 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,4 +374,10 @@ pub enum ErrorKind {

#[error("Not enough variables provided to fill placeholders in the formatted string")]
InsufficientVariables,

#[error("The preceding if-else block already handles all possible return paths. Consider removing this statement or restructuring your control flow")]
UnreachableStatement,

#[error("if statements are only allowed in hint functions either mark your function hint or use if expressions instead (e.g. `x = if cond {{ 1 }} else {{ 2 }};`)")]
IfElseBlocksOnlyAllowedInHintFn,
}
42 changes: 41 additions & 1 deletion src/mast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ pub struct MonomorphizedFnEnv {
current_scope: usize,

vars: HashMap<String, (usize, MTypeInfo)>,

// storing the monomorphised return type for early return in Ite blocks
return_typed: Option<Ty>,
}

impl MonomorphizedFnEnv {
Expand Down Expand Up @@ -1213,7 +1216,6 @@ pub fn monomorphize_block<B: Backend>(
if let Some((stmt, expr_mono)) = monomorphize_stmt(ctx, mono_fn_env, stmt)? {
stmts_mono.push(stmt);

// only return stmt can return `ExprMonoInfo` which contains propagated constants
if expr_mono.is_some() {
ret_expr_mono = expr_mono;
}
Expand Down Expand Up @@ -1359,6 +1361,43 @@ pub fn monomorphize_stmt<B: Backend>(
Some((stmt_mono, Some(expr_mono)))
}
StmtKind::Comment(_) => None,
StmtKind::Ite {
condition,
then_branch,
else_branch,
} => {
// enter the ite statement
mono_fn_env.nest();
let expected_return_ty =
if matches!(then_branch.last().unwrap().kind, StmtKind::Return(..)) {
mono_fn_env.return_typed.clone()
} else {
None
};
let expr_mono = monomorphize_expr(ctx, condition, mono_fn_env)?;
let (then_branch_mono, ret_mono) =
monomorphize_block(ctx, mono_fn_env, then_branch, expected_return_ty.as_ref())?;
let else_branch_mono = else_branch
.as_ref()
.map(|stmts| {
let (mono_branch, _) =
monomorphize_block(ctx, mono_fn_env, stmts, expected_return_ty.as_ref())?;
Ok(mono_branch)
})
.transpose()?;
let ite_stmt = Stmt {
kind: StmtKind::Ite {
condition: Box::new(expr_mono.expr),
then_branch: then_branch_mono,
else_branch: else_branch_mono,
},
span: stmt.span,
};

//exit the ite
mono_fn_env.pop();
Some((ite_stmt, ret_mono))
}
};

Ok(res)
Expand Down Expand Up @@ -1442,6 +1481,7 @@ pub fn instantiate_fn_call<B: Backend>(
None,
),
FnKind::Native(fn_def) => {
mono_fn_env.return_typed = ret_typed.clone();
let (stmts_typed, mono_info) =
monomorphize_block(ctx, mono_fn_env, &fn_def.body, ret_typed.as_ref())?;

Expand Down
15 changes: 15 additions & 0 deletions src/name_resolution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,21 @@ impl NameResCtx {
self.resolve_stmt(stmt)?;
}
}
StmtKind::Ite {
condition,
then_branch,
else_branch,
} => {
self.resolve_expr(condition);
for stmt in then_branch {
self.resolve_stmt(stmt)?;
}
if let Some(stmts) = else_branch {
for stmt in stmts {
self.resolve_stmt(stmt)?;
}
}
}
};

Ok(())
Expand Down
1 change: 1 addition & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub mod types;
pub use expr::{Expr, ExprKind, Op2};
use serde::Serialize;
pub use structs::{CustomType, StructDef};
use types::FnSig;

//~
//~ # Grammar
Expand Down
97 changes: 95 additions & 2 deletions src/parser/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ pub fn is_numeric(typ: &TyKind) -> bool {
matches!(typ, TyKind::Field { .. })
}

pub fn is_bool(typ: &TyKind) -> bool {
matches!(typ, TyKind::Bool)
}

//~
//~ ## Type
//~
Expand Down Expand Up @@ -1423,6 +1427,16 @@ pub enum StmtKind {
argument: ForLoopArgument,
body: Vec<Stmt>,
},

// if <condition> {<body>} else {<body>}
// it contains an extra field to keep the return type of the function from where the ite is called
// if the ite contains early returns then this is used to type check the return from the ite to that
// of the function
Ite {
condition: Box<Expr>,
then_branch: Vec<Stmt>,
else_branch: Option<Vec<Stmt>>,
},
}

impl Stmt {
Expand Down Expand Up @@ -1576,8 +1590,87 @@ impl Stmt {
kind: TokenKind::Keyword(Keyword::If),
span,
}) => {
// TODO: wait, this should be implemented as an expression! not a statement
Err(Error::new("parse", ErrorKind::UnexpectedError("if statements are not implemented yet. Use if expressions instead (e.g. `x = if cond {{ 1 }} else {{ 2 }};`)"), span))?
// if else blocks are only allowed for hint functions
tokens.bump(ctx);

// if cond {
// ^^
let condition = Expr::parse(ctx, tokens)?;

// if cond {
// ^^
tokens.bump_expected(ctx, TokenKind::LeftCurlyBracket)?;

let mut then_branch = vec![];
let mut is_there_else = false;
loop {
// if cond { <body> }
// ^^
let next_token = tokens.peek();
if matches!(
next_token,
Some(Token {
kind: TokenKind::RightCurlyBracket,
..
})
) {
tokens.bump(ctx);
let next_token = tokens.peek();
// if cond { <body> } else { <body> }
// ^^
is_there_else = matches!(
next_token,
Some(Token {
kind: TokenKind::Keyword(Keyword::Else),
..
})
);
break;
}
let stmt = Stmt::parse(ctx, tokens)?;
then_branch.push(stmt);
}
let mut else_branch = vec![];
if is_there_else {
// if cond {<body>} else {<body>}
// ^^
tokens.bump(ctx);
// if cond {<body>} else {<body>}
// ^^
tokens.bump_expected(ctx, TokenKind::LeftCurlyBracket)?;

loop {
let next_token = tokens.peek();
// if cond {<body>} else { <body> }
// ^^
if matches!(
next_token,
Some(Token {
kind: TokenKind::RightCurlyBracket,
..
})
) {
tokens.bump(ctx);
break;
}
let stmt = Stmt::parse(ctx, tokens)?;
else_branch.push(stmt);
}
}
let else_branch = if (else_branch.is_empty()) {
None
} else {
Some(else_branch)
};

Ok(Stmt {
kind: StmtKind::Ite {
condition: Box::new(condition),
then_branch,
else_branch,
},
span,
})
}

// return
Expand Down
Loading