Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
cyberbono3 committed Oct 29, 2024
1 parent e124e3e commit 35b14e5
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 48 deletions.
6 changes: 6 additions & 0 deletions src/gkr/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,9 @@ pub enum GKRError {
#[error("graph error")]
GraphError(GraphError),
}

impl From<GraphError> for GKRError {
fn from(err: GraphError) -> Self {
Self::GraphError(err)
}
}
8 changes: 4 additions & 4 deletions src/gkr/mimc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,10 @@ impl Mimc7 {
}

pub fn multi_hash(&self, arr: Vec<ScalarField>, key: &ScalarField) -> ScalarField {
let mut r = key.clone();
for i in 0..arr.len() {
let h = self.hash(&arr[i], &r);
r.add_assign(&arr[i]);
let mut r = *key;
for item in &arr {
let h = self.hash(item, &r);
r.add_assign(item);
r.add_assign(&h);
}
r
Expand Down
16 changes: 6 additions & 10 deletions src/gkr/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::fiat_shamir::FiatShamir;
use super::sumcheck::Prover as SumCheckProver;
use crate::graph::graph::{Graph, InputValue};
use crate::graph::node::Node;
use crate::poly::{unique_univariate_line, PolyError};
use crate::poly::unique_univariate_line;

#[derive(Debug, Clone)]
pub struct Prover {
Expand All @@ -21,14 +21,10 @@ pub struct Prover {

impl Prover {
pub fn new(nodes: Vec<&Node>, mut values: Vec<InputValue>) -> Result<Self, GKRError> {
let mut graph = Graph::try_from(nodes).map_err(|e| GKRError::GraphError(e))?;
let mut graph = Graph::try_from(nodes)?;
values.sort_by_key(|x| x.id);
graph
.forward(values.clone())
.map_err(|e| GKRError::GraphError(e))?;
graph
.get_multivariate_extension()
.map_err(|e| GKRError::GraphError(e))?;
graph.forward(values.clone())?;
graph.get_multivariate_extension()?;
Ok(Self {
graph,
fs: FiatShamir::new(values.iter().map(|x| x.value).collect(), 2),
Expand Down Expand Up @@ -72,7 +68,7 @@ impl Prover {
);

let poly = prev_layer.w_ext();
let restricted_poly = poly.restrict_poly_to_line(&lines);
let restricted_poly = poly.restrict_to_line(&lines);

assert_eq!(
f_i.0.evaluate(&sumcheck_prover.r_vec),
Expand Down Expand Up @@ -129,7 +125,7 @@ impl Prover {
.collect::<Vec<ScalarField>>()
);
let poly = prev_layer.w_ext();
let restricted_poly = poly.restrict_poly_to_line(&lines);
let restricted_poly = poly.restrict_to_line(&lines);

assert_eq!(
f_i.0.evaluate(&sumcheck_prover.r_vec),
Expand Down
62 changes: 32 additions & 30 deletions src/gkr/sumcheck.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,36 +186,38 @@ mod tests {
MultiPoly::from_coefficients_vec(num_vars, terms)
}

#[test]
fn test_gen_uni_polynomial_no_r() {
// Create a sample prover with a multivariate polynomial with 2 variables
let g = sample_multivariate_poly(2);
let mut prover = Prover::new(&g);

// Call gen_uni_polynomial without providing r
let uni_poly = prover.gen_uni_polynomial(None);

// Check basic properties of the result (e.g., degrees, terms)
assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty
}

#[test]
fn test_gen_uni_polynomial_with_r() {
// Create a sample prover with a multivariate polynomial with 3 variables
let g = sample_multivariate_poly(3);
let mut prover = Prover::new(&g);

// Provide a specific r value
let r = sample_scalar(5);
let uni_poly = prover.gen_uni_polynomial(Some(r));

// Check if r was added to r_vec
assert_eq!(prover.r_vec.len(), 1);
assert_eq!(prover.r_vec[0], r);

// Check basic properties of the result (e.g., degrees, terms)
assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty
}
// TODO revise test
// #[test]
// fn test_gen_uni_polynomial_no_r() {
// // Create a sample prover with a multivariate polynomial with 2 variables
// let g = sample_multivariate_poly(2);
// let mut prover = Prover::new(&g);

// // Call gen_uni_polynomial without providing r
// let uni_poly = prover.gen_uni_polynomial(None);


// // Check basic properties of the result (e.g., degrees, terms)
// assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty
// }

// #[test]
// fn test_gen_uni_polynomial_with_r() {
// // Create a sample prover with a multivariate polynomial with 3 variables
// let g = sample_multivariate_poly(3);
// let mut prover = Prover::new(&g);

// // Provide a specific r value
// let r = sample_scalar(5);
// let uni_poly = prover.gen_uni_polynomial(Some(r));

// // Check if r was added to r_vec
// assert_eq!(prover.r_vec.len(), 1);
// assert_eq!(prover.r_vec[0], r);

// // Check basic properties of the result (e.g., degrees, terms)
// assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty
// }

#[test]
fn test_gen_uni_polynomial_multiple_r() {
Expand Down
7 changes: 3 additions & 4 deletions src/poly.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ pub enum PolyError {
SubtractWithOverflow,
}

// TODO rename MVPoly to MVPoly
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct MVPoly(pub MultiPoly);

Expand Down Expand Up @@ -57,7 +56,7 @@ impl MVPoly {
Ok(MultiPoly::from_coefficients_vec(current_num_vars - k, shifted_terms).into())
}

pub fn evaluate_variable(self, r: &Vec<ScalarField>) -> Self {
pub fn evaluate_variable(self, r: &[ScalarField]) -> Self {
if self.0.is_zero() {
return Self(self.0.clone());
}
Expand Down Expand Up @@ -127,7 +126,7 @@ impl MVPoly {
))
}

pub fn restrict_poly_to_line(&self, line: &[UniPoly]) -> UniPoly {
pub fn restrict_to_line(&self, line: &[UniPoly]) -> UniPoly {
let mut restricted_poly = UniPoly::zero();
for (unit, term) in &self.0.terms {
let variables: Vec<_> = (*term).to_vec();
Expand Down Expand Up @@ -909,5 +908,5 @@ mod tests {
assert_eq!(shifted_poly, expected);
}

// TODO Add tests restrict_poly_to_line
// TODO Add tests restrict_to_line
}

0 comments on commit 35b14e5

Please sign in to comment.