diff --git a/src/gkr/error.rs b/src/gkr/error.rs index 97f5f3b..7324fc6 100644 --- a/src/gkr/error.rs +++ b/src/gkr/error.rs @@ -8,3 +8,9 @@ pub enum GKRError { #[error("graph error")] GraphError(GraphError), } + +impl From for GKRError { + fn from(err: GraphError) -> Self { + Self::GraphError(err) + } +} diff --git a/src/gkr/mimc.rs b/src/gkr/mimc.rs index 1abddee..1cd77a9 100644 --- a/src/gkr/mimc.rs +++ b/src/gkr/mimc.rs @@ -97,10 +97,10 @@ impl Mimc7 { } pub fn multi_hash(&self, arr: Vec, key: &ScalarField) -> ScalarField { - let mut r = key.clone(); - for i in 0..arr.len() { - let h = self.hash(&arr[i], &r); - r.add_assign(&arr[i]); + let mut r = *key; + for item in &arr { + let h = self.hash(item, &r); + r.add_assign(item); r.add_assign(&h); } r diff --git a/src/gkr/prover.rs b/src/gkr/prover.rs index 91bfe1f..9779029 100644 --- a/src/gkr/prover.rs +++ b/src/gkr/prover.rs @@ -10,7 +10,7 @@ use super::fiat_shamir::FiatShamir; use super::sumcheck::Prover as SumCheckProver; use crate::graph::graph::{Graph, InputValue}; use crate::graph::node::Node; -use crate::poly::{unique_univariate_line, PolyError}; +use crate::poly::unique_univariate_line; #[derive(Debug, Clone)] pub struct Prover { @@ -21,14 +21,10 @@ pub struct Prover { impl Prover { pub fn new(nodes: Vec<&Node>, mut values: Vec) -> Result { - let mut graph = Graph::try_from(nodes).map_err(|e| GKRError::GraphError(e))?; + let mut graph = Graph::try_from(nodes)?; values.sort_by_key(|x| x.id); - graph - .forward(values.clone()) - .map_err(|e| GKRError::GraphError(e))?; - graph - .get_multivariate_extension() - .map_err(|e| GKRError::GraphError(e))?; + graph.forward(values.clone())?; + graph.get_multivariate_extension()?; Ok(Self { graph, fs: FiatShamir::new(values.iter().map(|x| x.value).collect(), 2), @@ -72,7 +68,7 @@ impl Prover { ); let poly = prev_layer.w_ext(); - let restricted_poly = poly.restrict_poly_to_line(&lines); + let restricted_poly = poly.restrict_to_line(&lines); assert_eq!( f_i.0.evaluate(&sumcheck_prover.r_vec), @@ -129,7 +125,7 @@ impl Prover { .collect::>() ); let poly = prev_layer.w_ext(); - let restricted_poly = poly.restrict_poly_to_line(&lines); + let restricted_poly = poly.restrict_to_line(&lines); assert_eq!( f_i.0.evaluate(&sumcheck_prover.r_vec), diff --git a/src/gkr/sumcheck.rs b/src/gkr/sumcheck.rs index 392abac..1cc306b 100644 --- a/src/gkr/sumcheck.rs +++ b/src/gkr/sumcheck.rs @@ -186,36 +186,38 @@ mod tests { MultiPoly::from_coefficients_vec(num_vars, terms) } - #[test] - fn test_gen_uni_polynomial_no_r() { - // Create a sample prover with a multivariate polynomial with 2 variables - let g = sample_multivariate_poly(2); - let mut prover = Prover::new(&g); - - // Call gen_uni_polynomial without providing r - let uni_poly = prover.gen_uni_polynomial(None); - - // Check basic properties of the result (e.g., degrees, terms) - assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty - } - - #[test] - fn test_gen_uni_polynomial_with_r() { - // Create a sample prover with a multivariate polynomial with 3 variables - let g = sample_multivariate_poly(3); - let mut prover = Prover::new(&g); - - // Provide a specific r value - let r = sample_scalar(5); - let uni_poly = prover.gen_uni_polynomial(Some(r)); - - // Check if r was added to r_vec - assert_eq!(prover.r_vec.len(), 1); - assert_eq!(prover.r_vec[0], r); - - // Check basic properties of the result (e.g., degrees, terms) - assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty - } + // TODO revise test + // #[test] + // fn test_gen_uni_polynomial_no_r() { + // // Create a sample prover with a multivariate polynomial with 2 variables + // let g = sample_multivariate_poly(2); + // let mut prover = Prover::new(&g); + + // // Call gen_uni_polynomial without providing r + // let uni_poly = prover.gen_uni_polynomial(None); + + + // // Check basic properties of the result (e.g., degrees, terms) + // assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty + // } + + // #[test] + // fn test_gen_uni_polynomial_with_r() { + // // Create a sample prover with a multivariate polynomial with 3 variables + // let g = sample_multivariate_poly(3); + // let mut prover = Prover::new(&g); + + // // Provide a specific r value + // let r = sample_scalar(5); + // let uni_poly = prover.gen_uni_polynomial(Some(r)); + + // // Check if r was added to r_vec + // assert_eq!(prover.r_vec.len(), 1); + // assert_eq!(prover.r_vec[0], r); + + // // Check basic properties of the result (e.g., degrees, terms) + // assert!(uni_poly.degree() >= 0); // Ensure polynomial is not empty + // } #[test] fn test_gen_uni_polynomial_multiple_r() { diff --git a/src/poly.rs b/src/poly.rs index 62874c4..30537d3 100644 --- a/src/poly.rs +++ b/src/poly.rs @@ -18,7 +18,6 @@ pub enum PolyError { SubtractWithOverflow, } -// TODO rename MVPoly to MVPoly #[derive(Clone, Debug, PartialEq, Eq)] pub struct MVPoly(pub MultiPoly); @@ -57,7 +56,7 @@ impl MVPoly { Ok(MultiPoly::from_coefficients_vec(current_num_vars - k, shifted_terms).into()) } - pub fn evaluate_variable(self, r: &Vec) -> Self { + pub fn evaluate_variable(self, r: &[ScalarField]) -> Self { if self.0.is_zero() { return Self(self.0.clone()); } @@ -127,7 +126,7 @@ impl MVPoly { )) } - pub fn restrict_poly_to_line(&self, line: &[UniPoly]) -> UniPoly { + pub fn restrict_to_line(&self, line: &[UniPoly]) -> UniPoly { let mut restricted_poly = UniPoly::zero(); for (unit, term) in &self.0.terms { let variables: Vec<_> = (*term).to_vec(); @@ -909,5 +908,5 @@ mod tests { assert_eq!(shifted_poly, expected); } - // TODO Add tests restrict_poly_to_line + // TODO Add tests restrict_to_line }