diff --git a/.github/workflows/combine-prs.yml b/.github/workflows/combine-prs.yml new file mode 100644 index 000000000..cc3d2a986 --- /dev/null +++ b/.github/workflows/combine-prs.yml @@ -0,0 +1,24 @@ +name: Combine PRs + +on: + schedule: + - cron: "0 1 * * MON" + workflow_dispatch: # allows to manually trigger the workflow + +# The minimum permissions required to run this Action +permissions: + contents: write + pull-requests: write + checks: read + +jobs: + combine-prs: + runs-on: ubuntu-latest + + steps: + - name: combine-prs + id: combine-prs + uses: github/combine-prs@v3.1.1 + with: + github_token: ${{ secrets.ORG_GITHUB_PAT }} + labels: "dependabot,combined-pr" diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f8f4e68b..75573a698 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and follow [semantic versioning](https://semver.org/) for our releases. - [#238](https://github.com/EspressoSystems/jellyfish/pull/238) add public keys into signature aggregation APIs - [#251](https://github.com/EspressoSystems/jellyfish/pull/251) add sign_key_ref api for BLSKeyPair - [#297](https://github.com/EspressoSystems/jellyfish/pull/297) Updated `tagged-base64` dependency to the `crates.io` package +- [#299](https://github.com/EspressoSystems/jellyfish/pull/299) For Merkle tree, `DigestAlgorithm` now returns a `Result` type. ### Removed diff --git a/primitives/src/aead.rs b/primitives/src/aead.rs index 5bf1534b5..66989bde7 100644 --- a/primitives/src/aead.rs +++ b/primitives/src/aead.rs @@ -11,6 +11,7 @@ //! independent of RustCrypto's upstream changes. use crate::errors::PrimitivesError; +use ark_serialize::*; use ark_std::{ fmt, format, ops::{Deref, DerefMut}, @@ -124,7 +125,9 @@ impl fmt::Debug for DecKey { } /// Keypair for Authenticated Encryption with Associated Data -#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[derive( + Clone, Debug, Default, Serialize, Deserialize, CanonicalSerialize, CanonicalDeserialize, +)] pub struct KeyPair { enc_key: EncKey, dec_key: DecKey, @@ -240,13 +243,127 @@ impl DerefMut for Nonce { } /// The ciphertext produced by AEAD encryption -#[derive(Clone, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[derive( + Clone, + Debug, + PartialEq, + Eq, + Hash, + Serialize, + Deserialize, + CanonicalSerialize, + CanonicalDeserialize, +)] pub struct Ciphertext { nonce: Nonce, ct: Vec, ephemeral_pk: EncKey, } +// TODO: (alex) Temporarily add CanonicalSerde back to these structs due to the +// limitations of `tagged` proc macro and requests from downstream usage. +// Tracking issue: +mod canonical_serde { + use super::*; + + impl CanonicalSerialize for EncKey { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + let bytes: [u8; crypto_kx::PublicKey::BYTES] = self.clone().into(); + writer.write_all(&bytes)?; + Ok(()) + } + fn serialized_size(&self, _compress: Compress) -> usize { + crypto_kx::PublicKey::BYTES + } + } + + impl CanonicalDeserialize for EncKey { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + _validate: Validate, + ) -> Result { + let mut result = [0u8; crypto_kx::PublicKey::BYTES]; + reader.read_exact(&mut result)?; + Ok(EncKey(crypto_kx::PublicKey::from(result))) + } + } + + impl Valid for EncKey { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } + } + + impl CanonicalSerialize for DecKey { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + let bytes: [u8; crypto_kx::SecretKey::BYTES] = self.clone().into(); + writer.write_all(&bytes)?; + Ok(()) + } + fn serialized_size(&self, _compress: Compress) -> usize { + crypto_kx::SecretKey::BYTES + } + } + + impl CanonicalDeserialize for DecKey { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + _validate: Validate, + ) -> Result { + let mut result = [0u8; crypto_kx::SecretKey::BYTES]; + reader.read_exact(&mut result)?; + Ok(DecKey(crypto_kx::SecretKey::from(result))) + } + } + impl Valid for DecKey { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } + } + + impl CanonicalSerialize for Nonce { + fn serialize_with_mode( + &self, + mut writer: W, + _compress: Compress, + ) -> Result<(), SerializationError> { + writer.write_all(self.0.as_slice())?; + Ok(()) + } + fn serialized_size(&self, _compress: Compress) -> usize { + // see + 24 + } + } + + impl CanonicalDeserialize for Nonce { + fn deserialize_with_mode( + mut reader: R, + _compress: Compress, + _validate: Validate, + ) -> Result { + let mut result = [0u8; 24]; + reader.read_exact(&mut result)?; + Ok(Nonce(XNonce::from(result))) + } + } + impl Valid for Nonce { + fn check(&self) -> Result<(), SerializationError> { + Ok(()) + } + } +} + #[cfg(test)] mod test { use super::*; @@ -322,4 +439,30 @@ mod test { // wrong byte length assert!(bincode::deserialize::(&bytes[1..]).is_err()); } + + #[test] + fn test_canonical_serde() { + let mut rng = jf_utils::test_rng(); + let keypair = KeyPair::generate(&mut rng); + let msg = b"The quick brown fox jumps over the lazy dog".to_vec(); + let aad = b"my associated data".to_vec(); + let ciphertext = keypair.enc_key.encrypt(&mut rng, &msg, &aad).unwrap(); + + // when testing keypair, already tests serde on pk and sk + let mut bytes = Vec::new(); + CanonicalSerialize::serialize_compressed(&keypair, &mut bytes).unwrap(); + assert_eq!( + keypair, + KeyPair::deserialize_compressed(&bytes[..]).unwrap() + ); + assert!(KeyPair::deserialize_compressed(&bytes[1..]).is_err()); + + let mut bytes = Vec::new(); + CanonicalSerialize::serialize_compressed(&ciphertext, &mut bytes).unwrap(); + assert_eq!( + ciphertext, + Ciphertext::deserialize_compressed(&bytes[..]).unwrap() + ); + assert!(Ciphertext::deserialize_compressed(&bytes[1..]).is_err()); + } } diff --git a/primitives/src/merkle_tree/examples.rs b/primitives/src/merkle_tree/examples.rs index d4e4db274..00b032dde 100644 --- a/primitives/src/merkle_tree/examples.rs +++ b/primitives/src/merkle_tree/examples.rs @@ -8,12 +8,16 @@ //! E.g. Sparse merkle tree with BigUInt index. use super::{append_only::MerkleTree, prelude::RescueHash, DigestAlgorithm, Element, Index}; -use crate::rescue::{sponge::RescueCRHF, RescueParameter}; +use crate::{ + errors::PrimitivesError, + rescue::{sponge::RescueCRHF, RescueParameter}, +}; use ark_ff::Field; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, Write, }; +use ark_std::vec::Vec; use sha3::{Digest, Sha3_256}; use typenum::U3; @@ -23,13 +27,13 @@ pub struct Interval(pub F, pub F); // impl Element for Interval {} impl DigestAlgorithm, u64, F> for RescueHash { - fn digest(data: &[F]) -> F { - RescueCRHF::::sponge_no_padding(data, 1).unwrap()[0] + fn digest(data: &[F]) -> Result { + Ok(RescueCRHF::::sponge_no_padding(data, 1)?[0]) } - fn digest_leaf(pos: &u64, elem: &Interval) -> F { + fn digest_leaf(pos: &u64, elem: &Interval) -> Result { let data = [F::from(*pos), elem.0, elem.1]; - RescueCRHF::::sponge_no_padding(&data, 1).unwrap()[0] + Ok(RescueCRHF::::sponge_no_padding(&data, 1)?[0]) } } @@ -39,7 +43,7 @@ pub type IntervalMerkleTree = MerkleTree, RescueHash, u64, U3, /// Update the array length here #[derive(Default, Eq, PartialEq, Clone, Copy, Debug, Ord, PartialOrd, Hash)] -pub struct Sha3Node([u8; 32]); +pub struct Sha3Node(pub(crate) [u8; 32]); impl AsRef<[u8]> for Sha3Node { fn as_ref(&self) -> &[u8] { @@ -82,18 +86,21 @@ impl Valid for Sha3Node { /// Wrapper for SHA3_512 hash function pub struct Sha3Digest(); -impl DigestAlgorithm for Sha3Digest { - fn digest(data: &[Sha3Node]) -> Sha3Node { +impl DigestAlgorithm for Sha3Digest { + fn digest(data: &[Sha3Node]) -> Result { let mut hasher = Sha3_256::new(); for value in data { hasher.update(value); } - Sha3Node(hasher.finalize().into()) + Ok(Sha3Node(hasher.finalize().into())) } - fn digest_leaf(_pos: &I, _elem: &E) -> Sha3Node { - // Serialize and hash - todo!() + fn digest_leaf(_pos: &I, elem: &E) -> Result { + let mut writer = Vec::new(); + elem.serialize_compressed(&mut writer).unwrap(); + let mut hasher = Sha3_256::new(); + hasher.update(writer); + Ok(Sha3Node(hasher.finalize().into())) } } diff --git a/primitives/src/merkle_tree/hasher.rs b/primitives/src/merkle_tree/hasher.rs index 2248f4959..e97d9a36a 100644 --- a/primitives/src/merkle_tree/hasher.rs +++ b/primitives/src/merkle_tree/hasher.rs @@ -35,6 +35,8 @@ //! Use [`GenericHasherMerkleTree`] if you prefer to specify your own `Arity` //! and node [`Index`] types. +use crate::errors::PrimitivesError; + use super::{append_only::MerkleTree, DigestAlgorithm, Element, Index}; use ark_serialize::{ CanonicalDeserialize, CanonicalSerialize, Compress, Read, SerializationError, Valid, Validate, @@ -75,21 +77,19 @@ where H: Digest + Write, <::OutputSize as ArrayLength>::ArrayType: Copy, { - fn digest(data: &[HasherNode]) -> HasherNode { + fn digest(data: &[HasherNode]) -> Result, PrimitivesError> { let mut hasher = H::new(); for value in data { hasher.update(value.as_ref()); } - HasherNode(hasher.finalize()) + Ok(HasherNode(hasher.finalize())) } - fn digest_leaf(pos: &I, elem: &E) -> HasherNode { + fn digest_leaf(pos: &I, elem: &E) -> Result, PrimitivesError> { let mut hasher = H::new(); - pos.serialize_uncompressed(&mut hasher) - .expect("serialize should succeed"); - elem.serialize_uncompressed(&mut hasher) - .expect("serialize should succeed"); - HasherNode(hasher.finalize()) + pos.serialize_uncompressed(&mut hasher)?; + elem.serialize_uncompressed(&mut hasher)?; + Ok(HasherNode(hasher.finalize())) } } diff --git a/primitives/src/merkle_tree/internal.rs b/primitives/src/merkle_tree/internal.rs index b944c2d4e..c48ce0812 100644 --- a/primitives/src/merkle_tree/internal.rs +++ b/primitives/src/merkle_tree/internal.rs @@ -203,20 +203,20 @@ where let children = chunk .map(|(pos, elem)| { let pos = I::from(pos as u64); - Box::new(MerkleNode::Leaf { - value: H::digest_leaf(&pos, elem.borrow()), + Ok(Box::new(MerkleNode::Leaf { + value: H::digest_leaf(&pos, elem.borrow())?, pos, elem: elem.borrow().clone(), - }) + })) }) - .pad_using(Arity::to_usize(), |_| Box::new(MerkleNode::Empty)) - .collect_vec(); - Box::new(MerkleNode::::Branch { - value: digest_branch::(&children), + .pad_using(Arity::to_usize(), |_| Ok(Box::new(MerkleNode::Empty))) + .collect::, PrimitivesError>>()?; + Ok(Box::new(MerkleNode::::Branch { + value: digest_branch::(&children)?, children, - }) + })) }) - .collect_vec(); + .collect::, PrimitivesError>>()?; for _ in 1..height { cur_nodes = cur_nodes .into_iter() @@ -227,13 +227,13 @@ where .pad_using(Arity::to_usize(), |_| { Box::new(MerkleNode::::Empty) }) - .collect_vec(); - Box::new(MerkleNode::::Branch { - value: digest_branch::(&children), + .collect::>(); + Ok(Box::new(MerkleNode::::Branch { + value: digest_branch::(&children)?, children, - }) + })) }) - .collect_vec(); + .collect::, PrimitivesError>>()?; } Ok((cur_nodes[0].clone(), num_leaves)) } else { @@ -272,27 +272,27 @@ where .map(|chunk| { let children = chunk .map(|(pos, elem)| { - if (pos as u64) < num_leaves - 1 { + Ok(if (pos as u64) < num_leaves - 1 { Box::new(MerkleNode::ForgettenSubtree { - value: H::digest_leaf(&I::from(pos as u64), elem.borrow()), + value: H::digest_leaf(&I::from(pos as u64), elem.borrow())?, }) } else { let pos = I::from(pos as u64); Box::new(MerkleNode::Leaf { - value: H::digest_leaf(&pos, elem.borrow()), + value: H::digest_leaf(&pos, elem.borrow())?, pos, elem: elem.borrow().clone(), }) - } + }) }) - .pad_using(Arity::to_usize(), |_| Box::new(MerkleNode::Empty)) - .collect_vec(); - Box::new(MerkleNode::::Branch { - value: digest_branch::(&children), + .pad_using(Arity::to_usize(), |_| Ok(Box::new(MerkleNode::Empty))) + .collect::, PrimitivesError>>()?; + Ok(Box::new(MerkleNode::::Branch { + value: digest_branch::(&children)?, children, - }) + })) }) - .collect_vec(); + .collect::, PrimitivesError>>()?; for i in 1..cur_nodes.len() - 1 { cur_nodes[i] = Box::new(MerkleNode::ForgettenSubtree { value: cur_nodes[i].value(), @@ -308,13 +308,13 @@ where .pad_using(Arity::to_usize(), |_| { Box::new(MerkleNode::::Empty) }) - .collect_vec(); - Box::new(MerkleNode::::Branch { - value: digest_branch::(&children), + .collect::>(); + Ok(Box::new(MerkleNode::::Branch { + value: digest_branch::(&children)?, children, - }) + })) }) - .collect_vec(); + .collect::, PrimitivesError>>()?; for i in 1..cur_nodes.len() - 1 { cur_nodes[i] = Box::new(MerkleNode::ForgettenSubtree { value: cur_nodes[i].value(), @@ -327,7 +327,9 @@ where } } -pub(crate) fn digest_branch(data: &[Box>]) -> T +pub(crate) fn digest_branch( + data: &[Box>], +) -> Result where E: Element, H: DigestAlgorithm, @@ -335,7 +337,7 @@ where T: NodeValue, { // Question(Chengyu): any more efficient implementation? - let data = data.iter().map(|node| node.value()).collect_vec(); + let data = data.iter().map(|node| node.value()).collect::>(); H::digest(&data) } @@ -372,7 +374,7 @@ where }) } }) - .collect_vec(), + .collect::>(), }); if children.iter().all(|child| { matches!( @@ -399,7 +401,7 @@ where }) } }) - .collect_vec(), + .collect::>(), }); LookupResult::NotFound(non_membership_proof) }, @@ -506,7 +508,7 @@ where }) } }) - .collect_vec(), + .collect::>(), }); LookupResult::Ok(elem, proof) }, @@ -525,7 +527,7 @@ where }) } }) - .collect_vec(), + .collect::>(), }); LookupResult::NotFound(non_membership_proof) }, @@ -547,7 +549,7 @@ where pos: impl Borrow, traversal_path: &[usize], elem: impl Borrow, - ) -> LookupResult + ) -> Result, PrimitivesError> where H: DigestAlgorithm, Arity: Unsigned, @@ -561,8 +563,8 @@ where pos, } => { let ret = ark_std::mem::replace(node_elem, elem.clone()); - *value = H::digest_leaf(pos, elem); - LookupResult::Ok(ret, ()) + *value = H::digest_leaf(pos, elem)?; + Ok(LookupResult::Ok(ret, ())) }, MerkleNode::Branch { value, children } => { let res = (*children[traversal_path[height - 1]]).update_internal::( @@ -570,22 +572,22 @@ where pos, traversal_path, elem, - ); + )?; // If the branch containing the update was not in memory, the update failed and // nothing was changed, so we can short-circuit without recomputing this node's // value. if res == LookupResult::NotInMemory { - return res; + return Ok(res); } // Otherwise, an entry has been updated and the value of one of our children has // changed, so we must recompute our own value. - *value = digest_branch::(children); - res + *value = digest_branch::(children)?; + Ok(res) }, MerkleNode::Empty => { *self = if height == 0 { MerkleNode::Leaf { - value: H::digest_leaf(pos, elem), + value: H::digest_leaf(pos, elem)?, pos: pos.clone(), elem: elem.clone(), } @@ -596,15 +598,15 @@ where pos, traversal_path, elem, - ); + )?; MerkleNode::Branch { - value: digest_branch::(&children), + value: digest_branch::(&children)?, children, } }; - LookupResult::NotFound(()) + Ok(LookupResult::NotFound(())) }, - MerkleNode::ForgettenSubtree { .. } => LookupResult::NotInMemory, + MerkleNode::ForgettenSubtree { .. } => Ok(LookupResult::NotInMemory), } } @@ -647,7 +649,7 @@ where cur_pos += I::from(increment); frontier += 1; } - *value = digest_branch::(children); + *value = digest_branch::(children)?; Ok(cnt) }, MerkleNode::Empty => { @@ -655,7 +657,7 @@ where let elem = data.next().unwrap(); let elem = elem.borrow(); *self = MerkleNode::Leaf { - value: H::digest_leaf(pos, elem), + value: H::digest_leaf(pos, elem)?, pos: pos.clone(), elem: elem.clone(), }; @@ -682,7 +684,7 @@ where frontier += 1; } *self = MerkleNode::Branch { - value: digest_branch::(&children), + value: digest_branch::(&children)?, children, }; Ok(cnt) @@ -743,7 +745,7 @@ where cur_pos += I::from(increment); frontier += 1; } - *value = digest_branch::(children); + *value = digest_branch::(children)?; Ok(cnt) }, MerkleNode::Empty => { @@ -751,7 +753,7 @@ where let elem = data.next().unwrap(); let elem = elem.borrow(); *self = MerkleNode::Leaf { - value: H::digest_leaf(pos, elem), + value: H::digest_leaf(pos, elem)?, pos: pos.clone(), elem: elem.clone(), }; @@ -784,7 +786,7 @@ where frontier += 1; } *self = MerkleNode::Branch { - value: digest_branch::(&children), + value: digest_branch::(&children)?, children, }; Ok(cnt) @@ -823,7 +825,7 @@ where elem, } = &self.proof[0] { - let init = H::digest_leaf(pos, elem); + let init = H::digest_leaf(pos, elem)?; let computed_root = self .pos .to_traversal_path(self.tree_height() - 1) @@ -835,10 +837,12 @@ where match result { Ok(val) => match node { MerkleNode::Branch { value: _, children } => { - let mut data = - children.iter().map(|node| node.value()).collect_vec(); + let mut data = children + .iter() + .map(|node| node.value()) + .collect::>(); data[*branch] = val; - Ok(H::digest(&data)) + H::digest(&data) }, _ => Err(PrimitivesError::ParameterError( "Incompatible proof for this merkle tree".to_string(), @@ -882,10 +886,12 @@ where match result { Ok(val) => match node { MerkleNode::Branch { value: _, children } => { - let mut data = - children.iter().map(|node| node.value()).collect_vec(); + let mut data = children + .iter() + .map(|node| node.value()) + .collect::>(); data[*branch] = val; - Ok(H::digest(&data)) + H::digest(&data) }, MerkleNode::Empty => Ok(init), _ => Err(PrimitivesError::ParameterError( diff --git a/primitives/src/merkle_tree/macros.rs b/primitives/src/merkle_tree/macros.rs index 9f0c6e13a..9f122636c 100644 --- a/primitives/src/merkle_tree/macros.rs +++ b/primitives/src/merkle_tree/macros.rs @@ -167,7 +167,7 @@ macro_rules! impl_forgetable_merkle_tree_scheme { "Element does not match the proof.".to_string(), )); } - let proof_leaf_value = H::digest_leaf(pos, elem); + let proof_leaf_value = H::digest_leaf(pos, elem)?; let mut path_values = vec![proof_leaf_value]; traversal_path.iter().zip(proof.proof.iter().skip(1)).fold( Ok(proof_leaf_value), @@ -178,7 +178,7 @@ macro_rules! impl_forgetable_merkle_tree_scheme { let mut data: Vec<_> = children.iter().map(|node| node.value()).collect(); data[*branch] = val; - let digest = H::digest(&data); + let digest = H::digest(&data)?; path_values.push(digest); Ok(digest) }, diff --git a/primitives/src/merkle_tree/mod.rs b/primitives/src/merkle_tree/mod.rs index a322d433e..46b4d4e65 100644 --- a/primitives/src/merkle_tree/mod.rs +++ b/primitives/src/merkle_tree/mod.rs @@ -10,6 +10,7 @@ pub mod examples; pub mod hasher; pub mod light_weight; pub mod macros; +pub mod namespaced_merkle_tree; pub mod universal_merkle_tree; pub(crate) mod internal; @@ -129,10 +130,10 @@ where T: NodeValue, { /// Digest a list of values - fn digest(data: &[T]) -> T; + fn digest(data: &[T]) -> Result; /// Digest an indexed element - fn digest_leaf(pos: &I, elem: &E) -> T; + fn digest_leaf(pos: &I, elem: &E) -> Result; } /// An trait for Merkle tree index type. @@ -290,11 +291,13 @@ pub trait UniversalMerkleTreeScheme: MerkleTreeScheme { /// Update the leaf value at a given position /// * `pos` - zero-based index of the leaf in the tree /// * `elem` - newly updated element + /// * `returns` - Ok(elem) if the update is success, and `elem` is the + /// original element at the given `pos`. Err() if the update fails. fn update( &mut self, pos: impl Borrow, elem: impl Borrow, - ) -> LookupResult; + ) -> Result, PrimitivesError>; /// Returns the leaf value given a position /// * `pos` - zero-based index of the leaf in the tree diff --git a/primitives/src/merkle_tree/namespaced_merkle_tree.rs b/primitives/src/merkle_tree/namespaced_merkle_tree.rs new file mode 100644 index 000000000..e6bf4ba03 --- /dev/null +++ b/primitives/src/merkle_tree/namespaced_merkle_tree.rs @@ -0,0 +1,279 @@ +// Copyright (c) 2022 Espresso Systems (espressosys.com) +// This file is part of the Jellyfish library. + +// You should have received a copy of the MIT License +// along with the Jellyfish library. If not, see . + +//! Implementation of a Namespaced Merkle Tree. +use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; +use ark_std::{string::ToString, vec, vec::Vec}; +use core::{fmt::Debug, hash::Hash, marker::PhantomData}; + +use crate::errors::PrimitivesError; + +use super::{AppendableMerkleTreeScheme, DigestAlgorithm, Element, Index, NodeValue}; + +/// Namespaced Merkle Tree where leaves are sorted by a namespace identifier. +/// The data structure supports namespace inclusion proofs. +pub trait NamespacedMerkleTreeScheme: AppendableMerkleTreeScheme +where + Self::Element: Namespaced, +{ + /// Namespace proof type + type NamespaceProof: Clone; + /// Namespace type + type NamespaceId: Namespace; + + /// Returns the entire set of leaves corresponding to a given namespace and + /// a completeness proof. + fn get_namespace_leaves_and_proof( + &self, + namespace: Self::NamespaceId, + ) -> (Vec, Self::NamespaceProof); + + /// Verifies the completeness proof for a given set of leaves and a + /// namespace. + fn verify_namespace_proof( + leaves: &[Self::Element], + proof: Self::NamespaceProof, + namespace: Self::NamespaceId, + ) -> Result<(), PrimitivesError>; +} + +/// NamespacedHasher wraps a standard hash function (implementer of +/// DigestAlgorithm), turning it into a hash function that tags internal nodes +/// with namespace ranges. +pub struct NamespacedHasher +where + H: DigestAlgorithm, + E: Element + Namespaced, + N: Namespace, + I: Index, + T: NodeValue, +{ + phantom1: PhantomData, + phantom2: PhantomData, + phantom3: PhantomData, + phantom4: PhantomData, + phantom5: PhantomData, +} + +/// Trait indicating that a leaf has a namespace. +pub trait Namespaced { + /// Namespace type + type Namespace: Namespace; + /// Returns the namespace of the leaf + fn get_namespace(&self) -> Self::Namespace; +} + +/// Trait indicating that a digest algorithm can commit to +/// a namespace range. +pub trait BindNamespace: DigestAlgorithm +where + E: Element, + N: Namespace, + T: NodeValue, + I: Index, +{ + /// Generate a commitment that binds a node to a namespace range + fn generate_namespaced_commitment(namespaced_hash: NamespacedHash) -> T; +} + +/// Trait indiciating that a struct can act as an orderable namespace +pub trait Namespace: + Debug + Clone + CanonicalDeserialize + CanonicalSerialize + Default + Copy + Hash + Ord +{ + /// Returns the minimum possible namespace + fn min() -> Self; + /// Returns the maximum possible namespace + fn max() -> Self; +} + +impl Namespace for u64 { + fn min() -> u64 { + u64::MIN + } + fn max() -> u64 { + u64::MAX + } +} + +#[derive( + CanonicalSerialize, + CanonicalDeserialize, + Hash, + Copy, + Clone, + Debug, + Default, + Ord, + Eq, + PartialEq, + PartialOrd, +)] +/// Represents a namespaced internal tree node +pub struct NamespacedHash +where + N: Namespace, + T: NodeValue, +{ + min_namespace: N, + max_namespace: N, + hash: T, +} + +impl NamespacedHash +where + N: Namespace, + T: NodeValue, +{ + /// Constructs a new NamespacedHash + pub fn new(min_namespace: N, max_namespace: N, hash: T) -> Self { + Self { + min_namespace, + max_namespace, + hash, + } + } +} + +impl DigestAlgorithm> for NamespacedHasher +where + E: Element + Namespaced, + I: Index, + N: Namespace, + T: NodeValue, + H: DigestAlgorithm + BindNamespace, +{ + // Assumes that data is sorted by namespace, will be enforced by "append" + fn digest(data: &[NamespacedHash]) -> Result, PrimitivesError> { + if data.is_empty() { + return Ok(NamespacedHash::default()); + } + let first_node = data[0]; + let min_namespace = first_node.min_namespace; + let mut max_namespace = first_node.max_namespace; + let mut nodes = vec![H::generate_namespaced_commitment(first_node)]; + for node in &data[1..] { + // Ensure that namespaced nodes are sorted + if node.min_namespace < max_namespace { + return Err(PrimitivesError::InternalError( + "Namespace Merkle tree leaves are out of order".to_string(), + )); + } + max_namespace = node.max_namespace; + nodes.push(H::generate_namespaced_commitment(*node)); + } + + let inner_hash = H::digest(&nodes)?; + + Ok(NamespacedHash::new( + min_namespace, + max_namespace, + inner_hash, + )) + } + + fn digest_leaf(pos: &I, elem: &E) -> Result, PrimitivesError> { + let namespace = elem.get_namespace(); + let hash = H::digest_leaf(pos, elem)?; + Ok(NamespacedHash::new(namespace, namespace, hash)) + } +} + +#[cfg(test)] +mod nmt_tests { + use digest::Digest; + use sha3::Sha3_256; + + use super::*; + use crate::merkle_tree::examples::{Sha3Digest, Sha3Node}; + + type NamespaceId = u64; + type Hasher = NamespacedHasher; + + #[derive( + Default, + Eq, + PartialEq, + Hash, + Ord, + PartialOrd, + Copy, + Clone, + Debug, + CanonicalSerialize, + CanonicalDeserialize, + )] + struct Leaf { + namespace: NamespaceId, + } + + impl Leaf { + pub fn new(namespace: NamespaceId) -> Self { + Leaf { namespace } + } + } + + impl Namespaced for Leaf { + type Namespace = NamespaceId; + fn get_namespace(&self) -> NamespaceId { + self.namespace + } + } + + impl BindNamespace for Sha3Digest + where + E: Element + CanonicalSerialize, + I: Index, + N: Namespace, + { + // TODO ensure the hashing of (min,max,hash) is collision resistant + fn generate_namespaced_commitment( + namespaced_hash: NamespacedHash, + ) -> Sha3Node { + let mut hasher = Sha3_256::new(); + let mut writer = Vec::new(); + namespaced_hash + .min_namespace + .serialize_compressed(&mut writer) + .unwrap(); + namespaced_hash + .max_namespace + .serialize_compressed(&mut writer) + .unwrap(); + namespaced_hash + .hash + .serialize_compressed(&mut writer) + .unwrap(); + hasher.update(&mut writer); + Sha3Node(hasher.finalize().into()) + } + } + + #[test] + fn test_namespaced_hash() { + let num_leaves = 5; + let leaves: Vec = (0..num_leaves).map(Leaf::new).collect(); + + // Ensure that leaves are digested correctly + let mut hashes = leaves + .iter() + .enumerate() + .map(|(idx, leaf)| Hasher::digest_leaf(&(idx as u64), leaf)) + .collect::, PrimitivesError>>() + .unwrap(); + assert_eq!((hashes[0].min_namespace, hashes[0].max_namespace), (0, 0)); + + // Ensure that sorted internal nodes are digested correctly + let hash = Hasher::digest(&hashes).unwrap(); + assert_eq!( + (hash.min_namespace, hash.max_namespace), + (0, num_leaves - 1) + ); + + // Ensure that digest errors when internal nodes are not sorted by namespace + hashes[0] = hashes[hashes.len() - 1]; + assert!(Hasher::digest(&hashes).is_err()); + } +} diff --git a/primitives/src/merkle_tree/prelude.rs b/primitives/src/merkle_tree/prelude.rs index 365f24544..81601ffbc 100644 --- a/primitives/src/merkle_tree/prelude.rs +++ b/primitives/src/merkle_tree/prelude.rs @@ -16,7 +16,10 @@ pub use crate::{ }, }; -use crate::rescue::{sponge::RescueCRHF, RescueParameter}; +use crate::{ + errors::PrimitivesError, + rescue::{sponge::RescueCRHF, RescueParameter}, +}; use ark_std::marker::PhantomData; use num_bigint::BigUint; use typenum::U3; @@ -30,13 +33,13 @@ pub struct RescueHash { } impl DigestAlgorithm for RescueHash { - fn digest(data: &[F]) -> F { - RescueCRHF::::sponge_no_padding(data, 1).unwrap()[0] + fn digest(data: &[F]) -> Result { + Ok(RescueCRHF::::sponge_no_padding(data, 1)?[0]) } - fn digest_leaf(pos: &u64, elem: &F) -> F { + fn digest_leaf(pos: &u64, elem: &F) -> Result { let data = [F::zero(), F::from(*pos), *elem]; - RescueCRHF::::sponge_no_padding(&data, 1).unwrap()[0] + Ok(RescueCRHF::::sponge_no_padding(&data, 1)?[0]) } } @@ -47,24 +50,24 @@ pub type RescueMerkleTree = MerkleTree, u64, U3, F>; pub type RescueLightWeightMerkleTree = LightWeightMerkleTree, u64, U3, F>; impl DigestAlgorithm for RescueHash { - fn digest(data: &[F]) -> F { - RescueCRHF::::sponge_no_padding(data, 1).unwrap()[0] + fn digest(data: &[F]) -> Result { + Ok(RescueCRHF::::sponge_no_padding(data, 1)?[0]) } - fn digest_leaf(pos: &BigUint, elem: &F) -> F { + fn digest_leaf(pos: &BigUint, elem: &F) -> Result { let data = [F::zero(), F::from(pos.clone()), *elem]; - RescueCRHF::::sponge_no_padding(&data, 1).unwrap()[0] + Ok(RescueCRHF::::sponge_no_padding(&data, 1)?[0]) } } impl DigestAlgorithm for RescueHash { - fn digest(data: &[F]) -> F { - RescueCRHF::::sponge_no_padding(data, 1).unwrap()[0] + fn digest(data: &[F]) -> Result { + Ok(RescueCRHF::::sponge_no_padding(data, 1)?[0]) } - fn digest_leaf(pos: &F, elem: &F) -> F { + fn digest_leaf(pos: &F, elem: &F) -> Result { let data = [F::zero(), *pos, *elem]; - RescueCRHF::::sponge_no_padding(&data, 1).unwrap()[0] + Ok(RescueCRHF::::sponge_no_padding(&data, 1)?[0]) } } diff --git a/primitives/src/merkle_tree/universal_merkle_tree.rs b/primitives/src/merkle_tree/universal_merkle_tree.rs index d6bd05864..97ecb1cf4 100644 --- a/primitives/src/merkle_tree/universal_merkle_tree.rs +++ b/primitives/src/merkle_tree/universal_merkle_tree.rs @@ -38,17 +38,21 @@ where type NonMembershipProof = MerkleProof; type BatchNonMembershipProof = (); - fn update(&mut self, pos: impl Borrow, elem: impl Borrow) -> LookupResult { + fn update( + &mut self, + pos: impl Borrow, + elem: impl Borrow, + ) -> Result, PrimitivesError> { let pos = pos.borrow(); let elem = elem.borrow(); let traversal_path = pos.to_traversal_path(self.height); let ret = self .root - .update_internal::(self.height, pos, &traversal_path, elem); + .update_internal::(self.height, pos, &traversal_path, elem)?; if let LookupResult::NotFound(_) = ret { self.num_leaves += 1; } - ret + Ok(ret) } fn from_kv_set( @@ -62,7 +66,7 @@ where let mut mt = Self::from_elems(height, [] as [&Self::Element; 0])?; for tuple in data.into_iter() { let (key, value) = tuple.borrow(); - UniversalMerkleTreeScheme::update(&mut mt, key.borrow(), value.borrow()); + UniversalMerkleTreeScheme::update(&mut mt, key.borrow(), value.borrow())?; } Ok(mt) } @@ -138,14 +142,16 @@ where let mut path_values = vec![empty_value]; traversal_path.iter().zip(proof.proof.iter().skip(1)).fold( Ok(empty_value), - |result, (branch, node)| -> Result { + |result: Result, + (branch, node)| + -> Result { match result { Ok(val) => match node { MerkleNode::Branch { value: _, children } => { let mut data: Vec<_> = children.iter().map(|node| node.value()).collect(); data[*branch] = val; - let digest = H::digest(&data); + let digest = H::digest(&data)?; path_values.push(digest); Ok(digest) }, @@ -267,7 +273,7 @@ mod mt_tests { let mut mt = RescueSparseMerkleTree::::from_kv_set(10, HashMap::::new()).unwrap(); for i in 0..2 { - mt.update(F::from(i as u64), F::from(i as u64)); + mt.update(F::from(i as u64), F::from(i as u64)).unwrap(); } for i in 0..2 { let (val, proof) = mt.universal_lookup(F::from(i as u64)).expect_ok().unwrap(); diff --git a/relation/src/gadgets/emulated.rs b/relation/src/gadgets/emulated.rs index 19f60e8fd..a14ad0d72 100644 --- a/relation/src/gadgets/emulated.rs +++ b/relation/src/gadgets/emulated.rs @@ -175,8 +175,8 @@ impl PlonkCircuit { let mut val_carry_out = (a_limbs[0] * b_limbs[0] + k_limbs[0] * neg_modulus[0] - val_expected_limbs[0]) / b_pow; let mut carry_out = self.create_variable(val_carry_out)?; - // checking that the carry_out has at most [`E::B`] bits - self.enforce_in_range(carry_out, E::B)?; + // checking that the carry_out has at most [`E::B`] + 1 bits + self.enforce_in_range(carry_out, E::B + 1)?; // enforcing that a0 * b0 - k0 * modulus[0] - carry_out * 2^E::B = c0 self.general_arithmetic_gate( &[a.0[0], b.0[0], k.0[0], carry_out, c.0[0]], @@ -198,7 +198,9 @@ impl PlonkCircuit { let next_carry_out = self.create_variable(val_next_carry_out)?; // range checking for this carry out. - let num_vals = 2u64 * (i as u64) + 1; + // let a = 2^B - 1. The maximum possible value of `next_carry_out` is ((i + 1) * + // 2 * a^2 + a) / 2^B. + let num_vals = 2u64 * (i as u64) + 2; let log_num_vals = (u64::BITS - num_vals.leading_zeros()) as usize; self.enforce_in_range(next_carry_out, E::B + log_num_vals)?; @@ -320,7 +322,7 @@ impl PlonkCircuit { (a_limbs[0] * b_limbs[0] + k_limbs[0] * neg_modulus[0] - val_expected_limbs[0]) / b_pow; let mut carry_out = self.create_variable(val_carry_out)?; // checking that the carry_out has at most [`E::B`] bits - self.enforce_in_range(carry_out, E::B)?; + self.enforce_in_range(carry_out, E::B + 1)?; // enforcing that a0 * b0 - k0 * modulus[0] - carry_out * 2^E::B = c0 self.lc_gate( &[a.0[0], k.0[0], carry_out, self.zero(), c.0[0]], @@ -340,7 +342,7 @@ impl PlonkCircuit { let next_carry_out = self.create_variable(val_next_carry_out)?; // range checking for this carry out. - let num_vals = 2u64 * (i as u64) + 1; + let num_vals = 2u64 * (i as u64) + 2; let log_num_vals = (u64::BITS - num_vals.leading_zeros()) as usize; self.enforce_in_range(next_carry_out, E::B + log_num_vals)?; @@ -599,7 +601,7 @@ mod tests { use crate::{gadgets::from_emulated_field, Circuit, PlonkCircuit}; use ark_bls12_377::Fq as Fq377; use ark_bn254::{Fq as Fq254, Fr as Fr254}; - use ark_ff::PrimeField; + use ark_ff::{MontFp, PrimeField}; #[test] fn test_basics() { @@ -656,6 +658,15 @@ mod tests { fn test_emulated_mul() { test_emulated_mul_helper::(); test_emulated_mul_helper::(); + + // test for issue (https://github.com/EspressoSystems/jellyfish/issues/306) + let x : Fq377= MontFp!("218393408942992446968589193493746660101651787560689350338764189588519393175121782177906966561079408675464506489966"); + let y : Fq377 = MontFp!("122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179"); + + let mut circuit = PlonkCircuit::::new_turbo_plonk(); + let var_x = circuit.create_emulated_variable(x).unwrap(); + let _ = circuit.emulated_mul_constant(&var_x, y).unwrap(); + assert!(circuit.check_circuit_satisfiability(&[]).is_ok()); } fn test_emulated_mul_helper()