diff --git a/src/state_machine/coordinator/fire.rs b/src/state_machine/coordinator/fire.rs index c427c89c..b16bdb6c 100644 --- a/src/state_machine/coordinator/fire.rs +++ b/src/state_machine/coordinator/fire.rs @@ -65,8 +65,8 @@ pub struct Coordinator { } impl Coordinator { - /// Process the message inside the passed packet - pub fn process_timeout(&mut self) -> Result<(Option, Option), Error> { + /// Check the timeout + pub fn check_timeout(&mut self) -> Result<(Option, Option), Error> { let now = Instant::now(); match self.state.clone() { State::Idle => {} @@ -209,179 +209,6 @@ impl Coordinator { } Ok((None, None)) } - /// Process the message inside the passed packet - pub fn process_message( - &mut self, - packet: &Packet, - ) -> Result<(Option, Option), Error> { - loop { - match self.state.clone() { - State::Idle => { - // Did we receive a coordinator message? - if let Message::DkgBegin(dkg_begin) = &packet.msg { - if self.current_dkg_id >= dkg_begin.dkg_id { - // We have already processed this DKG round - return Ok((None, None)); - } - // Set the current sign id to one before the current message to ensure - // that we start the next round at the correct id. (Do this rather - // than overwriting afterwards to ensure logging is accurate) - self.current_dkg_id = dkg_begin.dkg_id.wrapping_sub(1); - let packet = self.start_dkg_round()?; - return Ok((Some(packet), None)); - } else if let Message::NonceRequest(nonce_request) = &packet.msg { - if self.current_sign_id >= nonce_request.sign_id { - // We have already processed this sign round - return Ok((None, None)); - } - // Set the current sign id to one before the current message to ensure - // that we start the next round at the correct id. (Do this rather - // than overwriting afterwards to ensure logging is accurate) - self.current_sign_id = nonce_request.sign_id.wrapping_sub(1); - self.current_sign_iter_id = nonce_request.sign_iter_id.wrapping_sub(1); - let packet = self.start_signing_round( - nonce_request.message.as_slice(), - nonce_request.signature_type, - )?; - return Ok((Some(packet), None)); - } - return Ok((None, None)); - } - State::DkgPublicDistribute => { - let packet = self.start_public_shares()?; - return Ok((Some(packet), None)); - } - State::DkgPublicGather => { - self.gather_public_shares(packet)?; - if self.state == State::DkgPublicGather { - // We need more data - return Ok((None, None)); - } - } - State::DkgPrivateDistribute => { - let packet = self.start_private_shares()?; - return Ok((Some(packet), None)); - } - State::DkgPrivateGather => { - self.gather_private_shares(packet)?; - if self.state == State::DkgPrivateGather { - // We need more data - return Ok((None, None)); - } - } - State::DkgEndDistribute => { - let packet = self.start_dkg_end()?; - return Ok((Some(packet), None)); - } - State::DkgEndGather => { - if let Err(error) = self.gather_dkg_end(packet) { - if let Error::DkgFailure(dkg_failures) = error { - return Ok(( - None, - Some(OperationResult::DkgError(DkgError::DkgEndFailure( - dkg_failures, - ))), - )); - } else { - return Err(error); - } - } - if self.state == State::DkgEndGather { - // We need more data - return Ok((None, None)); - } else if self.state == State::Idle { - // We are done with the DKG round! Return the operation result - return Ok(( - None, - Some(OperationResult::Dkg( - self.aggregate_public_key - .ok_or(Error::MissingAggregatePublicKey)?, - )), - )); - } - } - State::NonceRequest(signature_type) => { - let packet = self.request_nonces(signature_type)?; - return Ok((Some(packet), None)); - } - State::NonceGather(signature_type) => { - self.gather_nonces(packet, signature_type)?; - if self.state == State::NonceGather(signature_type) { - // We need more data - return Ok((None, None)); - } - } - State::SigShareRequest(signature_type) => { - let packet = self.request_sig_shares(signature_type)?; - return Ok((Some(packet), None)); - } - State::SigShareGather(signature_type) => { - if let Err(e) = self.gather_sig_shares(packet, signature_type) { - return Ok(( - None, - Some(OperationResult::SignError(SignError::Coordinator(e))), - )); - } - if self.state == State::SigShareGather(signature_type) { - // We need more data - return Ok((None, None)); - } else if self.state == State::Idle { - // We are done with the DKG round! Return the operation result - if let SignatureType::Taproot(_) = signature_type { - if let Some(schnorr_proof) = &self.schnorr_proof { - return Ok(( - None, - Some(OperationResult::SignTaproot(SchnorrProof { - r: schnorr_proof.r, - s: schnorr_proof.s, - })), - )); - } else { - return Ok(( - None, - Some(OperationResult::SignError(SignError::Coordinator( - Error::MissingSchnorrProof, - ))), - )); - } - } else if let SignatureType::Schnorr = signature_type { - if let Some(schnorr_proof) = &self.schnorr_proof { - return Ok(( - None, - Some(OperationResult::SignSchnorr(SchnorrProof { - r: schnorr_proof.r, - s: schnorr_proof.s, - })), - )); - } else { - return Ok(( - None, - Some(OperationResult::SignError(SignError::Coordinator( - Error::MissingSchnorrProof, - ))), - )); - } - } else if let Some(signature) = &self.signature { - return Ok(( - None, - Some(OperationResult::Sign(Signature { - R: signature.R, - z: signature.z, - })), - )); - } else { - return Ok(( - None, - Some(OperationResult::SignError(SignError::Coordinator( - Error::MissingSignature, - ))), - )); - } - } - } - } - } - } /// Ask signers to send DKG public shares pub fn start_public_shares(&mut self) -> Result { @@ -1180,34 +1007,6 @@ impl CoordinatorTrait for Coordinator { self.config.clone() } - /// Process inbound messages - fn process_inbound_messages( - &mut self, - packets: &[Packet], - ) -> Result<(Vec, Vec), Error> { - let mut outbound_packets = vec![]; - let mut operation_results = vec![]; - for packet in packets { - let (outbound_packet, operation_result) = self.process_message(packet)?; - if let Some(outbound_packet) = outbound_packet { - outbound_packets.push(outbound_packet); - } - if let Some(operation_result) = operation_result { - operation_results.push(operation_result); - } - } - - let (outbound_packet, operation_result) = self.process_timeout()?; - if let Some(outbound_packet) = outbound_packet { - outbound_packets.push(outbound_packet); - } - if let Some(operation_result) = operation_result { - operation_results.push(operation_result); - } - - Ok((outbound_packets, operation_results)) - } - /// Retrieve the aggregate public key fn get_aggregate_public_key(&self) -> Option { self.aggregate_public_key @@ -1228,6 +1027,188 @@ impl CoordinatorTrait for Coordinator { self.state.clone() } + /// Check timeout then process a message if passed one + fn process( + &mut self, + packet: Option<&Packet>, + ) -> Result<(Option, Option), Error> { + let (outbound_packet, operation_result) = self.check_timeout()?; + if outbound_packet.is_some() || operation_result.is_some() { + return Ok((outbound_packet, operation_result)); + } + + if let Some(packet) = packet { + loop { + match self.state.clone() { + State::Idle => { + // Did we receive a coordinator message? + if let Message::DkgBegin(dkg_begin) = &packet.msg { + if self.current_dkg_id >= dkg_begin.dkg_id { + // We have already processed this DKG round + return Ok((None, None)); + } + // Set the current sign id to one before the current message to ensure + // that we start the next round at the correct id. (Do this rather + // than overwriting afterwards to ensure logging is accurate) + self.current_dkg_id = dkg_begin.dkg_id.wrapping_sub(1); + let packet = self.start_dkg_round()?; + return Ok((Some(packet), None)); + } else if let Message::NonceRequest(nonce_request) = &packet.msg { + if self.current_sign_id >= nonce_request.sign_id { + // We have already processed this sign round + return Ok((None, None)); + } + // Set the current sign id to one before the current message to ensure + // that we start the next round at the correct id. (Do this rather + // than overwriting afterwards to ensure logging is accurate) + self.current_sign_id = nonce_request.sign_id.wrapping_sub(1); + self.current_sign_iter_id = nonce_request.sign_iter_id.wrapping_sub(1); + let packet = self.start_signing_round( + nonce_request.message.as_slice(), + nonce_request.signature_type, + )?; + return Ok((Some(packet), None)); + } + return Ok((None, None)); + } + State::DkgPublicDistribute => { + let packet = self.start_public_shares()?; + return Ok((Some(packet), None)); + } + State::DkgPublicGather => { + self.gather_public_shares(packet)?; + if self.state == State::DkgPublicGather { + // We need more data + return Ok((None, None)); + } + } + State::DkgPrivateDistribute => { + let packet = self.start_private_shares()?; + return Ok((Some(packet), None)); + } + State::DkgPrivateGather => { + self.gather_private_shares(packet)?; + if self.state == State::DkgPrivateGather { + // We need more data + return Ok((None, None)); + } + } + State::DkgEndDistribute => { + let packet = self.start_dkg_end()?; + return Ok((Some(packet), None)); + } + State::DkgEndGather => { + if let Err(error) = self.gather_dkg_end(packet) { + if let Error::DkgFailure(dkg_failures) = error { + return Ok(( + None, + Some(OperationResult::DkgError(DkgError::DkgEndFailure( + dkg_failures, + ))), + )); + } else { + return Err(error); + } + } + if self.state == State::DkgEndGather { + // We need more data + return Ok((None, None)); + } else if self.state == State::Idle { + // We are done with the DKG round! Return the operation result + return Ok(( + None, + Some(OperationResult::Dkg( + self.aggregate_public_key + .ok_or(Error::MissingAggregatePublicKey)?, + )), + )); + } + } + State::NonceRequest(signature_type) => { + let packet = self.request_nonces(signature_type)?; + return Ok((Some(packet), None)); + } + State::NonceGather(signature_type) => { + self.gather_nonces(packet, signature_type)?; + if self.state == State::NonceGather(signature_type) { + // We need more data + return Ok((None, None)); + } + } + State::SigShareRequest(signature_type) => { + let packet = self.request_sig_shares(signature_type)?; + return Ok((Some(packet), None)); + } + State::SigShareGather(signature_type) => { + if let Err(e) = self.gather_sig_shares(packet, signature_type) { + return Ok(( + None, + Some(OperationResult::SignError(SignError::Coordinator(e))), + )); + } + if self.state == State::SigShareGather(signature_type) { + // We need more data + return Ok((None, None)); + } else if self.state == State::Idle { + // We are done with the DKG round! Return the operation result + if let SignatureType::Taproot(_) = signature_type { + if let Some(schnorr_proof) = &self.schnorr_proof { + return Ok(( + None, + Some(OperationResult::SignTaproot(SchnorrProof { + r: schnorr_proof.r, + s: schnorr_proof.s, + })), + )); + } else { + return Ok(( + None, + Some(OperationResult::SignError(SignError::Coordinator( + Error::MissingSchnorrProof, + ))), + )); + } + } else if let SignatureType::Schnorr = signature_type { + if let Some(schnorr_proof) = &self.schnorr_proof { + return Ok(( + None, + Some(OperationResult::SignSchnorr(SchnorrProof { + r: schnorr_proof.r, + s: schnorr_proof.s, + })), + )); + } else { + return Ok(( + None, + Some(OperationResult::SignError(SignError::Coordinator( + Error::MissingSchnorrProof, + ))), + )); + } + } else if let Some(signature) = &self.signature { + return Ok(( + None, + Some(OperationResult::Sign(Signature { + R: signature.R, + z: signature.z, + })), + )); + } else { + return Ok(( + None, + Some(OperationResult::SignError(SignError::Coordinator( + Error::MissingSignature, + ))), + )); + } + } + } + } + } + } + Ok((None, None)) + } + /// Start a DKG round fn start_dkg_round(&mut self) -> Result { self.current_dkg_id = self.current_dkg_id.wrapping_add(1); @@ -1591,11 +1572,11 @@ pub mod test { let (outbound_messages, operation_results) = minimum_coordinators .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert_eq!(outbound_messages.len(), 1); - assert_eq!(operation_results.len(), 0); + assert!(outbound_messages.is_some()); + assert!(operation_results.is_none()); assert_eq!( minimum_coordinators.first().unwrap().state, State::DkgPrivateGather, @@ -1665,11 +1646,11 @@ pub mod test { let (outbound_messages, operation_results) = minimum_coordinators .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert_eq!(outbound_messages.len(), 1); - assert_eq!(operation_results.len(), 0); + assert!(outbound_messages.is_some()); + assert!(operation_results.is_none()); assert_eq!( minimum_coordinators.first().unwrap().state, State::DkgPrivateGather, @@ -1717,15 +1698,15 @@ pub mod test { // Sleep long enough to hit the timeout thread::sleep(expire); - let (outbound_messages, operation_results) = minimum_coordinator + let (outbound_message, operation_result) = minimum_coordinator .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert_eq!(outbound_messages.len(), 1); - assert_eq!(operation_results.len(), 0); - match &outbound_messages[0].msg { + assert!(outbound_message.is_some()); + assert!(operation_result.is_none()); + match outbound_message.unwrap().msg { Message::DkgEndBegin(_) => {} _ => { panic!("Expected DkgEndBegin message"); @@ -1815,19 +1796,19 @@ pub mod test { // Sleep long enough to hit the timeout thread::sleep(expire); - let (outbound_messages, operation_results) = insufficient_coordinators + let (outbound_message, operation_result) = insufficient_coordinators .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert!(outbound_messages.is_empty()); - assert_eq!(operation_results.len(), 1); + assert!(outbound_message.is_none()); + assert!(operation_result.is_some()); assert_eq!( insufficient_coordinators.first().unwrap().state, State::DkgPublicGather, ); - match &operation_results[0] { + match operation_result.unwrap() { OperationResult::DkgError(dkg_error) => match dkg_error { DkgError::DkgPublicTimeout(_) => {} _ => panic!("Expected DkgError::DkgPublicTimeout"), @@ -1881,14 +1862,14 @@ pub mod test { // Sleep long enough to hit the timeout thread::sleep(expire); - let (outbound_messages, operation_results) = insufficient_coordinator + let (outbound_message, operation_result) = insufficient_coordinator .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert!(outbound_messages.is_empty()); - assert_eq!(operation_results.len(), 1); + assert!(outbound_message.is_none()); + assert!(operation_result.is_some()); assert_eq!( insufficient_coordinator.first().unwrap().state, State::DkgPrivateGather, @@ -2495,18 +2476,18 @@ pub mod test { // Sleep long enough to hit the timeout thread::sleep(Duration::from_millis(256)); - let (outbound_messages, operation_results) = insufficient_coordinators + let (outbound_message, operation_result) = insufficient_coordinators .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert!(outbound_messages.is_empty()); - assert_eq!(operation_results.len(), 1); + assert!(outbound_message.is_none()); + assert!(operation_result.is_some()); for coordinator in &insufficient_coordinators { assert_eq!(coordinator.state, State::NonceGather(signature_type)); } - match &operation_results[0] { + match &operation_result.unwrap() { OperationResult::SignError(sign_error) => match sign_error { SignError::NonceTimeout(_, _) => {} _ => panic!("Expected SignError::NonceTimeout"), @@ -2563,14 +2544,14 @@ pub mod test { // Sleep long enough to hit the timeout thread::sleep(Duration::from_millis(256)); - let (outbound_messages, operation_results) = insufficient_coordinators + let (outbound_message, operation_result) = insufficient_coordinators .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert_eq!(outbound_messages.len(), 1); - assert_eq!(operation_results.len(), 0); + assert!(outbound_message.is_some()); + assert!(operation_result.is_none()); assert_eq!( insufficient_coordinators.first().unwrap().state, State::NonceGather(signature_type) @@ -2615,14 +2596,14 @@ pub mod test { // Sleep long enough to hit the timeout thread::sleep(Duration::from_millis(256)); - let (outbound_messages, operation_results) = insufficient_coordinators + let (outbound_message, operation_result) = insufficient_coordinators .first_mut() .unwrap() - .process_inbound_messages(&[]) + .process(None) .unwrap(); - assert_eq!(outbound_messages.len(), 0); - assert_eq!(operation_results.len(), 1); + assert!(outbound_message.is_none()); + assert!(operation_result.is_some()); assert_eq!( insufficient_coordinators.first_mut().unwrap().state, State::SigShareGather(signature_type) @@ -2753,31 +2734,31 @@ pub mod test { coordinator.current_sign_id = id; // Attempt to start an old DKG round let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + .process(Some(&Packet { sig: vec![], msg: Message::DkgBegin(DkgBegin { dkg_id: old_id }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packets.is_none()); + assert!(results.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_dkg_id, id); // Attempt to start the same DKG round let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + .process(Some(&Packet { sig: vec![], msg: Message::DkgBegin(DkgBegin { dkg_id: id }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packets.is_none()); + assert!(results.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_dkg_id, id); // Attempt to start an old Sign round let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + .process(Some(&Packet { sig: vec![], msg: Message::NonceRequest(NonceRequest { dkg_id: id, @@ -2786,16 +2767,16 @@ pub mod test { sign_iter_id: id, signature_type: SignatureType::Frost, }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packets.is_none()); + assert!(results.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_sign_id, id); // Attempt to start the same Sign round let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + .process(Some(&Packet { sig: vec![], msg: Message::NonceRequest(NonceRequest { dkg_id: id, @@ -2804,10 +2785,10 @@ pub mod test { sign_iter_id: id, signature_type: SignatureType::Frost, }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packets.is_none()); + assert!(results.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_sign_id, id); } diff --git a/src/state_machine/coordinator/frost.rs b/src/state_machine/coordinator/frost.rs index c4b6145c..fbdb4a2d 100644 --- a/src/state_machine/coordinator/frost.rs +++ b/src/state_machine/coordinator/frost.rs @@ -51,154 +51,6 @@ pub struct Coordinator { } impl Coordinator { - /// Process the message inside the passed packet - pub fn process_message( - &mut self, - packet: &Packet, - ) -> Result<(Option, Option), Error> { - loop { - match self.state.clone() { - State::Idle => { - // Did we receive a coordinator message? - if let Message::DkgBegin(dkg_begin) = &packet.msg { - if self.current_dkg_id >= dkg_begin.dkg_id { - // We have already processed this DKG round - return Ok((None, None)); - } - // Set the current sign id to one before the current message to ensure - // that we start the next round at the correct id. (Do this rather - // then overwriting afterwards to ensure logging is accurate) - self.current_dkg_id = dkg_begin.dkg_id.wrapping_sub(1); - let packet = self.start_dkg_round()?; - return Ok((Some(packet), None)); - } else if let Message::NonceRequest(nonce_request) = &packet.msg { - if self.current_sign_id >= nonce_request.sign_id { - // We have already processed this sign round - return Ok((None, None)); - } - // Set the current sign id to one before the current message to ensure - // that we start the next round at the correct id. (Do this rather - // then overwriting afterwards to ensure logging is accurate) - self.current_sign_id = nonce_request.sign_id.wrapping_sub(1); - self.current_sign_iter_id = nonce_request.sign_iter_id; - let packet = self.start_signing_round( - nonce_request.message.as_slice(), - nonce_request.signature_type, - )?; - return Ok((Some(packet), None)); - } - return Ok((None, None)); - } - State::DkgPublicDistribute => { - let packet = self.start_public_shares()?; - return Ok((Some(packet), None)); - } - State::DkgPublicGather => { - self.gather_public_shares(packet)?; - if self.state == State::DkgPublicGather { - // We need more data - return Ok((None, None)); - } - } - State::DkgPrivateDistribute => { - let packet = self.start_private_shares()?; - return Ok((Some(packet), None)); - } - State::DkgPrivateGather => { - self.gather_private_shares(packet)?; - if self.state == State::DkgPrivateGather { - // We need more data - return Ok((None, None)); - } - } - State::DkgEndDistribute => { - let packet = self.start_dkg_end()?; - return Ok((Some(packet), None)); - } - State::DkgEndGather => { - self.gather_dkg_end(packet)?; - if self.state == State::DkgEndGather { - // We need more data - return Ok((None, None)); - } else if self.state == State::Idle { - // We are done with the DKG round! Return the operation result - return Ok(( - None, - Some(OperationResult::Dkg( - self.aggregate_public_key - .ok_or(Error::MissingAggregatePublicKey)?, - )), - )); - } - } - State::NonceRequest(signature_type) => { - let packet = self.request_nonces(signature_type)?; - return Ok((Some(packet), None)); - } - State::NonceGather(signature_type) => { - self.gather_nonces(packet, signature_type)?; - if self.state == State::NonceGather(signature_type) { - // We need more data - return Ok((None, None)); - } - } - State::SigShareRequest(signature_type) => { - let packet = self.request_sig_shares(signature_type)?; - return Ok((Some(packet), None)); - } - State::SigShareGather(signature_type) => { - if let Err(e) = self.gather_sig_shares(packet, signature_type) { - return Ok(( - None, - Some(OperationResult::SignError(SignError::Coordinator(e))), - )); - } - if self.state == State::SigShareGather(signature_type) { - // We need more data - return Ok((None, None)); - } else if self.state == State::Idle { - // We are done with the DKG round! Return the operation result - if let SignatureType::Taproot(_) = signature_type { - let schnorr_proof = self - .schnorr_proof - .as_ref() - .ok_or(Error::MissingSchnorrProof)?; - return Ok(( - None, - Some(OperationResult::SignTaproot(SchnorrProof { - r: schnorr_proof.r, - s: schnorr_proof.s, - })), - )); - } else if let SignatureType::Schnorr = signature_type { - let schnorr_proof = self - .schnorr_proof - .as_ref() - .ok_or(Error::MissingSchnorrProof)?; - return Ok(( - None, - Some(OperationResult::SignSchnorr(SchnorrProof { - r: schnorr_proof.r, - s: schnorr_proof.s, - })), - )); - } else { - let signature = - self.signature.as_ref().ok_or(Error::MissingSignature)?; - return Ok(( - None, - Some(OperationResult::Sign(Signature { - R: signature.R, - z: signature.z, - })), - )); - } - } - } - } - } - } - /// Ask signers to send DKG public shares pub fn start_public_shares(&mut self) -> Result { self.dkg_public_shares.clear(); @@ -710,23 +562,155 @@ impl CoordinatorTrait for Coordinator { self.config.clone() } - /// Process inbound messages - fn process_inbound_messages( + /// Process the message inside the passed packet + fn process( &mut self, - packets: &[Packet], - ) -> Result<(Vec, Vec), Error> { - let mut outbound_packets = vec![]; - let mut operation_results = vec![]; - for packet in packets { - let (outbound_packet, operation_result) = self.process_message(packet)?; - if let Some(outbound_packet) = outbound_packet { - outbound_packets.push(outbound_packet); - } - if let Some(operation_result) = operation_result { - operation_results.push(operation_result); + packet: Option<&Packet>, + ) -> Result<(Option, Option), Error> { + if let Some(packet) = packet { + loop { + match self.state.clone() { + State::Idle => { + // Did we receive a coordinator message? + if let Message::DkgBegin(dkg_begin) = &packet.msg { + if self.current_dkg_id >= dkg_begin.dkg_id { + // We have already processed this DKG round + return Ok((None, None)); + } + // Set the current sign id to one before the current message to ensure + // that we start the next round at the correct id. (Do this rather + // then overwriting afterwards to ensure logging is accurate) + self.current_dkg_id = dkg_begin.dkg_id.wrapping_sub(1); + let packet = self.start_dkg_round()?; + return Ok((Some(packet), None)); + } else if let Message::NonceRequest(nonce_request) = &packet.msg { + if self.current_sign_id >= nonce_request.sign_id { + // We have already processed this sign round + return Ok((None, None)); + } + // Set the current sign id to one before the current message to ensure + // that we start the next round at the correct id. (Do this rather + // then overwriting afterwards to ensure logging is accurate) + self.current_sign_id = nonce_request.sign_id.wrapping_sub(1); + self.current_sign_iter_id = nonce_request.sign_iter_id; + let packet = self.start_signing_round( + nonce_request.message.as_slice(), + nonce_request.signature_type, + )?; + return Ok((Some(packet), None)); + } + return Ok((None, None)); + } + State::DkgPublicDistribute => { + let packet = self.start_public_shares()?; + return Ok((Some(packet), None)); + } + State::DkgPublicGather => { + self.gather_public_shares(packet)?; + if self.state == State::DkgPublicGather { + // We need more data + return Ok((None, None)); + } + } + State::DkgPrivateDistribute => { + let packet = self.start_private_shares()?; + return Ok((Some(packet), None)); + } + State::DkgPrivateGather => { + self.gather_private_shares(packet)?; + if self.state == State::DkgPrivateGather { + // We need more data + return Ok((None, None)); + } + } + State::DkgEndDistribute => { + let packet = self.start_dkg_end()?; + return Ok((Some(packet), None)); + } + State::DkgEndGather => { + self.gather_dkg_end(packet)?; + if self.state == State::DkgEndGather { + // We need more data + return Ok((None, None)); + } else if self.state == State::Idle { + // We are done with the DKG round! Return the operation result + return Ok(( + None, + Some(OperationResult::Dkg( + self.aggregate_public_key + .ok_or(Error::MissingAggregatePublicKey)?, + )), + )); + } + } + State::NonceRequest(signature_type) => { + let packet = self.request_nonces(signature_type)?; + return Ok((Some(packet), None)); + } + State::NonceGather(signature_type) => { + self.gather_nonces(packet, signature_type)?; + if self.state == State::NonceGather(signature_type) { + // We need more data + return Ok((None, None)); + } + } + State::SigShareRequest(signature_type) => { + let packet = self.request_sig_shares(signature_type)?; + return Ok((Some(packet), None)); + } + State::SigShareGather(signature_type) => { + if let Err(e) = self.gather_sig_shares(packet, signature_type) { + return Ok(( + None, + Some(OperationResult::SignError(SignError::Coordinator(e))), + )); + } + if self.state == State::SigShareGather(signature_type) { + // We need more data + return Ok((None, None)); + } else if self.state == State::Idle { + // We are done with the DKG round! Return the operation result + if let SignatureType::Taproot(_) = signature_type { + let schnorr_proof = self + .schnorr_proof + .as_ref() + .ok_or(Error::MissingSchnorrProof)?; + return Ok(( + None, + Some(OperationResult::SignTaproot(SchnorrProof { + r: schnorr_proof.r, + s: schnorr_proof.s, + })), + )); + } else if let SignatureType::Schnorr = signature_type { + let schnorr_proof = self + .schnorr_proof + .as_ref() + .ok_or(Error::MissingSchnorrProof)?; + return Ok(( + None, + Some(OperationResult::SignSchnorr(SchnorrProof { + r: schnorr_proof.r, + s: schnorr_proof.s, + })), + )); + } else { + let signature = + self.signature.as_ref().ok_or(Error::MissingSignature)?; + return Ok(( + None, + Some(OperationResult::Sign(Signature { + R: signature.R, + z: signature.z, + })), + )); + } + } + } + } } } - Ok((outbound_packets, operation_results)) + Ok((None, None)) } /// Retrieve the aggregate public key @@ -981,32 +965,32 @@ pub mod test { coordinator.current_dkg_id = id; coordinator.current_sign_id = id; // Attempt to start an old DKG round - let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + let (packet, result) = coordinator + .process(Some(&Packet { sig: vec![], msg: Message::DkgBegin(DkgBegin { dkg_id: old_id }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packet.is_none()); + assert!(result.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_dkg_id, id); // Attempt to start the same DKG round - let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + let (packet, result) = coordinator + .process(Some(&Packet { sig: vec![], msg: Message::DkgBegin(DkgBegin { dkg_id: id }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packet.is_none()); + assert!(result.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_dkg_id, id); // Attempt to start an old Sign round - let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + let (packet, result) = coordinator + .process(Some(&Packet { sig: vec![], msg: Message::NonceRequest(NonceRequest { dkg_id: id, @@ -1015,16 +999,16 @@ pub mod test { sign_iter_id: id, signature_type: SignatureType::Frost, }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packet.is_none()); + assert!(result.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_sign_id, id); // Attempt to start the same Sign round - let (packets, results) = coordinator - .process_inbound_messages(&[Packet { + let (packet, result) = coordinator + .process(Some(&Packet { sig: vec![], msg: Message::NonceRequest(NonceRequest { dkg_id: id, @@ -1033,10 +1017,10 @@ pub mod test { sign_iter_id: id, signature_type: SignatureType::Frost, }), - }]) + })) .unwrap(); - assert!(packets.is_empty()); - assert!(results.is_empty()); + assert!(packet.is_none()); + assert!(result.is_none()); assert_eq!(coordinator.state, State::Idle); assert_eq!(coordinator.current_sign_id, id); } diff --git a/src/state_machine/coordinator/mod.rs b/src/state_machine/coordinator/mod.rs index 11f1b8d4..783d0634 100644 --- a/src/state_machine/coordinator/mod.rs +++ b/src/state_machine/coordinator/mod.rs @@ -261,11 +261,11 @@ pub trait Coordinator: Clone + Debug + PartialEq { /// Retrieve the config fn get_config(&self) -> Config; - /// Process inbound messages - fn process_inbound_messages( + /// Check for timeout and maybe process a single message + fn process( &mut self, - packets: &[Packet], - ) -> Result<(Vec, Vec), Error>; + packets: Option<&Packet>, + ) -> Result<(Option, Option), Error>; /// Retrieve the aggregate public key fn get_aggregate_public_key(&self) -> Option; @@ -568,18 +568,21 @@ pub mod test { } for coordinator in coordinators.iter_mut() { // Process all coordinator messages, but don't bother with propogating these results - let _ = coordinator.process_inbound_messages(messages).unwrap(); + for message in messages { + let _ = coordinator.process(Some(&message)).unwrap(); + } } let mut results = vec![]; let mut messages = vec![]; for (i, coordinator) in coordinators.iter_mut().enumerate() { - let (outbound_messages, outbound_results) = coordinator - .process_inbound_messages(&inbound_messages) - .unwrap(); - // Only propogate a single coordinator's messages and results - if i == 0 { - messages.extend(outbound_messages); - results.extend(outbound_results); + for inbound_message in &inbound_messages { + let (outbound_message, outbound_result) = + coordinator.process(Some(&inbound_message)).unwrap(); + // Only propogate a single coordinator's messages and results + if i == 0 { + messages.extend(outbound_message); + results.extend(outbound_result); + } } } (messages, results)