From c8b1fd464399c8dd195f5261f5547280ab0a1a43 Mon Sep 17 00:00:00 2001 From: Joey Yandle Date: Tue, 9 Jan 2024 08:26:57 -0500 Subject: [PATCH] calculate dkg sizes in helper fn; calculate aggregate public key from signers who sent DkgPublicShares and also DkgEnd; check dkg end size if timeout and continue if threshold met --- src/state_machine/coordinator/fire.rs | 95 +++++++++++++++++++-------- 1 file changed, 68 insertions(+), 27 deletions(-) diff --git a/src/state_machine/coordinator/fire.rs b/src/state_machine/coordinator/fire.rs index 0476ffa6..7b8458b9 100644 --- a/src/state_machine/coordinator/fire.rs +++ b/src/state_machine/coordinator/fire.rs @@ -7,8 +7,8 @@ use crate::{ compute, curve::point::Point, net::{ - DkgBegin, DkgPrivateBegin, DkgPublicShares, Message, NonceRequest, NonceResponse, Packet, - Signable, SignatureShareRequest, + DkgBegin, DkgEnd, DkgPrivateBegin, DkgPublicShares, Message, NonceRequest, NonceResponse, + Packet, Signable, SignatureShareRequest, }, state_machine::{ coordinator::{Config, Coordinator as CoordinatorTrait, Error, State}, @@ -30,6 +30,7 @@ pub struct Coordinator { /// current signing iteration ID current_sign_iter_id: u64, dkg_public_shares: BTreeMap, + dkg_end_messages: BTreeMap, party_polynomials: BTreeMap, public_nonces: BTreeMap, signature_shares: BTreeMap>, @@ -70,11 +71,8 @@ impl Coordinator { if let Some(timeout) = self.config.dkg_public_timeout { if now.duration_since(start) > timeout { // check dkg_threshold to determine if we can continue - let dkg_size = self - .dkg_public_shares - .keys() - .map(|signer_id| self.config.signer_key_ids[signer_id].len() as u32) - .sum(); + let dkg_size = self.compute_dkg_public_size(); + if self.config.dkg_threshold > dkg_size { error!("Timeout gathering DkgPublicShares for dkg round {} signing round {} iteration {}, dkg_threshold not met ({}/{}), unable to continue", self.current_dkg_id, self.current_sign_id, self.current_sign_iter_id, dkg_size, self.config.dkg_threshold); let wait = self.dkg_wait_signer_ids.iter().copied().collect(); @@ -100,12 +98,26 @@ impl Coordinator { if let Some(start) = self.dkg_end_start { if let Some(timeout) = self.config.dkg_end_timeout { if now.duration_since(start) > timeout { - error!("Timeout gathering DkgEnd for dkg round {} signing round {} iteration {}, unable to continue", self.current_dkg_id, self.current_sign_id, self.current_sign_iter_id); - let wait = self.dkg_wait_signer_ids.iter().copied().collect(); - return Ok(( - None, - Some(OperationResult::DkgError(DkgError::DkgEndTimeout(wait))), - )); + let dkg_size = self.compute_dkg_end_size(); + + if self.config.dkg_threshold > dkg_size { + error!("Timeout gathering DkgEnd for dkg round {} signing round {} iteration {}, unable to continue", self.current_dkg_id, self.current_sign_id, self.current_sign_iter_id); + let wait = self.dkg_wait_signer_ids.iter().copied().collect(); + return Ok(( + None, + Some(OperationResult::DkgError(DkgError::DkgEndTimeout(wait))), + )); + } else { + warn!("Timeout gathering DkgEnd for dkg round {} signing round {} iteration {}, dkg_threshold was met ({}/{}), ", self.current_dkg_id, self.current_sign_id, self.current_sign_iter_id, dkg_size, self.config.dkg_threshold); + self.dkg_end_gathered()?; + return Ok(( + None, + Some(OperationResult::Dkg( + self.aggregate_public_key + .ok_or(Error::MissingAggregatePublicKey)?, + )), + )); + } } } } @@ -369,14 +381,6 @@ impl Coordinator { } fn public_shares_gathered(&mut self) -> Result<(), Error> { - // Calculate the aggregate public key - let key = self - .party_polynomials - .iter() - .fold(Point::default(), |s, (_, comm)| s + comm.poly[0]); - - info!("Aggregate public key: {}", key); - self.aggregate_public_key = Some(key); self.move_to(State::DkgPrivateDistribute)?; Ok(()) } @@ -390,19 +394,41 @@ impl Coordinator { if dkg_end.dkg_id != self.current_dkg_id { return Err(Error::BadDkgId(dkg_end.dkg_id, self.current_dkg_id)); } - self.dkg_wait_signer_ids.remove(&dkg_end.signer_id); - debug!( - "DKG_End round {} from signer {}. Waiting on {:?}", - dkg_end.dkg_id, dkg_end.signer_id, self.dkg_wait_signer_ids - ); + if self.dkg_wait_signer_ids.contains(&dkg_end.signer_id) { + self.dkg_wait_signer_ids.remove(&dkg_end.signer_id); + self.dkg_end_messages + .insert(dkg_end.signer_id, dkg_end.clone()); + debug!( + "DKG_End round {} from signer {}. Waiting on {:?}", + dkg_end.dkg_id, dkg_end.signer_id, self.dkg_wait_signer_ids + ); + } else { + warn!( + "Got DkgEnd from signer {} who we weren't waiting on", + &dkg_end.signer_id + ); + } } if self.dkg_wait_signer_ids.is_empty() { - self.move_to(State::Idle)?; + self.dkg_end_gathered()?; } Ok(()) } + fn dkg_end_gathered(&mut self) -> Result<(), Error> { + // Calculate the aggregate public key + let key = self + .dkg_end_messages + .keys() + .flat_map(|signer_id| self.dkg_public_shares[signer_id].comms.clone()) + .fold(Point::default(), |s, (_, comm)| s + comm.poly[0]); + + info!("Aggregate public key: {}", key); + self.aggregate_public_key = Some(key); + self.move_to(State::Idle) + } + fn request_nonces( &mut self, is_taproot: bool, @@ -663,6 +689,20 @@ impl Coordinator { R } + + fn compute_dkg_public_size(&self) -> u32 { + self.dkg_public_shares + .keys() + .map(|signer_id| self.config.signer_key_ids[signer_id].len() as u32) + .sum() + } + + fn compute_dkg_end_size(&self) -> u32 { + self.dkg_end_messages + .keys() + .map(|signer_id| self.config.signer_key_ids[signer_id].len() as u32) + .sum() + } } impl StateMachine for Coordinator { @@ -725,6 +765,7 @@ impl CoordinatorTrait for Coordinator { current_sign_id: 0, current_sign_iter_id: 0, dkg_public_shares: Default::default(), + dkg_end_messages: Default::default(), party_polynomials: Default::default(), public_nonces: Default::default(), signature_shares: Default::default(),