diff --git a/Cargo.lock b/Cargo.lock index 5c5d7fa2b..299169db4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1281,6 +1281,15 @@ dependencies = [ "serde", ] +[[package]] +name = "built" +version = "0.7.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56ed6191a7e78c36abdb16ab65341eefd73d64d303fffccdbb00d51e4205967b" +dependencies = [ + "cargo-lock", +] + [[package]] name = "bumpalo" version = "3.19.1" @@ -1380,6 +1389,18 @@ dependencies = [ "serde_yaml_ng", ] +[[package]] +name = "cargo-lock" +version = "10.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06acb4f71407ba205a07cb453211e0e6a67b21904e47f6ba1f9589e38f2e454" +dependencies = [ + "semver", + "serde", + "toml 0.8.23", + "url", +] + [[package]] name = "cargo_toml" version = "0.19.2" @@ -4526,6 +4547,21 @@ dependencies = [ "tracing", ] +[[package]] +name = "iroh-fake-store" +version = "0.1.1" +source = "git+https://github.com/IAvecilla/iroh-fake-store?branch=fake-store-update#73ca4293e23dd048890a9d269f6f1baf4ab5db21" +dependencies = [ + "anyhow", + "bao-tree", + "bytes", + "iroh-blobs", + "irpc", + "range-collections", + "ref-cast", + "tokio", +] + [[package]] name = "iroh-gossip" version = "0.96.0" @@ -4622,50 +4658,6 @@ dependencies = [ "syn 2.0.115", ] -[[package]] -name = "iroh-n0des" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c953c0ecee4e35043433855a06f7358430c6bd9bf0adb088b8198d14c2606095" -dependencies = [ - "anyhow", - "bytes", - "derive_more", - "ed25519-dalek 3.0.0-pre.1", - "futures-buffered", - "getrandom 0.3.4", - "iroh", - "iroh-metrics 0.38.2", - "iroh-n0des-macro", - "iroh-tickets", - "irpc", - "irpc-iroh", - "n0-error", - "n0-future", - "postcard", - "rand 0.9.2", - "rcan", - "serde", - "serde_json", - "strum 0.27.2", - "thiserror 2.0.18", - "tokio", - "tracing", - "tracing-subscriber", - "uuid", -] - -[[package]] -name = "iroh-n0des-macro" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e15d38b6ae3d9480e49883bea72880f80d595276e34090f5096d844e6f7f5e40" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - [[package]] name = "iroh-quinn" version = "0.16.1" @@ -4774,6 +4766,40 @@ dependencies = [ "z32", ] +[[package]] +name = "iroh-services" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e90cec7e16813b8b29eba1d955cccc2505e29c34ceed76722d3857bf2c9072" +dependencies = [ + "anyhow", + "built", + "bytes", + "derive_more", + "ed25519-dalek 3.0.0-pre.1", + "futures-buffered", + "getrandom 0.3.4", + "iroh", + "iroh-metrics 0.38.2", + "iroh-tickets", + "irpc", + "irpc-iroh", + "n0-error", + "n0-future", + "portmapper", + "postcard", + "rand 0.9.2", + "rcan", + "serde", + "serde_json", + "strum 0.27.2", + "thiserror 2.0.18", + "tokio", + "tracing", + "tracing-subscriber", + "uuid", +] + [[package]] name = "iroh-tickets" version = "0.3.0" @@ -7342,6 +7368,7 @@ dependencies = [ "psyche-metrics", "psyche-network", "pyo3", + "reqwest 0.12.28", "serde", "serde_json", "tikv-jemallocator", @@ -7416,9 +7443,10 @@ dependencies = [ "get_if_addrs", "iroh", "iroh-blobs", + "iroh-fake-store", "iroh-gossip", - "iroh-n0des", "iroh-relay", + "iroh-services", "n0-future", "postcard", "psyche-core", @@ -7708,7 +7736,7 @@ dependencies = [ [[package]] name = "pyo3-tch" version = "0.22.0" -source = "git+https://github.com/jquesnelle/tch-rs.git?rev=dda507e05a776547a112b6854d1e611684f8c729#dda507e05a776547a112b6854d1e611684f8c729" +source = "git+https://github.com/NousResearch/tch-rs.git?rev=dda507e05a776547a112b6854d1e611684f8c729#dda507e05a776547a112b6854d1e611684f8c729" dependencies = [ "pyo3", "tch", @@ -8726,6 +8754,10 @@ name = "semver" version = "1.0.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" +dependencies = [ + "serde", + "serde_core", +] [[package]] name = "send_wrapper" @@ -12060,7 +12092,7 @@ dependencies = [ [[package]] name = "tch" version = "0.22.0" -source = "git+https://github.com/jquesnelle/tch-rs.git?rev=dda507e05a776547a112b6854d1e611684f8c729#dda507e05a776547a112b6854d1e611684f8c729" +source = "git+https://github.com/NousResearch/tch-rs.git?rev=dda507e05a776547a112b6854d1e611684f8c729#dda507e05a776547a112b6854d1e611684f8c729" dependencies = [ "half", "lazy_static", @@ -12558,7 +12590,7 @@ dependencies = [ [[package]] name = "torch-sys" version = "0.22.0" -source = "git+https://github.com/jquesnelle/tch-rs.git?rev=dda507e05a776547a112b6854d1e611684f8c729#dda507e05a776547a112b6854d1e611684f8c729" +source = "git+https://github.com/NousResearch/tch-rs.git?rev=dda507e05a776547a112b6854d1e611684f8c729#dda507e05a776547a112b6854d1e611684f8c729" dependencies = [ "anyhow", "cc", diff --git a/Cargo.toml b/Cargo.toml index fbf803459..a63967dbe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,9 +82,9 @@ indicatif = "0.17.5" tokenizers = { version = "0.20.0", default-features = false, features = [ "onig", ] } -tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "dda507e05a776547a112b6854d1e611684f8c729" } -torch-sys = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "dda507e05a776547a112b6854d1e611684f8c729" } -pyo3-tch = { git = "https://github.com/jquesnelle/tch-rs.git", rev = "dda507e05a776547a112b6854d1e611684f8c729" } +tch = { git = "https://github.com/NousResearch/tch-rs.git", rev = "dda507e05a776547a112b6854d1e611684f8c729" } +torch-sys = { git = "https://github.com/NousResearch/tch-rs.git", rev = "dda507e05a776547a112b6854d1e611684f8c729" } +pyo3-tch = { git = "https://github.com/NousResearch/tch-rs.git", rev = "dda507e05a776547a112b6854d1e611684f8c729" } #tch = { path = "../tch-rs" } #torch-sys = { path = "../tch-rs/torch-sys" } #pyo3-tch = { path = "../tch-rs/pyo3-tch" } diff --git a/architectures/decentralized/solana-client/src/app.rs b/architectures/decentralized/solana-client/src/app.rs index fbc4b1f91..c541d5e74 100644 --- a/architectures/decentralized/solana-client/src/app.rs +++ b/architectures/decentralized/solana-client/src/app.rs @@ -47,7 +47,6 @@ pub struct App { update_tui_interval: Interval, tx_tui_state: Option>, authorizer: Option, - claimer: Option, metrics: Arc, allowlist: allowlist::AllowDynamic, p2p: NC, @@ -62,7 +61,6 @@ pub struct AppParams { pub backup_clusters: Vec, pub tx_tui_state: Option>, pub authorizer: Option, - pub claimer: Option, pub train_args: TrainArgs, } @@ -74,7 +72,6 @@ pub async fn build_app( backup_clusters, tx_tui_state, authorizer, - claimer, train_args: p, }: AppParams, ) -> Result { @@ -154,7 +151,6 @@ pub async fn build_app( tx_tui_state, update_tui_interval: interval(Duration::from_millis(150)), authorizer, - claimer, allowlist, metrics, p2p, @@ -242,9 +238,8 @@ impl App { .join_run( coordinator_instance_pubkey, coordinator_account, - self.authorizer, psyche_core::NodeIdentity::new(signer.to_bytes(), *p2p_identity.as_bytes()), - self.claimer, + self.authorizer, ) .await?; info!( @@ -360,9 +355,8 @@ impl App { .join_run( coordinator_instance_pubkey, coordinator_account, - self.authorizer, id, - self.claimer, + self.authorizer, ) .await?; info!( diff --git a/architectures/decentralized/solana-client/src/main.rs b/architectures/decentralized/solana-client/src/main.rs index 7bbf07c28..73eb7f019 100644 --- a/architectures/decentralized/solana-client/src/main.rs +++ b/architectures/decentralized/solana-client/src/main.rs @@ -79,11 +79,8 @@ enum Commands { rpc_3: String, #[clap(long, env, default_value_t = String::from(""))] ws_rpc_3: String, - #[clap(long, env)] authorizer: Option, - #[clap(long, env)] - claimer: Option, }, Predownload { #[clap(flatten)] @@ -173,7 +170,6 @@ async fn async_main() -> Result<()> { rpc_3, ws_rpc_3, authorizer, - claimer, } => { psyche_client::prepare_environment(); info!( @@ -240,7 +236,6 @@ async fn async_main() -> Result<()> { cluster: cluster.into(), backup_clusters, authorizer, - claimer, train_args: args, }) .await?; diff --git a/architectures/decentralized/solana-common/src/backend.rs b/architectures/decentralized/solana-common/src/backend.rs index 6794113a5..bd66d06b7 100644 --- a/architectures/decentralized/solana-common/src/backend.rs +++ b/architectures/decentralized/solana-common/src/backend.rs @@ -260,9 +260,8 @@ impl SolanaBackend { &self, coordinator_instance: Pubkey, coordinator_account: Pubkey, - authorizer: Option, id: psyche_core::NodeIdentity, - claimer: Option, + authorizer: Option, ) -> Result { let coordinator_instance_state = self.get_coordinator_instance(&coordinator_instance).await?; @@ -273,7 +272,6 @@ impl SolanaBackend { &coordinator_account, &authorization, id, - &claimer.unwrap_or(self.get_payer()), ); // TODO (vbrunet) - what was the point of doing specifically a timeout here but not the other TXs ? // We timeout the transaction at 5s max, since internally send() polls Solana until the diff --git a/architectures/decentralized/solana-common/src/instructions.rs b/architectures/decentralized/solana-common/src/instructions.rs index 45ee12fd8..c6c0e28cc 100644 --- a/architectures/decentralized/solana-common/src/instructions.rs +++ b/architectures/decentralized/solana-common/src/instructions.rs @@ -102,7 +102,6 @@ pub fn coordinator_join_run( coordinator_account: &Pubkey, authorization: &Pubkey, client_id: psyche_core::NodeIdentity, - claimer: &Pubkey, ) -> Instruction { anchor_instruction( psyche_solana_coordinator::ID, @@ -113,10 +112,7 @@ pub fn coordinator_join_run( coordinator_account: *coordinator_account, }, psyche_solana_coordinator::instruction::JoinRun { - params: psyche_solana_coordinator::logic::JoinRunParams { - client_id, - claimer: *claimer, - }, + params: psyche_solana_coordinator::logic::JoinRunParams { client_id }, }, ) } @@ -318,31 +314,31 @@ pub fn treasurer_participant_create( payer: *payer, run, participant, + user: *user, system_program: system_program::ID, }, psyche_solana_treasurer::instruction::ParticipantCreate { - params: psyche_solana_treasurer::logic::ParticipantCreateParams { user: *user }, + params: psyche_solana_treasurer::logic::ParticipantCreateParams {}, }, ) } pub fn treasurer_participant_claim( treasurer_index: u64, - claimer: &Pubkey, - claimer_collateral: &Pubkey, collateral_mint: &Pubkey, coordinator_account: &Pubkey, user: &Pubkey, claim_earned_points: u64, ) -> Instruction { + let user_collateral = associated_token::get_associated_token_address(user, collateral_mint); let run = psyche_solana_treasurer::find_run(treasurer_index); let run_collateral = associated_token::get_associated_token_address(&run, collateral_mint); let participant = psyche_solana_treasurer::find_participant(&run, user); anchor_instruction( psyche_solana_treasurer::ID, psyche_solana_treasurer::accounts::ParticipantClaimAccounts { - claimer: *claimer, - claimer_collateral: *claimer_collateral, + user: *user, + user_collateral, run, run_collateral, participant, @@ -351,7 +347,6 @@ pub fn treasurer_participant_claim( }, psyche_solana_treasurer::instruction::ParticipantClaim { params: psyche_solana_treasurer::logic::ParticipantClaimParams { - user: *user, claim_earned_points, }, }, diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs index 49fbe7022..e47ebffb4 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/client.rs @@ -19,15 +19,13 @@ use ts_rs::TS; AnchorDeserialize, Serialize, Deserialize, - PartialEq, TS, )] #[repr(C)] #[ts(rename = "SolanaClient")] pub struct Client { pub id: NodeIdentity, - #[ts(type = "number[]")] - pub claimer: Pubkey, + pub _unused: [u8; 8], pub earned: u64, pub slashed: u64, pub active: u64, @@ -37,7 +35,6 @@ impl Debug for Client { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Client") .field("id", &self.id) - .field("claimer", &self.claimer) .field("earned", &self.earned) .field("slashed", &self.slashed) .field("active", &self.active) diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs index b053373d6..b4b1da544 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/clients_state.rs @@ -21,7 +21,6 @@ use crate::program_error::ProgramError; AnchorDeserialize, Serialize, Deserialize, - PartialEq, TS, )] #[repr(C)] @@ -41,7 +40,6 @@ pub struct ClientsState { AnchorDeserialize, Serialize, Deserialize, - PartialEq, TS, )] #[repr(C)] diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs index a438bf55d..c4010acb7 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/instance_state.rs @@ -56,7 +56,6 @@ impl RunMetadata {} Zeroable, AnchorSerialize, AnchorDeserialize, - PartialEq, Serialize, Deserialize, TS, @@ -333,11 +332,7 @@ impl CoordinatorInstanceState { Ok(()) } - pub fn join_run( - &mut self, - id: NodeIdentity, - claimer: Pubkey, - ) -> Result<()> { + pub fn join_run(&mut self, id: NodeIdentity) -> Result<()> { let existing = match self.clients_state.clients.iter_mut().find(|x| x.id == id) { Some(client) => { @@ -362,10 +357,10 @@ impl CoordinatorInstanceState { let new_client = Client { id, - claimer, earned: 0, slashed: 0, active: self.clients_state.next_active, + _unused: Default::default(), }; if self.clients_state.clients.push(new_client).is_err() { diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs index 37f44c49c..9cb58b948 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/lib.rs @@ -7,10 +7,7 @@ mod program_error; use anchor_lang::prelude::*; pub use client::Client; -pub use clients_state::ClientsEpochRates; -pub use clients_state::ClientsState; pub use instance_state::CoordinatorInstanceState; -pub use instance_state::RunMetadata; use logic::*; pub use program_error::ProgramError; use psyche_coordinator::Committee; @@ -30,6 +27,8 @@ use serde::Deserialize; use serde::Serialize; use ts_rs::TS; +pub use crate::instance_state::RunMetadata; + declare_id!("4SHugWqSXwKE5fqDchkJcPEqnoZE22VYKtSTVm7axbT7"); pub const SOLANA_MAX_NUM_PENDING_CLIENTS: usize = SOLANA_MAX_NUM_CLIENTS; @@ -126,7 +125,7 @@ pub fn coordinator_account_from_bytes_mut( #[account(zero_copy)] #[repr(C)] -#[derive(Serialize, Deserialize, PartialEq, TS)] +#[derive(Serialize, Deserialize, TS)] pub struct CoordinatorAccount { pub version: u64, pub state: CoordinatorInstanceState, @@ -134,7 +133,7 @@ pub struct CoordinatorAccount { } impl CoordinatorAccount { - pub const VERSION: u64 = 2; + pub const VERSION: u64 = 1; pub fn space_with_discriminator() -> usize { CoordinatorAccount::DISCRIMINATOR.len() diff --git a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs index a973eb5e2..201a5f6e4 100644 --- a/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs +++ b/architectures/decentralized/solana-coordinator/programs/solana-coordinator/src/logic/join_run.rs @@ -44,7 +44,6 @@ pub struct JoinRunAccounts<'info> { #[derive(AnchorSerialize, AnchorDeserialize, Clone)] pub struct JoinRunParams { pub client_id: NodeIdentity, - pub claimer: Pubkey, } pub fn join_run_processor( @@ -56,5 +55,5 @@ pub fn join_run_processor( } let mut account = context.accounts.coordinator_account.load_mut()?; account.increment_nonce(); - account.state.join_run(params.client_id, params.claimer) + account.state.join_run(params.client_id) } diff --git a/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs b/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs index cb925ebaf..765ad81a8 100644 --- a/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs +++ b/architectures/decentralized/solana-tooling/src/process_coordinator_instructions.rs @@ -114,7 +114,6 @@ pub async fn process_update( Ok(()) } -#[allow(clippy::too_many_arguments)] pub async fn process_coordinator_join_run( endpoint: &mut ToolboxEndpoint, payer: &Keypair, @@ -123,7 +122,6 @@ pub async fn process_coordinator_join_run( coordinator_instance: &Pubkey, coordinator_account: &Pubkey, client_id: NodeIdentity, - claimer: &Pubkey, ) -> Result<()> { let accounts = JoinRunAccounts { user: user.pubkey(), @@ -134,10 +132,7 @@ pub async fn process_coordinator_join_run( let instruction = Instruction { accounts: accounts.to_account_metas(None), data: JoinRun { - params: JoinRunParams { - client_id, - claimer: *claimer, - }, + params: JoinRunParams { client_id }, } .data(), program_id: psyche_solana_coordinator::ID, @@ -172,7 +167,7 @@ pub async fn process_coordinator_set_paused( Ok(()) } -pub async fn process_coordinator_set_future_epoch_rates( +pub async fn process_coordiantor_set_future_epoch_rates( endpoint: &mut ToolboxEndpoint, payer: &Keypair, authority: &Keypair, diff --git a/architectures/decentralized/solana-tooling/src/process_treasurer_instructions.rs b/architectures/decentralized/solana-tooling/src/process_treasurer_instructions.rs index 61f782e13..ed8da52fe 100644 --- a/architectures/decentralized/solana-tooling/src/process_treasurer_instructions.rs +++ b/architectures/decentralized/solana-tooling/src/process_treasurer_instructions.rs @@ -89,12 +89,13 @@ pub async fn process_treasurer_run_update( pub async fn process_treasurer_participant_create( endpoint: &mut ToolboxEndpoint, payer: &Keypair, + user: &Keypair, run: &Pubkey, - user: &Pubkey, ) -> Result<()> { - let participant = find_participant(run, user); + let participant = find_participant(run, &user.pubkey()); let accounts = ParticipantCreateAccounts { payer: payer.pubkey(), + user: user.pubkey(), run: *run, participant, system_program: system_program::ID, @@ -102,13 +103,13 @@ pub async fn process_treasurer_participant_create( let instruction = Instruction { accounts: accounts.to_account_metas(None), data: ParticipantCreate { - params: ParticipantCreateParams { user: *user }, + params: ParticipantCreateParams {}, } .data(), program_id: psyche_solana_treasurer::ID, }; endpoint - .process_instruction_with_signers(payer, instruction, &[]) + .process_instruction_with_signers(payer, instruction, &[user]) .await?; Ok(()) } @@ -117,11 +118,10 @@ pub async fn process_treasurer_participant_create( pub async fn process_treasurer_participant_claim( endpoint: &mut ToolboxEndpoint, payer: &Keypair, - claimer: &Keypair, - claimer_collateral: &Pubkey, + user: &Keypair, + user_collateral: &Pubkey, collateral_mint: &Pubkey, run: &Pubkey, - user: &Pubkey, coordinator_account: &Pubkey, claim_earned_points: u64, ) -> Result<()> { @@ -129,10 +129,10 @@ pub async fn process_treasurer_participant_claim( run, collateral_mint, ); - let participant = find_participant(run, user); + let participant = find_participant(run, &user.pubkey()); let accounts = ParticipantClaimAccounts { - claimer: claimer.pubkey(), - claimer_collateral: *claimer_collateral, + user: user.pubkey(), + user_collateral: *user_collateral, run: *run, run_collateral, coordinator_account: *coordinator_account, @@ -143,7 +143,6 @@ pub async fn process_treasurer_participant_claim( accounts: accounts.to_account_metas(None), data: ParticipantClaim { params: ParticipantClaimParams { - user: *user, claim_earned_points, }, } @@ -151,7 +150,7 @@ pub async fn process_treasurer_participant_claim( program_id: psyche_solana_treasurer::ID, }; endpoint - .process_instruction_with_signers(payer, instruction, &[claimer]) + .process_instruction_with_signers(payer, instruction, &[user]) .await?; Ok(()) } diff --git a/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account-v0.so b/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account-v0.so new file mode 100644 index 000000000..665fd7ef2 Binary files /dev/null and b/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account-v0.so differ diff --git a/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account-v1.so b/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account-v1.so new file mode 100644 index 000000000..89ace7a38 Binary files /dev/null and b/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account-v1.so differ diff --git a/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account.so b/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account.so deleted file mode 100644 index b576d4f55..000000000 Binary files a/architectures/decentralized/solana-tooling/tests/fixtures/coordinator-account.so and /dev/null differ diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs index 70f1ca6c5..8257cbbd1 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_data_layout.rs @@ -1,204 +1,159 @@ -use anchor_lang::Discriminator; -use psyche_coordinator::ClientState; -use psyche_coordinator::Coordinator; -use psyche_coordinator::CoordinatorConfig; -use psyche_coordinator::CoordinatorEpochState; -use psyche_coordinator::CoordinatorProgress; use psyche_coordinator::Round; use psyche_coordinator::RunState; -use psyche_coordinator::Witness; -use psyche_coordinator::WitnessProof; use psyche_coordinator::model::Checkpoint; -use psyche_coordinator::model::HttpLLMTrainingDataLocation; use psyche_coordinator::model::HttpTrainingDataLocation; -use psyche_coordinator::model::HubRepo; -use psyche_coordinator::model::LLM; use psyche_coordinator::model::LLMArchitecture; use psyche_coordinator::model::LLMTrainingDataLocation; use psyche_coordinator::model::LLMTrainingDataType; use psyche_coordinator::model::Model; -use psyche_core::Bloom; use psyche_core::CosineLR; use psyche_core::FixedString; use psyche_core::FixedVec; use psyche_core::LearningRateSchedule; -use psyche_core::MerkleRoot; -use psyche_core::NodeIdentity; use psyche_core::OptimizerDefinition; use psyche_core::Shuffle; use psyche_core::SmallBoolean; use psyche_core::TokenSize; -use psyche_solana_coordinator::ClientsEpochRates; -use psyche_solana_coordinator::ClientsState; use psyche_solana_coordinator::CoordinatorAccount; -use psyche_solana_coordinator::CoordinatorInstanceState; -use psyche_solana_coordinator::RunMetadata; use psyche_solana_coordinator::coordinator_account_from_bytes; -use solana_sdk::pubkey::Pubkey; #[tokio::test] pub async fn run() { - let coordinator_account_from_reference = CoordinatorAccount { - version: CoordinatorAccount::VERSION, - state: CoordinatorInstanceState { - metadata: RunMetadata { - name: fixed_str("my-name"), - description: fixed_str("my-description"), - num_parameters: 1100000000, - vocab_size: 4242_32768, - }, - coordinator: Coordinator { - run_id: fixed_str("my-run-id"), - run_state: RunState::RoundTrain, - model: Model::LLM(LLM { - max_seq_len: 2048, - cold_start_warmup_steps: 999, - architecture: LLMArchitecture::HfAuto, - checkpoint: Checkpoint::Hub(HubRepo { - repo_id: fixed_str("my-repo-id"), - revision: Some(fixed_str("my-revision")), - }), - data_type: LLMTrainingDataType::Finetuning, - data_location: LLMTrainingDataLocation::Http( - HttpLLMTrainingDataLocation { - location: HttpTrainingDataLocation::Gcp { - bucket_name: fixed_str("my-bucket-name"), - filter_directory: fixed_str( - "my-filter-directory", - ), - }, - token_size_in_bytes: TokenSize::FourBytes, - shuffle: Shuffle::Seeded([55; 32]), - }, - ), - lr_schedule: LearningRateSchedule::Cosine(CosineLR::new( - 0.0004, 250, 0.666, 25000, 0.00004, - )), - optimizer: OptimizerDefinition::Distro { - clip_grad_norm: Some(1.0), - weight_decay: Some(42.42), - compression_decay: 0.999, - compression_topk: 2, - compression_chunk: 64, - quantize_1bit: true, - }, - }), - config: CoordinatorConfig { - warmup_time: 15, - cooldown_time: 30, - max_round_train_time: 15, - round_witness_time: 1, - global_batch_size_warmup_tokens: 34, - epoch_time: 60, - total_steps: 25000, - init_min_clients: 1, - min_clients: 1, - witness_nodes: 88, - global_batch_size_start: 2048, - global_batch_size_end: 2048, - verification_percent: 42, - waiting_for_members_extra_time: 3, - }, - progress: CoordinatorProgress { - epoch: 8989, - step: 777, - epoch_start_data_index: 574842891, + let coordinator_bytes = + include_bytes!("../fixtures/coordinator-account-v1.so").to_vec(); + let coordinator_account = + coordinator_account_from_bytes(&coordinator_bytes).unwrap(); + eprintln!("coordinator_account.state:{:#?}", coordinator_account.state); + // Check the general layout for corruption + assert_eq!(coordinator_account.version, CoordinatorAccount::VERSION); + assert_eq!(coordinator_account.nonce, 2); + let state = coordinator_account.state; + assert_eq!(state.is_warmup_first_tick, SmallBoolean::FALSE); + assert_eq!(state.is_training_first_tick, SmallBoolean::FALSE); + assert_eq!(state.client_version, fixed_str("test")); + // Check infos on the coordinator run metadata + let metadata = state.metadata; + assert_eq!(metadata.name, fixed_str("")); + assert_eq!(metadata.description, fixed_str("")); + assert_eq!(metadata.num_parameters, 1100000000); + assert_eq!(metadata.vocab_size, 32768); + // Check on the on the coordinator datastructure + let coordinator = state.coordinator; + assert_eq!(coordinator.run_id, fixed_str("test")); + assert_eq!(coordinator.run_state, RunState::Uninitialized); + assert_eq!(coordinator.run_state_start_unix_timestamp, 0); + assert_eq!(coordinator.pending_pause, SmallBoolean::FALSE); + // Coordinator model + match coordinator.model { + Model::LLM(llm) => { + assert_eq!(llm.max_seq_len, 2048); + assert_eq!(llm.cold_start_warmup_steps, 0); + assert_eq!(llm.architecture, LLMArchitecture::HfLlama); + match llm.checkpoint { + Checkpoint::Hub(hub) => { + assert_eq!( + hub.repo_id, + fixed_str("emozilla/llama2-1.1b-gqa-init") + ); + assert_eq!(hub.revision, None); }, - epoch_state: CoordinatorEpochState { - rounds: [Round { - witnesses: fixed_vec_repeat(Witness { - proof: WitnessProof { - position: 42, - index: 32, - witness: SmallBoolean::TRUE, - }, - participant_bloom: Bloom::new(4, &[7; 8]), - broadcast_bloom: Bloom::new(4, &[6; 8]), - broadcast_merkle: MerkleRoot { inner: [77; 32] }, - }), - data_index: 893322, - random_seed: 871, - height: 1002, - clients_len: 21, - tie_breaker_tasks: 34, - }; 4], - clients: fixed_vec_repeat(psyche_coordinator::Client { - id: NodeIdentity::from_single_key([77; 32]), - state: ClientState::Dropped, - exited_height: 42, - }), - exited_clients: fixed_vec_repeat( - psyche_coordinator::Client { - id: NodeIdentity::from_single_key([99; 32]), - state: ClientState::Dropped, - exited_height: 48, + _ => panic!("Expected Hub checkpoint"), + }; + assert_eq!(llm.data_type, LLMTrainingDataType::Pretraining); + match llm.data_location { + LLMTrainingDataLocation::Http(http) => { + match http.location { + HttpTrainingDataLocation::Gcp { + bucket_name, + filter_directory, + } => { + assert_eq!( + bucket_name, + fixed_str("nous-pretraining-public-us") + ); + assert_eq!( + filter_directory, + fixed_str("fineweb-edu-tokenized-llama2") + ); }, - ), - rounds_head: 77, - start_step: 88, - last_step: 99, - start_timestamp: 33, - first_round: SmallBoolean::TRUE, - cold_start_epoch: SmallBoolean::TRUE, + _ => panic!("Expected Gcp data location"), + }; + assert_eq!(http.token_size_in_bytes, TokenSize::TwoBytes); + assert_eq!(http.shuffle, Shuffle::DontShuffle); }, - pending_pause: SmallBoolean::TRUE, - run_state_start_unix_timestamp: 55_55_555_555, - }, - clients_state: ClientsState { - clients: fixed_vec_repeat(psyche_solana_coordinator::Client { - id: NodeIdentity::from_single_key([33; 32]), - active: 63473857845, - earned: 424242, - slashed: 7878, - claimer: Pubkey::from([88; 32]), - }), - next_active: 63473857845, - current_epoch_rates: ClientsEpochRates { - earning_rate_total_shared: 727272, - slashing_rate_per_client: 7272, + _ => panic!("Expected Http data location"), + }; + match llm.lr_schedule { + LearningRateSchedule::Cosine(learning_rate) => { + assert_eq!( + learning_rate, + CosineLR::new(0.0004, 250, 0.0, 25000, 0.00004) + ); }, - future_epoch_rates: ClientsEpochRates { - earning_rate_total_shared: 424242, - slashing_rate_per_client: 4242, + _ => panic!("Expected Constant LR schedule"), + }; + match llm.optimizer { + OptimizerDefinition::Distro { + clip_grad_norm, + weight_decay, + compression_decay, + compression_topk, + compression_chunk, + quantize_1bit, + } => { + assert_eq!(clip_grad_norm, Some(1.0)); + assert_eq!(weight_decay, None); + assert_eq!(compression_decay, 0.999); + assert_eq!(compression_topk, 2); + assert_eq!(compression_chunk, 64); + assert_eq!(quantize_1bit, false); }, - }, - is_warmup_first_tick: SmallBoolean::TRUE, - is_training_first_tick: SmallBoolean::TRUE, - client_version: fixed_str("my-client-version"), + _ => panic!("Expected Distro optimizer"), + } }, - nonce: 78787878, }; - /* - std::fs::write( - "./tests/fixtures/coordinator-account.so", - bytemuck::bytes_of(&coordinator_account_from_reference), - ) - .unwrap(); - */ - let coordinator_account_snapshot_bytes = &[ - CoordinatorAccount::DISCRIMINATOR, - include_bytes!("../fixtures/coordinator-account.so"), - ] - .concat(); - let coordinator_account_from_snapshot = - coordinator_account_from_bytes(coordinator_account_snapshot_bytes) - .unwrap(); - assert!( - &coordinator_account_from_reference - == coordinator_account_from_snapshot - ); + // Coordinator config + assert_eq!(coordinator.config.warmup_time, 15); + assert_eq!(coordinator.config.cooldown_time, 30); + assert_eq!(coordinator.config.max_round_train_time, 15); + assert_eq!(coordinator.config.round_witness_time, 1); + assert_eq!(coordinator.config.global_batch_size_warmup_tokens, 0); + assert_eq!(coordinator.config.epoch_time, 60); + assert_eq!(coordinator.config.total_steps, 25000); + assert_eq!(coordinator.config.init_min_clients, 1); + assert_eq!(coordinator.config.min_clients, 1); + assert_eq!(coordinator.config.witness_nodes, 0); + assert_eq!(coordinator.config.global_batch_size_start, 2048); + assert_eq!(coordinator.config.global_batch_size_end, 2048); + assert_eq!(coordinator.config.verification_percent, 0); + assert_eq!(coordinator.config.waiting_for_members_extra_time, 3); + // Coordinator progress + assert_eq!(coordinator.progress.epoch, 0); + assert_eq!(coordinator.progress.step, 0); + assert_eq!(coordinator.progress.epoch_start_data_index, 0); + // Coordinator epoch state + let epoch_state = coordinator.epoch_state; + assert_eq!(epoch_state.rounds, [Round::default(); 4]); + assert_eq!(epoch_state.clients, FixedVec::default()); + assert_eq!(epoch_state.exited_clients, FixedVec::default()); + assert_eq!(epoch_state.rounds_head, 0); + assert_eq!(epoch_state.start_step, 0); + assert_eq!(epoch_state.last_step, 0); + assert_eq!(epoch_state.start_timestamp, 0); + assert_eq!(epoch_state.first_round, SmallBoolean::FALSE); + assert_eq!(epoch_state.cold_start_epoch, SmallBoolean::FALSE); + // Coordinator clients state + let clients_state = state.clients_state; + assert_eq!(clients_state.clients.len(), 0); + assert_eq!(clients_state.next_active, 0); + let current_epoch_rates = clients_state.current_epoch_rates; + assert_eq!(current_epoch_rates.earning_rate_total_shared, 0); + assert_eq!(current_epoch_rates.slashing_rate_per_client, 0); + let future_epoch_rates = clients_state.future_epoch_rates; + assert_eq!(future_epoch_rates.earning_rate_total_shared, 1000000); + assert_eq!(future_epoch_rates.slashing_rate_per_client, 0); } fn fixed_str(value: &str) -> FixedString { FixedString::from_str_truncated(value) } - -fn fixed_vec_repeat( - value: T, -) -> FixedVec { - let mut vec = FixedVec::new(); - for _ in 0..N { - vec.push(value).unwrap(); - } - vec -} diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs index 0c03fdf4d..c078e084c 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_full_round.rs @@ -28,7 +28,6 @@ use psyche_solana_tooling::process_coordinator_instructions::process_coordinator use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_tick; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_witness; use psyche_solana_tooling::process_coordinator_instructions::process_update; -use solana_sdk::pubkey::Pubkey; use solana_sdk::signature::Keypair; use solana_sdk::signer::Signer; @@ -46,7 +45,6 @@ pub async fn run() { // Run constants let main_authority = Keypair::new(); let join_authority = Keypair::new(); - let claimer = Pubkey::new_unique(); let client = Keypair::new(); let ticker = Keypair::new(); let warmup_time = 10; @@ -159,7 +157,8 @@ pub async fn run() { ); // Generate the client key - let client_id = NodeIdentity::from_single_key(client.pubkey().to_bytes()); + let client_id = + NodeIdentity::new(client.pubkey().to_bytes(), Default::default()); // Add client to whitelist let authorization = process_authorizer_authorization_create( @@ -190,7 +189,6 @@ pub async fn run() { &coordinator_instance, &coordinator_account, client_id, - &claimer, ) .await .unwrap_err(); @@ -204,7 +202,6 @@ pub async fn run() { &coordinator_instance, &coordinator_account, client_id, - &claimer, ) .await .unwrap(); @@ -253,7 +250,6 @@ pub async fn run() { &coordinator_instance, &coordinator_account, client_id, - &claimer, ) .await .unwrap(); diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs index f254ba8cc..db33b9304 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_coordinator_rewards.rs @@ -22,14 +22,13 @@ use psyche_solana_tooling::create_memnet_endpoint::create_memnet_endpoint; use psyche_solana_tooling::get_accounts::get_coordinator_account_state; use psyche_solana_tooling::process_authorizer_instructions::process_authorizer_authorization_create; use psyche_solana_tooling::process_authorizer_instructions::process_authorizer_authorization_grantor_update; +use psyche_solana_tooling::process_coordinator_instructions::process_coordiantor_set_future_epoch_rates; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_init; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_join_run; -use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_set_future_epoch_rates; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_set_paused; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_tick; use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_witness; use psyche_solana_tooling::process_coordinator_instructions::process_update; -use solana_sdk::pubkey::Pubkey; use solana_sdk::signature::Keypair; use solana_sdk::signer::Signer; @@ -47,7 +46,6 @@ pub async fn run() { // Run constants let main_authority = Keypair::new(); let join_authority = Keypair::new(); - let claimer = Pubkey::new_unique(); let mut clients = vec![]; for _ in 0..240 { clients.push(Keypair::new()); @@ -132,7 +130,7 @@ pub async fn run() { .unwrap(); // Set the reward rate for the epoch - process_coordinator_set_future_epoch_rates( + process_coordiantor_set_future_epoch_rates( &mut endpoint, &payer, &main_authority, @@ -185,8 +183,7 @@ pub async fn run() { &authorization, &coordinator_instance, &coordinator_account, - NodeIdentity::from_single_key(client.pubkey().to_bytes()), - &claimer, + NodeIdentity::new(client.pubkey().to_bytes(), Default::default()), ) .await .unwrap(); diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_claim.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_claim.rs index 92607e1d2..ebcf43156 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_claim.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_create_claim.rs @@ -1,18 +1,9 @@ -use psyche_coordinator::CoordinatorConfig; -use psyche_core::NodeIdentity; -use psyche_solana_authorizer::logic::AuthorizationGrantorUpdateParams; use psyche_solana_coordinator::CoordinatorAccount; -use psyche_solana_coordinator::logic::JOIN_RUN_AUTHORIZATION_SCOPE; use psyche_solana_tooling::create_memnet_endpoint::create_memnet_endpoint; -use psyche_solana_tooling::process_authorizer_instructions::process_authorizer_authorization_create; -use psyche_solana_tooling::process_authorizer_instructions::process_authorizer_authorization_grantor_update; -use psyche_solana_tooling::process_coordinator_instructions::process_coordinator_join_run; use psyche_solana_tooling::process_treasurer_instructions::process_treasurer_participant_claim; use psyche_solana_tooling::process_treasurer_instructions::process_treasurer_participant_create; use psyche_solana_tooling::process_treasurer_instructions::process_treasurer_run_create; -use psyche_solana_tooling::process_treasurer_instructions::process_treasurer_run_update; use psyche_solana_treasurer::logic::RunCreateParams; -use psyche_solana_treasurer::logic::RunUpdateParams; use solana_sdk::pubkey::Pubkey; use solana_sdk::signature::Keypair; use solana_sdk::signer::Signer; @@ -35,8 +26,6 @@ pub async fn run() { let join_authority = Keypair::new(); let client1 = Keypair::new(); let client2 = Keypair::new(); - let claimer1 = Keypair::new(); - let claimer2 = Keypair::new(); // Prepare the collateral mints let collateral1_mint = endpoint @@ -67,7 +56,7 @@ pub async fn run() { .unwrap(); // Create the runs (it should init the underlying coordinators) - let (run1, coordinator1_instance) = process_treasurer_run_create( + let (run1, _) = process_treasurer_run_create( &mut endpoint, &payer, &collateral1_mint, @@ -82,7 +71,7 @@ pub async fn run() { ) .await .unwrap(); - let (run2, coordinator2_instance) = process_treasurer_run_create( + let (run2, _) = process_treasurer_run_create( &mut endpoint, &payer, &collateral2_mint, @@ -98,64 +87,6 @@ pub async fn run() { .await .unwrap(); - // Update the runs' coordinator configs - let dummy_config = CoordinatorConfig { - warmup_time: 10, - cooldown_time: 20, - max_round_train_time: 888, - round_witness_time: 42, - min_clients: 1, - init_min_clients: 1, - global_batch_size_start: 1, - global_batch_size_end: 42, - global_batch_size_warmup_tokens: 0, - verification_percent: 0, - witness_nodes: 0, - epoch_time: 999, - total_steps: 100, - waiting_for_members_extra_time: 3, - }; - process_treasurer_run_update( - &mut endpoint, - &payer, - &main_authority, - &run1, - &coordinator1_instance, - &coordinator1_account, - RunUpdateParams { - metadata: None, - config: Some(dummy_config), - model: None, - progress: None, - epoch_earning_rate_total_shared: None, - epoch_slashing_rate_per_client: None, - paused: None, - client_version: None, - }, - ) - .await - .unwrap(); - process_treasurer_run_update( - &mut endpoint, - &payer, - &main_authority, - &run2, - &coordinator2_instance, - &coordinator2_account, - RunUpdateParams { - metadata: None, - config: Some(dummy_config), - model: None, - progress: None, - epoch_earning_rate_total_shared: None, - epoch_slashing_rate_per_client: None, - paused: None, - client_version: None, - }, - ) - .await - .unwrap(); - // Get the run's collateral vaults let run1_collateral1 = endpoint .process_spl_associated_token_account_get_or_init( @@ -196,35 +127,35 @@ pub async fn run() { .await .unwrap(); - // Create the claimers ATA - let claimer1_collateral1 = endpoint + // Create the clients ATA + let client1_collateral1 = endpoint .process_spl_associated_token_account_get_or_init( &payer, - &claimer1.pubkey(), + &client1.pubkey(), &collateral1_mint, ) .await .unwrap(); - let claimer1_collateral2 = endpoint + let client1_collateral2 = endpoint .process_spl_associated_token_account_get_or_init( &payer, - &claimer1.pubkey(), + &client1.pubkey(), &collateral2_mint, ) .await .unwrap(); - let claimer2_collateral1 = endpoint + let client2_collateral1 = endpoint .process_spl_associated_token_account_get_or_init( &payer, - &claimer2.pubkey(), + &client2.pubkey(), &collateral1_mint, ) .await .unwrap(); - let claimer2_collateral2 = endpoint + let client2_collateral2 = endpoint .process_spl_associated_token_account_get_or_init( &payer, - &claimer2.pubkey(), + &client2.pubkey(), &collateral2_mint, ) .await @@ -234,117 +165,32 @@ pub async fn run() { process_treasurer_participant_create( &mut endpoint, &payer, + &client1, &run1, - &client1.pubkey(), ) .await .unwrap(); process_treasurer_participant_create( &mut endpoint, &payer, + &client1, &run2, - &client1.pubkey(), ) .await .unwrap(); process_treasurer_participant_create( &mut endpoint, &payer, + &client2, &run1, - &client2.pubkey(), ) .await .unwrap(); process_treasurer_participant_create( - &mut endpoint, - &payer, - &run2, - &client2.pubkey(), - ) - .await - .unwrap(); - - // Try claiming before joining, it should fail - process_treasurer_participant_claim( - &mut endpoint, - &payer, - &claimer1, - &claimer1_collateral1, - &collateral1_mint, - &run1, - &client1.pubkey(), - &coordinator1_account, - 0, - ) - .await - .unwrap_err(); - - // Create and activate the join authorization for everyone - let authorization = process_authorizer_authorization_create( - &mut endpoint, - &payer, - &join_authority, - &Pubkey::default(), - &JOIN_RUN_AUTHORIZATION_SCOPE, - ) - .await - .unwrap(); - process_authorizer_authorization_grantor_update( - &mut endpoint, - &payer, - &join_authority, - &authorization, - AuthorizationGrantorUpdateParams { active: true }, - ) - .await - .unwrap(); - - // Joining the runs - process_coordinator_join_run( - &mut endpoint, - &payer, - &client1, - &authorization, - &coordinator1_instance, - &coordinator1_account, - NodeIdentity::from_single_key(client1.pubkey().to_bytes()), - &claimer1.pubkey(), - ) - .await - .unwrap(); - process_coordinator_join_run( - &mut endpoint, - &payer, - &client2, - &authorization, - &coordinator1_instance, - &coordinator1_account, - NodeIdentity::from_single_key(client2.pubkey().to_bytes()), - &claimer2.pubkey(), - ) - .await - .unwrap(); - process_coordinator_join_run( - &mut endpoint, - &payer, - &client1, - &authorization, - &coordinator2_instance, - &coordinator2_account, - NodeIdentity::from_single_key(client1.pubkey().to_bytes()), - &claimer1.pubkey(), - ) - .await - .unwrap(); - process_coordinator_join_run( &mut endpoint, &payer, &client2, - &authorization, - &coordinator2_instance, - &coordinator2_account, - NodeIdentity::from_single_key(client2.pubkey().to_bytes()), - &claimer2.pubkey(), + &run2, ) .await .unwrap(); @@ -353,11 +199,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral1, + &client1, + &client1_collateral1, &collateral1_mint, &run1, - &client1.pubkey(), &coordinator1_account, 0, ) @@ -366,11 +211,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer2, - &claimer2_collateral1, + &client2, + &client2_collateral1, &collateral1_mint, &run1, - &client2.pubkey(), &coordinator1_account, 0, ) @@ -379,11 +223,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral2, + &client1, + &client1_collateral2, &collateral2_mint, &run2, - &client1.pubkey(), &coordinator2_account, 0, ) @@ -392,11 +235,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer2, - &claimer2_collateral2, + &client2, + &client2_collateral2, &collateral2_mint, &run2, - &client2.pubkey(), &coordinator2_account, 0, ) @@ -407,41 +249,24 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral1, + &client1, + &client1_collateral1, &collateral1_mint, &run1, - &client1.pubkey(), &coordinator1_account, 1, ) .await .unwrap_err(); - // Try claiming using the wrong client, it should fail - process_treasurer_participant_claim( - &mut endpoint, - &payer, - &claimer1, - &claimer1_collateral1, - &collateral1_mint, - &run1, - &client2.pubkey(), // Wrong client - &coordinator1_account, - 0, - ) - .await - .unwrap_err(); - - // Try claiming using the wrong claimer, it should fail + // Try claiming using the wrong owner, it should fail process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer2, // Wrong claimer - &claimer1_collateral1, + &client2, + &client1_collateral1, &collateral1_mint, &run1, - &client1.pubkey(), &coordinator1_account, 0, ) @@ -452,11 +277,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer2_collateral1, // Wrong ATA + &client1, + &client2_collateral1, &collateral1_mint, &run1, - &client1.pubkey(), &coordinator1_account, 0, ) @@ -465,11 +289,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral2, // Wrong ATA + &client1, + &client1_collateral2, &collateral1_mint, &run1, - &client1.pubkey(), &coordinator1_account, 0, ) @@ -480,11 +303,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral1, - &collateral2_mint, // Wrong mint + &client1, + &client1_collateral1, + &collateral2_mint, &run1, - &client1.pubkey(), &coordinator1_account, 0, ) @@ -495,11 +317,10 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral1, + &client1, + &client1_collateral1, &collateral1_mint, - &run2, // Wrong run - &client1.pubkey(), + &run2, &coordinator1_account, 0, ) @@ -510,22 +331,21 @@ pub async fn run() { process_treasurer_participant_claim( &mut endpoint, &payer, - &claimer1, - &claimer1_collateral1, + &client1, + &client1_collateral1, &collateral1_mint, &run1, - &client1.pubkey(), - &coordinator2_account, // Wrong coordinator account + &coordinator2_account, 0, ) .await .unwrap_err(); // Noone should have been able to claim anything yet - assert_amount(&mut endpoint, &claimer1_collateral1, 0).await; - assert_amount(&mut endpoint, &claimer2_collateral2, 0).await; - assert_amount(&mut endpoint, &claimer1_collateral1, 0).await; - assert_amount(&mut endpoint, &claimer2_collateral2, 0).await; + assert_amount(&mut endpoint, &client1_collateral1, 0).await; + assert_amount(&mut endpoint, &client2_collateral2, 0).await; + assert_amount(&mut endpoint, &client1_collateral1, 0).await; + assert_amount(&mut endpoint, &client2_collateral2, 0).await; // All the runs collateral should still be intact assert_amount(&mut endpoint, &run1_collateral1, 1_000_000_000_000).await; diff --git a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs index 209b1315a..3db8e64ee 100644 --- a/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs +++ b/architectures/decentralized/solana-tooling/tests/suites/memnet_treasurer_full_epoch.rs @@ -148,18 +148,61 @@ pub async fn run() { .await .unwrap(); + // Create the clients ATAs + let mut clients_collateral = vec![]; + for client in &clients { + clients_collateral.push( + endpoint + .process_spl_associated_token_account_get_or_init( + &payer, + &client.pubkey(), + &collateral_mint, + ) + .await + .unwrap(), + ); + } + // Create the participations accounts for client in &clients { process_treasurer_participant_create( &mut endpoint, &payer, + client, &run, - &client.pubkey(), ) .await .unwrap(); } + // Try claiming nothing, it should work, but we earned nothing + process_treasurer_participant_claim( + &mut endpoint, + &payer, + &clients[0], + &clients_collateral[0], + &collateral_mint, + &run, + &coordinator_account, + 0, + ) + .await + .unwrap(); + + // Claiming with the wrong collateral should fail + process_treasurer_participant_claim( + &mut endpoint, + &payer, + &clients[0], + &clients_collateral[1], + &collateral_mint, + &run, + &coordinator_account, + 0, + ) + .await + .unwrap_err(); + // Prepare the coordinator's config process_treasurer_run_update( &mut endpoint, @@ -251,55 +294,16 @@ pub async fn run() { .await .unwrap(); - // Create the clients's claimers - let mut claimers = vec![]; - for _ in &clients { - claimers.push(Keypair::new()); - } - // The clients can now join the run - for i in 0..clients.len() { + for client in &clients { process_coordinator_join_run( &mut endpoint, &payer, - &clients[i], + client, &authorization, &coordinator_instance, &coordinator_account, - NodeIdentity::from_single_key(clients[i].pubkey().to_bytes()), - &claimers[i].pubkey(), - ) - .await - .unwrap(); - } - - // Create the clients's claimers's ATA - let mut claimers_collateral = vec![]; - for claimer in &claimers { - claimers_collateral.push( - endpoint - .process_spl_associated_token_account_get_or_init( - &payer, - &claimer.pubkey(), - &collateral_mint, - ) - .await - .unwrap(), - ); - } - - // Try claiming nothing, it should work, but we earned nothing yet - for i in 0..clients.len() { - process_treasurer_participant_claim( - &mut endpoint, - &payer, - &claimers[i], - &claimers_collateral[i], - &collateral_mint, - &run, - &clients[i].pubkey(), - &coordinator_account, - 0, + NodeIdentity::new(client.pubkey().to_bytes(), Default::default()), ) .await .unwrap(); @@ -396,21 +400,18 @@ pub async fn run() { } // Not yet earned the credit, claiming anything should fail - for i in 0..clients.len() { - process_treasurer_participant_claim( - &mut endpoint, - &payer, - &claimers[i], - &claimers_collateral[i], - &collateral_mint, - &coordinator_instance, - &clients[i].pubkey(), - &coordinator_account, - 1, - ) - .await - .unwrap_err(); - } + process_treasurer_participant_claim( + &mut endpoint, + &payer, + &clients[0], + &clients_collateral[0], + &collateral_mint, + &coordinator_instance, + &coordinator_account, + 1, + ) + .await + .unwrap_err(); // Tick from cooldown to new epoch (should increment the earned points) endpoint @@ -428,21 +429,18 @@ pub async fn run() { .unwrap(); // We can claim earned points now, but it should fail because run isnt funded - for i in 0..clients.len() { - process_treasurer_participant_claim( - &mut endpoint, - &payer, - &claimers[i], - &claimers_collateral[i], - &collateral_mint, - &run, - &clients[i].pubkey(), - &coordinator_account, - earned_point_per_epoch_per_client, - ) - .await - .unwrap_err(); - } + process_treasurer_participant_claim( + &mut endpoint, + &payer, + &clients[0], + &clients_collateral[0], + &collateral_mint, + &run, + &coordinator_account, + earned_point_per_epoch_per_client, + ) + .await + .unwrap_err(); // We should be able to top-up run treasury at any time endpoint @@ -458,14 +456,15 @@ pub async fn run() { // Now that a new epoch has started, we can claim our earned point for i in 0..clients.len() { + let client = &clients[i]; + let client_collateral = &clients_collateral[i]; process_treasurer_participant_claim( &mut endpoint, &payer, - &claimers[i], - &claimers_collateral[i], + client, + client_collateral, &collateral_mint, &run, - &clients[i].pubkey(), &coordinator_account, earned_point_per_epoch_per_client, ) @@ -474,27 +473,24 @@ pub async fn run() { } // Can't claim anything past the earned points - for i in 0..clients.len() { - process_treasurer_participant_claim( - &mut endpoint, - &payer, - &claimers[i], - &claimers_collateral[i], - &collateral_mint, - &run, - &clients[i].pubkey(), - &coordinator_account, - 1, - ) - .await - .unwrap_err(); - } + process_treasurer_participant_claim( + &mut endpoint, + &payer, + &clients[0], + &clients_collateral[0], + &collateral_mint, + &run, + &coordinator_account, + 1, + ) + .await + .unwrap_err(); // Check that we could claim only exactly the right amount - for claimer_collateral in &claimers_collateral { + for client_collateral in &clients_collateral { assert_eq!( endpoint - .get_spl_token_account(claimer_collateral) + .get_spl_token_account(client_collateral) .await .unwrap() .unwrap() diff --git a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/lib.rs b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/lib.rs index b6f204d32..2888e214c 100644 --- a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/lib.rs +++ b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/lib.rs @@ -61,15 +61,9 @@ pub mod psyche_solana_treasurer { #[error_code] pub enum ProgramError { + #[msg("Invalid parameter")] + InvalidParameter, + #[msg("run_id must be 32 bytes or less")] RunIdInvalidLength, - - #[msg("Participant's client not found")] - ParticipantClientNotFound, - - #[msg("Claimer signer does not match the expected signer")] - ClaimerSignerMismatch, - - #[msg("Claimed points exceed earned points")] - ClaimedPointsExceedEarnedPoints, } diff --git a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_claim.rs b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_claim.rs index d80d5108f..7afebccf3 100644 --- a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_claim.rs +++ b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_claim.rs @@ -13,15 +13,15 @@ use crate::state::Run; #[instruction(params: ParticipantClaimParams)] pub struct ParticipantClaimAccounts<'info> { #[account()] - pub claimer: Signer<'info>, + pub user: Signer<'info>, #[account( mut, - constraint = claimer_collateral.mint == run.collateral_mint, - constraint = claimer_collateral.owner == claimer.key(), - constraint = claimer_collateral.delegate == None.into(), + constraint = user_collateral.mint == run.collateral_mint, + constraint = user_collateral.owner == user.key(), + constraint = user_collateral.delegate == None.into(), )] - pub claimer_collateral: Box>, + pub user_collateral: Box>, #[account( mut, @@ -46,7 +46,7 @@ pub struct ParticipantClaimAccounts<'info> { seeds = [ Participant::SEEDS_PREFIX, run.key().as_ref(), - params.user.as_ref() + user.key().as_ref() ], bump = participant.bump )] @@ -58,7 +58,6 @@ pub struct ParticipantClaimAccounts<'info> { #[derive(AnchorSerialize, AnchorDeserialize, Clone)] pub struct ParticipantClaimParams { - pub user: Pubkey, pub claim_earned_points: u64, } @@ -66,23 +65,21 @@ pub fn participant_claim_processor( context: Context, params: ParticipantClaimParams, ) -> Result<()> { - let user_bytes = params.user.as_ref(); - let coordinator_account = context.accounts.coordinator_account.load()?; - let client_state = match coordinator_account + let mut participant_earned_points = 0; + for client in context + .accounts + .coordinator_account + .load()? .state .clients_state .clients .iter() - .find(|client| client.id.signer() == user_bytes) { - Some(info) => info, - None => return err!(ProgramError::ParticipantClientNotFound), - }; - - if context.accounts.claimer.key() != client_state.claimer { - return err!(ProgramError::ClaimerSignerMismatch); + if *client.id.signer() == context.accounts.user.key().to_bytes() { + participant_earned_points = client.earned; + break; + } } - let participant_earned_points = client_state.earned; let participant = &mut context.accounts.participant; let run = &mut context.accounts.run; @@ -90,7 +87,7 @@ pub fn participant_claim_processor( let participant_unclaimed_earned_points = participant_earned_points - participant.claimed_earned_points; if params.claim_earned_points > participant_unclaimed_earned_points { - return err!(ProgramError::ClaimedPointsExceedEarnedPoints); + return err!(ProgramError::InvalidParameter); } // We distribute 1 collateral per point and let the coordinator decide the point reward rate @@ -109,7 +106,7 @@ pub fn participant_claim_processor( context.accounts.token_program.to_account_info(), Transfer { from: context.accounts.run_collateral.to_account_info(), - to: context.accounts.claimer_collateral.to_account_info(), + to: context.accounts.user_collateral.to_account_info(), authority: context.accounts.run.to_account_info(), }, ) diff --git a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_create.rs b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_create.rs index b73f8fbae..2dff39d45 100644 --- a/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_create.rs +++ b/architectures/decentralized/solana-treasurer/programs/solana-treasurer/src/logic/participant_create.rs @@ -9,6 +9,9 @@ pub struct ParticipantCreateAccounts<'info> { #[account(mut)] pub payer: Signer<'info>, + #[account()] + pub user: Signer<'info>, + #[account()] pub run: Box>, @@ -19,7 +22,7 @@ pub struct ParticipantCreateAccounts<'info> { seeds = [ Participant::SEEDS_PREFIX, run.key().as_ref(), - params.user.as_ref() + user.key().as_ref() ], bump )] @@ -30,9 +33,7 @@ pub struct ParticipantCreateAccounts<'info> { } #[derive(AnchorSerialize, AnchorDeserialize, Clone)] -pub struct ParticipantCreateParams { - pub user: Pubkey, -} +pub struct ParticipantCreateParams {} pub fn participant_create_processor( context: Context, diff --git a/architectures/decentralized/testing/src/docker_setup.rs b/architectures/decentralized/testing/src/docker_setup.rs index 4886eb6f1..ad1c46981 100644 --- a/architectures/decentralized/testing/src/docker_setup.rs +++ b/architectures/decentralized/testing/src/docker_setup.rs @@ -61,7 +61,14 @@ pub async fn e2e_testing_setup( docker_client: Arc, init_num_clients: usize, ) -> DockerTestCleanup { - e2e_testing_setup_with_min(docker_client, init_num_clients, init_num_clients, None).await + e2e_testing_setup_with_min( + docker_client, + init_num_clients, + init_num_clients, + None, + None, + ) + .await } /// Setup with explicit min_clients value and optional owner keypair path. @@ -72,10 +79,17 @@ pub async fn e2e_testing_setup_with_min( init_num_clients: usize, min_clients: usize, owner_keypair_path: Option<&Path>, + waiting_for_members_extra_time: Option, ) -> DockerTestCleanup { remove_old_client_containers(docker_client).await; - spawn_psyche_network_with_min(init_num_clients, min_clients, owner_keypair_path).unwrap(); + spawn_psyche_network_with_min( + init_num_clients, + min_clients, + owner_keypair_path, + waiting_for_members_extra_time, + ) + .unwrap(); spawn_ctrl_c_task(); @@ -270,7 +284,7 @@ pub async fn spawn_new_client_with_monitoring( // Updated spawn function pub fn spawn_psyche_network(init_num_clients: usize) -> Result<(), DockerWatcherError> { - spawn_psyche_network_with_min(init_num_clients, init_num_clients, None) + spawn_psyche_network_with_min(init_num_clients, init_num_clients, None, None) } /// Spawn the psyche network with explicit min_clients and optional owner keypair. @@ -278,19 +292,24 @@ pub fn spawn_psyche_network_with_min( init_num_clients: usize, min_clients: usize, owner_keypair_path: Option<&Path>, + waiting_for_members_extra_time: Option, ) -> Result<(), DockerWatcherError> { #[cfg(not(feature = "python"))] - let config_file_path = ConfigBuilder::new() + let mut builder = ConfigBuilder::new() .with_num_clients(init_num_clients) - .with_min_clients(min_clients) - .build(); + .with_min_clients(min_clients); #[cfg(feature = "python")] - let config_file_path = ConfigBuilder::new() + let mut builder = ConfigBuilder::new() .with_num_clients(init_num_clients) .with_min_clients(min_clients) .with_architecture("HfAuto") - .with_batch_size(8 * std::cmp::max(init_num_clients, 1) as u32) - .build(); + .with_batch_size(8 * std::cmp::max(init_num_clients, 1) as u32); + + if let Some(time) = waiting_for_members_extra_time { + builder = builder.with_waiting_for_members_extra_time(time); + } + + let config_file_path = builder.build(); println!("[+] Config file written to: {}", config_file_path.display()); diff --git a/architectures/decentralized/testing/src/utils.rs b/architectures/decentralized/testing/src/utils.rs index df2afdd82..96a180465 100644 --- a/architectures/decentralized/testing/src/utils.rs +++ b/architectures/decentralized/testing/src/utils.rs @@ -159,6 +159,7 @@ pub struct ConfigBuilder { min_clients: Option, batch_size: u32, architecture: String, + waiting_for_members_extra_time: Option, } impl Default for ConfigBuilder { @@ -187,6 +188,7 @@ impl ConfigBuilder { min_clients: None, batch_size: 4, architecture: String::from("HfLlama"), + waiting_for_members_extra_time: None, } } @@ -211,6 +213,11 @@ impl ConfigBuilder { self } + pub fn with_waiting_for_members_extra_time(mut self, time: u32) -> Self { + self.waiting_for_members_extra_time = Some(time); + self + } + pub fn build(mut self) -> PathBuf { // Use min_clients if set, otherwise default to num_clients let min_clients = self.min_clients.unwrap_or(self.num_clients); @@ -229,6 +236,10 @@ impl ConfigBuilder { #[cfg(feature = "python")] self.set_value("config.warmup_time", 100); + if let Some(time) = self.waiting_for_members_extra_time { + self.set_value("config.waiting_for_members_extra_time", time); + } + let config_content = toml::to_string(&self.base_config).unwrap(); let config_file_path = PathBuf::from("../../../config/solana-test/test-config.toml"); fs::write(&config_file_path, config_content).unwrap(); diff --git a/architectures/decentralized/testing/tests/integration_tests.rs b/architectures/decentralized/testing/tests/integration_tests.rs index eae0a7a0b..87971591a 100644 --- a/architectures/decentralized/testing/tests/integration_tests.rs +++ b/architectures/decentralized/testing/tests/integration_tests.rs @@ -486,13 +486,13 @@ async fn disconnect_client() { #[serial] async fn drop_a_client_waitingformembers_then_reconnect() { let n_clients = 2; - let num_of_epochs_to_run = 3; - let mut current_epoch = -1; let run_id = "test".to_string(); let docker = Arc::new(Docker::connect_with_socket_defaults().unwrap()); let mut watcher = DockerWatcher::new(docker.clone()); - let _cleanup = e2e_testing_setup(docker.clone(), 2).await; + // Use extra WFM time so we have a window to kill a client during WaitingForMembers + let _cleanup = + e2e_testing_setup_with_min(docker.clone(), n_clients, n_clients, None, Some(30)).await; let solana_client = SolanaTestClient::new(run_id, None).await; // Monitor clients @@ -501,56 +501,49 @@ async fn drop_a_client_waitingformembers_then_reconnect() { .monitor_container( &format!("{CLIENT_CONTAINER_PREFIX}-{i}"), vec![ - IntegrationTestLogMarker::Loss, IntegrationTestLogMarker::StateChange, - IntegrationTestLogMarker::LoadedModel, IntegrationTestLogMarker::Error, ], ) .unwrap(); } - let mut train_reached = false; + // Wait for both clients to reach WaitingForMembers, then kill client-2 + let mut killed_client = false; + let mut clients_in_wfm: Vec = Vec::new(); while let Some(response) = watcher.log_rx.recv().await { match response { Response::StateChange(_timestamp, client, old_state, new_state, _epoch, _step) => { let coordinator_state = solana_client.get_run_state().await; println!("state change client {client} - {old_state}=>{new_state}"); - // Once warmup starts, kill client 2's container - if new_state == RunState::RoundTrain.to_string() && !train_reached { - println!( - "Train started, killing container {}...", - &format!("{CLIENT_CONTAINER_PREFIX}-2") - ); - - let options = Some(KillContainerOptions { signal: "SIGKILL" }); - docker - .kill_container(&format!("{CLIENT_CONTAINER_PREFIX}-2"), options) - .await - .unwrap(); - - tokio::time::sleep(Duration::from_secs(2)).await; - train_reached = true; + // Track clients reaching WaitingForMembers and kill client-2 once both are in WFM + if new_state == RunState::WaitingForMembers.to_string() + && !clients_in_wfm.contains(&client) + && !killed_client + { + clients_in_wfm.push(client.clone()); + if clients_in_wfm.len() >= n_clients { + println!( + "Both clients in WaitingForMembers. Killing container {CLIENT_CONTAINER_PREFIX}-2..." + ); + let options = Some(KillContainerOptions { signal: "SIGKILL" }); + docker + .kill_container(&format!("{CLIENT_CONTAINER_PREFIX}-2"), options) + .await + .unwrap(); + tokio::time::sleep(Duration::from_secs(2)).await; + killed_client = true; + } } - // After killing client, verify we get stuck in WaitingForMembers - if train_reached && coordinator_state == RunState::WaitingForMembers { - println!("WaitingForMembers seen"); + // After killing client, wait for coordinator to return to WaitingForMembers + // (it may first advance to Warmup, detect dead client, then revert) + if killed_client && coordinator_state == RunState::WaitingForMembers { + println!("WaitingForMembers seen after kill"); break; } } - Response::Loss(client, epoch, step, loss) => { - println!("client: {client:?}, epoch: {epoch}, step: {step}, Loss: {loss:?}"); - - if epoch as i64 > current_epoch { - current_epoch = epoch as i64; - if epoch == num_of_epochs_to_run { - println!("Epoch {epoch} reached. Stopping"); - break; - } - } - } _ => {} } } @@ -567,7 +560,7 @@ async fn drop_a_client_waitingformembers_then_reconnect() { // Wait for state to change back to Warmup assert!( - solana_client.wait_for_run_state(RunState::Warmup, 30).await, + solana_client.wait_for_run_state(RunState::Warmup, 60).await, "System should have returned to Warmup state after client reconnection" ); println!("Successfully returned to Warmup state after client reconnection"); @@ -849,7 +842,7 @@ async fn test_pause_and_resume_run() { // Setup with min_clients=1 but init_num_clients=0 (we spawn manually) // Pass owner keypair to setup script let _cleanup = - e2e_testing_setup_with_min(docker.clone(), 0, 1, Some(owner_path.as_path())).await; + e2e_testing_setup_with_min(docker.clone(), 0, 1, Some(owner_path.as_path()), None).await; // Create SolanaTestClient with owner keypair for set_paused let solana_client = SolanaTestClient::new(run_id.clone(), Some(owner_keypair.clone())).await; diff --git a/architectures/inference-only/inference-node/Cargo.toml b/architectures/inference-only/inference-node/Cargo.toml index 0e99cbc20..044fc0f7f 100644 --- a/architectures/inference-only/inference-node/Cargo.toml +++ b/architectures/inference-only/inference-node/Cargo.toml @@ -36,6 +36,7 @@ iroh.workspace = true iroh-blobs.workspace = true iroh-gossip.workspace = true +reqwest = { version = "0.12", features = ["json"] } serde_json.workspace = true serde.workspace = true uuid = { version = "1", features = ["v4"] } diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 5e0a2f2fc..201df83e0 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -14,9 +14,10 @@ use axum::{ extract::State, http::StatusCode, response::{IntoResponse, Response}, - routing::post, + routing::{get, post}, }; use clap::Parser; +use iroh::EndpointAddr; use psyche_inference::{ INFERENCE_ALPN, InferenceGossipMessage, InferenceMessage, InferenceRequest, InferenceResponse, }; @@ -68,6 +69,7 @@ struct GatewayState { pending_requests: RwLock>>, network_tx: mpsc::Sender<(EndpointId, InferenceMessage)>, gossip_tx: mpsc::Sender, + endpoint_addr: EndpointAddr, } #[derive(serde::Deserialize, serde::Serialize, Clone, Debug)] @@ -271,6 +273,15 @@ async fn handle_load_model( )) } +#[axum::debug_handler] +async fn handle_bootstrap(State(state): State>) -> Json { + info!( + "Bootstrap request: returning endpoint addr {}", + state.endpoint_addr.id.fmt_short() + ); + Json(state.endpoint_addr.clone()) +} + #[derive(Debug)] enum AppError { NoNodesAvailable, @@ -435,11 +446,15 @@ async fn run_gateway() -> Result<()> { let (network_tx, mut network_rx) = mpsc::channel::<(EndpointId, InferenceMessage)>(100); let (gossip_tx, mut gossip_rx) = mpsc::channel::(100); + + let endpoint_addr = network.router().endpoint().addr(); + let state = Arc::new(GatewayState { available_nodes: RwLock::new(HashMap::new()), pending_requests: RwLock::new(HashMap::new()), network_tx, gossip_tx, + endpoint_addr, }); info!("Gateway ready! Listening on http://{}", args.listen_addr); @@ -594,6 +609,7 @@ async fn run_gateway() -> Result<()> { let app = Router::new() .route("/v1/chat/completions", post(handle_inference)) .route("/admin/load-model", post(handle_load_model)) + .route("/bootstrap", get(handle_bootstrap)) .with_state(state.clone()); let listener = tokio::net::TcpListener::bind(&args.listen_addr) diff --git a/architectures/inference-only/inference-node/src/lib.rs b/architectures/inference-only/inference-node/src/lib.rs index 9c56c128b..ab07a98bc 100644 --- a/architectures/inference-only/inference-node/src/lib.rs +++ b/architectures/inference-only/inference-node/src/lib.rs @@ -3,6 +3,22 @@ use iroh::EndpointAddr; use std::{fs, path::PathBuf}; use tracing::info; +/// Fetch the gateway's endpoint address via its HTTP `/bootstrap` endpoint. +pub async fn fetch_bootstrap_peer(gateway_url: &str) -> Result { + let url = format!("{}/bootstrap", gateway_url.trim_end_matches('/')); + info!("Fetching bootstrap info from {}", url); + let addr: EndpointAddr = reqwest::get(&url) + .await + .context("Failed to reach gateway bootstrap endpoint")? + .error_for_status() + .context("Gateway returned error for /bootstrap")? + .json() + .await + .context("Failed to parse bootstrap response as EndpointAddr")?; + info!("Got bootstrap peer: {}", addr.id.fmt_short()); + Ok(addr) +} + pub fn load_bootstrap_peers( bootstrap_peer_file: Option<&PathBuf>, fallback_message: &str, diff --git a/architectures/inference-only/inference-node/src/main.rs b/architectures/inference-only/inference-node/src/main.rs index 097770477..fe286cfe4 100644 --- a/architectures/inference-only/inference-node/src/main.rs +++ b/architectures/inference-only/inference-node/src/main.rs @@ -84,6 +84,10 @@ struct RunArgs { #[arg(long, default_value = "")] capabilities: String, + /// gateway HTTP URL to fetch bootstrap peer from + #[arg(long, env = "PSYCHE_GATEWAY_URL")] + bootstrap_url: Option, + /// bootstrap peer file (JSON file with gateway endpoint address) #[arg(long)] bootstrap_peer_file: Option, @@ -140,11 +144,25 @@ async fn main() -> Result<()> { info!("Relay kind: {:?}", run_args.relay_kind); info!("Capabilities: {:?}", capabilities); - let bootstrap_peers = psyche_inference_node::load_bootstrap_peers( + let mut bootstrap_peers = psyche_inference_node::load_bootstrap_peers( run_args.bootstrap_peer_file.as_ref(), "No bootstrap peers configured (no env vars or CLI args)", )?; + if bootstrap_peers.is_empty() { + if let Some(ref url) = run_args.bootstrap_url { + match psyche_inference_node::fetch_bootstrap_peer(url).await { + Ok(peer) => { + info!("Fetched bootstrap peer from {}", url); + bootstrap_peers.push(peer); + } + Err(e) => { + warn!("Failed to fetch bootstrap peer from {}: {:#}", url, e); + } + } + } + } + let cancel = CancellationToken::new(); info!("Initializing Python interpreter..."); @@ -250,6 +268,11 @@ async fn main() -> Result<()> { let mut heartbeat_interval = tokio::time::interval(std::time::Duration::from_secs(30)); heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + // re-bootstrap every 20 heartbeats (10 min) + let mut rebootstrap_interval = tokio::time::interval(std::time::Duration::from_secs(600)); + rebootstrap_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + rebootstrap_interval.tick().await; + loop { tokio::select! { _ = tokio::signal::ctrl_c() => { @@ -285,6 +308,20 @@ async fn main() -> Result<()> { } } + _ = rebootstrap_interval.tick() => { + if let Some(ref url) = run_args.bootstrap_url { + match psyche_inference_node::fetch_bootstrap_peer(url).await { + Ok(peer) => { + network.add_peers(vec![peer.id]); + debug!("Re-bootstrapped from {}: peer {}", url, peer.id.fmt_short()); + } + Err(e) => { + warn!("Re-bootstrap from {} failed: {:#}", url, e); + } + } + } + } + event = network.poll_next() => { match event { Ok(Some(NetworkEvent::MessageReceived((peer_id, msg)))) => { diff --git a/justfile b/justfile index 8cc465682..476805662 100644 --- a/justfile +++ b/justfile @@ -216,70 +216,49 @@ inference-stack model="gpt2": #!/usr/bin/env bash set -euo pipefail - # Check if tmux is available if ! command -v tmux &> /dev/null; then echo "Error: tmux is required but not installed" exit 1 fi SESSION="psyche-inference" - GATEWAY_PEER_FILE="/tmp/psyche-gateway-peer.json" - - # Clean up old peer file - rm -f "$GATEWAY_PEER_FILE" + GATEWAY_URL="http://localhost:8000" - # Kill existing session if it exists tmux kill-session -t $SESSION 2>/dev/null || true - echo "building gateway and inference node..." + echo "Building gateway and inference node..." nix build .#bin-psyche-inference-node-gateway-node .#psyche-inference-node - echo "Starting gateway node (bootstrap node)..." - - # Create new session with gateway (starts first to be bootstrap node) + echo "Starting gateway node..." tmux new-session -d -s $SESSION -n gateway - tmux send-keys -t $SESSION:gateway "PSYCHE_GATEWAY_ENDPOINT_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#bin-psyche-inference-node-gateway-node -- --discovery-mode n0 --relay-kind n0" C-m + tmux send-keys -t $SESSION:gateway "RUST_LOG=info,psyche_network=debug nix run .#bin-psyche-inference-node-gateway-node -- --discovery-mode local" C-m - # Wait for gateway to start and write peer file - echo "Waiting for gateway to initialize and write endpoint..." + echo "Waiting for gateway HTTP server to be ready..." for i in $(seq 1 30); do - if [ -f "$GATEWAY_PEER_FILE" ]; then - echo "Gateway peer file created" + if curl -sf "$GATEWAY_URL/bootstrap" > /dev/null 2>&1; then + echo "Gateway ready" break fi sleep 1 done - if [ ! -f "$GATEWAY_PEER_FILE" ]; then - echo "Error: Gateway failed to create peer file" + if ! curl -sf "$GATEWAY_URL/bootstrap" > /dev/null 2>&1; then + echo "Error: Gateway failed to start" exit 1 fi - # Wait a bit more for gateway HTTP server - sleep 2 - echo "Gateway ready" - echo "" - echo "Starting inference node..." - - # Create window for inference node (bootstraps from gateway) + echo "Starting inference node (bootstrapping from $GATEWAY_URL)..." tmux new-window -t $SESSION -n inference - tmux send-keys -t $SESSION:inference "PSYCHE_GATEWAY_BOOTSTRAP_FILE=$GATEWAY_PEER_FILE RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --model-name {{ model }} --discovery-mode n0 --relay-kind n0" C-m - - # Wait for inference node to start - sleep 3 - echo "Inference node started" - echo "" + tmux send-keys -t $SESSION:inference "RUST_LOG=info,psyche_network=debug nix run .#psyche-inference-node -- --model-name {{ model }} --discovery-mode local --bootstrap-url $GATEWAY_URL" C-m - # Create window for testing tmux new-window -t $SESSION -n test tmux send-keys -t $SESSION:test "echo 'Test inference with:'; echo 'curl -X POST http://127.0.0.1:8000/v1/chat/completions -H \"Content-Type: application/json\" -d '\"'\"'{\"messages\": [{\"role\": \"user\", \"content\": \"Hello, world!\"}], \"max_tokens\": 50}'\"'\"''" C-m - # Attach to session - echo "Starting inference stack in tmux session '$SESSION'" - echo "Windows: inference (node), gateway (HTTP API), test (for curl commands)" + echo "Inference stack running in tmux session '$SESSION'" + echo "Windows: gateway, inference, test" echo "" echo "To attach: tmux attach -t $SESSION" - echo "To kill: tmux kill-session -t $SESSION" + echo "To kill: tmux kill-session -t $SESSION" echo "" tmux attach -t $SESSION diff --git a/shared/client/src/state/init.rs b/shared/client/src/state/init.rs index 03c7a9be4..ec5d38d69 100644 --- a/shared/client/src/state/init.rs +++ b/shared/client/src/state/init.rs @@ -504,6 +504,7 @@ impl RunInitConfigAndIO { { let dp = init_config.data_parallelism; let tp = init_config.tensor_parallelism; + let num_local_ranks = init_config.device.size() as i64; tokio::task::spawn_blocking(move || { if tp != 1 || dp != 1 { @@ -515,7 +516,7 @@ impl RunInitConfigAndIO { psyche_modeling::ParallelismConfig { dp, tp }, Some(llm.max_seq_len as usize), init_config.sidecar_port, - None, + Some(num_local_ranks), ) .map(RawLoadedModelType::PythonDistributed) .map_err(InitRunError::PythonDistributedError) diff --git a/shared/coordinator/src/coordinator.rs b/shared/coordinator/src/coordinator.rs index 29e635d21..57dde7710 100644 --- a/shared/coordinator/src/coordinator.rs +++ b/shared/coordinator/src/coordinator.rs @@ -168,7 +168,6 @@ pub struct Witness { AnchorDeserialize, Serialize, Deserialize, - PartialEq, TS, Default, Debug, @@ -193,7 +192,6 @@ pub struct WitnessMetadata { AnchorDeserialize, Serialize, Deserialize, - PartialEq, TS, Default, Debug, @@ -239,16 +237,7 @@ pub type HealthChecks = Vec<(NodeIdentity, CommitteeProof)>; pub const NUM_STORED_ROUNDS: usize = 4; #[derive( - Clone, - Debug, - Zeroable, - Copy, - Serialize, - Deserialize, - AnchorDeserialize, - AnchorSerialize, - PartialEq, - TS, + Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorDeserialize, AnchorSerialize, TS, )] #[repr(C)] pub struct CoordinatorConfig { @@ -274,16 +263,7 @@ pub struct CoordinatorConfig { } #[derive( - Clone, - Debug, - Zeroable, - Copy, - Serialize, - Deserialize, - AnchorSerialize, - AnchorDeserialize, - PartialEq, - TS, + Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorSerialize, AnchorDeserialize, TS, )] #[repr(C)] pub struct CoordinatorEpochState { @@ -305,16 +285,7 @@ pub struct CoordinatorEpochState { } #[derive( - Clone, - Debug, - Zeroable, - Copy, - Serialize, - Deserialize, - AnchorSerialize, - AnchorDeserialize, - PartialEq, - TS, + Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorSerialize, AnchorDeserialize, TS, )] #[repr(C)] pub struct CoordinatorProgress { @@ -324,16 +295,7 @@ pub struct CoordinatorProgress { } #[derive( - Clone, - Debug, - Zeroable, - Copy, - Serialize, - Deserialize, - AnchorSerialize, - AnchorDeserialize, - PartialEq, - TS, + Clone, Debug, Zeroable, Copy, Serialize, Deserialize, AnchorSerialize, AnchorDeserialize, TS, )] #[repr(C)] pub struct Coordinator { diff --git a/shared/coordinator/src/model.rs b/shared/coordinator/src/model.rs index 46c0ae111..3176f276e 100644 --- a/shared/coordinator/src/model.rs +++ b/shared/coordinator/src/model.rs @@ -13,16 +13,7 @@ use serde::{Deserialize, Serialize}; use ts_rs::TS; #[derive( - Clone, - Debug, - Copy, - Zeroable, - AnchorDeserialize, - AnchorSerialize, - Serialize, - Deserialize, - PartialEq, - TS, + Clone, Debug, Copy, Zeroable, AnchorDeserialize, AnchorSerialize, Serialize, Deserialize, TS, )] #[repr(C)] pub enum Model { @@ -92,7 +83,6 @@ pub enum LLMTrainingDataType { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -119,7 +109,6 @@ pub enum LLMTrainingDataLocation { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -174,7 +163,6 @@ impl LLMTrainingDataLocationAndWeight { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -196,16 +184,7 @@ pub enum HttpTrainingDataLocation { } #[derive( - AnchorSerialize, - AnchorDeserialize, - Serialize, - Deserialize, - Clone, - Debug, - Zeroable, - Copy, - PartialEq, - TS, + AnchorSerialize, AnchorDeserialize, Serialize, Deserialize, Clone, Debug, Zeroable, Copy, TS, )] #[repr(C)] pub struct LLM { @@ -296,7 +275,6 @@ impl GcsRepo { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] diff --git a/shared/core/src/definitions.rs b/shared/core/src/definitions.rs index d1914ec72..9a5dfd157 100644 --- a/shared/core/src/definitions.rs +++ b/shared/core/src/definitions.rs @@ -20,7 +20,6 @@ pub trait LearningRateScheduler: Send + Sync { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -70,7 +69,6 @@ impl LearningRateScheduler for ConstantLR { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -197,7 +195,6 @@ impl LearningRateScheduler for CosineLR { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -286,7 +283,6 @@ impl LearningRateScheduler for WarmupStableDecayLR { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] @@ -354,7 +350,6 @@ impl From for LearningRateSchedule { Debug, Zeroable, Copy, - PartialEq, TS, )] #[repr(C)] diff --git a/shared/modeling/examples/inference.rs b/shared/modeling/examples/inference.rs index e77579d41..1bc6380dd 100644 --- a/shared/modeling/examples/inference.rs +++ b/shared/modeling/examples/inference.rs @@ -203,7 +203,7 @@ fn inference( psyche_modeling::ParallelismConfig { dp: tp, tp: 1 }, None, None, - None, + Some(args.device.size() as i64), )?) as Box } } diff --git a/shared/modeling/examples/train.rs b/shared/modeling/examples/train.rs index 2c02b695c..21d92949b 100644 --- a/shared/modeling/examples/train.rs +++ b/shared/modeling/examples/train.rs @@ -318,7 +318,7 @@ async fn main() -> Result<()> { psyche_modeling::ParallelismConfig { dp, tp }, Some(args.sequence_length), None, - None, + Some(args.device.size() as i64), )?; Ok(psyche_modeling::PythonDistributedTrainer::new( diff --git a/shared/network/Cargo.toml b/shared/network/Cargo.toml index 91082a332..62db75ef2 100644 --- a/shared/network/Cargo.toml +++ b/shared/network/Cargo.toml @@ -34,11 +34,16 @@ tokenizers.workspace = true get_if_addrs = "0.5.3" n0-future = "0.3.2" url = { version = "2.5", features = ["serde"] } -iroh-n0des = "0.9.0" +iroh-services = { version = "0.11", features = [ + "client_host", + "net_diagnostics", +] } [dev-dependencies] # for examples clap.workspace = true clap-markdown.workspace = true +iroh-fake-store = { git = "https://github.com/IAvecilla/iroh-fake-store", branch = "fake-store-update" } +n0-future = "0.3.2" test-log.workspace = true diff --git a/shared/network/examples/docker-isolated-test/Dockerfile b/shared/network/examples/docker-isolated-test/Dockerfile new file mode 100644 index 000000000..bfd15ca2f --- /dev/null +++ b/shared/network/examples/docker-isolated-test/Dockerfile @@ -0,0 +1,44 @@ +FROM python:3.12-bookworm AS builder + +# Install Rust nightly +RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain nightly +ENV PATH="/root/.cargo/bin:${PATH}" + +# Install build dependencies +RUN apt-get update && apt-get install -y \ + pkg-config \ + libssl-dev \ + protobuf-compiler \ + cmake \ + && rm -rf /var/lib/apt/lists/* + +# Install PyTorch (CPU only, for libtorch) +RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu + +# Set libtorch to use PyTorch's installed copy +ENV LIBTORCH_USE_PYTORCH=1 +ENV LIBTORCH_BYPASS_VERSION_CHECK=1 + +WORKDIR /build +COPY . . + +RUN cargo build --release --example model_sharing_test -p psyche-network + +# Find and copy all libtorch shared libs for the runtime image +RUN mkdir -p /libtorch-libs && \ + TORCH_DIR=$(python3 -c "import torch; print(torch.__path__[0])") && \ + cp ${TORCH_DIR}/lib/lib*.so* /libtorch-libs/ + +FROM debian:bookworm-slim + +RUN apt-get update && apt-get install -y \ + ca-certificates \ + libssl3 \ + libgomp1 \ + && rm -rf /var/lib/apt/lists/* + +COPY --from=builder /libtorch-libs/ /usr/local/lib/ +COPY --from=builder /build/target/release/examples/model_sharing_test /usr/local/bin/model_sharing_test +RUN ldconfig + +ENTRYPOINT ["model_sharing_test"] diff --git a/shared/network/examples/docker-isolated-test/docker-compose.direct.yml b/shared/network/examples/docker-isolated-test/docker-compose.direct.yml new file mode 100644 index 000000000..9c0aeb15e --- /dev/null +++ b/shared/network/examples/docker-isolated-test/docker-compose.direct.yml @@ -0,0 +1,72 @@ +# Docker Compose for direct-connection model sharing test (single sharer). +# +# Same setup as docker-compose.yml but sharer and downloaders share a single +# network, so holepunching succeeds and data flows via direct UDP connections. +# +# For multi-sharer tests, use run.sh --direct instead. +# +# Usage: +# docker compose -f shared/network/examples/docker-isolated-test/docker-compose.direct.yml up --build + +services: + sharer: + build: + context: ../../../.. + dockerfile: shared/network/examples/docker-isolated-test/Dockerfile + entrypoint: ['/bin/sh', '-c'] + command: + - | + model_sharing_test --mode=sharer \ + --sharer-id=0 \ + --num-parameters=${NUM_PARAMS:-10} \ + --parameter-size-mb=${PARAM_SIZE_MB:-200} \ + --discovery-mode=${DISCOVERY_MODE:-n0} \ + --relay-kind=${RELAY_KIND:-psyche} & + PID=$$! + while [ ! -f /tmp/sharer_addr_0.json ]; do sleep 0.5; done + cp /tmp/sharer_addr_0.json /shared/sharer_addr_0.json + wait $$PID + volumes: + - sharer-addr:/shared + environment: + - RUST_LOG=${RUST_LOG:-info} + networks: + - shared-net + healthcheck: + test: ['CMD-SHELL', 'test -f /shared/sharer_addr_0.json'] + interval: 2s + timeout: 2s + retries: 60 + start_period: 5s + + downloader: + build: + context: ../../../.. + dockerfile: shared/network/examples/docker-isolated-test/Dockerfile + command: + - '--mode=downloader' + - '--num-downloaders=1' + - '--num-parameters=${NUM_PARAMS:-10}' + - '--parameter-size-mb=${PARAM_SIZE_MB:-200}' + - '--max-concurrent-downloads=${MAX_CONCURRENT:-4}' + - '--discovery-mode=${DISCOVERY_MODE:-n0}' + - '--relay-kind=${RELAY_KIND:-psyche}' + - '--sharer-addr=/shared' + volumes: + - sharer-addr:/shared:ro + environment: + - RUST_LOG=${RUST_LOG:-info} + networks: + - shared-net + depends_on: + sharer: + condition: service_healthy + deploy: + replicas: ${NUM_DOWNLOADERS:-5} + +volumes: + sharer-addr: + +networks: + shared-net: + driver: bridge diff --git a/shared/network/examples/docker-isolated-test/docker-compose.yml b/shared/network/examples/docker-isolated-test/docker-compose.yml new file mode 100644 index 000000000..0d28a220e --- /dev/null +++ b/shared/network/examples/docker-isolated-test/docker-compose.yml @@ -0,0 +1,84 @@ +# Docker Compose for isolated network model sharing test (single sharer). +# +# Sharer and downloaders run on separate Docker networks that cannot reach +# each other directly. Both have internet access for the psyche relay, but +# holepunching between them fails naturally — forcing organic relay fallback. +# +# For multi-sharer tests, use run.sh instead (it generates a compose file +# with individual sharer services). +# +# Usage: +# docker compose -f shared/network/examples/docker-isolated-test/docker-compose.yml up --build +# +# Environment variables: +# NUM_PARAMS - number of parameters (default: 10) +# PARAM_SIZE_MB - size per parameter in MB (default: 200) +# NUM_DOWNLOADERS - downloader replicas (default: 5) +# MAX_CONCURRENT - max concurrent downloads (default: 4) +# DISCOVERY_MODE - local, n0 (default: n0) +# RELAY_KIND - disabled, psyche, n0 (default: psyche) + +services: + sharer: + build: + context: ../../../.. + dockerfile: shared/network/examples/docker-isolated-test/Dockerfile + entrypoint: ['/bin/sh', '-c'] + command: + - | + model_sharing_test --mode=sharer \ + --sharer-id=0 \ + --num-parameters=${NUM_PARAMS:-10} \ + --parameter-size-mb=${PARAM_SIZE_MB:-200} \ + --discovery-mode=${DISCOVERY_MODE:-n0} \ + --relay-kind=${RELAY_KIND:-psyche} & + PID=$$! + while [ ! -f /tmp/sharer_addr_0.json ]; do sleep 0.5; done + cp /tmp/sharer_addr_0.json /shared/sharer_addr_0.json + wait $$PID + volumes: + - sharer-addr:/shared + environment: + - RUST_LOG=${RUST_LOG:-info} + networks: + - sharer-net + healthcheck: + test: ['CMD-SHELL', 'test -f /shared/sharer_addr_0.json'] + interval: 2s + timeout: 2s + retries: 60 + start_period: 5s + + downloader: + build: + context: ../../../.. + dockerfile: shared/network/examples/docker-isolated-test/Dockerfile + command: + - '--mode=downloader' + - '--num-downloaders=1' + - '--num-parameters=${NUM_PARAMS:-10}' + - '--parameter-size-mb=${PARAM_SIZE_MB:-200}' + - '--max-concurrent-downloads=${MAX_CONCURRENT:-4}' + - '--discovery-mode=${DISCOVERY_MODE:-n0}' + - '--relay-kind=${RELAY_KIND:-psyche}' + - '--sharer-addr=/shared' + volumes: + - sharer-addr:/shared:ro + environment: + - RUST_LOG=${RUST_LOG:-info} + networks: + - downloader-net + depends_on: + sharer: + condition: service_healthy + deploy: + replicas: ${NUM_DOWNLOADERS:-5} + +volumes: + sharer-addr: + +networks: + sharer-net: + driver: bridge + downloader-net: + driver: bridge diff --git a/shared/network/examples/docker-isolated-test/run.sh b/shared/network/examples/docker-isolated-test/run.sh new file mode 100755 index 000000000..82caf4877 --- /dev/null +++ b/shared/network/examples/docker-isolated-test/run.sh @@ -0,0 +1,230 @@ +#!/usr/bin/env bash +# Run the Docker-based model sharing test. +# +# Generates a docker-compose file with individual sharer services so each +# can have a unique ID and optional bandwidth throttling. +# +# Usage: +# ./run.sh # 1 sharer, 5 downloaders +# ./run.sh --sharers 3 --slow-sharers 1 # 3 sharers, last 1 throttled +# ./run.sh --sharers 2 --slow-sharers 1 --slow-rate 50 +# ./run.sh --direct # single shared network (no relay) +# ./run.sh --params 20 --size 100 --downloaders 3 + +set -euo pipefail +cd "$(dirname "$0")" + +NUM_PARAMS=10 +PARAM_SIZE_MB=200 +NUM_SHARERS=1 +NUM_DOWNLOADERS=5 +MAX_CONCURRENT=4 +DISCOVERY_MODE=n0 +RELAY_KIND=psyche +SLOW_SHARERS=0 +SLOW_RATE_KB=100 +DIRECT=false +RELAY_ONLY=false + +while [[ $# -gt 0 ]]; do + case $1 in + --params) NUM_PARAMS="$2"; shift 2 ;; + --size) PARAM_SIZE_MB="$2"; shift 2 ;; + --sharers) NUM_SHARERS="$2"; shift 2 ;; + --downloaders) NUM_DOWNLOADERS="$2"; shift 2 ;; + --max-concurrent) MAX_CONCURRENT="$2"; shift 2 ;; + --discovery) DISCOVERY_MODE="$2"; shift 2 ;; + --relay) RELAY_KIND="$2"; shift 2 ;; + --slow-sharers) SLOW_SHARERS="$2"; shift 2 ;; + --slow-rate) SLOW_RATE_KB="$2"; shift 2 ;; + --direct) DIRECT=true; shift ;; + --relay-only) RELAY_ONLY=true; shift ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +if [ "$SLOW_SHARERS" -gt "$NUM_SHARERS" ]; then + echo "Error: --slow-sharers ($SLOW_SHARERS) cannot exceed --sharers ($NUM_SHARERS)" + exit 1 +fi + +LOG="/tmp/model_sharing_test_$(date +%s).log" +COMPOSE_FILE="/tmp/model_sharing_compose_$$.yml" + +if [ "$DIRECT" = true ]; then + NET_MODE="direct (single shared network)" +else + NET_MODE="isolated (separate networks, relay-only)" +fi + +echo "=== Model Sharing Test ===" +echo " Network: $NET_MODE" +echo " Sharers: $NUM_SHARERS (${SLOW_SHARERS} slow at ${SLOW_RATE_KB} KB/s)" +echo " Parameters: $NUM_PARAMS x ${PARAM_SIZE_MB}MB" +echo " Downloaders: $NUM_DOWNLOADERS" +echo " Max concurrent: $MAX_CONCURRENT" +echo " Discovery: $DISCOVERY_MODE" +echo " Relay: $RELAY_KIND" +echo " Relay-only: $RELAY_ONLY" +echo " Log file: $LOG" +echo "" + +# --------------------------------------------------------------------------- +# Generate docker-compose YAML with individual sharer services +# --------------------------------------------------------------------------- +generate_compose() { + local slow_start=$((NUM_SHARERS - SLOW_SHARERS)) + + cat <<'HEADER' +# Auto-generated by run.sh — do not edit manually. +services: +HEADER + + # --- Sharer services --- + for i in $(seq 0 $((NUM_SHARERS - 1))); do + local slow_flags="" + if [ "$i" -ge "$slow_start" ] && [ "$SLOW_SHARERS" -gt 0 ]; then + slow_flags="--slow --slow-rate-kb=${SLOW_RATE_KB}" + fi + + local relay_only_flag="" + if [ "$RELAY_ONLY" = true ]; then + relay_only_flag="--relay-only" + fi + + local network="sharer-net" + if [ "$DIRECT" = true ]; then + network="shared-net" + fi + + cat < "$COMPOSE_FILE" + +# Project directory so relative paths (build context) resolve correctly +PROJ_DIR="$(pwd)" +DC="docker compose -f $COMPOSE_FILE --project-directory $PROJ_DIR" + +cleanup() { + echo "" + echo "Cleaning up..." + $DC down -v 2>/dev/null || true + rm -f "$COMPOSE_FILE" +} +trap cleanup EXIT + +echo "Building Docker image (first run will compile from source)..." +$DC build 2>&1 | tee -a "$LOG" + +echo "" +echo "Starting test..." +$DC up --abort-on-container-exit 2>&1 | tee -a "$LOG" + +echo "" +echo "Done. Full log at: $LOG" diff --git a/shared/network/examples/model_sharing_test.rs b/shared/network/examples/model_sharing_test.rs new file mode 100644 index 000000000..6fca08710 --- /dev/null +++ b/shared/network/examples/model_sharing_test.rs @@ -0,0 +1,854 @@ +use anyhow::Result; +use clap::{Parser, ValueEnum}; +use iroh::EndpointAddr; +use iroh_blobs::BlobFormat; +use iroh_blobs::Hash; +use iroh_blobs::api::Tag; +use iroh_blobs::ticket::BlobTicket; +use iroh_fake_store::FakeStore; +use postcard; +use psyche_metrics::ClientMetrics; +use psyche_network::{ + ConnectionMonitor, DiscoveryMode, DownloadType, EndpointId, ModelRequestType, + NetworkConnection, NetworkEvent, PeerBandwidth, PeerManagerHandle, PublicKey, RelayKind, + TransmittableDownload, TransmittableModelConfig, allowlist, blob_ticket_param_request_task, +}; +use psyche_tui::LogOutput; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::{Duration, Instant}; +use tokio::select; +use tokio::sync::oneshot; +use tokio_util::sync::CancellationToken; +use tracing::{error, info, warn}; + +#[derive(Debug, Clone, ValueEnum)] +enum RunMode { + /// Run as a sharer only — prints endpoint address JSON to stdout for downloaders + Sharer, + /// Run as a downloader only — requires --sharer-addr + Downloader, +} + +#[derive(Parser, Debug)] +#[command(name = "model_sharing_test")] +#[command(about = "Test harness for P2P model sharing flow")] +struct CliArgs { + /// Run mode: "sharer" or "downloader" + #[clap(long, value_enum)] + mode: RunMode, + + /// JSON-encoded sharer endpoint address, or path to a file/directory containing address JSON. + /// In downloader mode with multiple sharers, point to the directory containing all address files. + #[clap(long)] + sharer_addr: Option, + + #[clap(long, default_value_t = 2)] + num_downloaders: usize, + + #[clap(long, default_value_t = 300)] + num_parameters: usize, + + /// Size of each parameter in MB + #[clap(long, default_value_t = 1000)] + parameter_size_mb: usize, + + #[clap(long, default_value_t = 4)] + max_concurrent_downloads: usize, + + /// Discovery mode: "local" or "n0" + #[clap(long, default_value = "local")] + discovery_mode: String, + + /// Relay kind: "disabled", "psyche", or "n0" + #[clap(long, default_value = "psyche")] + relay_kind: String, + + /// Force all blob downloads to go through the relay (strip direct IP addresses from tickets) + #[clap(long, default_value_t = false)] + relay_only: bool, + + /// Sharer instance ID (used to write unique address files when running multiple sharers) + #[clap(long, default_value_t = 0)] + sharer_id: usize, + + /// Whether this sharer should be slow (throttled) + #[clap(long, default_value_t = false)] + slow: bool, + + /// Bandwidth limit for slow sharers in KB/s + #[clap(long, default_value_t = 100)] + slow_rate_kb: u64, +} + +/// Required by NetworkConnection generics but unused in this example. +#[derive(Debug, Serialize, Deserialize)] +enum TestMessage { + Noop, +} + +type NC = NetworkConnection; + +fn generate_parameter_names(count: usize) -> Vec { + (0..count) + .map(|i| format!("model.layers.{i}.weight")) + .collect() +} + +async fn create_peer( + label: &str, + discovery_mode: DiscoveryMode, + relay_kind: RelayKind, + fake_store: Option<&FakeStore>, + relay_only: bool, +) -> Result { + let metrics = Arc::new(ClientMetrics::new(None, None)); + let network = if let Some(store) = fake_store { + let store: iroh_blobs::api::Store = std::ops::Deref::deref(store).clone(); + NC::init_with_blobs_store( + "model-sharing-test", + None, + None, + discovery_mode, + relay_kind, + vec![], + None, + allowlist::AllowAll, + metrics, + None, + store, + relay_only, + ) + .await? + } else { + NC::init( + "model-sharing-test", + None, + None, + discovery_mode, + relay_kind, + vec![], + None, + allowlist::AllowAll, + metrics, + None, + ) + .await? + }; + + info!("{label} initialized: {}", network.endpoint_id()); + Ok(network) +} + +fn format_bandwidth(bw: &PeerBandwidth) -> String { + match bw { + PeerBandwidth::NotMeasured => "not measured".to_string(), + PeerBandwidth::Measured(bytes_per_sec) => { + format!("{:.2} MB/s", bytes_per_sec / (1024.0 * 1024.0)) + } + } +} + +fn print_peer_status(monitor: &ConnectionMonitor, context: &str) { + let connections = monitor.get_all_connections(); + println!("\n [Peer Status: {context}] ({} peers)", connections.len()); + for conn in &connections { + let latency = conn + .latency() + .map(|d| format!("{d:?}")) + .unwrap_or_else(|| "n/a".to_string()); + println!( + " {} | bw: {} | latency: {}", + conn.endpoint_id, + format_bandwidth(&conn.bandwidth), + latency, + ); + } +} + +async fn run_sharer( + mut network: NC, + fake_store: FakeStore, + param_names: Vec, + param_size_bytes: usize, + relay_only: bool, + cancel: CancellationToken, +) -> Result<()> { + let endpoint_id = network.endpoint_id(); + info!( + "Sharer {endpoint_id}: preparing {} params ({} MB each)", + param_names.len(), + param_size_bytes / (1024 * 1024) + ); + + let mut endpoint_addr = network.endpoint_addr().await; + if relay_only { + // Strip direct IP addresses so downloaders must connect via relay + endpoint_addr.addrs.retain(|a| a.is_relay()); + info!( + "Sharer {endpoint_id}: relay-only mode, stripped direct addresses. Addr: {endpoint_addr:?}" + ); + } + let fake_hashes = fake_store.blobs().list().hashes().await?; + let param_tickets: HashMap = param_names + .iter() + .zip(fake_hashes.iter()) + .map(|(name, hash)| { + let ticket = BlobTicket::new(endpoint_addr.clone(), *hash, BlobFormat::Raw); + (name.clone(), ticket) + }) + .collect(); + + // Model config blob stored in FakeStore so BlobsProtocol can serve it + let config = TransmittableModelConfig::new( + r#"{"model_type": "test", "num_layers": 10}"#.to_string(), + r#"{"version":"1.0","truncation":null,"padding":null,"added_tokens":[],"normalizer":null,"pre_tokenizer":null,"post_processor":null,"decoder":null,"model":{"type":"BPE","dropout":null,"unk_token":null,"continuing_subword_prefix":null,"end_of_word_suffix":null,"fuse_unk":false,"byte_fallback":false,"ignore_merges":false,"vocab":{},"merges":[]}}"#.to_string(), + param_names.clone(), + ); + let config_bytes = postcard::to_allocvec(&TransmittableDownload::ModelConfig(config))?; + let config_blob = fake_store + .blobs() + .add_bytes(config_bytes) + .with_named_tag(Tag::from("model-config-share")) + .await?; + let config_ticket = + BlobTicket::new(endpoint_addr.clone(), config_blob.hash, config_blob.format); + + info!("Sharer {endpoint_id}: ready to serve requests"); + loop { + select! { + _ = cancel.cancelled() => { + info!("Sharer {endpoint_id}: shutting down"); + break; + } + event = network.poll_next() => { + match event { + Ok(Some(NetworkEvent::ParameterRequest(name, reply_tx))) => { + match param_tickets.get(&name) { + Some(ticket) => { + let _ = reply_tx.send(Ok(ticket.clone())); + } + None => { + warn!("Sharer {endpoint_id}: unknown parameter '{name}'"); + let _ = reply_tx.send(Err( + psyche_network::SharableModelError::ParameterUnknown(name), + )); + } + } + } + Ok(Some(NetworkEvent::ModelConfigRequest(reply_tx))) => { + let _ = reply_tx.send(Ok(config_ticket.clone())); + } + Ok(Some(_)) => {} + Ok(None) => {} + Err(e) => { + error!("Sharer {endpoint_id}: network error: {e:#}"); + break; + } + } + } + } + } + + Ok(()) +} + +struct DownloadRequest { + name: String, + ticket: BlobTicket, + peer_id: EndpointId, + start_time: Instant, + done_tx: oneshot::Sender, +} + +struct InFlightDownload { + name: String, + tag_name: String, + peer_id: EndpointId, + start_time: Instant, + done_tx: oneshot::Sender, +} + +/// Records a completed download and signals the worker. +/// Returns the tag name for cleanup. +fn record_completion( + flight: InFlightDownload, + endpoint_id: EndpointId, + completed: &mut usize, + param_reports: &mut Vec, +) -> String { + let duration = flight.start_time.elapsed(); + info!( + "Downloader {endpoint_id}: '{}' from {} in {duration:?}", + flight.name, + flight.peer_id.fmt_short(), + ); + + *completed += 1; + param_reports.push(ParamDownloadReport { + name: flight.name, + duration, + from_peer: flight.peer_id, + }); + + let _ = flight.done_tx.send(true); + flight.tag_name +} + +async fn run_downloader( + mut network: NC, + sharer_ids: Vec, + expected_param_count: usize, + max_concurrent: usize, + param_size_bytes: usize, + cancel: CancellationToken, +) -> Result { + let endpoint_id = network.endpoint_id(); + let connection_monitor = network.connection_monitor(); + let router = network.router(); + + let peer_manager = Arc::new(PeerManagerHandle::new( + 3, + cancel.clone(), + connection_monitor.clone(), + )); + peer_manager.set_peers(sharer_ids.clone()); + + let overall_start = Instant::now(); + let mut param_reports: Vec = Vec::new(); + + // Step 1: Download model config + let config_start = Instant::now(); + let (config_ticket, _) = blob_ticket_param_request_task( + ModelRequestType::Config, + router.clone(), + peer_manager.clone(), + cancel.clone(), + ) + .await?; + + network.start_download( + config_ticket, + Tag::from("model-config"), + DownloadType::ModelSharing(ModelRequestType::Config), + ); + + let param_names = loop { + select! { + _ = cancel.cancelled() => { + return Err(anyhow::anyhow!("Cancelled while downloading config")); + } + event = network.poll_next() => { + match event { + Ok(Some(NetworkEvent::DownloadComplete(result))) => { + if let TransmittableDownload::ModelConfig(config) = result.data { + info!("Downloader {endpoint_id}: config downloaded with {} params in {:?}", + config.parameter_names.len(), config_start.elapsed()); + break config.parameter_names; + } + } + Ok(Some(NetworkEvent::DownloadFailed(f))) => { + return Err(anyhow::anyhow!("Config download failed: {}", f.error)); + } + Ok(_) => {} + Err(e) => return Err(e.into()), + } + } + } + }; + let config_request_time = config_start.elapsed(); + + assert_eq!( + param_names.len(), + expected_param_count, + "Config parameter count mismatch" + ); + + // Step 2: Download parameters via worker pool + main download loop + info!( + "Downloader {endpoint_id}: downloading {} params from {} sharers...", + param_names.len(), + sharer_ids.len(), + ); + + let work_queue: Arc>> = + Arc::new(tokio::sync::Mutex::new(param_names.clone())); + let (request_tx, mut request_rx) = + tokio::sync::mpsc::channel::(max_concurrent * 2); + + // Workers: pick param -> request ticket from peer -> send to main loop + for worker_id in 0..max_concurrent { + let work_q = work_queue.clone(); + let pm = peer_manager.clone(); + let rtr = router.clone(); + let tx = request_tx.clone(); + let cancel = cancel.clone(); + let ep_id = endpoint_id; + + tokio::spawn(async move { + loop { + let name = { + let mut q = work_q.lock().await; + if q.is_empty() { + break; + } + q.remove(0) + }; + + if cancel.is_cancelled() { + break; + } + + let request_type = ModelRequestType::Parameter(name.clone()); + match blob_ticket_param_request_task( + request_type, + rtr.clone(), + pm.clone(), + cancel.clone(), + ) + .await + { + Ok((ticket, _)) => { + let (done_tx, done_rx) = oneshot::channel(); + let req = DownloadRequest { + peer_id: ticket.addr().id, + name, + ticket, + start_time: Instant::now(), + done_tx, + }; + if tx.send(req).await.is_err() { + break; + } + let _ = done_rx.await; + } + Err(e) => { + error!( + "Downloader {ep_id} worker-{worker_id}: ticket request failed for '{name}': {e}" + ); + break; + } + } + } + }); + } + drop(request_tx); + + // Main loop: owns &mut network for start_download + poll_next. + // Keyed by hash -> Vec because FakeStore can produce identical hashes for same-size blobs. + let mut in_flight: HashMap> = HashMap::new(); + let mut completed = 0usize; + let mut tag_counter = 0u64; + + loop { + if completed >= expected_param_count { + break; + } + + select! { + _ = cancel.cancelled() => { + return Err(anyhow::anyhow!("Cancelled during parameter downloads")); + } + req = request_rx.recv() => { + match req { + Some(dl_req) => { + tag_counter += 1; + let tag_name = format!("param-dl-{tag_counter}"); + let hash = dl_req.ticket.hash(); + + network.start_download( + dl_req.ticket, + Tag::from(tag_name.clone()), + DownloadType::ModelSharing(ModelRequestType::Parameter(dl_req.name.clone())), + ); + + in_flight.entry(hash).or_default().push(InFlightDownload { + name: dl_req.name, + tag_name, + peer_id: dl_req.peer_id, + start_time: dl_req.start_time, + done_tx: dl_req.done_tx, + }); + } + None => { + in_flight.retain(|_, v| !v.is_empty()); + if in_flight.is_empty() { + break; + } + } + } + } + event = network.poll_next() => { + match &event { + Ok(Some(NetworkEvent::DownloadComplete(r))) => { + let hash = r.hash; + if let Some(flight) = in_flight.get_mut(&hash).and_then(|v| v.pop()) { + let tag_to_delete = record_completion( + flight, endpoint_id, + &mut completed, &mut param_reports, + ); + if let Err(e) = network.delete_tag(&tag_to_delete).await { + warn!("Failed to delete tag {tag_to_delete}: {e}"); + } + } + } + Ok(Some(NetworkEvent::DownloadFailed(f))) => { + let hash = f.blob_ticket.hash(); + let tag = f.tag.to_string(); + if let Some(flight) = in_flight.get_mut(&hash).and_then(|v| v.pop()) { + if f.transfer_failed { + // Real failure (timeout, network error) — re-queue for retry + warn!( + "Downloader {endpoint_id}: transfer failed for '{}' from {}, re-queuing", + flight.name, flight.peer_id.fmt_short(), + ); + // Push to work queue BEFORE signaling done to avoid race where + // the worker wakes up, sees empty queue, and exits. + work_queue.lock().await.push(flight.name); + let _ = flight.done_tx.send(true); + } else { + // Deserialization error — transfer succeeded (FakeStore data can't deserialize). + // The core sets bandwidth to 0 on any DownloadFailed, but the + // BandwidthTracker had the real value from Progress events. + // Restore it using the tracker's last reading. + let peer_id = f.blob_ticket.addr().id; + let tracker_bw = network.bandwidth_tracker_peer_bandwidth(&peer_id); + let elapsed = flight.start_time.elapsed().as_secs_f64(); + let manual_bw = if elapsed > 0.0 { + param_size_bytes as f64 / elapsed + } else { + 0.0 + }; + info!( + "Bandwidth for '{}' from {}: tracker={}, manual={:.1} KB/s", + flight.name, + peer_id.fmt_short(), + format_bandwidth(&tracker_bw), + manual_bw / 1024.0, + ); + // Restore from tracker (preferred), fallback to manual + let restore_bw = match tracker_bw { + PeerBandwidth::Measured(bw) => PeerBandwidth::Measured(bw), + PeerBandwidth::NotMeasured if manual_bw > 0.0 => { + PeerBandwidth::Measured(manual_bw) + } + _ => PeerBandwidth::NotMeasured, + }; + network.connection_monitor().update_peer_bandwidth( + &peer_id, + restore_bw, + ); + record_completion( + flight, endpoint_id, + &mut completed, &mut param_reports, + ); + } + if let Err(e) = network.delete_tag(&tag).await { + warn!("Failed to delete tag {tag}: {e}"); + } + } + } + Ok(Some(_)) => {} + Ok(None) => {} + Err(e) => { + error!("Downloader {endpoint_id}: network error: {e:#}"); + } + } + + if completed % 10 == 0 || completed == expected_param_count { + if completed > 0 { + print_peer_status( + &connection_monitor, + &format!("Downloader {endpoint_id} {completed}/{expected_param_count}"), + ); + } + } + } + } + } + + let total_duration = overall_start.elapsed(); + info!( + "Downloader {endpoint_id}: {completed}/{} params in {total_duration:?}", + param_names.len() + ); + + Ok(DownloaderReport { + endpoint_id, + total_duration, + config_request_time, + param_reports, + }) +} + +#[derive(Debug)] +struct ParamDownloadReport { + name: String, + duration: Duration, + from_peer: PublicKey, +} + +#[derive(Debug)] +struct DownloaderReport { + endpoint_id: EndpointId, + total_duration: Duration, + config_request_time: Duration, + param_reports: Vec, +} + +fn print_report(reports: &[DownloaderReport], param_size_bytes: usize) { + let separator = "=".repeat(70); + println!("\n{separator}"); + println!(" MODEL SHARING TEST RESULTS"); + println!("{separator}"); + + for report in reports { + println!("\n--- Downloader {} ---", report.endpoint_id); + println!(" Config request time: {:?}", report.config_request_time); + println!(" Total download time: {:?}", report.total_duration); + println!(" Parameters downloaded: {}", report.param_reports.len()); + + if report.param_reports.is_empty() { + continue; + } + + let total_bytes = report.param_reports.len() as f64 * param_size_bytes as f64; + let avg_bw = total_bytes / report.total_duration.as_secs_f64(); + println!( + " Average bandwidth: {:.2} MB/s", + avg_bw / (1024.0 * 1024.0) + ); + + let mut per_peer: HashMap = HashMap::new(); + for pr in &report.param_reports { + let entry = per_peer + .entry(pr.from_peer.fmt_short().to_string()) + .or_insert((0, Duration::ZERO)); + entry.0 += 1; + entry.1 += pr.duration; + } + + println!(" Per-peer breakdown:"); + for (peer, (count, total_time)) in &per_peer { + let peer_bw = (*count as f64 * param_size_bytes as f64) / total_time.as_secs_f64(); + println!( + " {peer}: {count} params, total {total_time:?}, avg {:.2} MB/s", + peer_bw / (1024.0 * 1024.0) + ); + } + + let mut sorted_params: Vec<_> = report.param_reports.iter().collect(); + sorted_params.sort_by_key(|p| p.duration); + if let (Some(fastest), Some(slowest)) = (sorted_params.first(), sorted_params.last()) { + println!(" Fastest: '{}' in {:?}", fastest.name, fastest.duration); + println!(" Slowest: '{}' in {:?}", slowest.name, slowest.duration); + } + } + println!("\n{separator}"); +} + +/// Parse a single sharer address from a JSON string or a file path containing JSON. +fn parse_sharer_addr(raw: &str) -> Result { + // Try parsing as JSON first + if let Ok(addr) = serde_json::from_str::(raw) { + return Ok(addr); + } + // Try reading as a file path + let content = std::fs::read_to_string(raw) + .map_err(|e| anyhow::anyhow!("Failed to read sharer-addr file '{raw}': {e}"))?; + serde_json::from_str::(content.trim()) + .map_err(|e| anyhow::anyhow!("Failed to parse sharer-addr JSON: {e}")) +} + +/// Parse one or more sharer addresses. Accepts a JSON string, a single file, or a directory +/// containing multiple `sharer_addr_*.json` files. +fn parse_sharer_addrs(raw: &str) -> Result> { + // Try as JSON first + if let Ok(addr) = serde_json::from_str::(raw) { + return Ok(vec![addr]); + } + + let path = std::path::Path::new(raw); + + // If it's a directory, read all sharer_addr_*.json files in it + if path.is_dir() { + let mut addrs = Vec::new(); + let mut entries: Vec<_> = std::fs::read_dir(path)? + .filter_map(|e| e.ok()) + .filter(|e| { + e.file_name() + .to_str() + .is_some_and(|n| n.starts_with("sharer_addr_") && n.ends_with(".json")) + }) + .collect(); + entries.sort_by_key(|e| e.file_name()); + + for entry in entries { + let content = std::fs::read_to_string(entry.path())?; + let addr: EndpointAddr = serde_json::from_str(content.trim()) + .map_err(|e| anyhow::anyhow!("Failed to parse {:?}: {e}", entry.path()))?; + addrs.push(addr); + } + + if addrs.is_empty() { + return Err(anyhow::anyhow!( + "No sharer_addr_*.json files found in directory '{raw}'" + )); + } + Ok(addrs) + } else { + // Single file + Ok(vec![parse_sharer_addr(raw)?]) + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = CliArgs::parse(); + let _logger = psyche_tui::logging() + .with_output(LogOutput::Console) + .init()?; + + let discovery_mode: DiscoveryMode = args + .discovery_mode + .parse() + .map_err(|e: String| anyhow::anyhow!(e))?; + let relay_kind: RelayKind = args + .relay_kind + .parse() + .map_err(|e: String| anyhow::anyhow!(e))?; + + let param_names = generate_parameter_names(args.num_parameters); + let param_size_bytes = args.parameter_size_mb * 1024 * 1024; + let relay_only = args.relay_only; + + println!("Model Sharing Test Configuration:"); + println!(" Mode: {:?}", args.mode); + println!(" Parameters: {}", args.num_parameters); + println!(" Parameter size: {} MB", args.parameter_size_mb); + println!(" Max concurrent DLs: {}", args.max_concurrent_downloads); + println!(" Discovery mode: {discovery_mode:?}"); + println!(" Relay kind: {relay_kind:?}"); + if relay_only { + println!( + " Relay-only mode: ENABLED (direct IP transports disabled, all traffic via relay)" + ); + } + println!( + " Total data: {:.1} GB", + (args.num_parameters * args.parameter_size_mb) as f64 / 1024.0 + ); + println!(); + + let cancel = CancellationToken::new(); + + match args.mode { + RunMode::Sharer => { + let sharer_id = args.sharer_id; + let mut builder = FakeStore::builder() + .with_unique_blobs(args.num_parameters, param_size_bytes as u64); + if args.slow { + builder = builder.with_throttle( + std::num::NonZeroU64::new(args.slow_rate_kb * 1024) + .expect("slow_rate_kb must be > 0"), + ); + info!( + "Sharer-{sharer_id}: throttled to {} KB/s", + args.slow_rate_kb + ); + } + let store = builder.build(); + let label = if args.slow { + format!("Sharer-{sharer_id}-SLOW") + } else { + format!("Sharer-{sharer_id}") + }; + let network = + create_peer(&label, discovery_mode, relay_kind, Some(&store), relay_only).await?; + + let endpoint_addr = network.endpoint_addr().await; + let addr_json = serde_json::to_string(&endpoint_addr)?; + + // Print with marker so it can be parsed from logs + println!("SHARER_ADDR_JSON:{addr_json}"); + + // Write to /tmp/sharer_addr_{id}.json for Docker volume sharing + let addr_file = format!("/tmp/sharer_addr_{sharer_id}.json"); + if let Err(e) = std::fs::write(&addr_file, &addr_json) { + warn!("Could not write sharer addr to {addr_file}: {e}"); + } + + info!("{label} running, endpoint: {}", network.endpoint_id()); + run_sharer( + network, + store, + param_names, + param_size_bytes, + relay_only, + cancel, + ) + .await?; + } + + RunMode::Downloader => { + let sharer_addr_raw = args + .sharer_addr + .as_deref() + .ok_or_else(|| anyhow::anyhow!("--sharer-addr is required in downloader mode"))?; + + // Load sharer addresses: supports a single file/JSON or a directory of address files + let sharer_addrs = parse_sharer_addrs(sharer_addr_raw)?; + let sharer_ids: Vec = sharer_addrs.iter().map(|a| a.id).collect(); + info!( + "Downloader targeting {} sharer(s): {}", + sharer_ids.len(), + sharer_ids + .iter() + .map(|id| id.fmt_short().to_string()) + .collect::>() + .join(", ") + ); + + let mut downloader_handles = Vec::new(); + for i in 0..args.num_downloaders { + let network = create_peer( + &format!("Downloader-{i}"), + discovery_mode, + relay_kind, + None, + relay_only, + ) + .await?; + let sharer_ids = sharer_ids.clone(); + let cancel = cancel.clone(); + let expected = args.num_parameters; + let max_concurrent = args.max_concurrent_downloads; + downloader_handles.push(tokio::spawn(async move { + run_downloader( + network, + sharer_ids, + expected, + max_concurrent, + param_size_bytes, + cancel, + ) + .await + })); + } + + let mut reports = Vec::new(); + for handle in downloader_handles { + match handle.await? { + Ok(report) => reports.push(report), + Err(e) => error!("Downloader failed: {e:#}"), + } + } + + cancel.cancel(); + print_report(&reports, param_size_bytes); + } + } + + Ok(()) +} diff --git a/shared/network/src/download/manager.rs b/shared/network/src/download/manager.rs index 461ff9b5b..c75e49331 100644 --- a/shared/network/src/download/manager.rs +++ b/shared/network/src/download/manager.rs @@ -110,6 +110,9 @@ pub struct DownloadFailed { pub tag: Tag, pub error: anyhow::Error, pub download_type: DownloadType, + /// True when the network transfer itself failed (peer should be penalized). + /// False when the transfer succeeded but post-processing (e.g. deserialization) failed. + pub transfer_failed: bool, } impl Debug for DownloadComplete { @@ -350,6 +353,7 @@ impl DownloadManager { error: anyhow!("Download error"), tag, download_type: download.download_type.clone(), + transfer_failed: true, })) } DownloadProgressItem::Error(e) => { @@ -358,6 +362,7 @@ impl DownloadManager { error: e.into(), tag, download_type: download.download_type.clone(), + transfer_failed: true, })) } DownloadProgressItem::ProviderFailed { @@ -378,6 +383,7 @@ impl DownloadManager { error: err, tag, download_type: download.download_type.clone(), + transfer_failed: true, })), }; match &event { @@ -422,6 +428,7 @@ impl DownloadManager { tag: downloader.tag, error: err.into(), download_type: downloader.download_type.clone(), + transfer_failed: false, })), }, Err(e) => Some(DownloadManagerEvent::Failed(DownloadFailed { @@ -429,6 +436,7 @@ impl DownloadManager { tag: downloader.tag, error: e, download_type: downloader.download_type.clone(), + transfer_failed: true, })), } } diff --git a/shared/network/src/lib.rs b/shared/network/src/lib.rs index 4f6f60d97..9ba2adfd5 100644 --- a/shared/network/src/lib.rs +++ b/shared/network/src/lib.rs @@ -17,7 +17,8 @@ use iroh_gossip::{ net::Gossip, proto::{HyparviewConfig, PlumtreeConfig}, }; -use iroh_n0des::ApiSecret; +use iroh_services::{API_SECRET_ENV_VAR_NAME, ApiSecret, caps::NetDiagnosticsCap}; +use n0_future::task::AbortOnDropHandle; pub use p2p_model_sharing::{ MODEL_REQUEST_TIMEOUT_SECS, ModelConfigSharingMessage, ParameterSharingMessage, PeerManagerHandle, @@ -37,12 +38,10 @@ use std::{ use tokio::{ io::AsyncReadExt, select, + sync::mpsc, sync::{mpsc::UnboundedReceiver, oneshot}, task::JoinError, time::timeout, -}; -use tokio::{ - sync::mpsc, time::{Interval, interval}, }; use tokio_util::sync::CancellationToken; @@ -86,6 +85,7 @@ use iroh_relay::{RelayMap, RelayQuicConfig}; pub use latency_sorted::LatencySorted; pub use p2p_model_sharing::{ ALPN, ModelRequestType, SharableModel, SharableModelError, TransmittableModelConfig, + TransmittableModelParameter, }; pub use serde::Networkable; pub use serialized_distro::{ @@ -193,7 +193,8 @@ where metrics: Arc, endpoint: Endpoint, connection_monitor: ConnectionMonitor, - _iroh_metrics: Option, + _iroh_services_client: Option, + _iroh_diagnostics_task: Option>, } impl Debug for NetworkConnection @@ -244,6 +245,8 @@ where metrics, cancel, None, + None, + false, ) .await } @@ -277,6 +280,44 @@ where metrics, cancel, Some(additional_protocol), + None, + false, + ) + .await + } + + /// Initialize with a custom external blob store (e.g. FakeStore for testing). + /// The external store will be used for BlobsProtocol (serving blobs) while a + /// MemStore is still created internally for downloads and tag management. + #[allow(clippy::too_many_arguments)] + pub async fn init_with_blobs_store( + run_id: &str, + port: Option, + interface: Option, + discovery_mode: DiscoveryMode, + relay_kind: RelayKind, + bootstrap_peers: Vec, + secret_key: Option, + allowlist: A, + metrics: Arc, + cancel: Option, + external_blobs_store: iroh_blobs::api::Store, + relay_only: bool, + ) -> Result { + Self::init_internal::( + run_id, + port, + interface, + discovery_mode, + relay_kind, + bootstrap_peers, + secret_key, + allowlist, + metrics, + cancel, + None, + Some(external_blobs_store), + relay_only, ) .await } @@ -297,6 +338,8 @@ where metrics: Arc, cancel: Option, additional_protocol: Option<(&'static [u8], P)>, + external_blobs_store: Option, + relay_only: bool, ) -> Result { let secret_key = match secret_key { None => SecretKey::generate(&mut rand::rng()), @@ -352,7 +395,7 @@ where }; debug!("Using relay servers: {}", fmt_relay_mode(&relay_mode)); - let endpoint = Endpoint::builder() + let mut endpoint = Endpoint::builder() .secret_key(secret_key) .relay_mode(relay_mode) .transport_config(transport_config) @@ -361,6 +404,11 @@ where .hooks(allowlist_hook.clone()) .hooks(connection_monitor.clone()); + if relay_only { + info!("Relay-only mode: disabling direct IP transports"); + endpoint = endpoint.clear_ip_transports(); + } + let endpoint = match discovery_mode { DiscoveryMode::Local => { endpoint.address_lookup(local_discovery::LocalTestDiscovery::new(public_key)) @@ -395,25 +443,34 @@ where info!("Our endpoint ID: {}", endpoint_addr.id); - let iroh_metrics = { - let builder = iroh_n0des::Client::builder(&endpoint); + let iroh_services_client = { + let builder = iroh_services::Client::builder(&endpoint); let allowlist = allowlist.clone(); + (async move { - let client = builder.api_secret_from_env()?.build().await?; - const API_SECRET_ENV_VAR_NAME: &str = "N0DES_API_SECRET"; - - match std::env::var(API_SECRET_ENV_VAR_NAME) { - Ok(ticket_string) => { - let ticket = ApiSecret::from_str(&ticket_string) - .context(format!("invalid {API_SECRET_ENV_VAR_NAME}"))?; - let endpoint_id = ticket.remote.id; - allowlist.force_allow(endpoint_id); - } - Err(e) => unreachable!("{e:?}"), - } + let secret = ApiSecret::from_env_var(API_SECRET_ENV_VAR_NAME) + .context("failed to get API secret")?; + + let remote_id = secret.addr().id; + allowlist.force_allow(remote_id); + + let client = builder + .api_secret(secret)? + .build() + .await + .context("failed to build metrics client")?; + + timeout( + Duration::from_secs(10), + client.grant_capability(remote_id, vec![NetDiagnosticsCap::GetAny]), + ) + .await + .context("timed out while granting capability")? + .context("failed to grant capability")?; + Ok(client) }) - .await as anyhow::Result + .await as anyhow::Result } .map_or_else( |e| { @@ -465,14 +522,24 @@ where trace!("model parameter sharing created!"); trace!("creating router..."); - let blobs_protocol = BlobsProtocol::new(&store.clone(), None); + let blobs_protocol = match &external_blobs_store { + Some(ext_store) => BlobsProtocol::new(ext_store, None), + None => BlobsProtocol::new(&store.clone(), None), + }; let router = spawn_router( endpoint.clone(), SupportedProtocols::new(gossip.clone(), blobs_protocol, model_parameter_sharing), additional_protocol, + iroh_services_client + .as_ref() + .map(|_| iroh_services::ClientHost::new(&endpoint)), )?; trace!("router created!"); + let iroh_diagnostics_task = iroh_services_client + .as_ref() + .map(|client| spawn_network_diagnostics_loop(client.clone())); + let (gossip_tx, gossip_rx) = gossip .subscribe(gossip_topic(run_id), bootstrap_endpoint_ids) .await? @@ -500,7 +567,8 @@ where _download: Default::default(), endpoint, connection_monitor, - _iroh_metrics: iroh_metrics, + _iroh_services_client: iroh_services_client, + _iroh_diagnostics_task: iroh_diagnostics_task, }) } @@ -689,6 +757,12 @@ where Ok(()) } + pub async fn delete_tag(&self, tag_name: &str) -> anyhow::Result<()> { + let store = self.blobs_store.as_ref().clone(); + store.tags().delete(tag_name).await?; + Ok(()) + } + pub async fn endpoint_addr(&self) -> EndpointAddr { self.router.endpoint().addr() } @@ -777,6 +851,25 @@ where self.connection_monitor .update_peer_bandwidth(&peer_id, peer_bw); + // Log bandwidth every ~10 MB to track mid-download evolution without being noisy. + if update.downloaded_size_delta > 0 { + let prev = update.downloaded_size - update.downloaded_size_delta; + let interval = 10 * 1024 * 1024; // 10 MB + if update.downloaded_size / interval > prev / interval || update.all_done { + let bw_str = match &peer_bw { + PeerBandwidth::Measured(bw) => format!("{:.2} MB/s", bw / (1024.0 * 1024.0)), + PeerBandwidth::NotMeasured => "not measured".to_string(), + }; + info!( + "Download progress: {} bytes from {} | bandwidth: {} | {:?}", + update.downloaded_size, + peer_id.fmt_short(), + bw_str, + update.download_type, + ); + } + } + let hash = update.blob_ticket.hash(); if update.all_done { @@ -812,6 +905,10 @@ where self.connection_monitor.clear_all_bandwidth(); } + pub fn bandwidth_tracker_peer_bandwidth(&self, peer: &EndpointId) -> PeerBandwidth { + self.state.bandwidth_tracker.get_peer_bandwidth(peer) + } + pub fn connection_monitor(&self) -> ConnectionMonitor { self.connection_monitor.clone() } @@ -967,6 +1064,23 @@ fn hash_bytes(bytes: &Bytes) -> u64 { hasher.finish() } +fn spawn_network_diagnostics_loop(client: iroh_services::Client) -> AbortOnDropHandle<()> { + AbortOnDropHandle::new(tokio::spawn(async move { + let mut diagnostics_interval = tokio::time::interval(Duration::from_secs(60 * 60)); + diagnostics_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip); + + loop { + diagnostics_interval.tick().await; + + match timeout(Duration::from_secs(10), client.net_diagnostics(true)).await { + Ok(Ok(report)) => info!("Network diagnostics report: {report:?}"), + Ok(Err(e)) => warn!("Failed to run network diagnostics: {e:#}"), + Err(_) => warn!("Timed out while running network diagnostics"), + } + } + })) +} + // Simplified param_request_task pub async fn blob_ticket_param_request_task( model_request_type: ModelRequestType, diff --git a/shared/network/src/p2p_model_sharing.rs b/shared/network/src/p2p_model_sharing.rs index 2e51942ce..69e9e023f 100644 --- a/shared/network/src/p2p_model_sharing.rs +++ b/shared/network/src/p2p_model_sharing.rs @@ -357,7 +357,7 @@ pub struct TransmittableModelParameter { } impl TransmittableModelParameter { - fn new(param_name_bytes: Vec, param_value_bytes: Vec) -> Self { + pub fn new(param_name_bytes: Vec, param_value_bytes: Vec) -> Self { Self { param_name_bytes, param_value_bytes, diff --git a/shared/network/src/router.rs b/shared/network/src/router.rs index ef53dc067..4c2f63ec9 100644 --- a/shared/network/src/router.rs +++ b/shared/network/src/router.rs @@ -27,6 +27,7 @@ pub(crate) fn spawn_router( endpoint: Endpoint, protocols: SupportedProtocols, additional_protocol: Option<(&'static [u8], P)>, + iroh_services_host: Option, ) -> Result> { let mut builder = Router::builder(endpoint.clone()) .accept(iroh_gossip::ALPN, protocols.0) @@ -38,6 +39,10 @@ pub(crate) fn spawn_router( builder = builder.accept(alpn, handler); } + if let Some(host) = iroh_services_host { + builder = builder.accept(iroh_services::CLIENT_HOST_ALPN, host); + } + let router = Arc::new(builder.spawn()); Ok(router) @@ -79,6 +84,7 @@ mod tests { endpoint.clone(), SupportedProtocols::new(gossip.clone(), blobs_protocol, p2p_model_sharing), None, + None, )?; assert!(!router.is_shutdown()); diff --git a/shared/network/src/state.rs b/shared/network/src/state.rs index 989ef429c..17f0ba0f6 100644 --- a/shared/network/src/state.rs +++ b/shared/network/src/state.rs @@ -50,20 +50,18 @@ impl BandwidthTracker { } pub fn add_event(&mut self, from: EndpointId, num_bytes: u64) { + // Only track events with actual bytes transferred. + // Zero-byte events (TryProvider, PartComplete, ProviderFailed) are noise. + if num_bytes == 0 { + return; + } let now = Instant::now(); let events = self.events.entry(from).or_default(); events.push_back(DownloadEvent { timestamp: now, num_bytes, }); - - while let Some(event) = events.front() { - if now.duration_since(event.timestamp) > Duration::from_secs(self.average_period_secs) { - events.pop_front(); - } else { - break; - } - } + Self::prune_stale(events, now, self.average_period_secs); } pub fn clear(&mut self) { @@ -71,28 +69,86 @@ impl BandwidthTracker { } pub fn get_total_bandwidth(&self) -> f64 { - self.events.values().map(endpoint_bandwidth).sum() + let max_age = Duration::from_secs(self.average_period_secs); + let now = Instant::now(); + self.events + .values() + .map(|events| endpoint_bandwidth(events, now, max_age)) + .sum() } pub fn get_peer_bandwidth(&self, peer: &EndpointId) -> PeerBandwidth { + let max_age = Duration::from_secs(self.average_period_secs); + let now = Instant::now(); match self.events.get(peer) { None => PeerBandwidth::NotMeasured, Some(events) if events.is_empty() => PeerBandwidth::NotMeasured, - Some(events) => PeerBandwidth::Measured(endpoint_bandwidth(events)), + Some(events) => { + // If the newest event is older than the window, all data is stale + if now.duration_since(events.back().unwrap().timestamp) > max_age { + return PeerBandwidth::NotMeasured; + } + let bw = endpoint_bandwidth(events, now, max_age); + if bw > 0.0 { + PeerBandwidth::Measured(bw) + } else { + PeerBandwidth::NotMeasured + } + } + } + } + + fn prune_stale(events: &mut VecDeque, now: Instant, max_age_secs: u64) { + let max_age = Duration::from_secs(max_age_secs); + while let Some(event) = events.front() { + if now.duration_since(event.timestamp) > max_age { + events.pop_front(); + } else { + break; + } } } } -fn endpoint_bandwidth(val: &VecDeque) -> f64 { - if val.is_empty() { +/// Compute bandwidth in bytes/sec for events within the time window. +/// +/// Uses bytes transferred *after* the first event divided by the elapsed time +/// from the first event to `now`. The first event's bytes are excluded from the +/// numerator because no time has elapsed when it arrives (fencepost correction). +/// +/// Using `now` (instead of the last event's timestamp) as the time endpoint +/// ensures that congestion pauses between bursts are reflected in the +/// measurement. Without this, bursty relay traffic would report burst-rate +/// bandwidth (e.g. 8 MB/s) even when effective sustained throughput is far +/// lower (e.g. 500 KB/s). +fn endpoint_bandwidth(val: &VecDeque, now: Instant, max_age: Duration) -> f64 { + // Need at least 2 events to compute a rate between them + if val.len() < 2 { return 0.0; } - let duration = Instant::now().duration_since(val.front().unwrap().timestamp); - let total_bytes: u64 = val.iter().map(|v| v.num_bytes).sum(); - let seconds = duration.as_secs_f64(); + + // Only consider events within the window + let cutoff = now - max_age; + let mut in_window = val.iter().filter(|e| e.timestamp >= cutoff).peekable(); + + let first_in_window = match in_window.peek() { + Some(e) => *e, + None => return 0.0, + }; + + // Sum bytes from all events AFTER the first (exclude the first event's bytes) + let first_timestamp = first_in_window.timestamp; + let bytes_after_first: u64 = in_window.skip(1).map(|e| e.num_bytes).sum(); + + if bytes_after_first == 0 { + return 0.0; + } + + // Use `now` as the end point so congestion pauses dilute the measurement + let seconds = now.duration_since(first_timestamp).as_secs_f64(); if seconds > 0.0 { - total_bytes as f64 / seconds + bytes_after_first as f64 / seconds } else { 0.0 } diff --git a/tools/rust-tools/run-manager/src/commands/treasury/claim_rewards.rs b/tools/rust-tools/run-manager/src/commands/treasury/claim_rewards.rs index 3b0334c72..6f5ff63c5 100644 --- a/tools/rust-tools/run-manager/src/commands/treasury/claim_rewards.rs +++ b/tools/rust-tools/run-manager/src/commands/treasury/claim_rewards.rs @@ -1,5 +1,4 @@ use crate::commands::Command; -use anchor_lang::prelude::Pubkey; use anchor_spl::{associated_token, token}; use anyhow::{Context, Result}; use async_trait::async_trait; @@ -15,8 +14,6 @@ pub struct CommandTreasurerClaimRewards { pub run_id: String, #[clap(long, env)] pub treasurer_index: Option, - #[clap(long, env)] - pub user: Option, } #[async_trait] @@ -25,7 +22,6 @@ impl Command for CommandTreasurerClaimRewards { let Self { run_id, treasurer_index, - user, } = self; let treasurer_index = backend @@ -59,40 +55,35 @@ impl Command for CommandTreasurerClaimRewards { native_amount_to_ui_amount(treasurer_run_collateral_amount, collateral_mint_decimals) ); - let claimer = backend.get_payer(); - println!("Claimer: {claimer}"); + let user = backend.get_payer(); + println!("User: {user}"); - let claimer_collateral_address = associated_token::get_associated_token_address( - &claimer, + let user_collateral_address = associated_token::get_associated_token_address( + &user, &treasurer_run_state.collateral_mint, ); - if backend.get_balance(&claimer_collateral_address).await? == 0 { + if backend.get_balance(&user_collateral_address).await? == 0 { let instruction = associated_token::spl_associated_token_account::instruction::create_associated_token_account_idempotent( &backend.get_payer(), - &claimer, + &user, &treasurer_run_state.collateral_mint, &token::ID, ); let signature = backend - .send_and_retry("Create claimer ATA", &[instruction], &[]) + .send_and_retry("Create user ATA", &[instruction], &[]) .await?; - println!( - "Created associated token account for claimer during transaction: {signature}" - ); + println!("Created associated token account for user during transaction: {signature}"); } - let claimer_collateral_amount = backend - .get_token_account(&claimer_collateral_address) + let user_collateral_amount = backend + .get_token_account(&user_collateral_address) .await? .amount; println!( - "Claimer collateral amount: {}", - native_amount_to_ui_amount(claimer_collateral_amount, collateral_mint_decimals) + "User collateral amount: {}", + native_amount_to_ui_amount(user_collateral_amount, collateral_mint_decimals) ); - let user = user.unwrap_or(backend.get_payer()); - println!("User: {user}"); - let treasurer_participant_address = psyche_solana_treasurer::find_participant(&treasurer_run_address, &user); if backend.get_balance(&treasurer_participant_address).await? == 0 { @@ -152,8 +143,6 @@ impl Command for CommandTreasurerClaimRewards { let instruction = instructions::treasurer_participant_claim( treasurer_index, - &claimer, - &claimer_collateral_address, &treasurer_run_state.collateral_mint, &treasurer_run_state.coordinator_account, &user,