Skip to content

Commit

Permalink
add memory checks to VM circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
govereau committed Jan 9, 2024
1 parent 3a71910 commit 69e9db0
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 75 deletions.
113 changes: 51 additions & 62 deletions prover/src/circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ pub use ark_r1cs_std::{

use nexus_riscv::{
nop_vm,
vm::trace::{trace, Trace, Block},
vm::memory::path::{poseidon_config, ParamsVar},
vm::trace::{trace, Trace, Block, Witness},
};
use nexus_riscv_circuit::{
r1cs::{ZERO, V, R1CS},
Expand Down Expand Up @@ -50,6 +51,7 @@ impl Tr {
for x in b.regs.x {
v.push(F1::from(x));
}
v.push(b.steps[0].pc_path.root);
v
}
}
Expand All @@ -60,74 +62,66 @@ pub fn nop_circuit(k: usize) -> Result<Tr, ProofError> {
Ok(Tr::new(trace))
}

// fast version
fn build_witness_partial(cs: CS, rcs: R1CS) -> Result<Vec<FpVar<F1>>, SynthesisError> {
let mut output: Vec<FpVar<F1>> = Vec::new();
fn add_paths(cs: CS, w: &Witness, vars: &[FpVar<F1>]) -> Result<(), SynthesisError> {
let params = poseidon_config();
let params = ParamsVar::new_constant(cs.clone(), params)?;

for (i, x) in rcs.w.iter().enumerate() {
if rcs.input_range().contains(&i) {
// variables already allocated in z
} else if rcs.output_range().contains(&i) {
let av = AllocatedFp::new_witness(cs.clone(), || Ok(*x))?;
output.push(FpVar::Var(av))
} else {
cs.new_witness_variable(|| Ok(*x))?;
}
}
Ok(output)
}
// TODO: fixme (constants) - see init_cs in riscv-circuit
let root_in = &vars[Tr::ARITY];
let root_out = &vars[Tr::ARITY * 2];
let mem = Tr::ARITY * 2 + 1;

fn build_witness(
cs: CS,
index: usize,
_z: &[FpVar<F1>],
tr: &Tr,
) -> Result<Vec<FpVar<F1>>, SynthesisError> {
let b = tr.block(index);
let mut v = Vec::new();
for w in b {
let rcs = big_step(&w, true);
v = build_witness_partial(cs.clone(), rcs)?;
}
Ok(v)
w.pc_path
.verify_circuit(cs.clone(), &params, root_in, &vars[mem..])?;
w.read_path
.verify_circuit(cs.clone(), &params, root_in, &vars[mem + 2..])?;
w.write_path
.verify_circuit(cs.clone(), &params, root_out, &vars[mem + 4..])?;

Ok(())
}

// slow version
fn build_constraints_partial(
cs: CS,
witness_only: bool,
z: &[FpVar<F1>],
w: &Witness,
rcs: R1CS,
) -> Result<Vec<FpVar<F1>>, SynthesisError> {
let mut vars: Vec<Variable> = Vec::new();
let mut vars: Vec<FpVar<F1>> = Vec::new();
let mut output: Vec<FpVar<F1>> = Vec::new();

for (i, x) in rcs.w.iter().enumerate() {
if rcs.input_range().contains(&i) {
if let FpVar::Var(AllocatedFp { variable, .. }) = z[i - rcs.input_range().start] {
vars.push(variable)
} else {
panic!()
}
let fp = &z[i - rcs.input_range().start];
vars.push(fp.clone())
} else if rcs.output_range().contains(&i) {
let av = AllocatedFp::new_witness(cs.clone(), || Ok(*x))?;
vars.push(av.variable);
output.push(FpVar::Var(av))
let fp = FpVar::Var(AllocatedFp::new_witness(cs.clone(), || Ok(*x))?);
vars.push(fp.clone());
output.push(fp)
} else {
vars.push(cs.new_witness_variable(|| Ok(*x))?)
let fp = FpVar::Var(AllocatedFp::new_witness(cs.clone(), || Ok(*x))?);
vars.push(fp);
}
}

add_paths(cs.clone(), w, &vars)?;

if witness_only {
return Ok(output);
}

let row = |a: &V| {
a.iter().enumerate().fold(
lc!(),
|lc, (i, x)| {
if x == &ZERO {
lc
} else {
lc + (*x, vars[i])
a.iter().enumerate().fold(lc!(), |lc, (i, x)| {
if x == &ZERO {
lc
} else {
match &vars[i] {
FpVar::Constant(f) => lc + (*x * f, Variable::One),
FpVar::Var(av) => lc + (*x, av.variable),
}
},
)
}
})
};

for i in 0..rcs.a.len() {
Expand All @@ -143,41 +137,36 @@ fn build_constraints(
z: &[FpVar<F1>],
tr: &Tr,
) -> Result<Vec<FpVar<F1>>, SynthesisError> {
let witness_only = match cs.borrow().unwrap().mode {
SynthesisMode::Setup => false,
SynthesisMode::Prove { construct_matrices } => !construct_matrices,
};

let b = tr.block(index);
let mut z = z;
let mut v = Vec::new();

for w in b {
let rcs = big_step(&w, false);
v = build_constraints_partial(cs.clone(), z, rcs)?;
v = build_constraints_partial(cs.clone(), witness_only, z, &w, rcs)?;
z = &v;
}

Ok(v)
}

impl StepCircuit<F1> for Tr {
const ARITY: usize = 33;
const ARITY: usize = 34;

fn generate_constraints(
&self,
cs: CS,
k: &FpVar<F1>,
z: &[FpVar<F1>],
) -> Result<Vec<FpVar<F1>>, SynthesisError> {
let matrices = match cs.borrow().unwrap().mode {
SynthesisMode::Setup => true,
SynthesisMode::Prove { construct_matrices } => construct_matrices,
};

let index = k.value().map_or(0, |s| match s.into_bigint() {
BigInt(l) => l[0] as usize,
});

if !matrices {
build_witness(cs, index, z, self)
} else {
build_constraints(cs, index, z, self)
}
build_constraints(cs, index, z, self)
}
}
8 changes: 6 additions & 2 deletions riscv-circuit/src/r1cs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ impl R1CS {
n
}

pub fn set_var(&mut self, name: &str, val: u32) -> usize {
pub fn set_field_var(&mut self, name: &str, val: F) -> usize {
let j = self.new_var(name);
self.w[j] = F::from(val);
self.w[j] = val;
j
}

pub fn set_var(&mut self, name: &str, val: u32) -> usize {
self.set_field_var(name, F::from(val))
}

pub fn new_local_var(&mut self, name: &str) -> usize {
if self.vars.contains_key(name) {
panic!("local variable override {name}");
Expand Down
48 changes: 37 additions & 11 deletions riscv-circuit/src/riscv.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,46 @@
//! Circuits for the RISC-V VM (nexus-riscv)
use super::r1cs::*;

use nexus_riscv::vm::memory::path::Path;
use nexus_riscv::vm::trace::*;
use nexus_riscv::rv32::{*, parse::*};

use super::r1cs::*;

// Note: circuit generation code depends on this ordering
// (inputs: pc,x0..31 and then outputs: PC,x'0..31)

#[allow(clippy::field_reassign_with_default)]
#[allow(clippy::needless_range_loop)]
fn init_cs(pc: u32, regs: &[u32; 32]) -> R1CS {
fn init_cs(w: &Witness) -> R1CS {
let mut cs = R1CS::default();
cs.arity = 33;
cs.set_var("pc", pc);
cs.arity = 34;

// inputs
cs.set_var("pc", w.regs.pc);
for i in 0..32 {
cs.set_var(&format!("x{i}"), regs[i]);
cs.set_var(&format!("x{i}"), w.regs.x[i]);
}
cs.set_var("PC", pc);
cs.set_field_var("root", w.pc_path.root);

// outputs
cs.set_var("PC", w.PC);
for i in 0..32 {
cs.set_var(&format!("x'{i}"), regs[i]);
cs.set_var(&format!("x'{i}"), w.regs.x[i]);
}
cs.set_field_var("ROOT", w.write_path.root);

// memory contents
add_path(&mut cs, "pc_path", &w.pc_path);
add_path(&mut cs, "read_path", &w.read_path);
add_path(&mut cs, "write_path", &w.write_path);
cs
}

fn add_path(cs: &mut R1CS, prefix: &str, path: &Path) {
cs.set_field_var(&format!("{}_1", prefix), path.leaf[0]);
cs.set_field_var(&format!("{}_2", prefix), path.leaf[1]);
}

fn select_XY(cs: &mut R1CS, rs1: u32, rs2: u32) {
load_reg(cs, "rs1", "X", rs1);
load_reg(cs, "rs2", "Y", rs2);
Expand Down Expand Up @@ -458,7 +475,7 @@ fn parse_J(cs: &mut R1CS, J: u32) {
}

pub fn big_step(vm: &Witness, witness_only: bool) -> R1CS {
let mut cs = init_cs(vm.regs.pc, &vm.regs.x);
let mut cs = init_cs(vm);
cs.witness_only = witness_only;

select_XY(&mut cs, vm.rs1, vm.rs2);
Expand Down Expand Up @@ -1083,6 +1100,7 @@ fn misc(cs: &mut R1CS) {

#[cfg(test)]
mod test {
use nexus_riscv::vm::eval::Regs;
use nexus_riscv::rv32::parse::*;
use super::*;

Expand Down Expand Up @@ -1160,9 +1178,13 @@ mod test {
#[test]
fn test_select_XY() {
let regs: [u32; 32] = core::array::from_fn(|i| i as u32);
let w = Witness {
regs: Regs { pc: 0, x: regs },
..Witness::default()
};
for x in [0, 1, 2, 31] {
for y in [0, 6, 13] {
let mut cs = init_cs(0, &regs);
let mut cs = init_cs(&w);
select_XY(&mut cs, x, y);
assert!(cs.is_sat());
assert!(cs.get_var("X") == &F::from(x));
Expand All @@ -1174,8 +1196,12 @@ mod test {
#[test]
fn test_select_Z() {
let regs: [u32; 32] = core::array::from_fn(|i| i as u32);
let w = Witness {
regs: Regs { pc: 0, x: regs },
..Witness::default()
};
for i in 0..32 {
let mut cs = init_cs(0, &regs);
let mut cs = init_cs(&w);
let j = cs.new_var("Z");
let z = F::from(100);
cs.w[j] = z;
Expand Down

0 comments on commit 69e9db0

Please sign in to comment.