diff --git a/abci/src/signatures.rs b/abci/src/signatures.rs index b55400c..97e050a 100644 --- a/abci/src/signatures.rs +++ b/abci/src/signatures.rs @@ -318,20 +318,24 @@ impl SignBytes for CanonicalVote { impl SignBytes for VoteExtension { fn sign_bytes(&self, chain_id: &str, height: i64, round: i32) -> Result, Error> { - if self.r#type() != VoteExtensionType::ThresholdRecover { - return Err(Error::Canonical(String::from( - "only ThresholdRecover vote extensions can be signed", - ))); + match self.r#type() { + VoteExtensionType::ThresholdRecover => { + let ve = CanonicalVoteExtension { + chain_id: chain_id.to_string(), + extension: self.extension.clone(), + height, + round: round as i64, + r#type: self.r#type, + }; + + Ok(ve.encode_length_delimited_to_vec()) + }, + VoteExtensionType::ThresholdRecoverRaw => Ok(self.extension.to_vec()), + _ => Err(Error::Canonical(format!( + "unimplemented: vote extension of type {:?} cannot be signed", + self.r#type() + ))), } - let ve = CanonicalVoteExtension { - chain_id: chain_id.to_string(), - extension: self.extension.clone(), - height, - round: round as i64, - r#type: self.r#type, - }; - - Ok(ve.encode_length_delimited_to_vec()) } } @@ -416,11 +420,12 @@ pub mod tests { } #[test] - fn vote_extension_sign_bytes() { + fn vote_extension_threshold_sign_bytes() { let ve = VoteExtension { extension: Vec::from([1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]), r#type: VoteExtensionType::ThresholdRecover.into(), signature: Default::default(), + sign_request_id: None, }; let chain_id = "some-chain".to_string(); @@ -437,6 +442,27 @@ pub mod tests { assert_eq!(expect_sign_bytes, actual); } + #[test] + fn vote_extension_threshold_raw_sign_bytes() { + const EXTENSION: &[u8] = &[1u8, 2u8, 3u8, 4u8, 5u8, 6u8, 7u8, 8u8]; + let ve = VoteExtension { + extension: EXTENSION.to_vec(), + r#type: VoteExtensionType::ThresholdRecoverRaw.into(), + signature: Default::default(), + sign_request_id: Some("dpe-sign-request-id".as_bytes().to_vec()), + }; + + let chain_id = "some-chain".to_string(); + let height = 1; + let round = 2; + + let expect_sign_bytes = EXTENSION.to_vec(); + + let actual = ve.sign_bytes(&chain_id, height, round).unwrap(); + + assert_eq!(expect_sign_bytes, actual); + } + #[test] fn test_sign_digest() { let quorum_hash: [u8; 32] = @@ -459,4 +485,32 @@ pub mod tests { let sign_id = super::sign_digest(100, &quorum_hash, request_id, &sign_bytes_hash); assert_eq!(expect_sign_id, sign_id); // 194,4 } + + #[test] + fn test_raw_extension_sign_digest() { + const QUORUM_TYPE: u8 = 106; + + let quorum_hash: [u8; 32] = + hex::decode("dddabfe1c883dd8a2c71c4281a4212c3715a61f87d62a99aaed0f65a0506c053") + .unwrap() + .try_into() + .unwrap(); + + let request_id = + hex::decode("922a8fc39b6e265ca761eaaf863387a5e2019f4795a42260805f5562699fd9fa") + .unwrap(); + let request_id = request_id[..].try_into().unwrap(); + + let sign_bytes_hash = + hex::decode("7dfb2432d37f004c4eb2b9aebf601ba4ad59889b81d2e8c7029dce3e0bf8381c") + .unwrap(); + + let mut expect_sign_id = + hex::decode("6d98f773cef8484432c4946c6b96e04aab39fd119c77de2f21d668dd17d5d2f6") + .unwrap(); + expect_sign_id.reverse(); + + let sign_id = super::sign_digest(QUORUM_TYPE, &quorum_hash, request_id, &sign_bytes_hash); + assert_eq!(expect_sign_id, sign_id); + } } diff --git a/abci/tests/kvstore.rs b/abci/tests/kvstore.rs index a243363..bcd4109 100644 --- a/abci/tests/kvstore.rs +++ b/abci/tests/kvstore.rs @@ -320,6 +320,7 @@ impl Application for KVStoreABCI<'_> { vote_extensions: vec![proto::abci::ExtendVoteExtension { r#type: proto::types::VoteExtensionType::ThresholdRecover as i32, extension: height, + sign_request_id: None, }], }) }