diff --git a/orange-sdk/src/trusted_wallet/spark.rs b/orange-sdk/src/trusted_wallet/spark.rs index 495d397..ec62051 100644 --- a/orange-sdk/src/trusted_wallet/spark.rs +++ b/orange-sdk/src/trusted_wallet/spark.rs @@ -25,7 +25,7 @@ use spark_wallet::{ WalletTransfer, }; -use tokio::sync::watch; +use tokio::sync::{RwLock, mpsc, watch}; use std::future::Future; use std::pin::Pin; @@ -38,7 +38,9 @@ use uuid::Uuid; /// A wallet implementation using the Breez Spark SDK. #[derive(Clone)] pub(crate) struct Spark { - spark_wallet: Arc, + spark_wallets: Arc>>>, + spark_config: SparkWalletConfig, + seed: Seed, store: Arc, event_queue: Arc, tx_metadata: TxMetadataStore, @@ -46,6 +48,8 @@ pub(crate) struct Spark { shutdown_receiver: watch::Receiver<()>, logger: Arc, runtime: Arc, + // Channel to notify the event processor about new wallets + new_wallet_sender: mpsc::UnboundedSender<(Arc, usize)>, } impl TrustedWalletInterface for Spark { @@ -53,7 +57,8 @@ impl TrustedWalletInterface for Spark { &self, ) -> Pin> + Send + '_>> { Box::pin(async move { - let sats = self.spark_wallet.get_balance().await?; + let (spark_wallet, _) = self.get_current_wallet().await; + let sats = spark_wallet.get_balance().await?; Amount::from_sats(sats).map_err(|_| TrustedError::AmountError) }) } @@ -85,7 +90,8 @@ impl TrustedWalletInterface for Spark { ) })?; - let res = self.spark_wallet.create_lightning_invoice(sats, None, None).await?; + let (spark_wallet, _) = self.get_current_wallet().await; + let res = spark_wallet.create_lightning_invoice(sats, None, None).await?; Bolt11Invoice::from_str(&res.invoice) .map_err(|e| TrustedError::Other(format!("Failed to parse invoice: {e}"))) @@ -160,8 +166,9 @@ impl TrustedWalletInterface for Spark { ) })?; - let fee_sats = self - .spark_wallet + let (spark_wallet, _) = self.get_current_wallet().await; + + let fee_sats = spark_wallet .fetch_lightning_send_fee_estimate(&invoice.to_string(), Some(sats)) .await?; @@ -186,8 +193,8 @@ impl TrustedWalletInterface for Spark { ) })?; - let res = self - .spark_wallet + let (spark_wallet, idx) = self.get_current_wallet().await; + let res = spark_wallet .pay_lightning_invoice( &invoice.to_string(), Some(sats), @@ -220,9 +227,11 @@ impl TrustedWalletInterface for Spark { // about their status. if !is_rebalance { self.poll_lightning_payment( + spark_wallet, pay.id, id, PaymentHash(invoice.payment_hash().to_byte_array()), + idx, ); } @@ -271,6 +280,7 @@ const SPARK_PRIMARY_NAMESPACE: &str = "spark"; const SPARK_SYNC_NAMESPACE: &str = "sync_info"; const SPARK_PAYMENTS_NAMESPACE: &str = "payment"; const SPARK_SYNC_OFFSET_KEY: &str = "sync_offset"; +const SPARK_WALLET_INDEX_KEY: &str = "wallet_index"; impl Spark { /// Initialize a new Spark wallet instance with the given configuration. @@ -283,10 +293,116 @@ impl Spark { Err(TrustedError::InvalidNetwork)? } - let signer = match &config.seed { + // Load existing wallet count from storage + let latest_index = + match store.read(SPARK_PRIMARY_NAMESPACE, SPARK_SYNC_NAMESPACE, SPARK_WALLET_INDEX_KEY) + { + Ok(data) => u64::from_be_bytes(data.try_into().map_err(|e| { + log_error!(logger, "Failed to convert wallet index: {e:?}"); + InitFailure::TrustedFailure(TrustedError::Other(format!( + "Failed to convert wallet index: {e:?}" + ))) + })?), + Err(e) => { + if e.kind() == io::ErrorKind::NotFound { + log_info!(logger, "No wallet index found, starting with single wallet"); + 0 + } else { + log_error!(logger, "Failed to read wallet index: {e:?}"); + return Err(InitFailure::TrustedFailure(TrustedError::IOError(e))); + } + }, + }; + + // Since we are 0 indexed, and always want to have at least one wallet, we add 1 here + let wallet_count = latest_index + 1; + + log_info!(logger, "Initializing {wallet_count} Spark wallet(s)"); + + let mut wallets = Vec::with_capacity(wallet_count as usize); + for wallet_index in 0..wallet_count { + let spark_wallet = + Self::create_wallet_at_index(wallet_index, &config.seed, &spark_config, &logger) + .await + .map_err(InitFailure::TrustedFailure)?; + + wallets.push(spark_wallet); + } + + let (shutdown_sender, shutdown_receiver) = watch::channel::<()>(()); + let (new_wallet_sender, new_wallet_receiver) = mpsc::unbounded_channel(); + + // Start unified event processing for all wallets + Self::start_unified_event_processing( + wallets.clone(), + new_wallet_receiver, + shutdown_receiver.clone(), + Arc::clone(&store), + Arc::clone(&event_queue), + Arc::clone(&logger), + Arc::clone(&runtime), + ); + + let spark_wallets = Arc::new(RwLock::new(wallets)); + + log_info!(logger, "Spark wallet initialized"); + + // Check if we need to rotate the wallet on startup + let wallets_clone = Arc::clone(&spark_wallets); + let seed_clone = config.seed.clone(); + let config_clone = spark_config.clone(); + let store_clone = Arc::clone(&store); + let logger_clone = Arc::clone(&logger); + let new_wallet_sender_clone = new_wallet_sender.clone(); + runtime.spawn(async move { + if let Err(e) = Self::rotate_wallet( + &seed_clone, + config_clone, + &store_clone, + wallets_clone, + &logger_clone, + &new_wallet_sender_clone, + ) + .await + { + log_error!(logger_clone, "Failed to check wallet rotation on startup: {e:?}"); + } + }); + + Ok(Spark { + spark_wallets, + spark_config, + seed: config.seed.clone(), + store, + event_queue, + tx_metadata, + shutdown_sender, + shutdown_receiver, + logger, + runtime, + new_wallet_sender, + }) + } + + async fn get_current_wallet(&self) -> (Arc, usize) { + let wallets = self.spark_wallets.read().await; + let (idx, w) = + wallets.iter().enumerate().next_back().expect("must have at least one wallet"); + (Arc::clone(w), idx) + } + + /// Create a new SparkWallet for the given wallet index + async fn create_wallet_at_index( + wallet_index: u64, seed: &Seed, spark_config: &SparkWalletConfig, logger: &Logger, + ) -> Result, TrustedError> { + let signer = match seed { Seed::Seed64(bytes) => { // hash the seed to make sure it does not conflict with the lightning keys - let seed = Sha256::hash(bytes); + let mut seed = Sha256::hash(bytes); + // hash the seed `wallet_index` times to derive a new seed for each wallet + for _i in 0..wallet_index { + seed = Sha256::hash(seed.as_byte_array()); + } DefaultSigner::new(&seed[..], spark_config.network) .map_err(|e| TrustedError::Other(format!("Failed to create signer: {e}")))? }, @@ -294,7 +410,18 @@ impl Spark { // We don't hash the seed here, as mnemonics are meant to be easily recoverable // and if we hashed them, then you could not recover your spark coins from the mnemonic // in separate wallets. - let seed = mnemonic.to_seed(passphrase.as_deref().unwrap_or("")); + let mnemonic_seed = mnemonic.to_seed(passphrase.as_deref().unwrap_or("")); + let seed = if wallet_index == 0 { + // For the first wallet, use the original seed directly to maintain backwards compatibility + Sha256::hash(&mnemonic_seed) + } else { + // For subsequent wallets, hash the mnemonic seed wallet_index times + let mut seed = Sha256::hash(&mnemonic_seed); + for _i in 0..wallet_index { + seed = Sha256::hash(seed.as_byte_array()); + } + seed + }; DefaultSigner::new(&seed[..], spark_config.network) .map_err(|e| TrustedError::Other(format!("Failed to create signer: {e}")))? }, @@ -302,153 +429,318 @@ impl Spark { let pk = signer.get_identity_public_key().map_err(|e| TrustedError::Other(format!("{e:?}")))?; - log_info!(logger, "Starting Spark wallet with public key: {pk}"); + log_info!(logger, "Starting Spark wallet {wallet_index} with public key: {pk}"); - let spark_wallet = - Arc::new(SparkWallet::connect(spark_config, Arc::new(signer)).await.map_err(|e| { - log_error!(logger, "Failed to connect to Spark wallet: {e:?}"); - InitFailure::TrustedFailure(e.into()) - })?); + let spark_wallet = Arc::new( + SparkWallet::connect(spark_config.clone(), Arc::new(signer)).await.map_err(|e| { + log_error!(logger, "Failed to connect to Spark wallet {wallet_index}: {e:?}"); + TrustedError::from(e) + })?, + ); - let (shutdown_sender, shutdown_receiver) = watch::channel::<()>(()); - let mut events = spark_wallet.subscribe_events(); - let l = Arc::clone(&logger); - let w = Arc::clone(&spark_wallet); - let s = Arc::clone(&store); - let eq = Arc::clone(&event_queue); - let mut shutdown_recv = shutdown_receiver.clone(); + Ok(spark_wallet) + } + + /// Start unified event processing that handles events from all wallets in a single thread + fn start_unified_event_processing( + initial_wallets: Vec>, + mut new_wallet_receiver: mpsc::UnboundedReceiver<(Arc, usize)>, + mut shutdown_receiver: watch::Receiver<()>, store: Arc, + event_queue: Arc, logger: Arc, runtime: Arc, + ) { + let runtime_clone = Arc::clone(&runtime); + runtime.spawn(async move { + log_info!( + logger, + "Starting unified event processing for {} wallets", + initial_wallets.len() + ); + + // Create a unified channel for all wallet events + let (unified_sender, mut unified_receiver) = + mpsc::unbounded_channel::<(usize, WalletEvent, Arc)>(); + + // Spawn forwarder tasks for initial wallets + for (index, wallet) in initial_wallets.into_iter().enumerate() { + Self::spawn_wallet_forwarder( + index, + wallet, + unified_sender.clone(), + shutdown_receiver.clone(), + Arc::clone(&runtime_clone), + Arc::clone(&logger), + ); + } + + // Main event processing loop + loop { + tokio::select! { + _ = shutdown_receiver.changed() => { + log_info!(logger, "Unified event processing shutdown signal received"); + return; + } + // Handle new wallets being added + new_wallet_opt = new_wallet_receiver.recv() => { + if let Some((new_wallet, index)) = new_wallet_opt { + log_info!(logger, "Adding new wallet {index} to unified event processing"); + Self::spawn_wallet_forwarder( + index, + new_wallet, + unified_sender.clone(), + shutdown_receiver.clone(), + Arc::clone(&runtime_clone), + Arc::clone(&logger), + ); + } + } + // Handle events from any wallet + event_opt = unified_receiver.recv() => { + if let Some((wallet_index, event, wallet)) = event_opt { + log_debug!(logger, "Spark wallet {wallet_index} event: {event:?}"); + if let Err(e) = Self::handle_wallet_event( + event, + wallet_index, + wallet, + &store, + &event_queue, + &logger, + ).await { + log_error!(logger, "Failed to handle wallet {wallet_index} event: {e:?}"); + } + } + } + } + } + }); + } + + /// Spawn a forwarder task for a single wallet that forwards its events to the unified channel + fn spawn_wallet_forwarder( + wallet_index: usize, wallet: Arc, + unified_sender: mpsc::UnboundedSender<(usize, WalletEvent, Arc)>, + mut shutdown_receiver: watch::Receiver<()>, runtime: Arc, logger: Arc, + ) -> tokio::task::JoinHandle<()> { + let wallet_clone = Arc::clone(&wallet); runtime.spawn(async move { + log_debug!(logger, "Starting event forwarder for wallet {wallet_index}"); + let mut events = wallet.subscribe_events(); + loop { tokio::select! { - _ = shutdown_recv.changed() => { - log_info!(l, "Deposit tracking loop shutdown signal received"); + _ = shutdown_receiver.changed() => { + log_debug!(logger, "Event forwarder for wallet {wallet_index} shutdown"); return; } event = events.recv() => { match event { Ok(event) => { - log_debug!(l, "Spark event: {event:?}"); - match event { - WalletEvent::DepositConfirmed(node_id) => { - if let Ok(transfers) = w.list_transfers(None).await { - if let Some(transfer) = transfers - .into_iter() - .find(|t| t.leaves.iter().any(|l| l.leaf.id == node_id)) - { - eq.add_event(Event::OnchainPaymentReceived { - payment_id: PaymentId::Trusted( - convert_from_transfer_id(transfer.id.to_bytes()), - ), - // todo this is kinda hacky, maybe we should make this optional - txid: transfer - .leaves - .iter() - .find(|t| t.leaf.id == node_id) - .map(|t| t.leaf - .node_tx - .compute_txid()) - .unwrap_or(Txid::all_zeros()), - amount_sat: transfer.total_value_sat, - status: ConfirmationStatus::Unconfirmed, // fixme dont have block height - }) - .unwrap(); - } - } - }, - WalletEvent::StreamConnected => { - log_debug!(l, "Spark wallet stream connected"); - }, - WalletEvent::StreamDisconnected => { - log_debug!(l, "Spark wallet stream connected"); - }, - WalletEvent::Synced => { - log_debug!(l, "Spark wallet synced"); - if let Err(e) = Self::sync_payments_to_storage(w.as_ref(), &s, l.as_ref()).await { - log_error!(l, "Failed to sync payments to storage: {e:?}"); - } else { - log_info!(l, "Payments synced to storage"); - } - }, - WalletEvent::TransferClaimed(transfer) => { - if let Err(e) = Self::sync_payments_to_storage(w.as_ref(), &s, l.as_ref()).await { - log_error!(l, "Failed to sync payments to storage: {e:?}"); - } else { - log_info!(l, "Payments synced to storage"); - } - - match transfer.user_request { - None => { - log_debug!(l, "Transfer claimed without user request: {transfer:?}"); - }, - Some(SspUserRequest::LightningReceiveRequest(req)) => { - if let Ok(hash) = FromHex::from_hex(&req.invoice.payment_hash) { - eq.add_event(Event::PaymentReceived { - payment_id: PaymentId::Trusted(convert_from_transfer_id(transfer.id.to_bytes())), - payment_hash: PaymentHash(hash), - amount_msat: transfer.total_value_sat * 1_000, // convert to msats - custom_records: vec![], - lsp_fee_msats: None, - }) - .unwrap(); - } - }, - Some(req) => { - log_debug!(l, "Transfer claimed with user request: {req:?}"); - } - } - }, + if let Err(e) = unified_sender.send((wallet_index, event, Arc::clone(&wallet_clone))) { + log_error!(logger, "Failed to forward event from wallet {wallet_index}: {e:?}"); + return; // Channel closed, exit forwarder } }, Err(e) => { - log_debug!(l, "Spark event error: {e:?}"); + log_debug!(logger, "Wallet {wallet_index} event receiver error: {e:?}"); + // Continue listening for events despite this error }, } } } } - }); + }) + } - log_info!(logger, "Spark wallet initialized"); + /// Handle a wallet event from any wallet + async fn handle_wallet_event( + event: WalletEvent, wallet_index: usize, wallet: Arc, + store: &Arc, event_queue: &Arc, logger: &Logger, + ) -> Result<(), TrustedError> { + match event { + WalletEvent::DepositConfirmed(node_id) => { + if let Ok(transfers) = wallet.list_transfers(None).await { + if let Some(transfer) = transfers + .into_iter() + .find(|t| t.leaves.iter().any(|l| l.leaf.id == node_id)) + { + event_queue + .add_event(Event::OnchainPaymentReceived { + payment_id: PaymentId::Trusted(convert_from_transfer_id( + transfer.id.to_bytes(), + )), + // todo this is kinda hacky, maybe we should make this optional + txid: transfer + .leaves + .iter() + .find(|t| t.leaf.id == node_id) + .map(|t| t.leaf.node_tx.compute_txid()) + .unwrap_or(Txid::all_zeros()), + amount_sat: transfer.total_value_sat, + status: ConfirmationStatus::Unconfirmed, // fixme dont have block height + }) + .unwrap(); + } + } + }, + WalletEvent::StreamConnected => { + log_debug!(logger, "Spark wallet {wallet_index} stream connected"); + }, + WalletEvent::StreamDisconnected => { + log_debug!(logger, "Spark wallet {wallet_index} stream disconnected"); + }, + WalletEvent::Synced => { + log_debug!(logger, "Spark wallet {wallet_index} synced"); + if let Err(e) = + Self::sync_payments_to_storage(wallet.as_ref(), store, logger, wallet_index) + .await + { + log_error!( + logger, + "Failed to sync payments to storage for wallet {wallet_index}: {e:?}" + ); + } else { + log_info!(logger, "Payments synced to storage for wallet {wallet_index}"); + } + }, + WalletEvent::TransferClaimed(transfer) => { + if let Err(e) = + Self::sync_payments_to_storage(wallet.as_ref(), store, logger, wallet_index) + .await + { + log_error!( + logger, + "Failed to sync payments to storage for wallet {wallet_index}: {e:?}" + ); + } else { + log_info!(logger, "Payments synced to storage for wallet {wallet_index}"); + } - Ok(Spark { - spark_wallet, - store, - event_queue, - tx_metadata, - shutdown_sender, - shutdown_receiver, - logger, - runtime, - }) + match transfer.user_request { + None => { + log_debug!( + logger, + "Transfer claimed without user request for wallet {wallet_index}: {transfer:?}" + ); + }, + Some(SspUserRequest::LightningReceiveRequest(req)) => { + if let Ok(hash) = FromHex::from_hex(&req.invoice.payment_hash) { + event_queue + .add_event(Event::PaymentReceived { + payment_id: PaymentId::Trusted(convert_from_transfer_id( + transfer.id.to_bytes(), + )), + payment_hash: PaymentHash(hash), + amount_msat: transfer.total_value_sat * 1_000, // convert to msats + custom_records: vec![], + lsp_fee_msats: None, + }) + .unwrap(); + + // If this payment is from an old wallet and above a threshold, we should + // move the funds to the latest wallet. + } + }, + Some(req) => { + log_debug!( + logger, + "Transfer claimed with user request for wallet {wallet_index}: {req:?}" + ); + }, + } + }, + } + Ok(()) + } + + async fn rotate_wallet( + seed: &Seed, spark_config: SparkWalletConfig, store: &Arc, + spark_wallets: Arc>>>, logger: &Logger, + new_wallet_sender: &mpsc::UnboundedSender<(Arc, usize)>, + ) -> Result<(), TrustedError> { + // make sure we lock during the whole rotation process + let mut wallets = spark_wallets.write().await; + + // check if we need to rotate + let current_wallet = wallets.iter().last().expect("must have at least one wallet"); + // balance > 1 sat means we have funds, so no need to rotate + let bal = current_wallet.get_balance().await?; + if bal > 1 { + return Ok(()); + } + // check we actually have txs in the latest wallet + let txs = current_wallet.list_transfers(None).await?; + if txs.is_empty() { + // no txs, no need to rotate + return Ok(()); + } + + let next_index = wallets.len(); + log_info!(logger, "Rotating Spark wallet, new index: {next_index}"); + + let new_wallet = + Self::create_wallet_at_index(next_index as u64, seed, &spark_config, logger).await?; + + // The new wallet index is the current length (0-based indexing) + wallets.push(Arc::clone(&new_wallet)); + + // Add the new wallet to the unified event processing + if let Err(e) = new_wallet_sender.send((Arc::clone(&new_wallet), next_index)) { + log_error!(logger, "Failed to notify event processor about new wallet: {e:?}"); + } else { + log_info!(logger, "New wallet {next_index} added to event processing"); + } + + // convert to u64 for storage + let index = next_index as u64; + store + .write( + SPARK_PRIMARY_NAMESPACE, + SPARK_SYNC_NAMESPACE, + SPARK_WALLET_INDEX_KEY, + &index.to_be_bytes(), + ) + .map_err(TrustedError::IOError)?; + + log_info!(logger, "Spark wallet rotated successfully"); + Ok(()) } - /// Synchronizes payments from transfers to persistent storage + /// Generate a unique sync offset key for each wallet using the wallet index + fn get_sync_offset_key_for_wallet(wallet_index: usize) -> String { + if wallet_index == 0 { + // For backwards compatibility, the first wallet uses the original key + SPARK_SYNC_OFFSET_KEY.to_string() + } else { + format!("{}_{}", SPARK_SYNC_OFFSET_KEY, wallet_index) + } + } + + /// Synchronizes payments from transfers to persistent storage with a custom offset key async fn sync_payments_to_storage( spark_wallet: &SparkWallet, store: &Arc, logger: &Logger, + wallet_index: usize, ) -> Result<(), TrustedError> { // sync payments const BATCH_SIZE: u64 = 50; + let offset_key = Self::get_sync_offset_key_for_wallet(wallet_index); + // Get the last offset we processed from storage - let current_offset = match store.read( - SPARK_PRIMARY_NAMESPACE, - SPARK_SYNC_NAMESPACE, - SPARK_SYNC_OFFSET_KEY, - ) { - Ok(data) => u64::from_be_bytes(data.try_into().map_err(|e| { - TrustedError::Other(format!("Failed to convert sync offset: {e:?}")) - })?), - Err(e) => { - if e.kind() == io::ErrorKind::NotFound { - // If not found, start from the beginning - log_info!(logger, "No sync info found, starting from offset 0"); - 0 - } else { - log_error!(logger, "Failed to read sync info: {e:?}"); - return Err(TrustedError::IOError(e)); - } - }, - }; + let current_offset = + match store.read(SPARK_PRIMARY_NAMESPACE, SPARK_SYNC_NAMESPACE, &offset_key) { + Ok(data) => u64::from_be_bytes(data.try_into().map_err(|e| { + TrustedError::Other(format!("Failed to convert sync offset: {e:?}")) + })?), + Err(e) => { + if e.kind() == io::ErrorKind::NotFound { + // If not found, start from the beginning + log_info!(logger, "No sync info found, starting from offset 0"); + 0 + } else { + log_error!(logger, "Failed to read sync info: {e:?}"); + return Err(TrustedError::IOError(e)); + } + }, + }; // We'll keep querying in batches until we have all transfers let mut next_offset = current_offset; @@ -500,7 +792,7 @@ impl Spark { let save_res = store.write( SPARK_PRIMARY_NAMESPACE, SPARK_SYNC_NAMESPACE, - SPARK_SYNC_OFFSET_KEY, + &offset_key, &saved_offset.to_be_bytes(), ); @@ -515,16 +807,20 @@ impl Spark { /// Pools the lightning payment until it is in completed state. fn poll_lightning_payment( - &self, spark_id: String, payment_id: [u8; 32], payment_hash: PaymentHash, + &self, spark_wallet: Arc, spark_id: String, payment_id: [u8; 32], + payment_hash: PaymentHash, index: usize, ) { const MAX_POLL_ATTEMPTS: u64 = 10; log_info!(self.logger, "Polling lightning send payment {spark_id}"); let mut shutdown = self.shutdown_receiver.clone(); - let spark_wallet = Arc::clone(&self.spark_wallet); let event_queue = Arc::clone(&self.event_queue); let store = Arc::clone(&self.store); let logger = Arc::clone(&self.logger); + let wallets = Arc::clone(&self.spark_wallets); + let seed = self.seed.clone(); + let config = self.spark_config.clone(); + let new_wallet_sender = self.new_wallet_sender.clone(); self.runtime.spawn(async move { for i in 0..MAX_POLL_ATTEMPTS { log_info!(logger, "Polling lightning send payment {spark_id} attempt {i}",); @@ -555,11 +851,16 @@ impl Spark { }) .unwrap(); - if let Err(e) = Self::sync_payments_to_storage(spark_wallet.as_ref(), &store, logger.as_ref()).await { + if let Err(e) = Self::sync_payments_to_storage(spark_wallet.as_ref(), &store, logger.as_ref(), index).await { log_error!(logger, "Failed to sync payments to storage: {e:?}"); } else { log_info!(logger, "Payments synced to storage"); } + + // on success check if we should rotate the wallet + if let Err(e) = Self::rotate_wallet(&seed, config, &store, wallets, &logger, &new_wallet_sender).await { + log_error!(logger, "Failed to rotate wallet: {e:?}"); + } return; } else { log_debug!(logger, "Polling payment completed but no preimage yet");