diff --git a/proto/src/signatures.rs b/proto/src/signatures.rs index 7db86715..0c40ee95 100644 --- a/proto/src/signatures.rs +++ b/proto/src/signatures.rs @@ -1,16 +1,22 @@ +//! Digital signature processing use alloc::{string::String, vec::Vec}; use bytes::BufMut; use prost::Message; use crate::{ - types::{BlockId, CanonicalBlockId, Commit, SignedMsgType, StateId, Vote}, + types::{ + BlockId, CanonicalBlockId, CanonicalVoteExtension, Commit, SignedMsgType, StateId, Vote, + VoteExtension, VoteExtensionType, + }, Error, }; #[derive(Clone, Debug)] pub struct SignContext { pub chain_id: String, + height: i64, + round: i32, } impl SignContext {} @@ -59,6 +65,12 @@ impl SignBytes for BlockId { impl SignBytes for Vote { fn sign_bytes(&self, ctx: &SignContext) -> Result, Error> { + if ctx.height != self.height || ctx.round != self.round { + return Err(Error::create_canonical(String::from( + "vote height/round mismatch", + ))); + } + let block_id = self .block_id .clone() @@ -66,13 +78,18 @@ impl SignBytes for Vote { "missing vote.block id", )))?; - vote_sign_bytes(ctx, block_id, self.r#type, self.height, self.round as i64) + vote_sign_bytes(ctx, block_id, self.r#type()) } } impl SignBytes for Commit { fn sign_bytes(&self, ctx: &SignContext) -> Result, Error> { - // we just use some rough guesstimate of intial capacity + if ctx.height != self.height || ctx.round != self.round { + return Err(Error::create_canonical(String::from( + "commit height/round mismatch", + ))); + } + let block_id = self .block_id .clone() @@ -80,13 +97,26 @@ impl SignBytes for Commit { "missing vote.block id", )))?; - vote_sign_bytes( - ctx, - block_id, - SignedMsgType::Precommit.into(), - self.height, - self.round as i64, - ) + vote_sign_bytes(ctx, block_id, SignedMsgType::Precommit) + } +} + +impl SignBytes for VoteExtension { + fn sign_bytes(&self, ctx: &SignContext) -> Result, Error> { + if self.r#type() != VoteExtensionType::ThresholdRecover { + return Err(Error::create_canonical(String::from( + "only ThresholdRecover vote extensions can be signed", + ))); + } + let ve = CanonicalVoteExtension { + chain_id: ctx.chain_id.clone(), + extension: self.extension.clone(), + height: ctx.height, + round: ctx.round as i64, + r#type: self.r#type, + }; + + Ok(ve.encode_length_delimited_to_vec()) } } @@ -96,9 +126,7 @@ impl SignBytes for Commit { fn vote_sign_bytes( ctx: &SignContext, block_id: BlockId, - vote_type: i32, - height: i64, - round: i64, + vote_type: SignedMsgType, ) -> Result, Error> { // we just use some rough guesstimate of intial capacity let mut buf = Vec::with_capacity(80); @@ -106,9 +134,9 @@ fn vote_sign_bytes( let state_id = block_id.state_id.clone(); let block_id = block_id.sha256(ctx)?; - buf.put_i32_le(vote_type); - buf.put_i64_le(height); - buf.put_i64_le(round); + buf.put_i32_le(vote_type.into()); + buf.put_i64_le(ctx.height); + buf.put_i64_le(ctx.round as i64); buf.extend(block_id); buf.extend(state_id); @@ -119,10 +147,13 @@ fn vote_sign_bytes( #[cfg(test)] pub mod tests { - use alloc::string::ToString; + use alloc::{string::ToString, vec::Vec}; use super::SignBytes; - use crate::types::{Commit, PartSetHeader, SignedMsgType, Vote}; + use crate::{ + signatures::SignContext, + types::{Commit, PartSetHeader, SignedMsgType, Vote, VoteExtension, VoteExtensionType}, + }; #[test] /// Compare sign bytes for Vote with sign bytes generated by Tenderdash and @@ -154,6 +185,8 @@ pub mod tests { }; let ctx = super::SignContext { chain_id: "some-chain".to_string(), + height: vote.height, + round: vote.round, }; let actual = vote.sign_bytes(&ctx).unwrap(); @@ -188,10 +221,36 @@ pub mod tests { }; let ctx = super::SignContext { chain_id: "some-chain".to_string(), + height: commit.height, + round: commit.round, }; let actual = commit.sign_bytes(&ctx).unwrap(); assert_eq!(expect_sign_bytes, actual); } + + #[test] + fn vote_extension_sign_bytes() { + let ve = VoteExtension { + extension: Vec::from([1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]), + r#type: VoteExtensionType::ThresholdRecover.into(), + signature: Default::default(), + }; + + let ctx = SignContext { + chain_id: "some-chain".to_string(), + height: 1, + round: 2, + }; + + let expect_sign_bytes = hex::decode( + "2a0a080102030405060708110100000000000000190200000000000000220a736f6d652d636861696e2801", + ) + .unwrap(); + + let actual = ve.sign_bytes(&ctx).unwrap(); + + assert_eq!(expect_sign_bytes, actual); + } }