From e0d560d8c3249d0e89eb022dde037792079da3fa Mon Sep 17 00:00:00 2001 From: Shridhar Panigrahi Date: Sun, 12 Apr 2026 00:39:01 +0530 Subject: [PATCH] Implement key share refresh protocol for CGGMP24 Add key refresh protocol that updates ECDSA secret shares without changing the underlying shared secret key. Implements both non-threshold (Figure 7) and threshold (Appendix F.1.1) variants from the CGGMP24 paper, focusing only on share updates (aux data refresh is separate). Signed-off-by: Shridhar Panigrahi --- .../src/key_refresh_non_threshold.rs | 455 +++++++++++++++++ cggmp24-keygen/src/key_refresh_threshold.rs | 467 ++++++++++++++++++ cggmp24-keygen/src/lib.rs | 276 +++++++++++ cggmp24/src/lib.rs | 12 +- tests/tests/it/main.rs | 1 + tests/tests/it/share_refresh.rs | 199 ++++++++ 6 files changed, 1409 insertions(+), 1 deletion(-) create mode 100644 cggmp24-keygen/src/key_refresh_non_threshold.rs create mode 100644 cggmp24-keygen/src/key_refresh_threshold.rs create mode 100644 tests/tests/it/share_refresh.rs diff --git a/cggmp24-keygen/src/key_refresh_non_threshold.rs b/cggmp24-keygen/src/key_refresh_non_threshold.rs new file mode 100644 index 00000000..ded435c8 --- /dev/null +++ b/cggmp24-keygen/src/key_refresh_non_threshold.rs @@ -0,0 +1,455 @@ +use alloc::vec::Vec; + +use digest::Digest; +use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar}; +use generic_ec_zkp::schnorr_pok; +use rand_core::{CryptoRng, RngCore}; +use round_based::{ + rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty, + Outgoing, ProtocolMessage, SinkExt, +}; +use serde::{Deserialize, Serialize}; + +use crate::progress::Tracer; +use crate::{ + errors::IoError, + key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate}, + security_level::SecurityLevel, + utils, ExecutionId, +}; + +use super::{Bug, KeyRefreshAborted, KeyRefreshError}; + +macro_rules! prefixed { + ($name:tt) => { + concat!("dfns.cggmp24.key_refresh.non_threshold.", $name) + }; +} + +/// Message of non-threshold key refresh protocol +#[derive(ProtocolMessage, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub enum Msg { + /// Round 1 message + Round1(MsgRound1), + /// Reliability check message (optional additional round) + ReliabilityCheck(MsgReliabilityCheck), + /// Round 2 broadcast message + Round2Broad(MsgRound2Broad), + /// Round 2 unicast message + Round2Uni(MsgRound2Uni), + /// Round 3 message + Round3(MsgRound3), +} + +/// Message from round 1 +#[derive(Clone, Serialize, Deserialize, udigest::Digestable)] +#[serde(bound = "")] +#[udigest(bound = "")] +#[udigest(tag = prefixed!("round1"))] +pub struct MsgRound1 { + /// $V_i$ + #[udigest(as_bytes)] + pub commitment: digest::Output, +} +/// Message from round 2 broadcasted to everyone +#[serde_with::serde_as] +#[derive(Clone, Serialize, Deserialize, udigest::Digestable)] +#[serde(bound = "")] +#[udigest(bound = "")] +#[udigest(tag = prefixed!("round2_broad"))] +pub struct MsgRound2Broad { + /// `rid_i` + #[serde_as(as = "utils::HexOrBin")] + #[udigest(as_bytes)] + pub rid: L::KappaBytes, + /// $X_{i,j}$ — public commitments to per-party share updates + pub X_updates: Vec>, + /// $A_i$ + pub sch_commit: schnorr_pok::Commit, + /// $u_i$ + #[serde(with = "hex::serde")] + #[udigest(as_bytes)] + pub decommit: L::KappaBytes, +} +/// Message from round 2 unicasted to each party +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MsgRound2Uni { + /// $x_{i,j}$ — secret share update for party $j$ + pub x_update: Scalar, +} +/// Message from round 3 +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MsgRound3 { + /// $\psi_i$ + pub sch_proof: schnorr_pok::Proof, +} +/// Message parties exchange to ensure reliability of broadcast channel +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MsgReliabilityCheck(pub digest::Output); + +mod unambiguous { + use crate::{ExecutionId, SecurityLevel}; + use generic_ec::Curve; + + #[derive(udigest::Digestable)] + #[udigest(tag = prefixed!("hash_commitment"))] + #[udigest(bound = "")] + pub struct HashCom<'a, E: Curve, L: SecurityLevel> { + pub sid: ExecutionId<'a>, + pub party_index: u16, + pub decommitment: &'a super::MsgRound2Broad, + } + + #[derive(udigest::Digestable)] + #[udigest(tag = prefixed!("schnorr_pok"))] + #[udigest(bound = "")] + pub struct SchnorrPok<'a, E: Curve> { + pub sid: ExecutionId<'a>, + pub prover: u16, + #[udigest(as_bytes)] + pub rid: &'a [u8], + pub X: &'a generic_ec::Point, + pub sch_commit: &'a generic_ec_zkp::schnorr_pok::Commit, + } + + #[derive(udigest::Digestable)] + #[udigest(tag = prefixed!("echo_round"))] + #[udigest(bound = "")] + pub struct Echo<'a, D: digest::Digest> { + pub sid: ExecutionId<'a>, + pub commitment: &'a super::MsgRound1, + } +} + +/// Run non-threshold key refresh protocol (Figure 7 of CGGMP24, share-update only) +/// +/// Each party i generates random values $x_{i,j}$ for all j with $\sum_j x_{i,j} = 0$, +/// sends $x_{i,j}$ privately to party j, and broadcasts commitments $X_{i,j} = g^{x_{i,j}}$. +/// Party j's new share is $x_j' = x_j + \sum_i x_{i,j}$. +/// The shared secret key is preserved since $\sum_j x_j' = \sum_j x_j + \sum_j \sum_i x_{i,j} = x$. +pub async fn run_key_refresh( + mut tracer: Option<&mut dyn Tracer>, + i: u16, + n: u16, + reliable_broadcast_enforced: bool, + sid: ExecutionId<'_>, + rng: &mut R, + party: M, + current_key_share: &CoreKeyShare, +) -> Result, KeyRefreshError> +where + E: Curve, + L: SecurityLevel, + D: Digest + Clone + 'static, + R: RngCore + CryptoRng, + M: Mpc>, +{ + tracer.protocol_begins(); + + tracer.stage("Setup networking"); + let MpcParty { delivery, .. } = party.into_party(); + let (incomings, mut outgoings) = delivery.split(); + + let mut rounds = RoundsRouter::>::builder(); + let round1 = rounds.add_round(RoundInput::>::broadcast(i, n)); + let round1_sync = rounds.add_round(RoundInput::>::broadcast(i, n)); + let round2_broad = rounds.add_round(RoundInput::>::broadcast(i, n)); + let round2_uni = rounds.add_round(RoundInput::>::p2p(i, n)); + let round3 = rounds.add_round(RoundInput::>::broadcast(i, n)); + let mut rounds = rounds.listen(incomings); + + // Round 1 + tracer.round_begins(); + + tracer.stage("Sample per-party share updates x_{i,j}"); + // Each party i picks random x_{i,j} for j != i, then sets x_{i,i} = -sum_{j!=i} x_{i,j} + // so that sum_j x_{i,j} = 0 + let mut x_updates = Vec::with_capacity(usize::from(n)); + let mut sum = Scalar::::zero(); + for j in 0..n { + if j == i { + x_updates.push(Scalar::::zero()); // placeholder + } else { + let val = Scalar::::random(rng); + sum = sum + val; + x_updates.push(val); + } + } + x_updates[usize::from(i)] = -sum; + + // Compute public commitments X_{i,j} = g^{x_{i,j}} + let X_updates: Vec> = x_updates + .iter() + .map(|x| Point::generator() * x) + .collect(); + + // The sum of X_{i,j} for my row should be the identity (zero point) + debug_assert!(X_updates.iter().copied().sum::>().is_zero()); + + // For the Schnorr proof, we prove knowledge of x_{i,i} (our self-update) + let my_x_self_update = x_updates[usize::from(i)]; + let my_X_self_update = X_updates[usize::from(i)]; + + let mut rid = L::KappaBytes::default(); + rng.fill_bytes(rid.as_mut()); + + tracer.stage("Sample schnorr commitment"); + let (sch_secret, sch_commit) = schnorr_pok::prover_commits_ephemeral_secret::(rng); + + tracer.stage("Commit to public data"); + let my_decommitment = MsgRound2Broad { + rid, + X_updates: X_updates.clone(), + sch_commit: sch_commit.clone(), + decommit: { + let mut nonce = L::KappaBytes::default(); + rng.fill_bytes(nonce.as_mut()); + nonce + }, + }; + let hash_commit = udigest::hash::(&unambiguous::HashCom { + sid, + party_index: i, + decommitment: &my_decommitment, + }); + let my_commitment = MsgRound1 { + commitment: hash_commit, + }; + + tracer.send_msg(); + outgoings + .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone()))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + // Round 2 + tracer.round_begins(); + + tracer.receive_msgs(); + let commitments = rounds + .complete(round1) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + // Optional reliability check + if reliable_broadcast_enforced { + tracer.stage("Hash received msgs (reliability check)"); + let h_i = udigest::hash_iter::( + commitments + .iter_including_me(&my_commitment) + .map(|commitment| unambiguous::Echo { sid, commitment }), + ); + + tracer.send_msg(); + outgoings + .send(Outgoing::broadcast(Msg::ReliabilityCheck( + MsgReliabilityCheck(h_i.clone()), + ))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + tracer.round_begins(); + + tracer.receive_msgs(); + let round1_hashes = rounds + .complete(round1_sync) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + tracer.stage("Assert other parties hashed messages (reliability check)"); + let parties_have_different_hashes = round1_hashes + .into_iter_indexed() + .filter(|(_j, _msg_id, hash_j)| hash_j.0 != h_i) + .map(|(j, msg_id, _)| (j, msg_id)) + .collect::>(); + if !parties_have_different_hashes.is_empty() { + return Err( + KeyRefreshAborted::Round1NotReliable(parties_have_different_hashes).into(), + ); + } + } + + tracer.send_msg(); + outgoings + .feed(Outgoing::broadcast(Msg::Round2Broad( + my_decommitment.clone(), + ))) + .await + .map_err(IoError::send_message)?; + + // Send p2p share updates x_{i,j} to each party j + let messages = utils::iter_peers(i, n).map(|j| { + let message = MsgRound2Uni { + x_update: x_updates[usize::from(j)], + }; + Outgoing::p2p(j, Msg::Round2Uni(message)) + }); + outgoings + .send_all(&mut futures_util::stream::iter(messages.map(Ok))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + // Round 3 + tracer.round_begins(); + + tracer.receive_msgs(); + let decommitments = rounds + .complete(round2_broad) + .await + .map_err(IoError::receive_message)?; + let x_updates_received = rounds + .complete(round2_uni) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + tracer.stage("Validate decommitments"); + let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| { + let com_expected = udigest::hash::(&unambiguous::HashCom { + sid, + party_index: j, + decommitment: decom, + }); + com.commitment != com_expected + }); + if !blame.is_empty() { + return Err(KeyRefreshAborted::InvalidDecommitment(blame).into()); + } + + tracer.stage("Validate data sizes"); + let blame = decommitments + .iter_indexed() + .filter(|(_, _, d)| d.X_updates.len() != usize::from(n)) + .map(|t| t.0) + .collect::>(); + if !blame.is_empty() { + return Err(KeyRefreshAborted::InvalidDataSize { parties: blame }.into()); + } + + tracer.stage("Verify row sums of X_{j,*} are zero (updates preserve public key)"); + let blame = decommitments + .iter_indexed() + .filter(|(_, _, d)| !d.X_updates.iter().copied().sum::>().is_zero()) + .map(|t| t.0) + .collect::>(); + if !blame.is_empty() { + return Err(KeyRefreshAborted::UpdatesDontSumToZero { parties: blame }.into()); + } + + tracer.stage("Verify received share updates against public commitments"); + // Each received x_{j,i} should satisfy g^{x_{j,i}} == X_{j,i} + let blame = decommitments + .iter_indexed() + .zip(x_updates_received.iter()) + .filter(|((_, _, d), msg)| { + d.X_updates[usize::from(i)] != Point::generator() * msg.x_update + }) + .map(|t| t.0 .0) + .collect::>(); + if !blame.is_empty() { + return Err(KeyRefreshAborted::ShareUpdateVerificationFailed { parties: blame }.into()); + } + + tracer.stage("Calculate challenge rid"); + let rid = decommitments + .iter_including_me(&my_decommitment) + .map(|d| &d.rid) + .fold(L::KappaBytes::default(), utils::xor_array); + let challenge = Scalar::from_hash::(&unambiguous::SchnorrPok { + sid, + prover: i, + rid: rid.as_ref(), + X: &my_X_self_update, + sch_commit: &sch_commit, + }); + let challenge = schnorr_pok::Challenge { nonce: challenge }; + + tracer.stage("Prove knowledge of self-update x_{i,i}"); + let sch_proof = schnorr_pok::prove(&sch_secret, &challenge, &my_x_self_update); + + tracer.send_msg(); + let my_sch_proof = MsgRound3 { sch_proof }; + outgoings + .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone()))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + // Round 4 + tracer.round_begins(); + + tracer.receive_msgs(); + let sch_proofs = rounds + .complete(round3) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + tracer.stage("Validate schnorr proofs"); + let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| { + let X_j_self = decom.X_updates[usize::from(j)]; + let challenge = Scalar::from_hash::(&unambiguous::SchnorrPok { + sid, + prover: j, + rid: rid.as_ref(), + X: &X_j_self, + sch_commit: &decom.sch_commit, + }); + let challenge = schnorr_pok::Challenge { nonce: challenge }; + sch_proof + .sch_proof + .verify(&decom.sch_commit, &challenge, &X_j_self) + .is_err() + }); + if !blame.is_empty() { + return Err(KeyRefreshAborted::InvalidSchnorrProof(blame).into()); + } + + tracer.stage("Compute updated key share"); + // New secret share: x_i' = x_i + sum_j x_{j,i} + // = x_i + x_{i,i} + sum_{j!=i} x_{j,i} + let total_update: Scalar = x_updates_received.iter().map(|msg| msg.x_update).sum(); + let old_x: &Scalar = current_key_share.x.as_ref(); + let mut new_x = *old_x + my_x_self_update + total_update; + let new_x = + NonZero::from_secret_scalar(SecretScalar::new(&mut new_x)).ok_or(Bug::ZeroShare)?; + + // New public shares: X_j' = X_j + sum_i X_{i,j} + let new_public_shares = (0..n) + .map(|j| { + let j_idx = usize::from(j); + let old_X_j = current_key_share.public_shares[j_idx]; + let update_sum: Point = decommitments + .iter_including_me(&my_decommitment) + .map(|d| d.X_updates[j_idx]) + .sum(); + NonZero::from_point(*old_X_j + update_sum).ok_or(Bug::ZeroShare) + }) + .collect::, _>>()?; + + tracer.protocol_ends(); + + Ok(DirtyCoreKeyShare { + i, + key_info: DirtyKeyInfo { + curve: Default::default(), + shared_public_key: current_key_share.shared_public_key, + public_shares: new_public_shares, + vss_setup: None, + #[cfg(feature = "hd-wallet")] + chain_code: current_key_share.chain_code, + }, + x: new_x, + } + .validate() + .map_err(|e| Bug::InvalidKeyShare(e.into_error()))?) +} diff --git a/cggmp24-keygen/src/key_refresh_threshold.rs b/cggmp24-keygen/src/key_refresh_threshold.rs new file mode 100644 index 00000000..3022079d --- /dev/null +++ b/cggmp24-keygen/src/key_refresh_threshold.rs @@ -0,0 +1,467 @@ +use alloc::vec::Vec; + +use digest::Digest; +use generic_ec::{Curve, NonZero, Point, Scalar, SecretScalar}; +use generic_ec_zkp::{polynomial::Polynomial, schnorr_pok}; +use rand_core::{CryptoRng, RngCore}; +use round_based::{ + rounds_router::simple_store::RoundInput, rounds_router::RoundsRouter, Delivery, Mpc, MpcParty, + Outgoing, ProtocolMessage, SinkExt, +}; +use serde::{Deserialize, Serialize}; +use serde_with::serde_as; + +use crate::progress::Tracer; +use crate::{ + errors::IoError, + key_share::{CoreKeyShare, DirtyCoreKeyShare, DirtyKeyInfo, Validate, VssSetup}, + security_level::SecurityLevel, + utils, ExecutionId, +}; + +use super::{Bug, KeyRefreshAborted, KeyRefreshError}; + +macro_rules! prefixed { + ($name:tt) => { + concat!("dfns.cggmp24.key_refresh.threshold.", $name) + }; +} + +/// Message of threshold key refresh protocol +#[derive(ProtocolMessage, Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub enum Msg { + /// Round 1 message + Round1(MsgRound1), + /// Round 2a broadcast message + Round2Broad(MsgRound2Broad), + /// Round 2b unicast message + Round2Uni(MsgRound2Uni), + /// Round 3 message + Round3(MsgRound3), + /// Reliability check message (optional additional round) + ReliabilityCheck(MsgReliabilityCheck), +} + +/// Message from round 1 +#[derive(Clone, Serialize, Deserialize, udigest::Digestable)] +#[serde(bound = "")] +#[udigest(bound = "")] +#[udigest(tag = prefixed!("round1"))] +pub struct MsgRound1 { + /// $V_i$ + #[udigest(as_bytes)] + pub commitment: digest::Output, +} +/// Message from round 2 broadcasted to everyone +#[serde_as] +#[derive(Clone, Serialize, Deserialize, udigest::Digestable)] +#[serde(bound = "")] +#[udigest(bound = "")] +#[udigest(tag = prefixed!("round2_broad"))] +pub struct MsgRound2Broad { + /// `rid_i` + #[serde_as(as = "utils::HexOrBin")] + #[udigest(as_bytes)] + pub rid: L::KappaBytes, + /// $\vec F_i$ — commitment polynomial (public coefficients of update polynomial with zero constant term) + pub F: Polynomial>, + /// $A_i$ + pub sch_commit: schnorr_pok::Commit, + /// $u_i$ + #[serde(with = "hex::serde")] + #[udigest(as_bytes)] + pub decommit: L::KappaBytes, +} +/// Message from round 2 unicasted to each party +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MsgRound2Uni { + /// $\sigma_{i,j}$ — evaluation of update polynomial at party j's index + pub sigma: Scalar, +} +/// Message from round 3 +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MsgRound3 { + /// $\psi_i$ + pub sch_proof: schnorr_pok::Proof, +} +/// Message parties exchange to ensure reliability of broadcast channel +#[derive(Clone, Serialize, Deserialize)] +#[serde(bound = "")] +pub struct MsgReliabilityCheck(pub digest::Output); + +mod unambiguous { + use generic_ec::{Curve, Point}; + + use crate::{ExecutionId, SecurityLevel}; + + #[derive(udigest::Digestable)] + #[udigest(tag = prefixed!("hash_commitment"))] + #[udigest(bound = "")] + pub struct HashCom<'a, E: Curve, L: SecurityLevel> { + pub sid: ExecutionId<'a>, + pub party_index: u16, + pub decommitment: &'a super::MsgRound2Broad, + } + + #[derive(udigest::Digestable)] + #[udigest(tag = prefixed!("schnorr_pok"))] + #[udigest(bound = "")] + pub struct SchnorrPok<'a, E: Curve> { + pub sid: ExecutionId<'a>, + pub prover: u16, + #[udigest(as_bytes)] + pub rid: &'a [u8], + pub y: Point, + pub h: Point, + } + + #[derive(udigest::Digestable)] + #[udigest(tag = prefixed!("echo_round"))] + #[udigest(bound = "")] + pub struct Echo<'a, D: digest::Digest> { + pub sid: ExecutionId<'a>, + pub commitment: &'a super::MsgRound1, + } +} + +/// Run threshold key refresh protocol (CGGMP24 Appendix F.1.1, share-update only) +/// +/// Each party i generates a random degree-$(t-1)$ polynomial $f_i(x)$ with $f_i(0) = 0$, +/// shares evaluations $f_i(j)$ with each party j via Feldman VSS, and proves knowledge. +/// Party j's new share is $x_j' = x_j + \sum_i f_i(j)$. +/// The shared secret key is preserved since $F(0) = \sum_i f_i(0) = 0$. +pub async fn run_threshold_key_refresh( + mut tracer: Option<&mut dyn Tracer>, + i: u16, + t: u16, + n: u16, + reliable_broadcast_enforced: bool, + sid: ExecutionId<'_>, + rng: &mut R, + party: M, + current_key_share: &CoreKeyShare, +) -> Result, KeyRefreshError> +where + E: Curve, + L: SecurityLevel, + D: Digest + Clone + 'static, + R: RngCore + CryptoRng, + M: Mpc>, +{ + tracer.protocol_begins(); + + tracer.stage("Setup networking"); + let MpcParty { delivery, .. } = party.into_party(); + let (incomings, mut outgoings) = delivery.split(); + + let mut rounds = RoundsRouter::>::builder(); + let round1 = rounds.add_round(RoundInput::>::broadcast(i, n)); + let round1_sync = rounds.add_round(RoundInput::>::broadcast(i, n)); + let round2_broad = rounds.add_round(RoundInput::>::broadcast(i, n)); + let round2_uni = rounds.add_round(RoundInput::>::p2p(i, n)); + let round3 = rounds.add_round(RoundInput::>::broadcast(i, n)); + let mut rounds = rounds.listen(incomings); + + // Round 1 + tracer.round_begins(); + + tracer.stage("Sample update polynomial with zero constant term, rid_i, schnorr commitment"); + let mut rid = L::KappaBytes::default(); + rng.fill_bytes(rid.as_mut()); + + let (r, h) = schnorr_pok::prover_commits_ephemeral_secret::(rng); + + // Generate degree-(t-1) polynomial f_i with f_i(0) = 0 + // That means the constant term is zero, and we sample t-1 random coefficients + let f = Polynomial::>::sample_with_const_term( + rng, + usize::from(t) - 1, + SecretScalar::new(&mut Scalar::::zero()), + ); + let F = &f * &Point::generator(); + + // Verify constant term is zero (public commitment) + debug_assert!(F.coefs()[0].is_zero()); + + // Evaluate f_i at each party's index + let sigmas: Vec> = (0..n) + .map(|j| { + let x = Scalar::from(j + 1); + f.value(&x) + }) + .collect(); + + tracer.stage("Commit to public data"); + let my_decommitment = MsgRound2Broad { + rid, + F: F.clone(), + sch_commit: h, + decommit: { + let mut nonce = L::KappaBytes::default(); + rng.fill_bytes(nonce.as_mut()); + nonce + }, + }; + let hash_commit = udigest::hash::(&unambiguous::HashCom { + sid, + party_index: i, + decommitment: &my_decommitment, + }); + + tracer.send_msg(); + let my_commitment = MsgRound1 { + commitment: hash_commit, + }; + outgoings + .send(Outgoing::broadcast(Msg::Round1(my_commitment.clone()))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + // Round 2 + tracer.round_begins(); + + tracer.receive_msgs(); + let commitments = rounds + .complete(round1) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + // Optional reliability check + if reliable_broadcast_enforced { + tracer.stage("Hash received msgs (reliability check)"); + let h_i = udigest::hash_iter::( + commitments + .iter_including_me(&my_commitment) + .map(|commitment| unambiguous::Echo { sid, commitment }), + ); + + tracer.send_msg(); + outgoings + .send(Outgoing::broadcast(Msg::ReliabilityCheck( + MsgReliabilityCheck(h_i.clone()), + ))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + tracer.round_begins(); + + tracer.receive_msgs(); + let hashes = rounds + .complete(round1_sync) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + tracer.stage("Assert other parties hashed messages (reliability check)"); + let parties_have_different_hashes = hashes + .into_iter_indexed() + .filter(|(_j, _msg_id, h_j)| h_i != h_j.0) + .map(|(j, msg_id, _)| (j, msg_id)) + .collect::>(); + if !parties_have_different_hashes.is_empty() { + return Err( + KeyRefreshAborted::Round1NotReliable(parties_have_different_hashes).into(), + ); + } + } + + tracer.send_msg(); + outgoings + .feed(Outgoing::broadcast(Msg::Round2Broad( + my_decommitment.clone(), + ))) + .await + .map_err(IoError::send_message)?; + + let messages = utils::iter_peers(i, n).map(|j| { + let message = MsgRound2Uni { + sigma: sigmas[usize::from(j)], + }; + Outgoing::p2p(j, Msg::Round2Uni(message)) + }); + outgoings + .send_all(&mut futures_util::stream::iter(messages.map(Ok))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + // Round 3 + tracer.round_begins(); + + tracer.receive_msgs(); + let decommitments = rounds + .complete(round2_broad) + .await + .map_err(IoError::receive_message)?; + let sigmas_msg = rounds + .complete(round2_uni) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + tracer.stage("Validate decommitments"); + let blame = utils::collect_blame(&commitments, &decommitments, |j, com, decom| { + let com_expected = udigest::hash::(&unambiguous::HashCom { + sid, + party_index: j, + decommitment: decom, + }); + com.commitment != com_expected + }); + if !blame.is_empty() { + return Err(KeyRefreshAborted::InvalidDecommitment(blame).into()); + } + + tracer.stage("Validate polynomial degrees"); + let blame = decommitments + .iter_indexed() + .filter(|(_, _, d)| d.F.degree() + 1 != usize::from(t)) + .map(|t| t.0) + .collect::>(); + if !blame.is_empty() { + return Err(KeyRefreshAborted::InvalidDataSize { parties: blame }.into()); + } + + tracer.stage("Verify constant terms are zero (updates must preserve secret key)"); + let blame = decommitments + .iter_indexed() + .filter(|(_, _, d)| !d.F.coefs()[0].is_zero()) + .map(|t| t.0) + .collect::>(); + if !blame.is_empty() { + return Err(KeyRefreshAborted::NonZeroConstantTerm { parties: blame }.into()); + } + + tracer.stage("Validate Feldman VSS for update polynomials"); + let blame = decommitments + .iter_indexed() + .zip(sigmas_msg.iter()) + .filter(|((_, _, d), s)| { + d.F.value::<_, Point<_>>(&Scalar::from(i + 1)) != Point::generator() * s.sigma + }) + .map(|t| t.0 .0) + .collect::>(); + if !blame.is_empty() { + return Err(KeyRefreshAborted::FeldmanVerificationFailed { parties: blame }.into()); + } + + tracer.stage("Compute rid"); + let rid = decommitments + .iter_including_me(&my_decommitment) + .map(|d| &d.rid) + .fold(L::KappaBytes::default(), utils::xor_array); + + tracer.stage("Compute updated public shares"); + // The sum of update polynomials (public): F_sum = sum_i F_i + let polynomial_update_sum = decommitments + .iter_including_me(&my_decommitment) + .map(|d| &d.F) + .sum::>(); + + // New public shares: Y_j' = Y_j + F_sum(j+1) + let new_public_shares = (0..n) + .map(|l| { + let old_Y_l = current_key_share.public_shares[usize::from(l)]; + let update = polynomial_update_sum.value::<_, Point>(&Scalar::from(l + 1)); + NonZero::from_point(*old_Y_l + update).ok_or(Bug::ZeroShare) + }) + .collect::, _>>()?; + + tracer.stage("Compute updated secret share"); + // sigma_update_i = sum_j sigma_{j,i} (received from others) + sigma_{i,i} (own) + let sigma_update: Scalar = sigmas_msg.iter().map(|msg| msg.sigma).sum(); + let mut new_x = + { + let old_x: &Scalar = current_key_share.x.as_ref(); + *old_x + sigma_update + sigmas[usize::from(i)] + }; + let new_x = + NonZero::from_secret_scalar(SecretScalar::new(&mut new_x)).ok_or(Bug::ZeroShare)?; + debug_assert_eq!( + Point::generator() * &new_x, + new_public_shares[usize::from(i)] + ); + + tracer.stage("Calculate challenge"); + let challenge = Scalar::from_hash::(&unambiguous::SchnorrPok { + sid, + prover: i, + rid: rid.as_ref(), + y: *new_public_shares[usize::from(i)], + h: my_decommitment.sch_commit.0, + }); + let challenge = schnorr_pok::Challenge { nonce: challenge }; + + tracer.stage("Prove knowledge of updated share"); + let z = schnorr_pok::prove(&r, &challenge, &new_x); + + tracer.send_msg(); + let my_sch_proof = MsgRound3 { sch_proof: z }; + outgoings + .send(Outgoing::broadcast(Msg::Round3(my_sch_proof.clone()))) + .await + .map_err(IoError::send_message)?; + tracer.msg_sent(); + + // Output round + tracer.round_begins(); + + tracer.receive_msgs(); + let sch_proofs = rounds + .complete(round3) + .await + .map_err(IoError::receive_message)?; + tracer.msgs_received(); + + tracer.stage("Validate schnorr proofs"); + let blame = utils::collect_blame(&decommitments, &sch_proofs, |j, decom, sch_proof| { + let challenge = Scalar::from_hash::(&unambiguous::SchnorrPok { + sid, + prover: j, + rid: rid.as_ref(), + y: *new_public_shares[usize::from(j)], + h: decom.sch_commit.0, + }); + let challenge = schnorr_pok::Challenge { nonce: challenge }; + let y_j: Point = *new_public_shares[usize::from(j)]; + sch_proof + .sch_proof + .verify(&decom.sch_commit, &challenge, &y_j) + .is_err() + }); + if !blame.is_empty() { + return Err(KeyRefreshAborted::InvalidSchnorrProof(blame).into()); + } + + tracer.stage("Construct updated key share"); + let key_shares_indexes = (1..=n) + .map(|idx| NonZero::from_scalar(Scalar::from(idx))) + .collect::>>() + .ok_or(Bug::NonZeroScalar)?; + + tracer.protocol_ends(); + + Ok(DirtyCoreKeyShare { + i, + key_info: DirtyKeyInfo { + curve: Default::default(), + shared_public_key: current_key_share.shared_public_key, + public_shares: new_public_shares, + vss_setup: Some(VssSetup { + min_signers: t, + I: key_shares_indexes, + }), + #[cfg(feature = "hd-wallet")] + chain_code: current_key_share.chain_code, + }, + x: new_x, + } + .validate() + .map_err(|err| Bug::InvalidKeyShare(err.into_error()))?) +} diff --git a/cggmp24-keygen/src/lib.rs b/cggmp24-keygen/src/lib.rs index c834ccf5..cdc82f33 100644 --- a/cggmp24-keygen/src/lib.rs +++ b/cggmp24-keygen/src/lib.rs @@ -21,6 +21,11 @@ mod non_threshold; /// Threshold DKG specific types mod threshold; +/// Non-threshold key refresh specific types +mod key_refresh_non_threshold; +/// Threshold key refresh specific types +mod key_refresh_threshold; + mod errors; mod execution_id; mod utils; @@ -45,6 +50,11 @@ use crate::{ pub use self::execution_id::ExecutionId; #[doc(no_inline)] pub use self::msg::{non_threshold::Msg as NonThresholdMsg, threshold::Msg as ThresholdMsg}; +#[doc(no_inline)] +pub use self::msg::key_refresh::{ + NonThresholdMsg as KeyRefreshNonThresholdMsg, + ThresholdMsg as KeyRefreshThresholdMsg, +}; /// Defines default choice for digest and security level used across the crate mod default_choice { @@ -64,6 +74,23 @@ pub mod msg { Msg, MsgReliabilityCheck, MsgRound1, MsgRound2Broad, MsgRound2Uni, MsgRound3, }; } + /// Messages types related to key refresh protocols + pub mod key_refresh { + /// Messages for non-threshold key refresh + pub mod non_threshold { + pub use crate::key_refresh_non_threshold::{ + Msg, MsgReliabilityCheck, MsgRound1, MsgRound2Broad, MsgRound2Uni, MsgRound3, + }; + } + /// Messages for threshold key refresh + pub mod threshold { + pub use crate::key_refresh_threshold::{ + Msg, MsgReliabilityCheck, MsgRound1, MsgRound2Broad, MsgRound2Uni, MsgRound3, + }; + } + pub use non_threshold::Msg as NonThresholdMsg; + pub use threshold::Msg as ThresholdMsg; + } } /// Key generation entry point. You can call [`set_threshold`] to make it into a @@ -372,3 +399,252 @@ enum Bug { pub fn keygen(eid: ExecutionId, i: u16, n: u16) -> KeygenBuilder { KeygenBuilder::new(eid, i, n) } + +/// Key refresh entry point for non-threshold keys +pub type KeyRefreshBuilder< + 'a, + E, + L = crate::default_choice::SecurityLevel, + D = crate::default_choice::Digest, +> = GenericKeyRefreshBuilder<'a, E, NonThreshold, L, D>; + +/// Key refresh entry point for threshold keys +pub type ThresholdKeyRefreshBuilder< + 'a, + E, + L = crate::default_choice::SecurityLevel, + D = crate::default_choice::Digest, +> = GenericKeyRefreshBuilder<'a, E, WithThreshold, L, D>; + +/// Key refresh builder with choice for threshold or non-threshold variant +pub struct GenericKeyRefreshBuilder<'a, E: Curve, M, L: SecurityLevel, D: Digest> { + i: u16, + n: u16, + reliable_broadcast_enforced: bool, + optional_t: M, + execution_id: ExecutionId<'a>, + tracer: Option<&'a mut dyn Tracer>, + current_key_share: &'a CoreKeyShare, + _params: core::marker::PhantomData<(E, L, D)>, +} + +impl<'a, E, L, D> GenericKeyRefreshBuilder<'a, E, NonThreshold, L, D> +where + E: Curve, + L: SecurityLevel, + D: Digest + Clone + 'static, +{ + /// Constructs [KeyRefreshBuilder] + /// + /// Takes local party index $i$, number of parties $n$, and the current key share + pub fn new(eid: ExecutionId<'a>, i: u16, n: u16, key_share: &'a CoreKeyShare) -> Self { + Self { + i, + n, + optional_t: NonThreshold, + reliable_broadcast_enforced: true, + execution_id: eid, + tracer: None, + current_key_share: key_share, + _params: core::marker::PhantomData, + } + } +} + +impl<'a, E, L, D, M> GenericKeyRefreshBuilder<'a, E, M, L, D> +where + E: Curve, + L: SecurityLevel, + D: Digest + Clone + 'static, +{ + /// Specifies to refresh key shares for a threshold scheme + pub fn set_threshold(self, t: u16) -> GenericKeyRefreshBuilder<'a, E, WithThreshold, L, D> { + GenericKeyRefreshBuilder { + i: self.i, + n: self.n, + optional_t: WithThreshold(t), + reliable_broadcast_enforced: self.reliable_broadcast_enforced, + execution_id: self.execution_id, + tracer: self.tracer, + current_key_share: self.current_key_share, + _params: core::marker::PhantomData, + } + } + + /// Specifies another hash function to use + pub fn set_digest(self) -> GenericKeyRefreshBuilder<'a, E, M, L, D2> + where + D2: Digest + Clone + 'static, + { + GenericKeyRefreshBuilder { + i: self.i, + n: self.n, + optional_t: self.optional_t, + reliable_broadcast_enforced: self.reliable_broadcast_enforced, + execution_id: self.execution_id, + tracer: self.tracer, + current_key_share: self.current_key_share, + _params: core::marker::PhantomData, + } + } + + /// Specifies [security level](crate::security_level) + pub fn set_security_level(self) -> GenericKeyRefreshBuilder<'a, E, M, L2, D> + where + L2: SecurityLevel, + { + GenericKeyRefreshBuilder { + i: self.i, + n: self.n, + optional_t: self.optional_t, + reliable_broadcast_enforced: self.reliable_broadcast_enforced, + execution_id: self.execution_id, + tracer: self.tracer, + current_key_share: self.current_key_share, + _params: core::marker::PhantomData, + } + } + + /// Sets a tracer that tracks progress of protocol execution + pub fn set_progress_tracer(mut self, tracer: &'a mut dyn Tracer) -> Self { + self.tracer = Some(tracer); + self + } + + #[doc = include_str!("../docs/enforce_reliable_broadcast.md")] + pub fn enforce_reliable_broadcast(self, enforce: bool) -> Self { + Self { + reliable_broadcast_enforced: enforce, + ..self + } + } +} + +impl<'a, E, L, D> GenericKeyRefreshBuilder<'a, E, NonThreshold, L, D> +where + E: Curve, + L: SecurityLevel, + D: Digest + Clone + 'static, +{ + /// Starts non-threshold key refresh + pub async fn start( + self, + rng: &mut R, + party: M, + ) -> Result, KeyRefreshError> + where + R: RngCore + CryptoRng, + M: Mpc>, + { + key_refresh_non_threshold::run_key_refresh( + self.tracer, + self.i, + self.n, + self.reliable_broadcast_enforced, + self.execution_id, + rng, + party, + self.current_key_share, + ) + .await + } +} + +impl<'a, E, L, D> GenericKeyRefreshBuilder<'a, E, WithThreshold, L, D> +where + E: Curve, + L: SecurityLevel, + D: Digest + Clone + 'static, +{ + /// Starts threshold key refresh + pub async fn start( + self, + rng: &mut R, + party: M, + ) -> Result, KeyRefreshError> + where + R: RngCore + CryptoRng, + M: Mpc>, + { + key_refresh_threshold::run_threshold_key_refresh( + self.tracer, + self.i, + self.optional_t.0, + self.n, + self.reliable_broadcast_enforced, + self.execution_id, + rng, + party, + self.current_key_share, + ) + .await + } +} + +/// Key refresh protocol error +#[derive(Debug, displaydoc::Display)] +#[cfg_attr(feature = "std", derive(thiserror::Error))] +#[displaydoc("key refresh protocol failed to complete")] +pub struct KeyRefreshError(#[cfg_attr(feature = "std", source)] KeyRefreshReason); + +crate::errors::impl_from! { + impl From for KeyRefreshError { + err: KeyRefreshAborted => KeyRefreshError(KeyRefreshReason::Aborted(err)), + err: IoError => KeyRefreshError(KeyRefreshReason::IoError(err)), + err: Bug => KeyRefreshError(KeyRefreshReason::Bug(err)), + } +} + +#[derive(Debug, displaydoc::Display)] +#[cfg_attr(feature = "std", derive(thiserror::Error))] +enum KeyRefreshReason { + /// Protocol was maliciously aborted by another party + #[displaydoc("protocol was aborted by malicious party")] + Aborted(#[cfg_attr(feature = "std", source)] KeyRefreshAborted), + #[displaydoc("i/o error")] + IoError(#[cfg_attr(feature = "std", source)] IoError), + /// Bug occurred + #[displaydoc("bug occurred")] + Bug(Bug), +} + +impl From for KeyRefreshReason { + fn from(err: KeyRefreshAborted) -> Self { + KeyRefreshReason::Aborted(err) + } +} + +/// Error indicating that key refresh was aborted by malicious party +#[derive(Debug, displaydoc::Display)] +#[cfg_attr(feature = "std", derive(thiserror::Error))] +enum KeyRefreshAborted { + #[displaydoc("party decommitment doesn't match commitment: {0:?}")] + InvalidDecommitment(Vec), + #[displaydoc("party provided invalid schnorr proof: {0:?}")] + InvalidSchnorrProof(Vec), + #[displaydoc("party secret share update is not consistent: {parties:?}")] + FeldmanVerificationFailed { parties: Vec }, + #[displaydoc("party data size is not suitable for parameters: {parties:?}")] + InvalidDataSize { parties: Vec }, + #[displaydoc("party update polynomial has non-zero constant term: {parties:?}")] + NonZeroConstantTerm { parties: Vec }, + #[displaydoc("party share update doesn't match public commitment: {parties:?}")] + ShareUpdateVerificationFailed { parties: Vec }, + #[displaydoc("party update public keys don't sum to zero: {parties:?}")] + UpdatesDontSumToZero { parties: Vec }, + #[displaydoc("round1 wasn't reliable")] + Round1NotReliable(Vec<(PartyIndex, MsgId)>), +} + +/// Key share refresh protocol for non-threshold keys +/// +/// Each party of the protocol should have uniquely assigned index $i$ such that $0 \le i < n$. +/// All parties must provide their current key share from a previous keygen or key refresh. +pub fn key_refresh<'a, E: Curve>( + eid: ExecutionId<'a>, + i: u16, + n: u16, + key_share: &'a CoreKeyShare, +) -> KeyRefreshBuilder<'a, E> { + KeyRefreshBuilder::new(eid, i, n, key_share) +} diff --git a/cggmp24/src/lib.rs b/cggmp24/src/lib.rs index 68f638f4..369c8890 100644 --- a/cggmp24/src/lib.rs +++ b/cggmp24/src/lib.rs @@ -25,7 +25,6 @@ //! * [Trusted dealer](crate::trusted_dealer) (importing key into TSS) //! //! This crate **does not** (currently) support: -//! * Key refresh for both threshold (i.e., t-out-of-n) and non-threshold (i.e., n-out-of-n) keys //! * Identifiable abort //! //! Our implementation has been audited by Kudelski. Report can be found [here][report]. @@ -334,6 +333,8 @@ pub use { #[doc(inline)] pub use cggmp24_keygen::{keygen, progress, ExecutionId}; +/// Re-export of the key share refresh entry point from cggmp24-keygen +pub use cggmp24_keygen::key_refresh as share_refresh; use generic_ec::{coords::HasAffineX, Curve, Point}; use round_based::PartyIndex; @@ -373,6 +374,15 @@ pub mod keygen { pub use msg::non_threshold::Msg as NonThresholdMsg; pub use msg::threshold::Msg as ThresholdMsg; + + /// Key share refresh protocols + pub mod share_refresh { + pub use cggmp24_keygen::{ + key_refresh, GenericKeyRefreshBuilder, KeyRefreshBuilder, KeyRefreshError, + ThresholdKeyRefreshBuilder, + }; + pub use cggmp24_keygen::msg::key_refresh as msg; + } } pub use self::{ diff --git a/tests/tests/it/main.rs b/tests/tests/it/main.rs index aa357215..981e6824 100644 --- a/tests/tests/it/main.rs +++ b/tests/tests/it/main.rs @@ -2,6 +2,7 @@ mod key_refresh; mod keygen; mod old_shares; mod pipeline; +mod share_refresh; mod signing; mod stark_prehashed; mod trusted_dealer; diff --git a/tests/tests/it/share_refresh.rs b/tests/tests/it/share_refresh.rs new file mode 100644 index 00000000..d458166f --- /dev/null +++ b/tests/tests/it/share_refresh.rs @@ -0,0 +1,199 @@ +use generic_ec::{Curve, Point}; +use rand::{seq::SliceRandom, Rng}; +use rand_dev::DevRng; + +use cggmp24::{key_share::reconstruct_secret_key, ExecutionId}; + +cggmp24_tests::test_suite! { + test: non_threshold_share_refresh_works, + generics: all_curves, + suites: { + n3: (3, false), + n5: (5, false), + n5_reliable: (5, true), + } +} +fn non_threshold_share_refresh_works(n: u16, reliable_broadcast: bool) +where + E: Curve + cggmp24_tests::CurveParams, +{ + let mut rng = DevRng::new(); + + // First, generate key shares via keygen + let eid: [u8; 32] = rng.gen(); + let eid = ExecutionId::new(&eid); + + let key_shares = round_based::sim::run(n, |i, party| { + let party = cggmp24_tests::buffer_outgoing(party); + let mut party_rng = rng.fork(); + + async move { + cggmp24::keygen::(eid, i, n) + .set_security_level::() + .set_digest::() + .enforce_reliable_broadcast(false) + .start(&mut party_rng, party) + .await + } + }) + .unwrap() + .expect_ok() + .into_vec(); + + let original_pk = key_shares[0].shared_public_key; + let original_sk = reconstruct_secret_key(&key_shares).unwrap(); + + // Now run key share refresh + let eid: [u8; 32] = rng.gen(); + let eid = ExecutionId::new(&eid); + + let refreshed_shares = round_based::sim::run_with_setup( + key_shares.iter(), + |i, party, key_share| { + let party = cggmp24_tests::buffer_outgoing(party); + let mut party_rng = rng.fork(); + async move { + cggmp24::share_refresh::(eid, i, n, key_share) + .set_security_level::() + .set_digest::() + .enforce_reliable_broadcast(reliable_broadcast) + .start(&mut party_rng, party) + .await + } + }, + ) + .unwrap() + .expect_ok() + .into_vec(); + + // Verify: public key is preserved + for share in &refreshed_shares { + assert_eq!(share.shared_public_key, original_pk); + } + + // Verify: each party's public share matches their secret share + for (i, share) in refreshed_shares.iter().enumerate() { + assert_eq!( + Point::::generator() * &share.x, + share.public_shares[i] + ); + } + + // Verify: secret shares changed (with overwhelming probability) + let shares_changed = key_shares + .iter() + .zip(&refreshed_shares) + .any(|(old, new)| Point::::generator() * &old.x != Point::::generator() * &new.x); + assert!(shares_changed, "secret shares should change after refresh"); + + // Verify: reconstructed secret key is the same + let refreshed_sk = reconstruct_secret_key(&refreshed_shares).unwrap(); + assert_eq!( + Point::::generator() * &original_sk, + Point::::generator() * &refreshed_sk, + "reconstructed secret key should be preserved" + ); +} + +cggmp24_tests::test_suite! { + test: threshold_share_refresh_works, + generics: all_curves, + suites: { + t2n3: (2, 3, false), + t3n5: (3, 5, false), + t3n5_reliable: (3, 5, true), + } +} +fn threshold_share_refresh_works(t: u16, n: u16, reliable_broadcast: bool) +where + E: Curve + cggmp24_tests::CurveParams, +{ + let mut rng = DevRng::new(); + + // First, generate threshold key shares via keygen + let eid: [u8; 32] = rng.gen(); + let eid = ExecutionId::new(&eid); + + let key_shares = round_based::sim::run(n, |i, party| { + let party = cggmp24_tests::buffer_outgoing(party); + let mut party_rng = rng.fork(); + + async move { + cggmp24::keygen::(eid, i, n) + .set_security_level::() + .set_digest::() + .enforce_reliable_broadcast(false) + .set_threshold(t) + .start(&mut party_rng, party) + .await + } + }) + .unwrap() + .expect_ok() + .into_vec(); + + let original_pk = key_shares[0].shared_public_key; + + // Reconstruct secret key from t shares + let t_shares: Vec<_> = key_shares + .choose_multiple(&mut rng, t.into()) + .cloned() + .collect(); + let original_sk = reconstruct_secret_key(&t_shares).unwrap(); + + // Now run threshold key share refresh + let eid: [u8; 32] = rng.gen(); + let eid = ExecutionId::new(&eid); + + let refreshed_shares = round_based::sim::run_with_setup( + key_shares.iter(), + |i, party, key_share| { + let party = cggmp24_tests::buffer_outgoing(party); + let mut party_rng = rng.fork(); + async move { + cggmp24::share_refresh::(eid, i, n, key_share) + .set_security_level::() + .set_digest::() + .enforce_reliable_broadcast(reliable_broadcast) + .set_threshold(t) + .start(&mut party_rng, party) + .await + } + }, + ) + .unwrap() + .expect_ok() + .into_vec(); + + // Verify: public key is preserved + for share in &refreshed_shares { + assert_eq!(share.shared_public_key, original_pk); + } + + // Verify: each party's public share matches their secret share + for (i, share) in refreshed_shares.iter().enumerate() { + assert_eq!( + Point::::generator() * &share.x, + share.public_shares[i] + ); + } + + // Verify: secret shares changed + let shares_changed = key_shares + .iter() + .zip(&refreshed_shares) + .any(|(old, new)| Point::::generator() * &old.x != Point::::generator() * &new.x); + assert!(shares_changed, "secret shares should change after refresh"); + + // Verify: reconstructed secret key is the same (using t random shares) + let t_refreshed: Vec<_> = refreshed_shares + .choose_multiple(&mut rng, t.into()) + .cloned() + .collect(); + let refreshed_sk = reconstruct_secret_key(&t_refreshed).unwrap(); + assert_eq!( + Point::::generator() * &original_sk, + Point::::generator() * &refreshed_sk, + "reconstructed secret key should be preserved after threshold refresh" + ); +}