diff --git a/Cargo.toml b/Cargo.toml index 71d85f0..493717e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,9 +16,9 @@ codegen-units = 1 # Reduce number of codegen units to increase optimizations. panic = 'abort' # Abort on panic [workspace.dependencies] -bitcoin-payment-instructions = { git = "https://github.com/benthecarman/bitcoin-payment-instructions.git", branch = "orange-fork", features = ["http"] } -lightning = { git = "https://github.com/tnull/rust-lightning", branch = "2025-08-bump-electrum-client-0.1" } -lightning-invoice = { git = "https://github.com/tnull/rust-lightning", branch = "2025-08-bump-electrum-client-0.1" } +bitcoin-payment-instructions = { version = "0.6.0" } +lightning = { version = "0.2.0" } +lightning-invoice = { version = "0.34.0" } [profile.release] panic = "abort" diff --git a/examples/cli/src/main.rs b/examples/cli/src/main.rs index 5b00aad..7821335 100644 --- a/examples/cli/src/main.rs +++ b/examples/cli/src/main.rs @@ -16,7 +16,6 @@ use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; use std::sync::atomic::{AtomicBool, Ordering}; -use tokio::runtime::Runtime; use tokio::signal; const NETWORK: Network = Network::Bitcoin; // Supports Bitcoin and Regtest @@ -58,7 +57,6 @@ enum Commands { struct WalletState { wallet: Wallet, - _runtime: Arc, // Keep runtime alive shutdown: Arc, } @@ -128,20 +126,20 @@ fn get_config(network: Network) -> Result { } impl WalletState { - async fn new(runtime: Arc) -> Result { + async fn new() -> Result { let shutdown = Arc::new(AtomicBool::new(false)); let config = get_config(NETWORK) .with_context(|| format!("Failed to get wallet config for network: {NETWORK:?}"))?; println!("{} Initializing wallet...", "⚡".bright_yellow()); - match Wallet::new_with_runtime(runtime.clone(), config).await { + match Wallet::new(config).await { Ok(wallet) => { println!("{} Wallet initialized successfully!", "✅".bright_green()); println!("Network: {}", NETWORK.to_string().bright_cyan()); let w = wallet.clone(); - runtime.spawn(async move { + tokio::spawn(async move { let event = w.next_event_async().await; match event { Event::PaymentSuccessful { payment_id, .. } => { @@ -193,12 +191,19 @@ impl WalletState { fee_msat ); }, + Event::SplicePending { new_funding_txo, .. } => { + println!( + "{} Splice pending: {}", + "🔄".bright_yellow(), + new_funding_txo + ); + }, } w.event_handled().unwrap(); }); - Ok(WalletState { wallet, _runtime: runtime, shutdown }) + Ok(WalletState { wallet, shutdown }) }, Err(e) => Err(anyhow::anyhow!("Failed to initialize wallet: {:?}", e)), } @@ -213,23 +218,21 @@ impl WalletState { } } -fn main() -> Result<()> { +#[tokio::main(flavor = "multi_thread")] +async fn main() -> Result<()> { let cli = Cli::parse(); println!("{}", "🟠 Orange CLI Wallet".bright_yellow().bold()); println!("{}", "Type 'help' for available commands or 'exit' to quit".dimmed()); println!(); - // Create runtime outside async context to avoid drop issues - let runtime = Arc::new(Runtime::new().context("Failed to create tokio runtime")?); - // Initialize wallet once at startup - let mut state = runtime.block_on(WalletState::new(runtime.clone()))?; + let mut state = WalletState::new().await?; // Set up signal handling for graceful shutdown let shutdown_state = state.shutdown.clone(); let shutdown_wallet = state.wallet.clone(); - runtime.spawn(async move { + tokio::task::spawn(async move { if let Ok(()) = signal::ctrl_c().await { println!("\n{} Shutdown signal received, stopping wallet...", "âšī¸".bright_yellow()); shutdown_state.store(true, Ordering::Relaxed); @@ -241,12 +244,12 @@ fn main() -> Result<()> { // If a command was provided via command line, execute it and start interactive mode if let Some(command) = cli.command { - runtime.block_on(execute_command(command, &mut state))?; + execute_command(command, &mut state).await?; println!(); } // Start interactive mode - runtime.block_on(start_interactive_mode(state)) + start_interactive_mode(state).await } async fn start_interactive_mode(mut state: WalletState) -> Result<()> { diff --git a/graduated-rebalancer/src/lib.rs b/graduated-rebalancer/src/lib.rs index 6c68b41..827d0ba 100644 --- a/graduated-rebalancer/src/lib.rs +++ b/graduated-rebalancer/src/lib.rs @@ -108,6 +108,9 @@ pub trait LightningWallet: Send + Sync { &self, payment_hash: [u8; 32], ) -> Pin> + Send + '_>>; + /// Check if we already have a channel with the LSP + fn has_channel_with_lsp(&self) -> bool; + /// Open a channel with the LSP using on-chain funds fn open_channel_with_lsp( &self, amt: Amount, @@ -117,6 +120,16 @@ pub trait LightningWallet: Send + Sync { fn await_channel_pending( &self, channel_id: u128, ) -> Pin + Send + '_>>; + + /// Splice funds from on-chain to an existing channel with the LSP + fn splice_to_lsp_channel( + &self, amt: Amount, + ) -> Pin> + Send + '_>>; + + /// Wait for a splice pending notification, returns the splice outpoint + fn await_splice_pending( + &self, channel_id: u128, + ) -> Pin + Send + '_>>; } /// Represents a payment from the lightning wallet @@ -176,7 +189,8 @@ pub enum RebalancerEvent { /// Trait for handling rebalancer events pub trait EventHandler: Send + Sync { /// Handle a rebalancer event - fn handle_event(&self, event: RebalancerEvent); + fn handle_event(&self, event: RebalancerEvent) + -> Pin + Send + '_>>; } /// A no-op event handler that discards all events @@ -184,7 +198,10 @@ pub trait EventHandler: Send + Sync { pub struct IgnoringEventHandler; impl EventHandler for IgnoringEventHandler { - fn handle_event(&self, _event: RebalancerEvent) { + fn handle_event( + &self, _event: RebalancerEvent, + ) -> Pin + Send + '_>> { + Box::pin(async move {}) // Do nothing } } @@ -260,11 +277,13 @@ where rebalance_id.as_hex() ); - self.event_handler.handle_event(RebalancerEvent::RebalanceInitiated { - trigger_id: params.id, - trusted_rebalance_payment_id: rebalance_id, - amount_msat: transfer_amt.milli_sats(), - }); + self.event_handler + .handle_event(RebalancerEvent::RebalanceInitiated { + trigger_id: params.id, + trusted_rebalance_payment_id: rebalance_id, + amount_msat: transfer_amt.milli_sats(), + }) + .await; let ln_payment = match self .ln_wallet @@ -297,14 +316,16 @@ where ln_payment.id.as_hex(), ); - self.event_handler.handle_event(RebalancerEvent::RebalanceSuccessful { - trigger_id: params.id, - trusted_rebalance_payment_id: rebalance_id, - ln_rebalance_payment_id: ln_payment.id, - amount_msat: transfer_amt.milli_sats(), - fee_msat: ln_payment.fee_paid_msat.unwrap_or_default() - + trusted_payment.fee_paid_msat.unwrap_or_default(), - }); + self.event_handler + .handle_event(RebalancerEvent::RebalanceSuccessful { + trigger_id: params.id, + trusted_rebalance_payment_id: rebalance_id, + ln_rebalance_payment_id: ln_payment.id, + amount_msat: transfer_amt.milli_sats(), + fee_msat: ln_payment.fee_paid_msat.unwrap_or_default() + + trusted_payment.fee_paid_msat.unwrap_or_default(), + }) + .await; }, Err(e) => { log_info!(self.logger, "Rebalance trusted transaction failed with {e:?}",); @@ -313,34 +334,55 @@ where } } - /// Perform on-chain to lightning rebalance by opening a channel + /// Perform on-chain to lightning rebalance by opening a channel or splicing into an existing one async fn do_onchain_rebalance(&self, params: TriggerParams) { - // This should open a channel with the LSP using available on-chain funds - let _ = self.balance_mutex.lock().await; - log_info!(self.logger, "Opening channel with LSP with on-chain funds"); + let (channel_outpoint, user_channel_id) = if self.ln_wallet.has_channel_with_lsp() { + log_info!(self.logger, "Splicing into channel with LSP with on-chain funds"); - // todo for now we can only open a channel, eventually move to splicing - let user_chan_id = match self.ln_wallet.open_channel_with_lsp(params.amount).await { - Ok(chan_id) => chan_id, - Err(e) => { - log_error!(self.logger, "Failed to open channel with LSP: {e:?}"); - return; - }, - }; + let user_chan_id = match self.ln_wallet.splice_to_lsp_channel(params.amount).await { + Ok(chan_id) => chan_id, + Err(e) => { + log_error!(self.logger, "Failed to open channel with LSP: {e:?}"); + return; + }, + }; + + log_info!(self.logger, "Initiated splice opened with LSP"); + + let channel_outpoint = self.ln_wallet.await_splice_pending(user_chan_id).await; + + log_info!(self.logger, "Splice initiated at: {channel_outpoint}"); + + (channel_outpoint, user_chan_id) + } else { + log_info!(self.logger, "Opening channel with LSP with on-chain funds"); - log_info!(self.logger, "Initiated channel opened with LSP"); + let user_chan_id = match self.ln_wallet.open_channel_with_lsp(params.amount).await { + Ok(chan_id) => chan_id, + Err(e) => { + log_error!(self.logger, "Failed to open channel with LSP: {e:?}"); + return; + }, + }; + + log_info!(self.logger, "Initiated channel opened with LSP"); - let channel_outpoint = self.ln_wallet.await_channel_pending(user_chan_id).await; + let channel_outpoint = self.ln_wallet.await_channel_pending(user_chan_id).await; - log_info!(self.logger, "Channel open succeeded at: {channel_outpoint}",); + log_info!(self.logger, "Channel open succeeded at: {channel_outpoint}"); + + (channel_outpoint, user_chan_id) + }; - self.event_handler.handle_event(RebalancerEvent::OnChainRebalanceInitiated { - trigger_id: params.id, - user_channel_id: user_chan_id, - channel_outpoint, - }); + self.event_handler + .handle_event(RebalancerEvent::OnChainRebalanceInitiated { + trigger_id: params.id, + user_channel_id, + channel_outpoint, + }) + .await; } /// Stops the rebalancer, waits for any active rebalances to complete diff --git a/justfile b/justfile index 027a201..e55c86f 100644 --- a/justfile +++ b/justfile @@ -2,10 +2,16 @@ default: @just --list test *args: - cargo test {{ args }} --features _test-utils -p orange-sdk + #!/usr/bin/env bash + THREADS=$(($(nproc) / 2)) + if [ $THREADS -lt 1 ]; then THREADS=1; fi + cargo test {{ args }} --features _test-utils -p orange-sdk -- --test-threads=$THREADS test-cashu *args: - cargo test {{ args }} --features _cashu-tests -p orange-sdk + #!/usr/bin/env bash + THREADS=$(($(nproc) / 2)) + if [ $THREADS -lt 1 ]; then THREADS=1; fi + cargo test {{ args }} --features _cashu-tests -p orange-sdk -- --test-threads=$THREADS cli: cd examples/cli && cargo run diff --git a/orange-sdk/Cargo.toml b/orange-sdk/Cargo.toml index 6133bb5..29a7356 100644 --- a/orange-sdk/Cargo.toml +++ b/orange-sdk/Cargo.toml @@ -17,29 +17,33 @@ default = ["spark"] uniffi = ["dep:uniffi", "spark", "cashu"] spark = ["breez-sdk-spark", "uuid", "serde_json"] cashu = ["cdk", "serde_json"] -_test-utils = ["corepc-node", "cashu", "uuid/v7", "rand"] +_test-utils = ["corepc-node", 'electrsd', "cashu", "uuid/v7", "rand"] _cashu-tests = ["_test-utils", "cdk-ldk-node", "cdk/mint", "cdk-sqlite", "cdk-axum", "axum"] [dependencies] graduated-rebalancer = { path = "../graduated-rebalancer", version = "0.1.0" } -ldk-node = { git = "https://github.com/benthecarman/ldk-node.git", branch = "esplora-auth" } -bitcoin-payment-instructions = { workspace = true } +ldk-node = { version = "0.7.0" } +lightning-macros = "0.2.0" +bitcoin-payment-instructions = { workspace = true, features = ["http"] } chrono = { version = "0.4", default-features = false } rand = { version = "0.8.5", optional = true } -reqwest = { version = "0.12.23", default-features = false, features = ["rustls-tls"] } breez-sdk-spark = { git = "https://github.com/breez/spark-sdk.git", rev = "1f2e9995230cd582d6b4aa7d06d76b99defb635e", default-features = false, features = ["rustls-tls"], optional = true } tokio = { version = "1.0", default-features = false, features = ["rt-multi-thread", "sync"] } uuid = { version = "1.0", default-features = false, optional = true } -cdk = { git = "https://github.com/benthecarman/cdk.git", rev = "39c1206a4a1dda2adc1f3e23628136ef645f6c6b", default-features = false, features = ["wallet"], optional = true } +cdk = { version = "0.14.2", default-features = false, features = ["wallet"], optional = true } serde_json = { version = "1.0", optional = true } async-trait = "0.1" log = "0.4.28" -corepc-node = { version = "0.8.0", features = ["29_0", "download"], optional = true } -cdk-ldk-node = { git = "https://github.com/benthecarman/cdk.git", rev = "39c1206a4a1dda2adc1f3e23628136ef645f6c6b", optional = true } -cdk-sqlite = { git = "https://github.com/benthecarman/cdk.git", rev = "39c1206a4a1dda2adc1f3e23628136ef645f6c6b", optional = true } -cdk-axum = { git = "https://github.com/benthecarman/cdk.git", rev = "39c1206a4a1dda2adc1f3e23628136ef645f6c6b", optional = true } +corepc-node = { version = "0.10.1", features = ["29_0", "download"], optional = true } +electrsd = { version = "0.36.1", default-features = false, features = ["esplora_a33e97e1", "corepc-node_29_0"], optional = true } +cdk-ldk-node = { version = "0.14.2", optional = true } +cdk-sqlite = { version = "0.14.2", optional = true } +cdk-axum = { version = "0.14.2", optional = true } axum = { version = "0.8.1", optional = true } uniffi = { version = "0.29", features = ["cli", "tokio"], optional = true } + +[dev-dependencies] +test-log = "0.2.18" diff --git a/orange-sdk/src/event.rs b/orange-sdk/src/event.rs index 7332123..a815fd1 100644 --- a/orange-sdk/src/event.rs +++ b/orange-sdk/src/event.rs @@ -11,12 +11,12 @@ use ldk_node::lightning::util::ser::{Writeable, Writer}; use ldk_node::lightning::{impl_writeable_tlv_based_enum, log_debug, log_error, log_warn}; use ldk_node::lightning_types::payment::{PaymentHash, PaymentPreimage}; use ldk_node::payment::{ConfirmationStatus, PaymentKind}; -use ldk_node::{CustomTlvRecord, UserChannelId}; +use ldk_node::{CustomTlvRecord, DynStore, UserChannelId}; use std::collections::VecDeque; -use std::sync::{Arc, Condvar, Mutex}; +use std::sync::Arc; use std::task::{Poll, Waker}; -use tokio::sync::watch; +use tokio::sync::{Mutex, watch}; /// The event queue will be persisted under this key. pub(crate) const EVENT_QUEUE_PERSISTENCE_PRIMARY_NAMESPACE: &str = ""; @@ -131,6 +131,17 @@ pub enum Event { /// The fee paid, in msats, for the rebalance payment. fee_msat: u64, }, + /// We have initiated a splice and are waiting for it to confirm. + SplicePending { + /// The `channel_id` of the channel. + channel_id: ChannelId, + /// The `user_channel_id` of the channel. + user_channel_id: UserChannelId, + /// The `node_id` of the channel counterparty. + counterparty_node_id: PublicKey, + /// The outpoint of the channel's splice funding transaction. + new_funding_txo: OutPoint, + }, } impl_writeable_tlv_based_enum!(Event, @@ -182,6 +193,12 @@ impl_writeable_tlv_based_enum!(Event, (6, amount_msat, required), (8, fee_msat, required), }, + (8, SplicePending) => { + (1, channel_id, required), + (3, counterparty_node_id, required), + (5, user_channel_id, required), + (7, new_funding_txo, required), + }, ); /// A queue for events emitted by the [`Wallet`]. @@ -190,36 +207,34 @@ impl_writeable_tlv_based_enum!(Event, pub struct EventQueue { queue: Arc>>, waker: Arc>>, - notifier: Condvar, - kv_store: Arc, + kv_store: Arc, logger: Arc, } impl EventQueue { - pub(crate) fn new(kv_store: Arc, logger: Arc) -> Self { + pub(crate) fn new(kv_store: Arc, logger: Arc) -> Self { let queue = Arc::new(Mutex::new(VecDeque::new())); let waker = Arc::new(Mutex::new(None)); - let notifier = Condvar::new(); - Self { queue, waker, notifier, kv_store, logger } + Self { queue, waker, kv_store, logger } } - pub(crate) fn add_event(&self, event: Event) -> Result<(), ldk_node::lightning::io::Error> { + pub(crate) async fn add_event( + &self, event: Event, + ) -> Result<(), ldk_node::lightning::io::Error> { { - let mut locked_queue = self.queue.lock().unwrap(); + let mut locked_queue = self.queue.lock().await; locked_queue.push_back(event); - self.persist_queue(&locked_queue)?; + self.persist_queue(&locked_queue).await?; } - self.notifier.notify_one(); - - if let Some(waker) = self.waker.lock().unwrap().take() { + if let Some(waker) = self.waker.lock().await.take() { waker.wake(); } Ok(()) } - pub(crate) fn next_event(&self) -> Option { - let locked_queue = self.queue.lock().unwrap(); + pub(crate) async fn next_event(&self) -> Option { + let locked_queue = self.queue.lock().await; locked_queue.front().cloned() } @@ -227,48 +242,42 @@ impl EventQueue { EventFuture { event_queue: Arc::clone(&self.queue), waker: Arc::clone(&self.waker) }.await } - pub(crate) fn wait_next_event(&self) -> Event { - let locked_queue = - self.notifier.wait_while(self.queue.lock().unwrap(), |queue| queue.is_empty()).unwrap(); - locked_queue.front().unwrap().clone() - } - - pub(crate) fn event_handled(&self) -> Result<(), ldk_node::lightning::io::Error> { + pub(crate) async fn event_handled(&self) -> Result<(), ldk_node::lightning::io::Error> { { - let mut locked_queue = self.queue.lock().unwrap(); + let mut locked_queue = self.queue.lock().await; locked_queue.pop_front(); - self.persist_queue(&locked_queue)?; + self.persist_queue(&locked_queue).await?; } - self.notifier.notify_one(); - if let Some(waker) = self.waker.lock().unwrap().take() { + if let Some(waker) = self.waker.lock().await.take() { waker.wake(); } Ok(()) } - fn persist_queue( + async fn persist_queue( &self, locked_queue: &VecDeque, ) -> Result<(), ldk_node::lightning::io::Error> { let data = EventQueueSerWrapper(locked_queue).encode(); - self.kv_store - .write( + KVStore::write( + self.kv_store.as_ref(), + EVENT_QUEUE_PERSISTENCE_PRIMARY_NAMESPACE, + EVENT_QUEUE_PERSISTENCE_SECONDARY_NAMESPACE, + EVENT_QUEUE_PERSISTENCE_KEY, + data, + ) + .await + .map_err(|e| { + log_error!( + self.logger.as_ref(), + "Write for key {}/{}/{} failed due to: {}", EVENT_QUEUE_PERSISTENCE_PRIMARY_NAMESPACE, EVENT_QUEUE_PERSISTENCE_SECONDARY_NAMESPACE, EVENT_QUEUE_PERSISTENCE_KEY, - &data, - ) - .map_err(|e| { - log_error!( - self.logger.as_ref(), - "Write for key {}/{}/{} failed due to: {}", - EVENT_QUEUE_PERSISTENCE_PRIMARY_NAMESPACE, - EVENT_QUEUE_PERSISTENCE_SECONDARY_NAMESPACE, - EVENT_QUEUE_PERSISTENCE_KEY, - e - ); e - })?; + ); + e + })?; Ok(()) } } @@ -296,10 +305,12 @@ impl Future for EventFuture { fn poll( self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>, ) -> Poll { - if let Some(event) = self.event_queue.lock().unwrap().front() { - Poll::Ready(event.clone()) + if let Some(event) = self.event_queue.try_lock().ok().and_then(|q| q.front().cloned()) { + Poll::Ready(event) } else { - *self.waker.lock().unwrap() = Some(cx.waker().clone()); + if let Ok(mut waker) = self.waker.try_lock() { + *waker = Some(cx.waker().clone()); + } Poll::Pending } } @@ -312,11 +323,12 @@ pub(crate) struct LdkEventHandler { pub(crate) tx_metadata: store::TxMetadataStore, pub(crate) payment_receipt_sender: watch::Sender<()>, pub(crate) channel_pending_sender: watch::Sender, + pub(crate) splice_pending_sender: watch::Sender, pub(crate) logger: Arc, } impl LdkEventHandler { - pub(crate) fn handle_ldk_node_event(&self, event: ldk_node::Event) { + pub(crate) async fn handle_ldk_node_event(&self, event: ldk_node::Event) { match event { ldk_node::Event::PaymentSuccessful { payment_id, @@ -331,22 +343,30 @@ impl LdkEventHandler { log_error!(self.logger, "Failed to set preimage for payment {payment_id:?}"); } - if let Err(e) = self.event_queue.add_event(Event::PaymentSuccessful { - payment_id, - payment_hash, - payment_preimage: preimage, - fee_paid_msat, - }) { + if let Err(e) = self + .event_queue + .add_event(Event::PaymentSuccessful { + payment_id, + payment_hash, + payment_preimage: preimage, + fee_paid_msat, + }) + .await + { log_error!(self.logger, "Failed to add PaymentSuccessful event: {e:?}"); return; } }, ldk_node::Event::PaymentFailed { payment_id, payment_hash, reason } => { - if let Err(e) = self.event_queue.add_event(Event::PaymentFailed { - payment_id: PaymentId::SelfCustodial(payment_id.unwrap().0), // safe - payment_hash, - reason, - }) { + if let Err(e) = self + .event_queue + .add_event(Event::PaymentFailed { + payment_id: PaymentId::SelfCustodial(payment_id.unwrap().0), // safe + payment_hash, + reason, + }) + .await + { log_error!(self.logger, "Failed to add PaymentFailed event: {e:?}"); return; } @@ -366,13 +386,17 @@ impl LdkEventHandler { } }); - if let Err(e) = self.event_queue.add_event(Event::PaymentReceived { - payment_id: PaymentId::SelfCustodial(payment_id.0), - payment_hash, - amount_msat, - custom_records, - lsp_fee_msats, - }) { + if let Err(e) = self + .event_queue + .add_event(Event::PaymentReceived { + payment_id: PaymentId::SelfCustodial(payment_id.0), + payment_hash, + amount_msat, + custom_records, + lsp_fee_msats, + }) + .await + { log_error!(self.logger, "Failed to add PaymentReceived event: {e:?}"); } let _ = self.payment_receipt_sender.send(()); @@ -387,21 +411,24 @@ impl LdkEventHandler { ldk_node::Event::ChannelPending { .. } => { log_debug!(self.logger, "Received ChannelPending event"); }, - ldk_node::Event::ChannelReady { channel_id, user_channel_id, counterparty_node_id } => { - let funding_txo = self - .ldk_node - .list_channels() - .iter() - .find(|c| c.user_channel_id == user_channel_id) - .and_then(|c| c.funding_txo) - .unwrap(); + ldk_node::Event::ChannelReady { + channel_id, + user_channel_id, + counterparty_node_id, + funding_txo, + } => { + let funding_txo = funding_txo.unwrap(); // safe - if let Err(e) = self.event_queue.add_event(Event::ChannelOpened { - channel_id, - user_channel_id, - counterparty_node_id: counterparty_node_id.unwrap(), // safe - funding_txo, - }) { + if let Err(e) = self + .event_queue + .add_event(Event::ChannelOpened { + channel_id, + user_channel_id, + counterparty_node_id: counterparty_node_id.unwrap(), // safe + funding_txo, + }) + .await + { log_error!(self.logger, "Failed to add ChannelOpened event: {e:?}"); return; } @@ -417,16 +444,46 @@ impl LdkEventHandler { // try to reopen the channel. store::set_rebalance_enabled(self.event_queue.kv_store.as_ref(), false); - if let Err(e) = self.event_queue.add_event(Event::ChannelClosed { - channel_id, - user_channel_id, - counterparty_node_id: counterparty_node_id.unwrap(), // safe - reason, - }) { + if let Err(e) = self + .event_queue + .add_event(Event::ChannelClosed { + channel_id, + user_channel_id, + counterparty_node_id: counterparty_node_id.unwrap(), // safe + reason, + }) + .await + { log_error!(self.logger, "Failed to add ChannelClosed event: {e:?}"); return; } }, + ldk_node::Event::SplicePending { + channel_id, + user_channel_id, + counterparty_node_id, + new_funding_txo, + } => { + log_debug!(self.logger, "Received SplicePending event {event:?}"); + let _ = self.splice_pending_sender.send(user_channel_id.0); + + if let Err(e) = self + .event_queue + .add_event(Event::SplicePending { + channel_id, + user_channel_id, + counterparty_node_id, + new_funding_txo, + }) + .await + { + log_error!(self.logger, "Failed to add SplicePending event: {e:?}"); + return; + } + }, + ldk_node::Event::SpliceFailed { .. } => { + log_warn!(self.logger, "Received SpliceFailed event: {event:?}"); + }, } if let Err(e) = self.ldk_node.event_handled() { diff --git a/orange-sdk/src/ffi/bitcoin_payment_instructions.rs b/orange-sdk/src/ffi/bitcoin_payment_instructions.rs index 2a49b09..97dbe08 100644 --- a/orange-sdk/src/ffi/bitcoin_payment_instructions.rs +++ b/orange-sdk/src/ffi/bitcoin_payment_instructions.rs @@ -85,6 +85,7 @@ pub enum ParseError { UnknownRequiredParameter, HrnResolutionError(String), InstructionsExpired, + InvalidLnurl(String), } impl Display for ParseError { @@ -108,6 +109,7 @@ impl Display for ParseError { write!(f, "Human readable name resolution error: {}", e) }, ParseError::InstructionsExpired => write!(f, "Payment instructions have expired"), + ParseError::InvalidLnurl(e) => write!(f, "Invalid LNURL: {}", e), } } } @@ -131,6 +133,7 @@ impl From for ParseError { ParseError::HrnResolutionError(msg.to_string()) }, BPIParseError::InstructionsExpired => ParseError::InstructionsExpired, + BPIParseError::InvalidLnurl(msg) => ParseError::InvalidLnurl(msg.to_string()), } } } diff --git a/orange-sdk/src/ffi/orange/mod.rs b/orange-sdk/src/ffi/orange/mod.rs index a064012..af22e65 100644 --- a/orange-sdk/src/ffi/orange/mod.rs +++ b/orange-sdk/src/ffi/orange/mod.rs @@ -262,6 +262,17 @@ pub enum Event { /// The fee paid, in msats, for the rebalance payment. fee_msat: u64, }, + /// We have initiated a splice and are waiting for it to confirm. + SplicePending { + /// The `channel_id` of the channel. + channel_id: Vec, + /// The `user_channel_id` of the channel. + user_channel_id: Vec, + /// The `node_id` of the channel counterparty. + counterparty_node_id: Vec, + /// The outpoint of the channel's splice funding transaction. + new_funding_txo: String, + }, } impl From for Event { @@ -349,6 +360,17 @@ impl From for Event { amount_msat, fee_msat, }, + OrangeEvent::SplicePending { + channel_id, + user_channel_id, + counterparty_node_id, + new_funding_txo, + } => Event::SplicePending { + channel_id: channel_id.0.to_vec(), + user_channel_id: user_channel_id.0.to_be_bytes().to_vec(), + counterparty_node_id: counterparty_node_id.serialize().to_vec(), + new_funding_txo: new_funding_txo.to_string(), + }, } } } diff --git a/orange-sdk/src/ffi/orange/wallet.rs b/orange-sdk/src/ffi/orange/wallet.rs index f52e4f3..9b308a9 100644 --- a/orange-sdk/src/ffi/orange/wallet.rs +++ b/orange-sdk/src/ffi/orange/wallet.rs @@ -88,6 +88,7 @@ impl_into_core_type!(SingleUseReceiveUri, OrangeSingleUseReceiveUri); #[derive(Clone, uniffi::Object)] pub struct Wallet { inner: Arc, + _rt: Arc, } #[uniffi::export(async_runtime = "tokio")] @@ -98,11 +99,9 @@ impl Wallet { let config: OrangeWalletConfig = config.try_into()?; - let rt_clone = rt.clone(); - let inner = - rt.block_on(async move { OrangeWallet::new_with_runtime(rt_clone, config).await })?; + let inner = rt.block_on(async move { OrangeWallet::new(config).await })?; - Ok(Wallet { inner: Arc::new(inner) }) + Ok(Wallet { inner: Arc::new(inner), _rt: rt }) } pub fn node_id(&self) -> String { @@ -198,10 +197,10 @@ impl Wallet { /// Authenticates the user via [LNURL-auth] for the given LNURL string. /// /// [LNURL-auth]: https://github.com/lnurl/luds/blob/luds/04.md - pub fn lnurl_auth(&self, lnurl: &str) -> Result<(), WalletError> { - self.inner.lnurl_auth(lnurl)?; - Ok(()) - } + // pub fn lnurl_auth(&self, lnurl: &str) -> Result<(), WalletError> { + // self.inner.lnurl_auth(lnurl)?; + // Ok(()) + // } /// Returns the wallet's configured tunables. pub fn get_tunables(&self) -> Arc { diff --git a/orange-sdk/src/lib.rs b/orange-sdk/src/lib.rs index 5c00033..31b57ad 100644 --- a/orange-sdk/src/lib.rs +++ b/orange-sdk/src/lib.rs @@ -24,13 +24,10 @@ use ldk_node::bitcoin::secp256k1::PublicKey; use ldk_node::io::sqlite_store::SqliteStore; use ldk_node::lightning::ln::msgs::SocketAddress; use ldk_node::lightning::util::logger::Logger as _; -use ldk_node::lightning::util::persist::KVStore; use ldk_node::lightning::{log_debug, log_error, log_info, log_trace, log_warn}; use ldk_node::lightning_invoice::Bolt11Invoice; -use ldk_node::payment::PaymentKind; -use ldk_node::{BuildError, ChannelDetails, NodeError}; - -use tokio::runtime::Runtime; +use ldk_node::payment::{PaymentDirection, PaymentKind}; +use ldk_node::{BuildError, ChannelDetails, DynStore, NodeError}; use std::collections::HashMap; use std::fmt::{self, Debug, Write}; @@ -43,6 +40,7 @@ mod ffi; mod lightning_wallet; pub(crate) mod logging; mod rebalancer; +mod runtime; mod store; pub mod trusted_wallet; @@ -51,6 +49,7 @@ use logging::Logger; use trusted_wallet::TrustedError; pub use crate::logging::LoggerType; +use crate::runtime::Runtime; #[cfg(feature = "cashu")] pub use crate::trusted_wallet::cashu::CashuConfig; #[cfg(feature = "spark")] @@ -113,7 +112,7 @@ struct WalletImpl { /// Metadata store for tracking transactions. tx_metadata: TxMetadataStore, /// Key-value store for persistent storage. - store: Arc, + store: Arc, /// Logger for logging wallet operations. logger: Arc, /// The Tokio runtime for asynchronous operations. @@ -508,28 +507,18 @@ impl Wallet { /// Recovery ensures trusted wallet funds can be restored when reconstructed from the same seed /// across different devices or installations. pub async fn new(config: WalletConfig) -> Result { - let rt = tokio::runtime::Builder::new_multi_thread() - .enable_all() - .build() - .map_err(|e| InitFailure::IoError(e.into()))?; - - Self::new_with_runtime(Arc::new(rt), config).await - } - /// Constructs a new Wallet with a runtime. - /// - /// `runtime` must be a reference to the running `tokio` runtime which we are currently - /// operating in. - // TODO: WOW that is a terrible API lol - pub async fn new_with_runtime( - runtime: Arc, config: WalletConfig, - ) -> Result { let tunables = config.tunables; let network = config.network; let logger = Arc::new(Logger::new(&config.logger_type).expect("Failed to open log file")); log_info!(logger, "Initializing orange on network: {network}"); - let store: Arc = match &config.storage_config { + let runtime = Arc::new(Runtime::new(Arc::clone(&logger)).map_err(|e| { + log_error!(logger, "Failed to set up tokio runtime: {e}"); + BuildError::RuntimeSetupFailed + })?); + + let store: Arc = match &config.storage_config { StorageConfig::LocalSQLite(path) => { Arc::new(SqliteStore::new(path.into(), Some("orange.sqlite".to_owned()), None)?) }, @@ -537,7 +526,7 @@ impl Wallet { let event_queue = Arc::new(EventQueue::new(Arc::clone(&store), Arc::clone(&logger))); - let tx_metadata = TxMetadataStore::new(Arc::clone(&store)); + let tx_metadata = TxMetadataStore::new(Arc::clone(&store)).await; let trusted: Arc> = match &config.extra_config { #[cfg(feature = "spark")] @@ -621,7 +610,7 @@ impl Wallet { // Spawn a background thread that every second, we see if we should initiate a rebalance // This will withdraw from the trusted balance to our LN balance, possibly opening a channel. let rb = Arc::clone(&rebalancer); - runtime.spawn(async move { + runtime.spawn_cancellable_background_task(async move { loop { rb.do_rebalance_if_needed().await; @@ -647,7 +636,7 @@ impl Wallet { /// Sets whether the wallet should automatically rebalance from trusted/onchain to lightning. pub fn set_rebalance_enabled(&self, value: bool) { - store::set_rebalance_enabled(self.inner.store.as_ref(), value); + store::set_rebalance_enabled(self.inner.store.as_ref(), value) } /// Whether the wallet should automatically rebalance from trusted/onchain to lightning. @@ -673,9 +662,14 @@ impl Wallet { /// Lists the transactions which have been made. pub async fn list_transactions(&self) -> Result, WalletError> { let trusted_payments = self.inner.trusted.list_payments().await?; - let lightning_payments = self.inner.ln_wallet.list_payments(); + let mut lightning_payments = self.inner.ln_wallet.list_payments(); + lightning_payments.sort_by_key(|l| l.latest_update_timestamp); + + let splice_outs = store::read_splice_outs(self.inner.store.as_ref()); - let mut res = Vec::with_capacity(trusted_payments.len() + lightning_payments.len()); + let mut res = Vec::with_capacity( + trusted_payments.len() + lightning_payments.len() + splice_outs.len(), + ); let tx_metadata = self.inner.tx_metadata.read(); let mut internal_transfers = HashMap::new(); @@ -791,13 +785,17 @@ impl Wallet { }, _ => None, }; - let fee = match payment.fee_paid_msat { - None => lightning_receive_fee, - Some(fee) => Some( - Amount::from_milli_sats(fee) - .unwrap() - .saturating_add(lightning_receive_fee.unwrap_or(Amount::ZERO)), - ), + let fee = if payment.direction == PaymentDirection::Outbound { + match payment.fee_paid_msat { + None => Some(lightning_receive_fee.unwrap_or(Amount::ZERO)), + Some(fee) => Some( + Amount::from_milli_sats(fee) + .unwrap() + .saturating_add(lightning_receive_fee.unwrap_or(Amount::ZERO)), + ), + } + } else { + Some(lightning_receive_fee.unwrap_or(Amount::ZERO)) }; if let Some(tx_metadata) = tx_metadata.get(&PaymentId::SelfCustodial(payment.id.0)) { match &tx_metadata.ty { @@ -894,6 +892,18 @@ impl Wallet { } } + for details in splice_outs { + res.push(Transaction { + id: PaymentId::SelfCustodial(details.id.0), + status: details.status.into(), + outbound: details.direction == PaymentDirection::Outbound, + amount: details.amount_msat.map(|a| Amount::from_milli_sats(a).unwrap()), + fee: details.fee_paid_msat.map(|fee| Amount::from_milli_sats(fee).unwrap()), + payment_type: (&details).into(), + time_since_epoch: Duration::from_secs(details.latest_update_timestamp), + }); + } + for (id, tx_info) in internal_transfers { debug_assert!( tx_info.send_fee.is_some(), @@ -1023,7 +1033,8 @@ impl Wallet { pub async fn parse_payment_instructions( &self, instructions: &str, ) -> Result { - PaymentInstructions::parse(instructions, self.inner.network, &HTTPHrnResolver, true).await + PaymentInstructions::parse(instructions, self.inner.network, &HTTPHrnResolver::new(), true) + .await } // /// Verifies instructions which allow us to claim funds given as: @@ -1109,7 +1120,9 @@ impl Wallet { let mut pay_lightning = async |method, ty: fn() -> PaymentType| { let typ = ty(); let balance = if matches!(typ, PaymentType::OutgoingOnChain { .. }) { - ln_balance.onchain + // if we are paying on-chain, we can either use the on-chain balance or the + // lightning balance with a splice. Use the larger of the two. + ln_balance.onchain.max(ln_balance.lightning) } else { ln_balance.lightning }; @@ -1134,7 +1147,7 @@ impl Wallet { }, ); let inner_ref = Arc::clone(&self.inner); - self.inner.runtime.spawn(async move { + self.inner.runtime.spawn_cancellable_background_task(async move { inner_ref.rebalancer.do_rebalance_if_needed().await; }); return Ok(()); @@ -1152,7 +1165,8 @@ impl Wallet { let methods = match &instructions.instructions { PaymentInstructions::ConfigurableAmount(conf) => { - let res = conf.clone().set_amount(instructions.amount, &HTTPHrnResolver).await; + let res = + conf.clone().set_amount(instructions.amount, &HTTPHrnResolver::new()).await; let fixed_instr = res.map_err(|e| { log_error!( self.inner.logger, @@ -1263,13 +1277,13 @@ impl Wallet { Ok(()) } - /// Authenticates the user via [LNURL-auth] for the given LNURL string. - /// - /// [LNURL-auth]: https://github.com/lnurl/luds/blob/luds/04.md - pub fn lnurl_auth(&self, lnurl: &str) -> Result<(), WalletError> { - self.inner.ln_wallet.inner.ldk_node.lnurl_auth(lnurl)?; - Ok(()) - } + // Authenticates the user via [LNURL-auth] for the given LNURL string. + // + // [LNURL-auth]: https://github.com/lnurl/luds/blob/luds/04.md + // pub fn lnurl_auth(&self, _lnurl: &str) -> Result<(), WalletError> { + // // todo wait for merge, self.inner.ln_wallet.inner.ldk_node.lnurl_auth(lnurl)?; + // Ok(()) + // } /// Returns the wallet's configured tunables. pub fn get_tunables(&self) -> Tunables { @@ -1285,7 +1299,7 @@ impl Wallet { /// **Caution:** Users must handle events as quickly as possible to prevent a large event backlog, /// which can increase the memory footprint of [`Wallet`]. pub fn next_event(&self) -> Option { - self.inner.event_queue.next_event() + self.inner.runtime.block_on(self.inner.event_queue.next_event()) } /// Returns the next event in the event queue. @@ -1309,14 +1323,17 @@ impl Wallet { /// **Caution:** Users must handle events as quickly as possible to prevent a large event backlog, /// which can increase the memory footprint of [`Wallet`]. pub fn wait_next_event(&self) -> Event { - self.inner.event_queue.wait_next_event() + let fut = self.inner.event_queue.next_event_async(); + // We use our runtime for the sync variant to ensure `tokio::task::block_in_place` is + // always called if we'd ever hit this in an outer runtime context. + self.inner.runtime.block_on(fut) } /// Confirm the last retrieved event handled. /// /// **Note:** This **MUST** be called after each event has been handled. pub fn event_handled(&self) -> Result<(), ()> { - self.inner.event_queue.event_handled().map_err(|e| { + self.inner.runtime.block_on(self.inner.event_queue.event_handled()).map_err(|e| { log_error!( self.inner.logger, "Couldn't mark event handled due to persistence failure: {e}" @@ -1337,5 +1354,20 @@ impl Wallet { log_debug!(self.inner.logger, "Stopping ln wallet..."); self.inner.ln_wallet.stop(); + + // Cancel cancellable background tasks + self.inner.runtime.abort_cancellable_background_tasks(); + + // Wait until non-cancellable background tasks (mod LDK's background processor) are done. + self.inner.runtime.wait_on_background_tasks(); + } + + /// Manually sync the LDK and BDK wallets with the current chain state and update the fee rate cache. + /// + /// This is done automatically in the background, but can be triggered manually if needed. Often useful for + /// testing purposes. + pub fn sync_ln_wallet(&self) -> Result<(), WalletError> { + self.inner.ln_wallet.inner.ldk_node.sync_wallets()?; + Ok(()) } } diff --git a/orange-sdk/src/lightning_wallet.rs b/orange-sdk/src/lightning_wallet.rs index 472d6af..b09e07c 100644 --- a/orange-sdk/src/lightning_wallet.rs +++ b/orange-sdk/src/lightning_wallet.rs @@ -1,8 +1,10 @@ use crate::bitcoin::OutPoint; +use crate::bitcoin::hashes::Hash; use crate::event::{EventQueue, LdkEventHandler}; use crate::logging::Logger; +use crate::runtime::Runtime; use crate::store::{TxMetadataStore, TxStatus}; -use crate::{ChainSource, InitFailure, PaymentType, Seed, WalletConfig}; +use crate::{ChainSource, InitFailure, PaymentType, Seed, WalletConfig, store}; use bitcoin_payment_instructions::PaymentMethod; use bitcoin_payment_instructions::amount::Amount; @@ -10,15 +12,17 @@ use bitcoin_payment_instructions::amount::Amount; use ldk_node::bitcoin::base64::Engine; use ldk_node::bitcoin::base64::prelude::BASE64_STANDARD; use ldk_node::bitcoin::secp256k1::PublicKey; -use ldk_node::bitcoin::{Address, Network, Script}; +use ldk_node::bitcoin::{Address, Network}; +use ldk_node::config::{AsyncPaymentsRole, BackgroundSyncConfig}; use ldk_node::lightning::ln::channelmanager::PaymentId; use ldk_node::lightning::ln::msgs::SocketAddress; use ldk_node::lightning::util::logger::Logger as _; -use ldk_node::lightning::util::persist::KVStore; use ldk_node::lightning::{log_debug, log_error, log_info}; use ldk_node::lightning_invoice::{Bolt11Invoice, Bolt11InvoiceDescription, Description}; -use ldk_node::payment::{PaymentDetails, PaymentDirection, PaymentKind, PaymentStatus}; -use ldk_node::{NodeError, UserChannelId, lightning}; +use ldk_node::payment::{ + ConfirmationStatus, PaymentDetails, PaymentDirection, PaymentKind, PaymentStatus, +}; +use ldk_node::{DynStore, NodeError, UserChannelId}; use graduated_rebalancer::{LightningBalance, ReceivedLightningPayment}; @@ -26,8 +30,7 @@ use std::collections::HashMap; use std::fmt::Debug; use std::pin::Pin; use std::sync::Arc; - -use tokio::runtime::Runtime; +use std::time::SystemTime; use tokio::sync::watch; #[derive(Debug, Clone, Copy)] @@ -38,8 +41,11 @@ pub(crate) struct LightningWalletBalance { pub(crate) struct LightningWalletImpl { pub(crate) ldk_node: Arc, + logger: Arc, + store: Arc, payment_receipt_flag: watch::Receiver<()>, channel_pending_receipt_flag: watch::Receiver, + splice_pending_receipt_flag: watch::Receiver, lsp_node_id: PublicKey, lsp_socket_addr: SocketAddress, } @@ -52,7 +58,7 @@ const DEFAULT_INVOICE_EXPIRY_SECS: u32 = 86_400; // 24 hours impl LightningWallet { pub(super) async fn init( - runtime: Arc, config: WalletConfig, store: Arc, + runtime: Arc, config: WalletConfig, store: Arc, event_queue: Arc, tx_metadata: TxMetadataStore, logger: Arc, ) -> Result { log_info!(logger, "Creating LDK node..."); @@ -104,22 +110,42 @@ impl LightningWallet { let (lsp_socket_addr, lsp_node_id, lsp_token) = config.lsp; builder.set_liquidity_source_lsps2(lsp_node_id, lsp_socket_addr.clone(), lsp_token); match config.chain_source { - ChainSource::Esplora { url, username, password } => match (&username, &password) { - (Some(username), Some(password)) => { - let mut headers = HashMap::with_capacity(1); - headers.insert( - "Authorization".to_string(), - format!( - "Basic {}", - BASE64_STANDARD.encode(format!("{}:{}", username, password)) - ), - ); - builder.set_chain_source_esplora_with_headers(url, headers, None) - }, - (None, None) => builder.set_chain_source_esplora(url, None), - _ => { - return Err(InitFailure::LdkNodeStartFailure(NodeError::WalletOperationFailed)); - }, + ChainSource::Esplora { url, username, password } => { + let sync_config = if config.network == Network::Regtest { + ldk_node::config::EsploraSyncConfig { + background_sync_config: Some(BackgroundSyncConfig { + onchain_wallet_sync_interval_secs: 2, + lightning_wallet_sync_interval_secs: 2, + fee_rate_cache_update_interval_secs: 30, + }), + } + } else { + ldk_node::config::EsploraSyncConfig::default() + }; + + match (&username, &password) { + (Some(username), Some(password)) => { + let mut headers = HashMap::with_capacity(1); + headers.insert( + "Authorization".to_string(), + format!( + "Basic {}", + BASE64_STANDARD.encode(format!("{username}:{password}")) + ), + ); + builder.set_chain_source_esplora_with_headers( + url, + headers, + Some(sync_config), + ) + }, + (None, None) => builder.set_chain_source_esplora(url, Some(sync_config)), + _ => { + return Err(InitFailure::LdkNodeStartFailure( + NodeError::WalletOperationFailed, + )); + }, + } }, ChainSource::Electrum(url) => builder.set_chain_source_electrum(url, None), ChainSource::BitcoindRPC { host, port, user, password } => { @@ -127,61 +153,47 @@ impl LightningWallet { }, }; + builder.set_async_payments_role(Some(AsyncPaymentsRole::Client))?; + builder.set_custom_logger(Arc::clone(&logger) as Arc); - // download scorer and write to storage - // todo switch to https://github.com/lightningdevkit/ldk-node/pull/449 once available + builder.set_runtime(runtime.get_handle()); + if let Some(url) = config.scorer_url { - let fetch = tokio::time::timeout(std::time::Duration::from_secs(10), reqwest::get(url)); - let res = fetch.await.map_err(|e| { - log_error!(logger, "Timed out downloading scorer: {e}"); - InitFailure::LdkNodeStartFailure(NodeError::InvalidUri) - })?; - - let req = res.map_err(|e| { - log_error!(logger, "Failed to download scorer: {e}"); - InitFailure::LdkNodeStartFailure(NodeError::InvalidUri) - })?; - - let bytes = req.bytes().await.map_err(|e| { - log_debug!(logger, "Failed to read scorer bytes: {e}"); - InitFailure::LdkNodeStartFailure(NodeError::InvalidUri) - })?; - - store.write( - lightning::util::persist::SCORER_PERSISTENCE_PRIMARY_NAMESPACE, - lightning::util::persist::SCORER_PERSISTENCE_SECONDARY_NAMESPACE, - lightning::util::persist::SCORER_PERSISTENCE_KEY, - bytes.as_ref(), - )?; + builder.set_pathfinding_scores_source(url); } let ldk_node = Arc::new(builder.build_with_store(Arc::clone(&store))?); let (payment_receipt_sender, payment_receipt_flag) = watch::channel(()); let (channel_pending_sender, channel_pending_receipt_flag) = watch::channel(0); + let (splice_pending_sender, splice_pending_receipt_flag) = watch::channel(0); let ev_handler = Arc::new(LdkEventHandler { event_queue, ldk_node: Arc::clone(&ldk_node), tx_metadata, payment_receipt_sender, channel_pending_sender, - logger, + splice_pending_sender, + logger: Arc::clone(&logger), }); let inner = Arc::new(LightningWalletImpl { ldk_node, + logger, + store, payment_receipt_flag, channel_pending_receipt_flag, + splice_pending_receipt_flag, lsp_node_id, lsp_socket_addr, }); - inner.ldk_node.start_with_runtime(Arc::clone(&runtime))?; + inner.ldk_node.start()?; - runtime.spawn(async move { + runtime.spawn_cancellable_background_task(async move { loop { let event = ev_handler.ldk_node.next_event_async().await; - log_debug!(ev_handler.logger, "Got ldk-node event {:?}", event); - ev_handler.handle_ldk_node_event(event); + log_debug!(ev_handler.logger, "Got ldk-node event {event:?}"); + ev_handler.handle_ldk_node_event(event).await; } }); @@ -200,6 +212,12 @@ impl LightningWallet { flag.wait_for(|t| t == &channel_id).await.expect("channel pending not received"); } + pub(crate) async fn await_splice_pending(&self, channel_id: u128) { + let mut flag = self.inner.splice_pending_receipt_flag.clone(); + flag.mark_unchanged(); + flag.wait_for(|t| t == &channel_id).await.expect("splice pending not received"); + } + pub(crate) fn get_on_chain_address(&self) -> Result { self.inner.ldk_node.onchain_payment().new_address() } @@ -281,13 +299,113 @@ impl LightningWallet { .inner .ldk_node .bolt12_payment() - .send_using_amount(offer, amount.milli_sats(), None, None), - PaymentMethod::OnChain(address) => self - .inner - .ldk_node - .onchain_payment() - .send_to_address(address, amount.sats_rounding_up(), None) - .map(|txid| PaymentId(*txid.as_ref())), + .send_using_amount(offer, amount.milli_sats(), None, None, None), + PaymentMethod::OnChain(address) => { + let amount_sats = amount.sats().map_err(|_| NodeError::InvalidAmount)?; + + let balance = self.inner.ldk_node.list_balances(); + + // if we have enough onchain balance, send onchain + if balance.spendable_onchain_balance_sats > amount_sats { + self.inner + .ldk_node + .onchain_payment() + .send_to_address(address, amount_sats, None) + .map(|txid| PaymentId(*txid.as_ref())) + } else { + // otherwise try to pay via splice out + + // find existing channel to splice out of + let channels = self.inner.ldk_node.list_channels(); + let channel = + channels.iter().find(|c| c.counterparty_node_id == self.inner.lsp_node_id); + + match channel { + None => { + log_error!(self.inner.logger, "No existing channel to splice out of"); + Err(NodeError::InsufficientFunds) + }, + Some(chan) => { + self.inner.ldk_node.splice_out( + &chan.user_channel_id, + chan.counterparty_node_id, + address, + amount_sats, + )?; + + loop { + self.await_splice_pending(chan.user_channel_id.0).await; + let channels = self.inner.ldk_node.list_channels(); + let new_chan = channels + .iter() + .find(|c| c.user_channel_id == chan.user_channel_id); + match new_chan { + Some(c) => { + if c.funding_txo + .is_some_and(|f| f != chan.funding_txo.unwrap()) + { + let funding_txo = c.funding_txo.unwrap(); + + let id = PaymentId(funding_txo.txid.to_byte_array()); + let details = PaymentDetails { + id, + kind: PaymentKind::Onchain { + txid: funding_txo.txid, + status: ConfirmationStatus::Unconfirmed, // todo how do we update this? + }, + amount_msat: Some(amount_sats * 1_000), + fee_paid_msat: Some(69), // todo get real fee + direction: PaymentDirection::Outbound, + status: PaymentStatus::Succeeded, + latest_update_timestamp: SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(), + }; + + store::write_splice_out( + self.inner.store.as_ref(), + &details, + ); + return Ok(id); + } + }, + None => { + log_error!( + self.inner.logger, + "Channel disappeared while awaiting splice out" + ); + return Err(NodeError::WalletOperationFailed); + }, + } + } + }, + } + } + }, + } + } + + pub(crate) async fn splice_balance_into_channel( + &self, amount: Amount, + ) -> Result { + // find existing channel to splice into + let channels = self.inner.ldk_node.list_channels(); + let channel = channels.iter().find(|c| c.counterparty_node_id == self.inner.lsp_node_id); + + match channel { + Some(chan) => { + self.inner.ldk_node.splice_in( + &chan.user_channel_id, + chan.counterparty_node_id, + amount.sats_rounding_up(), + )?; + Ok(chan.user_channel_id) + }, + None => { + log_error!(self.inner.logger, "No existing channel to splice into"); + Err(NodeError::WalletOperationFailed) + }, } } @@ -295,18 +413,20 @@ impl LightningWallet { let bal = self.inner.ldk_node.list_balances().spendable_onchain_balance_sats; // need a dummy p2wsh address to estimate the fee, p2wsh is used for LN channels - let fake_addr = Address::p2wsh(Script::new(), self.inner.ldk_node.config().network); - - let fee = self - .inner - .ldk_node - .onchain_payment() - .estimate_send_all_to_address(&fake_addr, true, None)?; + // let fake_addr = Address::p2wsh(Script::new(), self.inner.ldk_node.config().network); + // + // let fee = self + // .inner + // .ldk_node + // .onchain_payment() + // .estimate_send_all_to_address(&fake_addr, true, None)?; + // todo get real fee + let fee = 1000; let id = self.inner.ldk_node.open_channel( self.inner.lsp_node_id, self.inner.lsp_socket_addr.clone(), - bal - fee.to_sat(), + bal - fee, None, None, )?; @@ -401,6 +521,11 @@ impl graduated_rebalancer::LightningWallet for LightningWallet { }) } + fn has_channel_with_lsp(&self) -> bool { + let channels = self.inner.ldk_node.list_channels(); + channels.iter().any(|c| c.counterparty_node_id == self.inner.lsp_node_id) + } + fn open_channel_with_lsp( &self, _amt: Amount, ) -> Pin> + Send + '_>> { @@ -429,6 +554,49 @@ impl graduated_rebalancer::LightningWallet for LightningWallet { } }) } + + fn splice_to_lsp_channel( + &self, amt: Amount, + ) -> Pin> + Send + '_>> { + let bal = self.inner.ldk_node.list_balances(); + // if we don't have enough onchain balance, return error + // if we are within 1,000 sats of the amount, reduce the amount to account for fees + if bal.spendable_onchain_balance_sats < amt.sats_rounding_up() { + return Box::pin(async move { Err(NodeError::InsufficientFunds) }); + } else if bal.spendable_onchain_balance_sats < amt.sats_rounding_up() + 1_000 { + let reduced_amt = amt.saturating_sub(Amount::from_sats(1_000).expect("valid amount")); + return Box::pin(async move { + self.splice_balance_into_channel(reduced_amt).await.map(|c| c.0) + }); + } + + Box::pin(async move { self.splice_balance_into_channel(amt).await.map(|c| c.0) }) + } + + fn await_splice_pending( + &self, channel_id: u128, + ) -> Pin + Send + '_>> { + Box::pin(async move { + // todo since we can't see if we have any active splices, we just await the next splice pending event + // this is kinda race-y hopefully we can fix + self.await_splice_pending(channel_id).await; + loop { + let channels = self.inner.ldk_node.list_channels(); + let chan = channels + .into_iter() + .find(|c| c.user_channel_id.0 == channel_id && c.funding_txo.is_some()); + match chan { + Some(c) => { + return c.funding_txo.expect("channel has no funding txo"); + }, + None => { + self.await_splice_pending(channel_id).await; + // Wait for the next channel pending event + }, + } + } + }) + } } impl From for TxStatus { diff --git a/orange-sdk/src/rebalancer.rs b/orange-sdk/src/rebalancer.rs index 2d7881a..84771d5 100644 --- a/orange-sdk/src/rebalancer.rs +++ b/orange-sdk/src/rebalancer.rs @@ -8,11 +8,12 @@ use crate::trusted_wallet::DynTrustedWalletInterface; use crate::{Event, EventQueue, PaymentType, Tunables, store}; use bitcoin_payment_instructions::amount::Amount; use graduated_rebalancer::{RebalanceTrigger, RebalancerEvent, TriggerParams}; +use ldk_node::DynStore; use ldk_node::lightning::util::logger::Logger as _; -use ldk_node::lightning::util::persist::KVStore; use ldk_node::lightning::{log_error, log_info, log_trace, log_warn}; use ldk_node::payment::{ConfirmationStatus, PaymentDirection, PaymentKind, PaymentStatus}; use std::cmp; +use std::pin::Pin; use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::{Duration, SystemTime}; @@ -29,7 +30,7 @@ pub(crate) struct OrangeTrigger { /// The event handler for processing wallet events. event_queue: Arc, /// Key-value store for persistent storage. - store: Arc, + store: Arc, /// Time of the last on-chain sync, used to determine when to trigger rebalances. onchain_sync_time: AtomicU64, /// Logger for logging events and errors. @@ -41,7 +42,7 @@ impl OrangeTrigger { pub(crate) fn new( ln_wallet: Arc, trusted: Arc>, tunables: Tunables, tx_metadata: TxMetadataStore, event_queue: Arc, - store: Arc, logger: Arc, + store: Arc, logger: Arc, ) -> Self { let start = ln_wallet.inner.ldk_node.status().latest_onchain_wallet_sync_timestamp.unwrap_or(0); @@ -183,7 +184,7 @@ impl RebalanceTrigger for OrangeTrigger { }; log_trace!(self.logger, "Generated OnchainPaymentReceived event: {event:?}"); - if let Err(e) = self.event_queue.add_event(event) { + if let Err(e) = self.event_queue.add_event(event).await { log_error!( self.logger, "Failed to add OnchainPaymentReceived event: {e:?}" @@ -279,76 +280,92 @@ impl OrangeRebalanceEventHandler { } impl graduated_rebalancer::EventHandler for OrangeRebalanceEventHandler { - fn handle_event(&self, event: RebalancerEvent) { - match event { - RebalancerEvent::RebalanceInitiated { - trigger_id, - trusted_rebalance_payment_id, - amount_msat, - } => { - let metadata = TxMetadata { - ty: TxType::PendingRebalance {}, - time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(), - }; - self.tx_metadata.insert(PaymentId::Trusted(trusted_rebalance_payment_id), metadata); - if let Err(e) = self.event_queue.add_event(Event::RebalanceInitiated { - trigger_payment_id: PaymentId::Trusted(trigger_id), + fn handle_event( + &self, event: RebalancerEvent, + ) -> Pin + Send + '_>> { + Box::pin(async move { + match event { + RebalancerEvent::RebalanceInitiated { + trigger_id, trusted_rebalance_payment_id, amount_msat, - }) { - log_error!(self.logger, "Failed to add RebalanceSuccessful event: {e:?}"); - } - }, - RebalancerEvent::RebalanceSuccessful { - trigger_id, - trusted_rebalance_payment_id: rebalance_id, - ln_rebalance_payment_id: lightning_id, - amount_msat, - fee_msat, - } => { - let triggering_transaction_id = PaymentId::Trusted(trigger_id); - self.tx_metadata - .set_tx_caused_rebalance(&triggering_transaction_id) - .expect("Failed to write metadata for rebalance transaction"); - let metadata = TxMetadata { - ty: TxType::TrustedToLightning { - trusted_payment: rebalance_id, - lightning_payment: lightning_id, - payment_triggering_transfer: triggering_transaction_id, - }, - time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(), - }; - self.tx_metadata.upsert(PaymentId::Trusted(rebalance_id), metadata); - self.tx_metadata.insert(PaymentId::SelfCustodial(lightning_id), metadata); - - if let Err(e) = self.event_queue.add_event(Event::RebalanceSuccessful { - trigger_payment_id: triggering_transaction_id, + } => { + let metadata = TxMetadata { + ty: TxType::PendingRebalance {}, + time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(), + }; + self.tx_metadata + .insert(PaymentId::Trusted(trusted_rebalance_payment_id), metadata); + if let Err(e) = self + .event_queue + .add_event(Event::RebalanceInitiated { + trigger_payment_id: PaymentId::Trusted(trigger_id), + trusted_rebalance_payment_id, + amount_msat, + }) + .await + { + log_error!(self.logger, "Failed to add RebalanceSuccessful event: {e:?}"); + } + }, + RebalancerEvent::RebalanceSuccessful { + trigger_id, trusted_rebalance_payment_id: rebalance_id, ln_rebalance_payment_id: lightning_id, amount_msat, fee_msat, - }) { - log_error!(self.logger, "Failed to add RebalanceSuccessful event: {e:?}"); - } - }, - RebalancerEvent::OnChainRebalanceInitiated { - trigger_id, - channel_outpoint, - user_channel_id: _, - } => { - let chan_txid = channel_outpoint.txid; - let triggering_txid = Txid::from_byte_array(trigger_id); - let trigger_id = PaymentId::SelfCustodial(triggering_txid.to_byte_array()); - self.tx_metadata - .set_tx_caused_rebalance(&trigger_id) - .expect("Failed to write metadata for onchain rebalance transaction"); - let metadata = TxMetadata { - ty: TxType::OnchainToLightning { channel_txid: chan_txid, triggering_txid }, - time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(), - }; - self.tx_metadata - .insert(PaymentId::SelfCustodial(chan_txid.to_byte_array()), metadata); - }, - } + } => { + let triggering_transaction_id = PaymentId::Trusted(trigger_id); + self.tx_metadata + .set_tx_caused_rebalance(&triggering_transaction_id) + .expect("Failed to write metadata for rebalance transaction"); + let metadata = TxMetadata { + ty: TxType::TrustedToLightning { + trusted_payment: rebalance_id, + lightning_payment: lightning_id, + payment_triggering_transfer: triggering_transaction_id, + }, + time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(), + }; + self.tx_metadata.upsert(PaymentId::Trusted(rebalance_id), metadata); + self.tx_metadata.insert(PaymentId::SelfCustodial(lightning_id), metadata); + + let event_queue = Arc::clone(&self.event_queue); + let logger = Arc::clone(&self.logger); + tokio::spawn(async move { + if let Err(e) = event_queue + .add_event(Event::RebalanceSuccessful { + trigger_payment_id: triggering_transaction_id, + trusted_rebalance_payment_id: rebalance_id, + ln_rebalance_payment_id: lightning_id, + amount_msat, + fee_msat, + }) + .await + { + log_error!(logger, "Failed to add RebalanceSuccessful event: {e:?}"); + } + }); + }, + RebalancerEvent::OnChainRebalanceInitiated { + trigger_id, + channel_outpoint, + user_channel_id: _, + } => { + let chan_txid = channel_outpoint.txid; + let triggering_txid = Txid::from_byte_array(trigger_id); + let trigger_id = PaymentId::SelfCustodial(triggering_txid.to_byte_array()); + self.tx_metadata + .set_tx_caused_rebalance(&trigger_id) + .expect("Failed to write metadata for onchain rebalance transaction"); + let metadata = TxMetadata { + ty: TxType::OnchainToLightning { channel_txid: chan_txid, triggering_txid }, + time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap(), + }; + self.tx_metadata + .insert(PaymentId::SelfCustodial(chan_txid.to_byte_array()), metadata); + }, + } + }) } } diff --git a/orange-sdk/src/runtime.rs b/orange-sdk/src/runtime.rs new file mode 100644 index 0000000..f26e8da --- /dev/null +++ b/orange-sdk/src/runtime.rs @@ -0,0 +1,152 @@ +// This file is Copyright its original authors, visible in version control history. +// +// This file is licensed under the Apache License, Version 2.0 or the MIT license , at your option. You may not use this file except in +// accordance with one or both of these licenses. + +use ldk_node::lightning::util::logger::Logger as _; +use ldk_node::lightning::{log_debug, log_error, log_trace, log_warn}; +use std::future::Future; +use std::sync::{Arc, Mutex}; +use std::time::Duration; +use tokio::task::{JoinHandle, JoinSet}; + +use crate::logging::Logger; + +// The timeout after which we give up waiting on a background task to exit on shutdown. +pub(crate) const BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS: u64 = 5; + +pub(crate) struct Runtime { + mode: RuntimeMode, + background_tasks: Mutex>, + cancellable_background_tasks: Mutex>, + logger: Arc, +} + +impl Runtime { + pub fn new(logger: Arc) -> Result { + let mode = match tokio::runtime::Handle::try_current() { + Ok(handle) => RuntimeMode::Handle(handle), + Err(_) => { + let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build()?; + RuntimeMode::Owned(rt) + }, + }; + let background_tasks = Mutex::new(JoinSet::new()); + let cancellable_background_tasks = Mutex::new(JoinSet::new()); + + Ok(Self { mode, background_tasks, cancellable_background_tasks, logger }) + } + + pub fn get_handle(&self) -> tokio::runtime::Handle { + match &self.mode { + RuntimeMode::Owned(rt) => rt.handle().clone(), + RuntimeMode::Handle(h) => h.clone(), + } + } + + #[allow(unused)] + pub fn with_handle(handle: tokio::runtime::Handle, logger: Arc) -> Self { + let mode = RuntimeMode::Handle(handle); + let background_tasks = Mutex::new(JoinSet::new()); + let cancellable_background_tasks = Mutex::new(JoinSet::new()); + + Self { mode, background_tasks, cancellable_background_tasks, logger } + } + + pub fn spawn_background_task(&self, future: F) + where + F: Future + Send + 'static, + { + let mut background_tasks = self.background_tasks.lock().unwrap(); + let runtime_handle = self.handle(); + background_tasks.spawn_on(future, runtime_handle); + } + + pub fn spawn_cancellable_background_task(&self, future: F) + where + F: Future + Send + 'static, + { + let mut cancellable_background_tasks = self.cancellable_background_tasks.lock().unwrap(); + let runtime_handle = self.handle(); + cancellable_background_tasks.spawn_on(future, runtime_handle); + } + + #[allow(unused)] + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let handle = self.handle(); + handle.spawn_blocking(func) + } + + pub fn block_on(&self, future: F) -> F::Output { + // While we generally decided not to overthink via which call graph users would enter our + // runtime context, we'd still try to reuse whatever current context would be present + // during `block_on`, as this is the context `block_in_place` would operate on. So we try + // to detect the outer context here, and otherwise use whatever was set during + // initialization. + let handle = tokio::runtime::Handle::try_current().unwrap_or(self.handle().clone()); + tokio::task::block_in_place(move || handle.block_on(future)) + } + + pub fn abort_cancellable_background_tasks(&self) { + let mut tasks = core::mem::take(&mut *self.cancellable_background_tasks.lock().unwrap()); + if tasks.is_empty() { + log_warn!(self.logger, "Stopping cancellable background tasks with no tasks"); + return; + } + tasks.abort_all(); + self.block_on(async { while tasks.join_next().await.is_some() {} }) + } + + pub fn wait_on_background_tasks(&self) { + let mut tasks = core::mem::take(&mut *self.background_tasks.lock().unwrap()); + if tasks.is_empty() { + log_warn!(self.logger, "Stopping background tasks with no tasks"); + return; + } + self.block_on(async { + loop { + let timeout_fut = tokio::time::timeout( + Duration::from_secs(BACKGROUND_TASK_SHUTDOWN_TIMEOUT_SECS), + tasks.join_next_with_id(), + ); + match timeout_fut.await { + Ok(Some(Ok((id, _)))) => { + log_trace!(self.logger, "Stopped background task with id {id}"); + }, + Ok(Some(Err(e))) => { + tasks.abort_all(); + log_trace!(self.logger, "Stopping background task failed: {e}"); + break; + }, + Ok(None) => { + log_debug!(self.logger, "Stopped all background tasks"); + break; + }, + Err(e) => { + tasks.abort_all(); + log_error!(self.logger, "Stopping background task timed out: {e}"); + break; + }, + } + } + }) + } + + fn handle(&self) -> &tokio::runtime::Handle { + match &self.mode { + RuntimeMode::Owned(rt) => rt.handle(), + RuntimeMode::Handle(handle) => handle, + } + } +} + +enum RuntimeMode { + Owned(tokio::runtime::Runtime), + Handle(tokio::runtime::Handle), +} diff --git a/orange-sdk/src/store.rs b/orange-sdk/src/store.rs index a305cab..2908e25 100644 --- a/orange-sdk/src/store.rs +++ b/orange-sdk/src/store.rs @@ -13,14 +13,16 @@ use bitcoin_payment_instructions::amount::Amount; +use ldk_node::DynStore; use ldk_node::bitcoin::Txid; use ldk_node::bitcoin::hex::{DisplayHex, FromHex}; use ldk_node::lightning::io; use ldk_node::lightning::ln::msgs::DecodeError; use ldk_node::lightning::types::payment::PaymentPreimage; -use ldk_node::lightning::util::persist::KVStore; +use ldk_node::lightning::util::persist::{KVStore, KVStoreSync}; use ldk_node::lightning::util::ser::{Readable, Writeable, Writer}; use ldk_node::lightning::{impl_writeable_tlv_based, impl_writeable_tlv_based_enum}; +use ldk_node::payment::PaymentDetails; use std::collections::HashMap; use std::fmt; @@ -30,6 +32,7 @@ use std::time::Duration; const STORE_PRIMARY_KEY: &str = "orange_sdk"; const STORE_SECONDARY_KEY: &str = "payment_store"; +const SPLICE_OUT_SECONDARY_KEY: &str = "splice_out"; /// The status of a transaction. This is used to track the state of a transaction #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -292,19 +295,20 @@ impl_writeable_tlv_based!(TxMetadata, { (0, ty, required), (2, time, required) } #[derive(Clone)] pub(crate) struct TxMetadataStore { tx_metadata: Arc>>, - store: Arc, + store: Arc, } impl TxMetadataStore { - pub fn new(store: Arc) -> TxMetadataStore { - let keys = store - .list(STORE_PRIMARY_KEY, STORE_SECONDARY_KEY) + pub async fn new(store: Arc) -> TxMetadataStore { + let keys = KVStore::list(store.as_ref(), STORE_PRIMARY_KEY, STORE_SECONDARY_KEY) + .await .expect("We do not allow reads to fail"); let mut tx_metadata = HashMap::with_capacity(keys.len()); for key in keys { - let data_bytes = store - .read(STORE_PRIMARY_KEY, STORE_SECONDARY_KEY, &key) - .expect("We do not allow reads to fail"); + let data_bytes = + KVStore::read(store.as_ref(), STORE_PRIMARY_KEY, STORE_SECONDARY_KEY, &key) + .await + .expect("We do not allow reads to fail"); let key = PaymentId::from_str(&key).expect("Invalid key in transaction metadata storage"); let data = Readable::read(&mut &data_bytes[..]) @@ -323,9 +327,14 @@ impl TxMetadataStore { let key_str = key.to_string(); let ser = value.encode(); let old = tx_metadata.insert(key, value); - self.store - .write(STORE_PRIMARY_KEY, STORE_SECONDARY_KEY, &key_str, &ser) - .expect("We do not allow writes to fail"); + KVStoreSync::write( + self.store.as_ref(), + STORE_PRIMARY_KEY, + STORE_SECONDARY_KEY, + &key_str, + ser, + ) + .expect("We do not allow writes to fail"); old.is_some() } @@ -345,9 +354,14 @@ impl TxMetadataStore { metadata.ty = TxType::PaymentTriggeringTransferLightning { ty: *ty }; let key_str = payment_id.to_string(); let ser = metadata.encode(); - self.store - .write(STORE_PRIMARY_KEY, STORE_SECONDARY_KEY, &key_str, &ser) - .expect("We do not allow writes to fail"); + KVStoreSync::write( + self.store.as_ref(), + STORE_PRIMARY_KEY, + STORE_SECONDARY_KEY, + &key_str, + ser, + ) + .expect("We do not allow writes to fail"); Ok(()) } else { eprintln!("payment_id {payment_id} is not a payment, cannot set rebalance"); @@ -377,14 +391,14 @@ impl TxMetadataStore { }, }; - self.store - .write( - STORE_PRIMARY_KEY, - STORE_SECONDARY_KEY, - &payment_id.to_string(), - &metadata.encode(), - ) - .expect("We do not allow writes to fail"); + KVStoreSync::write( + self.store.as_ref(), + STORE_PRIMARY_KEY, + STORE_SECONDARY_KEY, + &payment_id.to_string(), + metadata.encode(), + ) + .expect("We do not allow writes to fail"); Ok(()) } }, @@ -398,14 +412,14 @@ impl TxMetadataStore { }, }; - self.store - .write( - STORE_PRIMARY_KEY, - STORE_SECONDARY_KEY, - &payment_id.to_string(), - &metadata.encode(), - ) - .expect("We do not allow writes to fail"); + KVStoreSync::write( + self.store.as_ref(), + STORE_PRIMARY_KEY, + STORE_SECONDARY_KEY, + &payment_id.to_string(), + metadata.encode(), + ) + .expect("We do not allow writes to fail"); Ok(()) } }, @@ -431,8 +445,8 @@ impl TxMetadataStore { const REBALANCE_ENABLED_KEY: &str = "rebalance_enabled"; -pub(crate) fn get_rebalance_enabled(store: &dyn KVStore) -> bool { - match store.read(STORE_PRIMARY_KEY, "", REBALANCE_ENABLED_KEY) { +pub(crate) fn get_rebalance_enabled(store: &DynStore) -> bool { + match KVStoreSync::read(store, STORE_PRIMARY_KEY, "", REBALANCE_ENABLED_KEY) { Ok(bytes) => Readable::read(&mut &bytes[..]).expect("Invalid data in rebalance_enabled"), Err(e) if e.kind() == io::ErrorKind::NotFound => { // if rebalance_enabled is not found, default to true @@ -447,13 +461,38 @@ pub(crate) fn get_rebalance_enabled(store: &dyn KVStore) -> bool { } } -pub(crate) fn set_rebalance_enabled(store: &dyn KVStore, enabled: bool) { +pub(crate) fn set_rebalance_enabled(store: &DynStore, enabled: bool) { let bytes = enabled.encode(); - store - .write(STORE_PRIMARY_KEY, "", REBALANCE_ENABLED_KEY, &bytes) + KVStoreSync::write(store, STORE_PRIMARY_KEY, "", REBALANCE_ENABLED_KEY, bytes) .expect("Failed to write rebalance_enabled"); } +pub(crate) fn write_splice_out(store: &DynStore, details: &PaymentDetails) { + KVStoreSync::write( + store, + STORE_PRIMARY_KEY, + SPLICE_OUT_SECONDARY_KEY, + &details.id.0.to_lower_hex_string(), + details.encode(), + ) + .expect("Failed to write splice out txid"); +} + +pub(crate) fn read_splice_outs(store: &DynStore) -> Vec { + let keys = KVStoreSync::list(store, STORE_PRIMARY_KEY, SPLICE_OUT_SECONDARY_KEY) + .expect("We do not allow reads to fail"); + let mut splice_outs = Vec::with_capacity(keys.len()); + for key in keys { + let data_bytes = + KVStoreSync::read(store, STORE_PRIMARY_KEY, SPLICE_OUT_SECONDARY_KEY, &key) + .expect("We do not allow reads to fail"); + let data = + Readable::read(&mut &data_bytes[..]).expect("Invalid data in splice out storage"); + splice_outs.push(data); + } + splice_outs +} + #[cfg(test)] mod tests { use super::*; diff --git a/orange-sdk/src/trusted_wallet/cashu/cashu_store.rs b/orange-sdk/src/trusted_wallet/cashu/cashu_store.rs index 170b284..50f3407 100644 --- a/orange-sdk/src/trusted_wallet/cashu/cashu_store.rs +++ b/orange-sdk/src/trusted_wallet/cashu/cashu_store.rs @@ -5,9 +5,11 @@ use std::sync::{Arc, RwLock}; use async_trait::async_trait; use cdk::cdk_database::WalletDatabase; +use ldk_node::DynStore; use ldk_node::lightning::io; use ldk_node::lightning::util::persist::KVStore; +use crate::trusted_wallet::TrustedError; use cdk::mint_url::MintUrl; use cdk::nuts::{ CurrencyUnit, Id, KeySet, KeySetInfo, Keys, MintInfo, PublicKey, SpendingConditions, State, @@ -21,8 +23,6 @@ use cdk::wallet::{ use std::collections::hash_map::DefaultHasher; use std::hash::{Hash, Hasher}; -use crate::trusted_wallet::TrustedError; - // Constants for organizing data in the KV store const CASHU_PRIMARY_KEY: &str = "cashu_wallet"; @@ -110,7 +110,7 @@ impl From for cdk::cdk_database::Error { /// A KV store-based implementation of the Cashu WalletDatabase trait pub struct CashuKvDatabase { - store: Arc, + store: Arc, // In-memory caches for frequently accessed data mints_cache: Arc>>>, proofs_cache: Arc>>, @@ -140,7 +140,7 @@ impl CashuKvDatabase { /// /// Returns a `Result` containing the initialized database or a `DatabaseError` if /// initialization fails. - pub fn new(store: Arc) -> Result { + pub async fn new(store: Arc) -> Result { let database = Self { store, mints_cache: Arc::new(RwLock::new(HashMap::new())), @@ -148,20 +148,20 @@ impl CashuKvDatabase { }; // Initialize caches from persistent storage - database.load_caches()?; + database.load_caches().await?; Ok(database) } - fn load_caches(&self) -> Result<(), DatabaseError> { + async fn load_caches(&self) -> Result<(), DatabaseError> { // Load mints cache - if let Ok(mints) = self.load_mints_from_store() { + if let Ok(mints) = self.load_mints_from_store().await { let mut cache = self.mints_cache.write().unwrap(); *cache = mints; } // Load proofs cache - if let Ok(proofs) = self.load_proofs_from_store() { + if let Ok(proofs) = self.load_proofs_from_store().await { let mut cache = self.proofs_cache.write().unwrap(); *cache = proofs; } @@ -169,20 +169,25 @@ impl CashuKvDatabase { Ok(()) } - fn load_mints_from_store(&self) -> Result>, DatabaseError> { - let keys = self.store.list(CASHU_PRIMARY_KEY, MINTS_KEY).map_err(DatabaseError::Io)?; + async fn load_mints_from_store( + &self, + ) -> Result>, DatabaseError> { + let keys = KVStore::list(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY) + .await + .map_err(DatabaseError::Io)?; - let mut mints = HashMap::new(); + let mut mints = HashMap::with_capacity(keys.len()); for key in keys { - let data = - self.store.read(CASHU_PRIMARY_KEY, MINTS_KEY, &key).map_err(DatabaseError::Io)?; + let data = KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY, &key) + .await + .map_err(DatabaseError::Io)?; if !data.is_empty() { let mint_url: MintUrl = serde_json::from_slice(&data) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; // Try to load mint info - let mint_info = self.load_mint_info(&mint_url).ok().flatten(); + let mint_info = self.load_mint_info(&mint_url).await.ok().flatten(); mints.insert(mint_url, mint_info); } } @@ -190,13 +195,16 @@ impl CashuKvDatabase { Ok(mints) } - fn load_proofs_from_store(&self) -> Result, DatabaseError> { - let keys = self.store.list(CASHU_PRIMARY_KEY, PROOFS_KEY).map_err(DatabaseError::Io)?; + async fn load_proofs_from_store(&self) -> Result, DatabaseError> { + let keys = KVStore::list(self.store.as_ref(), CASHU_PRIMARY_KEY, PROOFS_KEY) + .await + .map_err(DatabaseError::Io)?; - let mut proofs = Vec::new(); + let mut proofs = Vec::with_capacity(keys.len()); for key in keys { - let data = - self.store.read(CASHU_PRIMARY_KEY, PROOFS_KEY, &key).map_err(DatabaseError::Io)?; + let data = KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, PROOFS_KEY, &key) + .await + .map_err(DatabaseError::Io)?; if !data.is_empty() { let proof: ProofInfo = serde_json::from_slice(&data) @@ -209,9 +217,9 @@ impl CashuKvDatabase { Ok(proofs) } - fn load_mint_info(&self, mint_url: &MintUrl) -> Result, DatabaseError> { + async fn load_mint_info(&self, mint_url: &MintUrl) -> Result, DatabaseError> { let key = Self::generate_mint_info_key(mint_url); - match self.store.read(CASHU_PRIMARY_KEY, MINTS_KEY, &key) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY, &key).await { Ok(data) => { if data.is_empty() { return Ok(None); @@ -225,14 +233,16 @@ impl CashuKvDatabase { } } - fn save_mint_info( + async fn save_mint_info( &self, mint_url: &MintUrl, mint_info: &MintInfo, ) -> Result<(), DatabaseError> { let key = Self::generate_mint_info_key(mint_url); let data = serde_json::to_vec(mint_info) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store.write(CASHU_PRIMARY_KEY, MINTS_KEY, &key, &data).map_err(DatabaseError::Io) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY, &key, data) + .await + .map_err(DatabaseError::Io) } fn generate_proof_key(proof: &ProofInfo) -> String { @@ -277,13 +287,13 @@ impl WalletDatabase for CashuKvDatabase { let mint_data = serde_json::to_vec(&mint_url) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, MINTS_KEY, &mint_key, &mint_data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY, &mint_key, mint_data) + .await .map_err(DatabaseError::Io)?; // Save mint info if provided if let Some(info) = &mint_info { - self.save_mint_info(&mint_url, info)?; + self.save_mint_info(&mint_url, info).await?; } // Update cache @@ -299,21 +309,27 @@ impl WalletDatabase for CashuKvDatabase { let mint_key = Self::generate_mint_key(&mint_url); // Remove mint URL by writing empty data - self.store - .write(CASHU_PRIMARY_KEY, MINTS_KEY, &mint_key, &[]) + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY, &mint_key, false) + .await .map_err(DatabaseError::Io)?; // Remove mint info let info_key = Self::generate_mint_info_key(&mint_url); - self.store - .write(CASHU_PRIMARY_KEY, MINTS_KEY, &info_key, &[]) + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, MINTS_KEY, &info_key, false) + .await .map_err(DatabaseError::Io)?; // Remove mint keysets let keysets_key = Self::generate_mint_keysets_key(&mint_url); - self.store - .write(CASHU_PRIMARY_KEY, MINT_KEYSETS_KEY, &keysets_key, &[]) - .map_err(DatabaseError::Io)?; + KVStore::remove( + self.store.as_ref(), + CASHU_PRIMARY_KEY, + MINT_KEYSETS_KEY, + &keysets_key, + false, + ) + .await + .map_err(DatabaseError::Io)?; // Update cache { @@ -334,7 +350,7 @@ impl WalletDatabase for CashuKvDatabase { } // Load from storage - self.load_mint_info(&mint_url).map_err(Into::into) + self.load_mint_info(&mint_url).await.map_err(Into::into) } async fn get_mints(&self) -> Result>, Self::Err> { @@ -374,19 +390,32 @@ impl WalletDatabase for CashuKvDatabase { for keyset in keysets { // Check if keyset already exists in individual keysets table let keyset_key = format!("keyset_{}", keyset.id); - let existing_keyset = - match self.store.read(CASHU_PRIMARY_KEY, KEYSETS_TABLE_KEY, &keyset_key) { - Ok(data) if !data.is_empty() => { - let existing: KeySetInfo = serde_json::from_slice(&data) - .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - Some(existing) - }, - _ => None, - }; + let existing_keyset = match KVStore::read( + self.store.as_ref(), + CASHU_PRIMARY_KEY, + KEYSETS_TABLE_KEY, + &keyset_key, + ) + .await + { + Ok(data) if !data.is_empty() => { + let existing: KeySetInfo = serde_json::from_slice(&data) + .map_err(|e| DatabaseError::Serialization(e.to_string()))?; + Some(existing) + }, + _ => None, + }; // Check u32 mapping for conflicts let u32_key = format!("u32_{}", u32::from(keyset.id)); - match self.store.read(CASHU_PRIMARY_KEY, KEYSET_U32_MAPPING_KEY, &u32_key) { + match KVStore::read( + self.store.as_ref(), + CASHU_PRIMARY_KEY, + KEYSET_U32_MAPPING_KEY, + &u32_key, + ) + .await + { Ok(data) if !data.is_empty() => { let existing_id_str = String::from_utf8(data.to_vec()) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; @@ -401,9 +430,15 @@ impl WalletDatabase for CashuKvDatabase { _ => { // No existing mapping, create one let id_data = keyset.id.to_string().as_bytes().to_vec(); - self.store - .write(CASHU_PRIMARY_KEY, KEYSET_U32_MAPPING_KEY, &u32_key, &id_data) - .map_err(DatabaseError::Io)?; + KVStore::write( + self.store.as_ref(), + CASHU_PRIMARY_KEY, + KEYSET_U32_MAPPING_KEY, + &u32_key, + id_data, + ) + .await + .map_err(DatabaseError::Io)?; }, } @@ -420,9 +455,15 @@ impl WalletDatabase for CashuKvDatabase { // Store individual keyset let keyset_data = serde_json::to_vec(&final_keyset) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, KEYSETS_TABLE_KEY, &keyset_key, &keyset_data) - .map_err(DatabaseError::Io)?; + KVStore::write( + self.store.as_ref(), + CASHU_PRIMARY_KEY, + KEYSETS_TABLE_KEY, + &keyset_key, + keyset_data, + ) + .await + .map_err(DatabaseError::Io)?; updated_keysets.push(final_keyset); } @@ -453,8 +494,8 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec(&all_mint_keysets) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, MINT_KEYSETS_KEY, &key, &data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_KEYSETS_KEY, &key, data) + .await .map_err(DatabaseError::Io)?; Ok(()) @@ -465,7 +506,7 @@ impl WalletDatabase for CashuKvDatabase { ) -> Result>, Self::Err> { let key = Self::generate_mint_keysets_key(&mint_url); - match self.store.read(CASHU_PRIMARY_KEY, MINT_KEYSETS_KEY, &key) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_KEYSETS_KEY, &key).await { Ok(data) => { if data.is_empty() { return Ok(None); @@ -482,7 +523,7 @@ impl WalletDatabase for CashuKvDatabase { async fn get_keyset_by_id(&self, keyset_id: &Id) -> Result, Self::Err> { // Read directly from the dedicated KEYSETS_TABLE keyed by the keyset ID for efficiency let key = format!("keyset_{}", keyset_id); - match self.store.read(CASHU_PRIMARY_KEY, KEYSETS_TABLE_KEY, &key) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, KEYSETS_TABLE_KEY, &key).await { Ok(data) if !data.is_empty() => { let keyset: KeySetInfo = serde_json::from_slice(&data) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; @@ -499,15 +540,16 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec("e).map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, &key, &data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, &key, data) + .await .map_err(DatabaseError::Io)?; Ok(()) } async fn get_mint_quote(&self, quote_id: &str) -> Result, Self::Err> { - match self.store.read(CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, quote_id) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, quote_id).await + { Ok(data) => { if data.is_empty() { return Ok(None); @@ -522,14 +564,14 @@ impl WalletDatabase for CashuKvDatabase { } async fn get_mint_quotes(&self) -> Result, Self::Err> { - let keys = - self.store.list(CASHU_PRIMARY_KEY, MINT_QUOTES_KEY).map_err(DatabaseError::Io)?; + let keys = KVStore::list(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_QUOTES_KEY) + .await + .map_err(DatabaseError::Io)?; - let mut quotes = Vec::new(); + let mut quotes = Vec::with_capacity(keys.len()); for key in keys { - let data = self - .store - .read(CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, &key) + let data = KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, &key) + .await .map_err(DatabaseError::Io)?; if !data.is_empty() { @@ -544,8 +586,8 @@ impl WalletDatabase for CashuKvDatabase { async fn remove_mint_quote(&self, quote_id: &str) -> Result<(), Self::Err> { // Mark as removed by writing empty data - self.store - .write(CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, quote_id, &[]) + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, MINT_QUOTES_KEY, quote_id, false) + .await .map_err(DatabaseError::Io)?; Ok(()) @@ -556,15 +598,16 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec("e).map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, &key, &data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, &key, data) + .await .map_err(DatabaseError::Io)?; Ok(()) } async fn get_melt_quote(&self, quote_id: &str) -> Result, Self::Err> { - match self.store.read(CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, quote_id) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, quote_id).await + { Ok(data) => { if data.is_empty() { return Ok(None); @@ -579,14 +622,14 @@ impl WalletDatabase for CashuKvDatabase { } async fn get_melt_quotes(&self) -> Result, Self::Err> { - let keys = - self.store.list(CASHU_PRIMARY_KEY, MELT_QUOTES_KEY).map_err(DatabaseError::Io)?; + let keys = KVStore::list(self.store.as_ref(), CASHU_PRIMARY_KEY, MELT_QUOTES_KEY) + .await + .map_err(DatabaseError::Io)?; - let mut quotes = Vec::new(); + let mut quotes = Vec::with_capacity(keys.len()); for key in keys { - let data = self - .store - .read(CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, &key) + let data = KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, &key) + .await .map_err(DatabaseError::Io)?; if !data.is_empty() { @@ -600,8 +643,8 @@ impl WalletDatabase for CashuKvDatabase { } async fn remove_melt_quote(&self, quote_id: &str) -> Result<(), Self::Err> { - self.store - .write(CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, quote_id, &[]) + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, MELT_QUOTES_KEY, quote_id, false) + .await .map_err(DatabaseError::Io)?; Ok(()) @@ -618,7 +661,9 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec(&keyset).map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store.write(CASHU_PRIMARY_KEY, KEYS_KEY, &key, &data).map_err(DatabaseError::Io)?; + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, KEYS_KEY, &key, data) + .await + .map_err(DatabaseError::Io)?; Ok(()) } @@ -626,7 +671,7 @@ impl WalletDatabase for CashuKvDatabase { async fn get_keys(&self, id: &Id) -> Result, Self::Err> { let key = id.to_string(); - match self.store.read(CASHU_PRIMARY_KEY, KEYS_KEY, &key) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, KEYS_KEY, &key).await { Ok(data) => { if data.is_empty() { return Ok(None); @@ -643,7 +688,9 @@ impl WalletDatabase for CashuKvDatabase { async fn remove_keys(&self, id: &Id) -> Result<(), Self::Err> { let key = id.to_string(); - self.store.write(CASHU_PRIMARY_KEY, KEYS_KEY, &key, &[]).map_err(DatabaseError::Io)?; + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, KEYS_KEY, &key, false) + .await + .map_err(DatabaseError::Io)?; Ok(()) } @@ -657,8 +704,8 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec(proof) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, PROOFS_KEY, &key, &data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, PROOFS_KEY, &key, data) + .await .map_err(DatabaseError::Io)?; } @@ -666,8 +713,8 @@ impl WalletDatabase for CashuKvDatabase { for y in &removed_ys { let key = format!("proof_{}", hex::encode(y.serialize())); - self.store - .write(CASHU_PRIMARY_KEY, PROOFS_KEY, &key, &[]) + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, PROOFS_KEY, &key, false) + .await .map_err(DatabaseError::Io)?; } @@ -718,13 +765,20 @@ impl WalletDatabase for CashuKvDatabase { Ok(filtered_proofs) } + async fn get_balance( + &self, mint_url: Option, unit: Option, state: Option>, + ) -> Result { + let proofs = self.get_proofs(mint_url, unit, state, None).await?; + Ok(proofs.iter().map(|p| u64::from(p.proof.amount)).sum()) + } + async fn update_proofs_state(&self, ys: Vec, state: State) -> Result<(), Self::Err> { // Update proofs in storage and cache for y in &ys { let key = format!("proof_{}", hex::encode(y.serialize())); // Read existing proof - match self.store.read(CASHU_PRIMARY_KEY, PROOFS_KEY, &key) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, PROOFS_KEY, &key).await { Ok(data) if !data.is_empty() => { let mut proof: ProofInfo = serde_json::from_slice(&data) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; @@ -736,9 +790,15 @@ impl WalletDatabase for CashuKvDatabase { let updated_data = serde_json::to_vec(&proof) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, PROOFS_KEY, &key, &updated_data) - .map_err(DatabaseError::Io)?; + KVStore::write( + self.store.as_ref(), + CASHU_PRIMARY_KEY, + PROOFS_KEY, + &key, + updated_data, + ) + .await + .map_err(DatabaseError::Io)?; }, _ => continue, // Proof not found, skip } @@ -761,11 +821,14 @@ impl WalletDatabase for CashuKvDatabase { let key = keyset_id.to_string(); // Read current counter - let current_count = match self.store.read(CASHU_PRIMARY_KEY, KEYSET_COUNTERS_KEY, &key) { - Ok(data) if !data.is_empty() => serde_json::from_slice::(&data) - .map_err(|e| DatabaseError::Serialization(e.to_string()))?, - _ => 0, // Default to 0 if not found - }; + let current_count = + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, KEYSET_COUNTERS_KEY, &key) + .await + { + Ok(data) if !data.is_empty() => serde_json::from_slice::(&data) + .map_err(|e| DatabaseError::Serialization(e.to_string()))?, + _ => 0, // Default to 0 if not found + }; let new_count = current_count + count; @@ -773,8 +836,8 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec(&new_count) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, KEYSET_COUNTERS_KEY, &key, &data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, KEYSET_COUNTERS_KEY, &key, data) + .await .map_err(DatabaseError::Io)?; Ok(new_count) @@ -785,8 +848,8 @@ impl WalletDatabase for CashuKvDatabase { let data = serde_json::to_vec(&transaction) .map_err(|e| DatabaseError::Serialization(e.to_string()))?; - self.store - .write(CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key, &data) + KVStore::write(self.store.as_ref(), CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key, data) + .await .map_err(DatabaseError::Io)?; Ok(()) @@ -797,7 +860,7 @@ impl WalletDatabase for CashuKvDatabase { ) -> Result, Self::Err> { let key = transaction_id.to_string(); - match self.store.read(CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key) { + match KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key).await { Ok(data) => { if data.is_empty() { return Ok(None); @@ -815,15 +878,16 @@ impl WalletDatabase for CashuKvDatabase { &self, mint_url: Option, direction: Option, unit: Option, ) -> Result, Self::Err> { - let keys = - self.store.list(CASHU_PRIMARY_KEY, TRANSACTIONS_KEY).map_err(DatabaseError::Io)?; + let keys = KVStore::list(self.store.as_ref(), CASHU_PRIMARY_KEY, TRANSACTIONS_KEY) + .await + .map_err(DatabaseError::Io)?; - let mut transactions = Vec::new(); + let mut transactions = Vec::with_capacity(keys.len()); for key in keys { - let data = self - .store - .read(CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key) - .map_err(DatabaseError::Io)?; + let data = + KVStore::read(self.store.as_ref(), CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key) + .await + .map_err(DatabaseError::Io)?; if !data.is_empty() { let transaction: Transaction = serde_json::from_slice(&data) @@ -862,18 +926,16 @@ impl WalletDatabase for CashuKvDatabase { async fn remove_transaction(&self, transaction_id: TransactionId) -> Result<(), Self::Err> { let key = transaction_id.to_string(); - self.store - .write(CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key, &[]) + KVStore::remove(self.store.as_ref(), CASHU_PRIMARY_KEY, TRANSACTIONS_KEY, &key, false) + .await .map_err(DatabaseError::Io)?; Ok(()) } } -pub(super) fn read_has_recovered( - store: &Arc, -) -> Result { - match store.read(CASHU_PRIMARY_KEY, "", HAS_RECOVERED_KEY) { +pub(super) async fn read_has_recovered(store: &Arc) -> Result { + match KVStore::read(store.as_ref(), CASHU_PRIMARY_KEY, "", HAS_RECOVERED_KEY).await { Ok(data) => { if data.is_empty() { return Ok(false); @@ -885,10 +947,12 @@ pub(super) fn read_has_recovered( } } -pub(super) fn write_has_recovered( - store: &Arc, has_recovered: bool, +pub(super) async fn write_has_recovered( + store: &Arc, has_recovered: bool, ) -> Result<(), TrustedError> { let data = vec![if has_recovered { 1 } else { 0 }]; - store.write(CASHU_PRIMARY_KEY, "", HAS_RECOVERED_KEY, &data).map_err(TrustedError::IOError) + KVStore::write(store.as_ref(), CASHU_PRIMARY_KEY, "", HAS_RECOVERED_KEY, data) + .await + .map_err(TrustedError::IOError) } diff --git a/orange-sdk/src/trusted_wallet/cashu/mod.rs b/orange-sdk/src/trusted_wallet/cashu/mod.rs index d507a15..091e766 100644 --- a/orange-sdk/src/trusted_wallet/cashu/mod.rs +++ b/orange-sdk/src/trusted_wallet/cashu/mod.rs @@ -2,15 +2,16 @@ use crate::bitcoin::hex::DisplayHex; use crate::logging::Logger; +use crate::runtime::Runtime; use crate::store::{PaymentId, TxMetadataStore, TxStatus}; use crate::trusted_wallet::{Payment, TrustedError, TrustedWalletInterface}; use crate::{Event, EventQueue, InitFailure, Seed, WalletConfig}; +use ldk_node::DynStore; use ldk_node::bitcoin::hashes::Hash; use ldk_node::bitcoin::hashes::sha256::Hash as Sha256; use ldk_node::bitcoin::hex::FromHex; use ldk_node::lightning::util::logger::Logger as _; -use ldk_node::lightning::util::persist::KVStore; use ldk_node::lightning::{log_error, log_info}; use ldk_node::lightning_invoice::Bolt11Invoice; use ldk_node::lightning_types::payment::{PaymentHash, PaymentPreimage}; @@ -37,7 +38,6 @@ use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::runtime::Runtime; /// Cashu KV store implementation pub mod cashu_store; @@ -298,7 +298,7 @@ impl TrustedWalletInterface for Cashu { let tx_metadata = self.tx_metadata.clone(); let quote_id = quote.id.clone(); let payment_success_sender = self.payment_success_sender.clone(); - self.runtime.spawn(async move { + self.runtime.spawn_background_task(async move { let mut metadata = HashMap::new(); if let Some(hash) = &payment_hash { metadata.insert(PAYMENT_HASH_METADATA_KEY.to_string(), hash.to_string()); @@ -377,12 +377,14 @@ impl TrustedWalletInterface for Cashu { } let fee_paid_sat: u64 = res.fee_paid.into(); - let _ = event_queue.add_event(Event::PaymentSuccessful { - payment_id, - payment_hash: hash, - payment_preimage, - fee_paid_msat: Some(fee_paid_sat * 1_000), // convert to msats - }); + let _ = event_queue + .add_event(Event::PaymentSuccessful { + payment_id, + payment_hash: hash, + payment_preimage, + fee_paid_msat: Some(fee_paid_sat * 1_000), // convert to msats + }) + .await; payment_success_sender.send(()).unwrap(); }, @@ -395,11 +397,13 @@ impl TrustedWalletInterface for Cashu { }; if !is_rebalance { - let _ = event_queue.add_event(Event::PaymentFailed { - payment_id, - payment_hash, - reason: None, - }); + let _ = event_queue + .add_event(Event::PaymentFailed { + payment_id, + payment_hash, + reason: None, + }) + .await; } }, state => { @@ -420,11 +424,13 @@ impl TrustedWalletInterface for Cashu { }; if !is_rebalance { - let _ = event_queue.add_event(Event::PaymentFailed { - payment_id, - payment_hash, - reason: None, - }); + let _ = event_queue + .add_event(Event::PaymentFailed { + payment_id, + payment_hash, + reason: None, + }) + .await; } }, } @@ -475,7 +481,7 @@ const PAYMENT_HASH_METADATA_KEY: &str = "payment_hash"; impl Cashu { pub(crate) async fn init( - config: &WalletConfig, cashu_config: CashuConfig, store: Arc, + config: &WalletConfig, cashu_config: CashuConfig, store: Arc, event_queue: Arc, tx_metadata: TxMetadataStore, logger: Arc, runtime: Arc, ) -> Result { @@ -505,7 +511,7 @@ impl Cashu { }, }; - let db = Arc::new(CashuKvDatabase::new(Arc::clone(&store)).map_err(|e| { + let db = Arc::new(CashuKvDatabase::new(Arc::clone(&store)).await.map_err(|e| { InitFailure::TrustedFailure(TrustedError::Other(format!( "Failed to create Cashu database: {e}" ))) @@ -542,7 +548,7 @@ impl Cashu { let logger_for_monitoring = Arc::clone(&logger); let eq_for_monitoring = Arc::clone(&event_queue); let rt_for_monitoring = Arc::clone(&runtime); - runtime.spawn(async move { + runtime.spawn_cancellable_background_task(async move { loop { tokio::select! { _ = shutdown_receiver.changed() => { @@ -556,7 +562,7 @@ impl Cashu { let wallet = Arc::clone(&wallet_for_monitoring); let event_queue = Arc::clone(&eq_for_monitoring); let logger = Arc::clone(&logger_for_monitoring); - rt_for_monitoring.spawn(async move { + rt_for_monitoring.spawn_cancellable_background_task(async move { if let Err(e) = Self::monitor_mint_quote(wallet, event_queue, &logger, mint_quote).await { log_error!(logger, "Failed to monitor mint quote: {e:?}"); } @@ -579,18 +585,18 @@ impl Cashu { } // spawn background task to recover funds if first time initializing - let has_recovered = read_has_recovered(&store)?; + let has_recovered = read_has_recovered(&store).await?; if !has_recovered { let w = Arc::clone(&cashu_wallet); let l = Arc::clone(&logger); - runtime.spawn(async move { + runtime.spawn_background_task(async move { match w.restore().await { Err(e) => log_error!(l, "Failed to restore cashu mint: {e}"), Ok(amt) => { if amt > cdk::Amount::ZERO { log_info!(l, "Restored cashu mint: {}, amt: {amt}", w.mint_url); } - if let Err(e) = write_has_recovered(&store, true) { + if let Err(e) = write_has_recovered(&store, true).await { log_error!(l, "Failed to write has_recovered flag: {e:?}"); } }, @@ -688,6 +694,7 @@ impl Cashu { custom_records: vec![], lsp_fee_msats: None, }) + .await .map_err(|e| TrustedError::Other(format!("Failed to add event: {e}")))?; log_info!(logger, "Sent PaymentReceived event for mint quote: {}", mint_quote.id); diff --git a/orange-sdk/src/trusted_wallet/dummy.rs b/orange-sdk/src/trusted_wallet/dummy.rs index 9d080be..0dcd527 100644 --- a/orange-sdk/src/trusted_wallet/dummy.rs +++ b/orange-sdk/src/trusted_wallet/dummy.rs @@ -2,6 +2,7 @@ use crate::EventQueue; use crate::bitcoin::hashes::Hash; +use crate::runtime::Runtime; use crate::store::{PaymentId, TxMetadataStore, TxStatus}; use crate::trusted_wallet::{Payment, TrustedError, TrustedWalletInterface}; use bitcoin_payment_instructions::PaymentMethod; @@ -17,11 +18,10 @@ use ldk_node::{Event, Node}; use rand::RngCore; use std::env::temp_dir; use std::pin::Pin; +use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, RwLock}; use std::time::Duration; -use tokio::runtime::Runtime; -use tokio::sync::watch; +use tokio::sync::{RwLock, watch}; use uuid::Uuid; /// A dummy implementation of `TrustedWalletInterface` for testing purposes. @@ -44,8 +44,6 @@ pub struct DummyTrustedWalletExtraConfig { pub lsp: Arc, /// The Bitcoind node to connect to pub bitcoind: Arc, - /// The runtime to use for async tasks - pub rt: Arc, } impl DummyTrustedWallet { @@ -78,7 +76,7 @@ impl DummyTrustedWallet { let ldk_node = Arc::new(builder.build().unwrap()); - ldk_node.start_with_runtime(Arc::clone(&rt)).unwrap(); + ldk_node.start().unwrap(); let current_bal_msats = Arc::new(AtomicU64::new(0)); let payments: Arc>> = Arc::new(RwLock::new(vec![])); @@ -88,7 +86,7 @@ impl DummyTrustedWallet { let events_ref = Arc::clone(&ldk_node); let bal = Arc::clone(¤t_bal_msats); let pays = Arc::clone(&payments); - rt.spawn(async move { + rt.spawn_cancellable_background_task(async move { loop { let event = events_ref.next_event_async().await; match event { @@ -101,7 +99,7 @@ impl DummyTrustedWallet { // convert id let id = mangle_payment_id(payment_id.unwrap().0); - let mut payments = pays.write().unwrap(); + let mut payments = pays.write().await; let item = payments.iter_mut().find(|p| p.id == id); if let Some(payment) = item { payment.status = TxStatus::Completed; @@ -131,6 +129,7 @@ impl DummyTrustedWallet { payment_preimage: payment_preimage.unwrap(), // safe fee_paid_msat, }) + .await .unwrap(); } @@ -140,7 +139,7 @@ impl DummyTrustedWallet { // convert id let id = mangle_payment_id(payment_id.unwrap().0); - let mut payments = pays.write().unwrap(); + let mut payments = pays.write().await; let item = payments.iter().cloned().enumerate().find(|(_, p)| p.id == id); if let Some((idx, payment)) = item { // remove from list and refund balance @@ -162,6 +161,7 @@ impl DummyTrustedWallet { payment_hash, reason, }) + .await .unwrap(); } }, @@ -169,7 +169,7 @@ impl DummyTrustedWallet { // convert id let id = mangle_payment_id(payment_id.unwrap().0); - let mut payments = pays.write().unwrap(); + let mut payments = pays.write().await; // We create invoices on the fly without adding the payment to our list // We need to insert it into our payments list @@ -198,6 +198,7 @@ impl DummyTrustedWallet { custom_records: vec![], lsp_fee_msats: None, }) + .await .unwrap(); }, Event::PaymentForwarded { .. } => {}, @@ -205,6 +206,8 @@ impl DummyTrustedWallet { Event::ChannelPending { .. } => {}, Event::ChannelReady { .. } => {}, Event::ChannelClosed { .. } => {}, + Event::SplicePending { .. } => {}, + Event::SpliceFailed { .. } => {}, } println!("dummy: {event:?}"); if let Err(e) = events_ref.event_handled() { @@ -216,7 +219,7 @@ impl DummyTrustedWallet { // wait for ldk to be ready let iterations = if std::env::var("CI").is_ok() { 120 } else { 10 }; for _ in 0..iterations { - if ldk_node.status().is_listening { + if ldk_node.status().is_running { break; } tokio::time::sleep(Duration::from_secs(1)).await; @@ -307,7 +310,7 @@ impl TrustedWalletInterface for DummyTrustedWallet { fn list_payments( &self, ) -> Pin, TrustedError>> + Send + '_>> { - Box::pin(async move { Ok(self.payments.read().unwrap().clone()) }) + Box::pin(async move { Ok(self.payments.read().await.clone()) }) } fn estimate_fee( @@ -335,7 +338,7 @@ impl TrustedWalletInterface for DummyTrustedWallet { let id = self .ldk_node .bolt12_payment() - .send_using_amount(&offer, amount.milli_sats(), None, None) + .send_using_amount(&offer, amount.milli_sats(), None, None, None) .unwrap() .0; @@ -360,7 +363,7 @@ impl TrustedWalletInterface for DummyTrustedWallet { .as_secs(); // add to payments - let mut list = self.payments.write().unwrap(); + let mut list = self.payments.write().await; list.push(Payment { id, amount, diff --git a/orange-sdk/src/trusted_wallet/spark/mod.rs b/orange-sdk/src/trusted_wallet/spark/mod.rs index eb9b55f..af6f4a3 100644 --- a/orange-sdk/src/trusted_wallet/spark/mod.rs +++ b/orange-sdk/src/trusted_wallet/spark/mod.rs @@ -9,8 +9,8 @@ use crate::store::{PaymentId, TxMetadataStore, TxStatus}; use crate::trusted_wallet::{Payment, TrustedError, TrustedWalletInterface}; use crate::{Event, EventQueue, InitFailure, Seed, WalletConfig}; +use ldk_node::DynStore; use ldk_node::lightning::util::logger::Logger as _; -use ldk_node::lightning::util::persist::KVStore; use ldk_node::lightning::{log_debug, log_error, log_info, log_warn}; use ldk_node::lightning_invoice::Bolt11Invoice; use ldk_node::lightning_types::payment::{PaymentHash, PaymentPreimage}; @@ -28,12 +28,12 @@ use graduated_rebalancer::ReceivedLightningPayment; use tokio::sync::watch; +use crate::runtime::Runtime; use std::future::Future; use std::pin::Pin; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::runtime::Runtime; use uuid::Uuid; /// Configuration options for the Spark wallet. @@ -256,9 +256,9 @@ impl TrustedWalletInterface for Spark { impl Spark { /// Initialize a new Spark wallet instance with the given configuration. pub(crate) async fn init( - config: &WalletConfig, spark_config: SparkWalletConfig, - store: Arc, event_queue: Arc, - tx_metadata: TxMetadataStore, logger: Arc, runtime: Arc, + config: &WalletConfig, spark_config: SparkWalletConfig, store: Arc, + event_queue: Arc, tx_metadata: TxMetadataStore, logger: Arc, + runtime: Arc, ) -> Result { let spark_config: breez_sdk_spark::Config = spark_config.to_breez_config(config.network)?; @@ -294,7 +294,7 @@ impl Spark { log_info!(logger, "Added Spark event listener with ID: {listener_id}"); let w = Arc::clone(&spark_wallet); let mut shutdown_recv = shutdown_receiver.clone(); - runtime.spawn(async move { + runtime.spawn_background_task(async move { let _ = shutdown_recv.changed().await; w.remove_event_listener(&listener_id).await; }); @@ -341,12 +341,12 @@ impl EventListener for SparkEventHandler { log_info!(self.logger, "Spark wallet claimed deposits! {claimed_deposits:?}"); }, SdkEvent::PaymentSucceeded { payment } => { - if let Err(e) = self.handle_payment_succeeded(payment) { + if let Err(e) = self.handle_payment_succeeded(payment).await { log_error!(self.logger, "Failed to handle payment succeeded: {e:?}"); } }, SdkEvent::PaymentFailed { payment } => { - if let Err(e) = self.handle_payment_failed(payment) { + if let Err(e) = self.handle_payment_failed(payment).await { log_error!(self.logger, "Failed to handle payment succeeded: {e:?}"); } }, @@ -361,7 +361,7 @@ impl EventListener for SparkEventHandler { } impl SparkEventHandler { - fn handle_payment_succeeded( + async fn handle_payment_succeeded( &self, payment: breez_sdk_spark::Payment, ) -> Result<(), TrustedError> { log_info!(self.logger, "Spark payment succeeded: {payment:?}"); @@ -409,12 +409,14 @@ impl SparkEventHandler { ); } - self.event_queue.add_event(Event::PaymentSuccessful { - payment_id, - payment_hash: PaymentHash(payment_hash), - payment_preimage: PaymentPreimage(preimage), - fee_paid_msat: Some((payment.fees * 1_000) as u64), // convert to msats - })?; + self.event_queue + .add_event(Event::PaymentSuccessful { + payment_id, + payment_hash: PaymentHash(payment_hash), + payment_preimage: PaymentPreimage(preimage), + fee_paid_msat: Some((payment.fees * 1_000) as u64), // convert to msats + }) + .await?; self.payment_success_sender.send(()).unwrap(); }, @@ -437,13 +439,15 @@ impl SparkEventHandler { Some((payment.fees * 1_000) as u64) // convert to msats }; - self.event_queue.add_event(Event::PaymentReceived { - payment_id: PaymentId::Trusted(id), - payment_hash: PaymentHash(payment_hash), - amount_msat: (payment.amount * 1_000) as u64, // convert to msats - custom_records: vec![], - lsp_fee_msats, - })?; + self.event_queue + .add_event(Event::PaymentReceived { + payment_id: PaymentId::Trusted(id), + payment_hash: PaymentHash(payment_hash), + amount_msat: (payment.amount * 1_000) as u64, // convert to msats + custom_records: vec![], + lsp_fee_msats, + }) + .await?; }, _ => { log_debug!( @@ -458,7 +462,9 @@ impl SparkEventHandler { Ok(()) } - fn handle_payment_failed(&self, payment: breez_sdk_spark::Payment) -> Result<(), TrustedError> { + async fn handle_payment_failed( + &self, payment: breez_sdk_spark::Payment, + ) -> Result<(), TrustedError> { log_info!(self.logger, "Spark payment failed: {payment:?}"); let id = parse_payment_id(&payment.id)?; @@ -484,11 +490,13 @@ impl SparkEventHandler { TrustedError::Other(format!("Invalid payment_hash hex: {e:?}")) })?; - self.event_queue.add_event(Event::PaymentFailed { - payment_id, - payment_hash: Some(PaymentHash(payment_hash)), - reason: None, - })?; + self.event_queue + .add_event(Event::PaymentFailed { + payment_id, + payment_hash: Some(PaymentHash(payment_hash)), + reason: None, + }) + .await?; }, _ => { log_debug!(self.logger, "Unsupported payment details for Send: {payment:?}") diff --git a/orange-sdk/src/trusted_wallet/spark/spark_store.rs b/orange-sdk/src/trusted_wallet/spark/spark_store.rs index fe4197b..5524b59 100644 --- a/orange-sdk/src/trusted_wallet/spark/spark_store.rs +++ b/orange-sdk/src/trusted_wallet/spark/spark_store.rs @@ -2,13 +2,15 @@ use std::sync::Arc; -use crate::{KVStore, io}; +use crate::io; use breez_sdk_spark::{ DepositInfo, ListPaymentsRequest, Payment, PaymentDetails, PaymentMetadata, StorageError, UpdateDepositPayload, }; +use ldk_node::DynStore; use ldk_node::lightning::util::persist::KVSTORE_NAMESPACE_KEY_MAX_LEN; +use ldk_node::lightning::util::persist::KVStore; const SPARK_PRIMARY_NAMESPACE: &str = "spark"; const SPARK_CACHE_NAMESPACE: &str = "cache"; @@ -16,7 +18,7 @@ const SPARK_PAYMENTS_NAMESPACE: &str = "payment"; const SPARK_DEPOSITS_NAMESPACE: &str = "deposit"; #[derive(Clone)] -pub(crate) struct SparkStore(pub(crate) Arc); +pub(crate) struct SparkStore(pub(crate) Arc); /// The Spark sdk can produce keys that are too long, we just truncate them here fn sanitize_key(key: String) -> String { @@ -31,15 +33,23 @@ fn sanitize_key(key: String) -> String { impl breez_sdk_spark::Storage for SparkStore { async fn delete_cached_item(&self, key: String) -> Result<(), StorageError> { let key = sanitize_key(key); - self.0 - .remove(SPARK_PRIMARY_NAMESPACE, SPARK_CACHE_NAMESPACE, &key, false) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + KVStore::remove( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_CACHE_NAMESPACE, + &key, + false, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; Ok(()) } async fn get_cached_item(&self, key: String) -> Result, StorageError> { let key = sanitize_key(key); - match self.0.read(SPARK_PRIMARY_NAMESPACE, SPARK_CACHE_NAMESPACE, &key) { + match KVStore::read(self.0.as_ref(), SPARK_PRIMARY_NAMESPACE, SPARK_CACHE_NAMESPACE, &key) + .await + { Ok(bytes) => Ok(Some(String::from_utf8(bytes).map_err(|e| { StorageError::Serialization(format!("Invalid UTF-8 in cached item: {e:?}")) })?)), @@ -55,26 +65,36 @@ impl breez_sdk_spark::Storage for SparkStore { async fn set_cached_item(&self, key: String, value: String) -> Result<(), StorageError> { let key = sanitize_key(key); - self.0 - .write(SPARK_PRIMARY_NAMESPACE, SPARK_CACHE_NAMESPACE, &key, value.as_bytes()) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + KVStore::write( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_CACHE_NAMESPACE, + &key, + value.into_bytes(), + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; Ok(()) } async fn list_payments( &self, request: ListPaymentsRequest, ) -> Result, StorageError> { - let keys = self - .0 - .list(SPARK_PRIMARY_NAMESPACE, SPARK_PAYMENTS_NAMESPACE) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + let keys = + KVStore::list(self.0.as_ref(), SPARK_PRIMARY_NAMESPACE, SPARK_PAYMENTS_NAMESPACE) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; let mut payments = Vec::with_capacity(keys.len()); for key in keys { - let data = self - .0 - .read(SPARK_PRIMARY_NAMESPACE, SPARK_PAYMENTS_NAMESPACE, &key) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + let data = KVStore::read( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_PAYMENTS_NAMESPACE, + &key, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; let payment: breez_sdk_spark::Payment = serde_json::from_slice(&data) .map_err(|e| StorageError::Serialization(format!("{e:?}")))?; @@ -105,9 +125,15 @@ impl breez_sdk_spark::Storage for SparkStore { let data = serde_json::to_vec(&payment) .map_err(|e| StorageError::Serialization(format!("{e:?}")))?; - self.0 - .write(SPARK_PRIMARY_NAMESPACE, SPARK_PAYMENTS_NAMESPACE, &payment.id, &data) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + KVStore::write( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_PAYMENTS_NAMESPACE, + &payment.id, + data, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; Ok(()) } @@ -121,10 +147,10 @@ impl breez_sdk_spark::Storage for SparkStore { async fn get_payment_by_id( &self, id: String, ) -> Result { - let data = self - .0 - .read(SPARK_PRIMARY_NAMESPACE, SPARK_PAYMENTS_NAMESPACE, &id) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + let data = + KVStore::read(self.0.as_ref(), SPARK_PRIMARY_NAMESPACE, SPARK_PAYMENTS_NAMESPACE, &id) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; let payment: breez_sdk_spark::Payment = serde_json::from_slice(&data) .map_err(|e| StorageError::Serialization(format!("{e:?}")))?; @@ -176,33 +202,49 @@ impl breez_sdk_spark::Storage for SparkStore { let data = serde_json::to_vec(&info).map_err(|e| StorageError::Serialization(format!("{e:?}")))?; - self.0 - .write(SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE, &id, &data) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + KVStore::write( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_DEPOSITS_NAMESPACE, + &id, + data, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; Ok(()) } async fn delete_deposit(&self, txid: String, vout: u32) -> Result<(), StorageError> { let id = format!("{txid}:{vout}"); - self.0 - .remove(SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE, &id, false) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + KVStore::remove( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_DEPOSITS_NAMESPACE, + &id, + false, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; Ok(()) } async fn list_deposits(&self) -> Result, StorageError> { - let keys = self - .0 - .list(SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + let keys = + KVStore::list(self.0.as_ref(), SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; let mut deposits = Vec::with_capacity(keys.len()); for key in keys { - let data = self - .0 - .read(SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE, &key) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + let data = KVStore::read( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_DEPOSITS_NAMESPACE, + &key, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; let deposit: DepositInfo = serde_json::from_slice(&data) .map_err(|e| StorageError::Serialization(format!("{e:?}")))?; @@ -217,7 +259,14 @@ impl breez_sdk_spark::Storage for SparkStore { ) -> Result<(), StorageError> { let id = format!("{txid}:{vout}"); - let data = match self.0.read(SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE, &id) { + let data = match KVStore::read( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_DEPOSITS_NAMESPACE, + &id, + ) + .await + { Ok(data) => data, Err(e) => { if let io::ErrorKind::NotFound = e.kind() { @@ -245,9 +294,15 @@ impl breez_sdk_spark::Storage for SparkStore { let data = serde_json::to_vec(&deposit) .map_err(|e| StorageError::Serialization(format!("{e:?}")))?; - self.0 - .write(SPARK_PRIMARY_NAMESPACE, SPARK_DEPOSITS_NAMESPACE, &id, &data) - .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; + KVStore::write( + self.0.as_ref(), + SPARK_PRIMARY_NAMESPACE, + SPARK_DEPOSITS_NAMESPACE, + &id, + data, + ) + .await + .map_err(|e| StorageError::Implementation(format!("{e:?}")))?; Ok(()) } diff --git a/orange-sdk/tests/integration_tests.rs b/orange-sdk/tests/integration_tests.rs index 0d02b5e..79dd8ad 100644 --- a/orange-sdk/tests/integration_tests.rs +++ b/orange-sdk/tests/integration_tests.rs @@ -1,8 +1,6 @@ #![cfg(feature = "_test-utils")] -use crate::test_utils::{ - TestParams, build_test_nodes, generate_blocks, open_channel_from_lsp, wait_next_event, -}; +use crate::test_utils::{generate_blocks, open_channel_from_lsp, wait_for_tx, wait_next_event}; use bitcoin_payment_instructions::amount::Amount; use bitcoin_payment_instructions::http_resolver::HTTPHrnResolver; use bitcoin_payment_instructions::{ParseError, PaymentInstructions}; @@ -10,6 +8,7 @@ use ldk_node::NodeError; use ldk_node::bitcoin::Network; use ldk_node::lightning_invoice::{Bolt11InvoiceDescription, Description}; use ldk_node::payment::{ConfirmationStatus, PaymentDirection, PaymentStatus}; +use log::info; use orange_sdk::bitcoin::hashes::Hash; use orange_sdk::{Event, PaymentInfo, PaymentType, TxStatus, WalletError}; use std::sync::Arc; @@ -17,22 +16,23 @@ use std::time::Duration; mod test_utils; -#[test] -fn test_node_start() { - let TestParams { wallet, rt, .. } = build_test_nodes(); - - rt.block_on(async move { - let bal = wallet.get_balance().await.unwrap(); +#[tokio::test(flavor = "multi_thread")] +async fn test_node_start() { + test_utils::run_test(|params| async move { + let bal = params.wallet.get_balance().await.unwrap(); assert_eq!(bal.available_balance(), Amount::ZERO); assert_eq!(bal.pending_balance, Amount::ZERO); }) + .await; } -#[test] -fn test_receive_to_trusted() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +#[test_log::test] +async fn test_receive_to_trusted() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); assert_eq!(starting_bal.available_balance(), Amount::ZERO); assert_eq!(starting_bal.pending_balance, Amount::ZERO); @@ -78,14 +78,19 @@ fn test_receive_to_trusted() { Some(recv_amt), "Amount should equal received amount for trusted wallet (no fees deducted)" ); + + info!("test passed"); }) + .await; } -#[test] -fn test_pay_from_trusted() { - let TestParams { wallet, third_party, lsp, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_pay_from_trusted() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); + let lsp = Arc::clone(¶ms.lsp); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); assert_eq!(starting_bal.available_balance(), Amount::ZERO); assert_eq!(starting_bal.pending_balance, Amount::ZERO); @@ -160,13 +165,16 @@ fn test_pay_from_trusted() { pt => panic!("Payment type should be OutgoingLightningBolt11, got {pt:?}"), } }) + .await; } -#[test] -fn test_sweep_to_ln() { - let TestParams { wallet, lsp, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_sweep_to_ln() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let lsp = Arc::clone(¶ms.lsp); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); assert_eq!(starting_bal.available_balance(), Amount::ZERO); assert_eq!(starting_bal.pending_balance, Amount::ZERO); @@ -316,13 +324,15 @@ fn test_sweep_to_ln() { expected_total.milli_sats() ); }) + .await; } -#[test] -fn test_receive_to_ln() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_receive_to_ln() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { let recv_amt = open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; let txs = wallet.list_transactions().await.unwrap(); @@ -356,13 +366,19 @@ fn test_receive_to_ln() { fee_ratio * 100.0 ); }) + .await; } -#[test] -fn test_receive_to_onchain() { - let TestParams { wallet, lsp, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +#[test_log::test] +async fn test_receive_onchain() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let lsp = Arc::clone(¶ms.lsp); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); assert_eq!(starting_bal.available_balance(), Amount::ZERO); assert_eq!(starting_bal.pending_balance, Amount::ZERO); @@ -375,8 +391,10 @@ fn test_receive_to_onchain() { .send_to_address(&uri.address.unwrap(), recv_amt.sats().unwrap(), None) .unwrap(); + wait_for_tx(&electrsd.client, sent_txid).await; + // confirm transaction - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; // check we received on-chain, should be pending // wait for payment success @@ -421,9 +439,9 @@ fn test_receive_to_onchain() { // a rebalance should be initiated, we need to mine the channel opening transaction // for it to be confirmed and reflected in the wallet's history - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; tokio::time::sleep(Duration::from_secs(5)).await; // wait for sync - generate_blocks(&bitcoind, 6); // confirm the channel opening transaction + generate_blocks(&bitcoind, &electrsd, 6).await; // confirm the channel opening transaction tokio::time::sleep(Duration::from_secs(5)).await; // wait for sync // wait for rebalance to be initiated @@ -464,17 +482,130 @@ fn test_receive_to_onchain() { assert!(wallet.next_event().is_none()); }) + .await; } -fn run_test_pay_lightning_from_self_custody(amountless: bool) { - let TestParams { wallet, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +#[test_log::test] +async fn test_receive_to_onchain_with_channel() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let lsp = Arc::clone(¶ms.lsp); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); + + let start = open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; + + let starting_bal = wallet.get_balance().await.unwrap(); + // channel amt - opening fees + assert_eq!( + starting_bal.available_balance(), + start.saturating_sub(Amount::from_sats(2_000).unwrap()) + ); + assert_eq!(starting_bal.pending_balance, Amount::ZERO); + + let recv_amt = Amount::from_sats(300_000).unwrap(); + + let uri = wallet.get_single_use_receive_uri(Some(recv_amt)).await.unwrap(); + let sent_txid = third_party + .onchain_payment() + .send_to_address(&uri.address.unwrap(), recv_amt.sats().unwrap(), None) + .unwrap(); + + println!("Sent txid: {sent_txid}"); + + wait_for_tx(&electrsd.client, sent_txid).await; + + // confirm transaction + generate_blocks(&bitcoind, &electrsd, 6).await; + wallet.sync_ln_wallet().unwrap(); + + // check we received on-chain, should be pending + // wait for payment success + test_utils::wait_for_condition("pending balance to update", || async { + // onchain balance is always listed as pending until we splice it into the channel. + wallet.get_balance().await.unwrap().pending_balance == recv_amt + }) + .await; + + println!("waiting for onchain recv event"); + let event = wait_next_event(&wallet).await; + match event { + Event::OnchainPaymentReceived { txid, amount_sat, status, .. } => { + assert_eq!(txid, sent_txid); + assert_eq!(amount_sat, recv_amt.sats().unwrap()); + assert!(matches!(status, ConfirmationStatus::Confirmed { .. })); + }, + ev => panic!("Expected OnchainPaymentReceived event, got {ev:?}"), + } + + println!("waiting for splice pending event"); + let event = wait_next_event(&wallet).await; + match event { + Event::SplicePending { counterparty_node_id, .. } => { + assert_eq!(counterparty_node_id, lsp.node_id()); + }, + ev => panic!("Expected SplicePending event, got {ev:?}"), + } + + // confirm splice + generate_blocks(&bitcoind, &electrsd, 6).await; + tokio::time::sleep(Duration::from_secs(5)).await; + + let event = wait_next_event(&wallet).await; + match event { + Event::ChannelOpened { counterparty_node_id, .. } => { + assert_eq!(counterparty_node_id, lsp.node_id()); + }, + ev => panic!("Expected ChannelOpened event, got {ev:?}"), + } + + let txs = wallet.list_transactions().await.unwrap(); + assert_eq!(txs.len(), 2); + let tx = txs.into_iter().last().unwrap(); + + // Comprehensive validation for on-chain receive after rebalance + assert!(!tx.outbound, "Incoming payment should not be outbound"); + assert_eq!(tx.status, TxStatus::Completed, "Payment should be completed"); + assert_eq!( + tx.payment_type, + PaymentType::IncomingOnChain { txid: Some(sent_txid) }, + "Payment type should be IncomingOnChain with correct txid" + ); + assert_ne!(tx.time_since_epoch, Duration::ZERO, "Time should be set"); + assert_eq!(tx.amount, Some(recv_amt), "Amount should equal received amount"); + assert!( + tx.fee.unwrap() > Amount::ZERO, + "On-chain receive should have rebalance fees after channel opening" + ); + + // Validate fee is reasonable (should be less than 5% of received amount for rebalance) + let fee_ratio = tx.fee.unwrap().milli_sats() as f64 / recv_amt.milli_sats() as f64; + assert!( + fee_ratio < 0.05, + "Rebalance fee should be less than 5% of received amount, got {:.2}%", + fee_ratio * 100.0 + ); + + let next = wallet.next_event(); + assert!(next.is_none(), "Expected no more events, got {next:?}"); + }) + .await; +} + +async fn run_test_pay_lightning_from_self_custody(amountless: bool) { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { // get a channel so we can make a payment open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; // wait for sync - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; test_utils::wait_for_condition("wallet sync after channel open", || async { wallet.channels().iter().any(|a| a.confirmations.is_some_and(|c| c > 0) && a.is_usable) && third_party @@ -558,23 +689,27 @@ fn run_test_pay_lightning_from_self_custody(amountless: bool) { && p.direction == PaymentDirection::Inbound && p.amount_msat == Some(amount.milli_sats()))); }) + .await; } -#[test] -fn test_pay_lightning_from_self_custody() { - run_test_pay_lightning_from_self_custody(false); - run_test_pay_lightning_from_self_custody(true); +#[tokio::test(flavor = "multi_thread")] +async fn test_pay_lightning_from_self_custody() { + run_test_pay_lightning_from_self_custody(false).await; + run_test_pay_lightning_from_self_custody(true).await; } -fn run_test_pay_bolt12_from_self_custody(amountless: bool) { - let TestParams { wallet, bitcoind, third_party, rt, .. } = build_test_nodes(); +async fn run_test_pay_bolt12_from_self_custody(amountless: bool) { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { // get a channel so we can make a payment open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; // wait for sync - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; test_utils::wait_for_condition("wallet sync after channel open", || async { wallet.channels().iter().any(|a| a.confirmations.is_some_and(|c| c > 0) && a.is_usable) && third_party @@ -648,19 +783,23 @@ fn run_test_pay_bolt12_from_self_custody(amountless: bool) { && p.direction == PaymentDirection::Inbound && p.amount_msat == Some(amount.milli_sats()))); }) + .await; } -#[test] -fn test_pay_bolt12_from_self_custody() { - run_test_pay_bolt12_from_self_custody(false); - run_test_pay_bolt12_from_self_custody(true); +#[tokio::test(flavor = "multi_thread")] +async fn test_pay_bolt12_from_self_custody() { + run_test_pay_bolt12_from_self_custody(false).await; + run_test_pay_bolt12_from_self_custody(true).await; } -#[test] -fn test_pay_onchain_from_self_custody() { - let TestParams { wallet, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_pay_onchain_from_self_custody() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { // disable rebalancing so we have on-chain funds wallet.set_rebalance_enabled(false); @@ -680,7 +819,7 @@ fn test_pay_onchain_from_self_custody() { .unwrap(); // confirm tx - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; // wait for node to sync and see the balance update test_utils::wait_for_condition("wallet sync after on-chain receive", || async { @@ -696,8 +835,14 @@ fn test_pay_onchain_from_self_custody() { let info = PaymentInfo::build(instr, Some(send_amount)).unwrap(); wallet.pay(&info).await.unwrap(); + // sleep for a second to wait for proper broadcast + tokio::time::sleep(Duration::from_secs(1)).await; + // confirm the tx - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; + + // sleep for a second to wait for sync + tokio::time::sleep(Duration::from_secs(1)).await; // wait for payment to complete test_utils::wait_for_condition("on-chain payment completion", || async { @@ -759,13 +904,128 @@ fn test_pay_onchain_from_self_custody() { }) .await; }) + .await; +} + +#[tokio::test(flavor = "multi_thread")] +#[test_log::test] +async fn test_pay_onchain_from_channel() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); + + // get a channel so we can make a payment + let recv = open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; + + let starting_bal = wallet.get_balance().await.unwrap(); + assert_eq!( + starting_bal.available_balance(), + recv.saturating_sub(Amount::from_sats(2_000).unwrap()) + ); + assert_eq!(starting_bal.pending_balance, Amount::ZERO); + + // wait for sync + generate_blocks(&bitcoind, &electrsd, 6).await; + test_utils::wait_for_condition("wallet sync after channel open", || async { + wallet.channels().iter().any(|a| a.confirmations.is_some_and(|c| c > 0) && a.is_usable) + }) + .await; + + // get address from third party node + let addr = third_party.onchain_payment().new_address().unwrap(); + let send_amount = Amount::from_sats(10_000).unwrap(); + + let instr = wallet.parse_payment_instructions(addr.to_string().as_str()).await.unwrap(); + let info = PaymentInfo::build(instr, Some(send_amount)).unwrap(); + wallet.pay(&info).await.unwrap(); + + // sleep for a second to wait for proper broadcast + tokio::time::sleep(Duration::from_secs(1)).await; + + // confirm the tx + generate_blocks(&bitcoind, &electrsd, 6).await; + + // sleep for a second to wait for sync + tokio::time::sleep(Duration::from_secs(1)).await; + + // wait for payment to complete + test_utils::wait_for_condition("on-chain payment completion", || async { + let payments = wallet.list_transactions().await.unwrap(); + let payment = payments.into_iter().find(|p| p.outbound); + if payment.as_ref().is_some_and(|p| p.status == TxStatus::Failed) { + panic!("Payment failed"); + } + payment.is_some_and(|p| p.status == TxStatus::Completed) + }) + .await; + + // check the payment is correct + let payments = wallet.list_transactions().await.unwrap(); + let payment = payments.into_iter().find(|p| p.outbound).unwrap(); + + // Comprehensive validation for outgoing on-chain payment + assert_eq!(payment.amount, Some(send_amount), "Amount should equal sent amount"); + assert!( + payment.fee.is_some_and(|f| f > Amount::ZERO), + "On-chain payment should have non-zero fees" + ); + assert!(payment.outbound, "Outgoing payment should be outbound"); + assert!( + matches!(payment.payment_type, PaymentType::OutgoingOnChain { .. }), + "Payment type should be OutgoingOnChain" + ); + assert_eq!(payment.status, TxStatus::Completed, "Payment should be completed"); + assert_ne!(payment.time_since_epoch, Duration::ZERO, "Time should be set"); + + // Validate fee is reasonable for on-chain (should be less than 1% of sent amount) + let fee_ratio = payment.fee.unwrap().milli_sats() as f64 / send_amount.milli_sats() as f64; + assert!( + fee_ratio < 0.01, + "On-chain fee should be less than 1% of sent amount, got {:.2}%", + fee_ratio * 100.0 + ); + + // Check that payment_type contains txid for completed payments + if let PaymentType::OutgoingOnChain { txid } = &payment.payment_type { + assert!(txid.is_some(), "Completed on-chain payment should have txid"); + } + + // check balance left our wallet + let bal = wallet.get_balance().await.unwrap(); + // fixme change to exact match once we have the real feee + assert!( + bal.available_balance() + < starting_bal + .available_balance() + .saturating_sub(send_amount) + .saturating_sub(payment.fee.unwrap()) + ); + + // Wait for third party node to receive it + test_utils::wait_for_condition("on-chain payment received", || async { + let payments = third_party.list_payments(); + payments.iter().any(|p| { + p.status == PaymentStatus::Succeeded + && p.direction == PaymentDirection::Inbound + && p.amount_msat == Some(send_amount.milli_sats()) + }) + }) + .await; + }) + .await; } -#[test] -fn test_force_close_handling() { - let TestParams { wallet, lsp, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_force_close_handling() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let lsp = Arc::clone(¶ms.lsp); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); assert_eq!(starting_bal.available_balance(), Amount::ZERO); assert_eq!(starting_bal.pending_balance, Amount::ZERO); @@ -777,7 +1037,7 @@ fn test_force_close_handling() { open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; // mine some blocks to ensure the channel is confirmed - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; // get channel details let channel = lsp @@ -803,13 +1063,18 @@ fn test_force_close_handling() { let rebalancing = wallet.get_rebalance_enabled(); assert!(!rebalancing); }) + .await; } -#[test] -fn test_close_all_channels() { - let TestParams { wallet, lsp, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_close_all_channels() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let lsp = Arc::clone(¶ms.lsp); + let bitcoind = Arc::clone(¶ms.bitcoind); + let third_party = Arc::clone(¶ms.third_party); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); assert_eq!(starting_bal.available_balance(), Amount::ZERO); assert_eq!(starting_bal.pending_balance, Amount::ZERO); @@ -821,7 +1086,7 @@ fn test_close_all_channels() { open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; // mine some blocks to ensure the channel is confirmed - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; // init closing all channels wallet.close_channels().unwrap(); @@ -839,13 +1104,15 @@ fn test_close_all_channels() { let rebalancing = wallet.get_rebalance_enabled(); assert!(!rebalancing); }) + .await; } -#[test] -fn test_threshold_boundary_trusted_balance_limit() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_threshold_boundary_trusted_balance_limit() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // we're not testing rebalancing here, so disable it to keep things simple // on slow CI this can cause tests to fail if rebalancing kicks in wallet.set_rebalance_enabled(false); @@ -920,13 +1187,15 @@ fn test_threshold_boundary_trusted_balance_limit() { "Payment above limit should use Lightning with fees" ); }) + .await; } -#[test] -fn test_threshold_boundary_rebalance_min() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_threshold_boundary_rebalance_min() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { let starting_bal = wallet.get_balance().await.unwrap(); let tunables = wallet.get_tunables(); let rebalance_min = tunables.rebalance_min; @@ -943,7 +1212,7 @@ fn test_threshold_boundary_rebalance_min() { .await; test_utils::wait_for_condition("wait for transaction", || async { - wallet.list_transactions().await.unwrap().len() >= 1 + !wallet.list_transactions().await.unwrap().is_empty() }) .await; @@ -996,13 +1265,14 @@ fn test_threshold_boundary_rebalance_min() { "Total balance should still be below trusted_balance_limit" ); }) + .await; } -#[test] -fn test_threshold_boundary_onchain_receive_threshold() { - let TestParams { wallet, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_threshold_boundary_onchain_receive_threshold() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); - rt.block_on(async move { let tunables = wallet.get_tunables(); let onchain_threshold = tunables.onchain_receive_threshold; @@ -1064,13 +1334,14 @@ fn test_threshold_boundary_onchain_receive_threshold() { ); } }) + .await; } -#[test] -fn test_threshold_combinations_and_edge_cases() { - let TestParams { wallet, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_threshold_combinations_and_edge_cases() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); - rt.block_on(async move { let tunables = wallet.get_tunables(); // Test edge case: ensure thresholds are properly ordered @@ -1127,13 +1398,15 @@ fn test_threshold_combinations_and_edge_cases() { ); } }) + .await; } -#[test] -fn test_invalid_payment_instructions() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_invalid_payment_instructions() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // Test 1: Payment with insufficient balance let amount = Amount::from_sats(1_000_000).unwrap(); // 1 BTC - more than we have let desc = Bolt11InvoiceDescription::Direct(Description::empty()); @@ -1191,13 +1464,15 @@ fn test_invalid_payment_instructions() { let txs = wallet.list_transactions().await.unwrap(); assert_eq!(txs.len(), 0, "Failed payments should not be recorded in transaction list"); }) + .await; } -#[test] -fn test_payment_with_expired_invoice() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_payment_with_expired_invoice() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // Add some balance first so the payment can theoretically succeed if not expired let initial_amount = Amount::from_sats(5000).unwrap(); let uri = wallet.get_single_use_receive_uri(Some(initial_amount)).await.unwrap(); @@ -1222,13 +1497,16 @@ fn test_payment_with_expired_invoice() { let parse_result = wallet.parse_payment_instructions(invoice.to_string().as_str()).await; assert!(matches!(parse_result.unwrap_err(), ParseError::InstructionsExpired)); }) + .await; } -#[test] -fn test_payment_network_mismatch() { - let TestParams { wallet, bitcoind, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_payment_network_mismatch() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let bitcoind = Arc::clone(¶ms.bitcoind); + let electrsd = Arc::clone(¶ms.electrsd); - rt.block_on(async move { // disable rebalancing so we have on-chain funds wallet.set_rebalance_enabled(false); @@ -1244,7 +1522,7 @@ fn test_payment_network_mismatch() { .unwrap(); // confirm tx - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; test_utils::wait_for_condition("wallet sync after on-chain receive", || async { wallet.get_balance().await.unwrap().pending_balance >= recv_amount }) @@ -1263,10 +1541,14 @@ fn test_payment_network_mismatch() { ); // now force a correct parsing to ensure we fail when trying to pay - let instr = - PaymentInstructions::parse(wrong_network, Network::Bitcoin, &HTTPHrnResolver, true) - .await - .unwrap(); + let instr = PaymentInstructions::parse( + wrong_network, + Network::Bitcoin, + &HTTPHrnResolver::new(), + true, + ) + .await + .unwrap(); // If it parsed, trying to pay should fail due to network mismatch let amount = Amount::from_sats(1000).unwrap(); @@ -1277,18 +1559,22 @@ fn test_payment_network_mismatch() { "Payment to wrong network address should fail with LDK error, got {pay_result:?}" ); }) + .await; } -#[test] -fn test_concurrent_payments() { - let TestParams { wallet, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_concurrent_payments() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let bitcoind = Arc::clone(¶ms.bitcoind); + let electrsd = Arc::clone(¶ms.electrsd); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // First, build up sufficient balance for concurrent sending let _channel_amount = open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; // Wait for sync - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; test_utils::wait_for_condition("wallet sync after channel open", || async { wallet.channels().iter().any(|a| a.confirmations.is_some_and(|c| c > 0) && a.is_usable) && third_party @@ -1469,13 +1755,15 @@ fn test_concurrent_payments() { "Concurrent transaction queries should return same count" ); }) + .await; } -#[test] -fn test_concurrent_receive_operations() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_concurrent_receive_operations() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { let amount = Amount::from_sats(1000).unwrap(); // Test: Generate multiple receive URIs concurrently @@ -1498,9 +1786,7 @@ fn test_concurrent_receive_operations() { // Wait for first payment to complete test_utils::wait_for_condition("first payment to succeed", || async { - third_party - .payment(&payment_id_1) - .map_or(false, |p| p.status == PaymentStatus::Succeeded) + third_party.payment(&payment_id_1).is_some_and(|p| p.status == PaymentStatus::Succeeded) }) .await; @@ -1509,9 +1795,7 @@ fn test_concurrent_receive_operations() { // Wait for second payment to complete test_utils::wait_for_condition("second payment to succeed", || async { - third_party - .payment(&payment_id_2) - .map_or(false, |p| p.status == PaymentStatus::Succeeded) + third_party.payment(&payment_id_2).is_some_and(|p| p.status == PaymentStatus::Succeeded) }) .await; @@ -1530,13 +1814,15 @@ fn test_concurrent_receive_operations() { let incoming_count = txs.iter().filter(|tx| !tx.outbound).count(); assert_eq!(incoming_count, 2, "Should have exactly 2 incoming transactions"); }) + .await; } -#[test] -fn test_balance_consistency_under_load() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_balance_consistency_under_load() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // Add some initial balance let initial_amount = Amount::from_sats(10000).unwrap(); let uri = wallet.get_single_use_receive_uri(Some(initial_amount)).await.unwrap(); @@ -1583,13 +1869,14 @@ fn test_balance_consistency_under_load() { ); } }) + .await; } -#[test] -fn test_invalid_tunables_relationships() { - let TestParams { wallet, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_invalid_tunables_relationships() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); - rt.block_on(async move { let current_tunables = wallet.get_tunables(); // Test 1: Verify default tunables are valid @@ -1674,13 +1961,14 @@ fn test_invalid_tunables_relationships() { ); } }) + .await; } -#[test] -fn test_extreme_amount_handling() { - let TestParams { wallet, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_extreme_amount_handling() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); - rt.block_on(async move { // Test 1: Large but reasonable Bitcoin amount let large_reasonable = Amount::from_sats(1_000_000).unwrap(); // 1M sats = 0.01 BTC let uri_result = wallet.get_single_use_receive_uri(Some(large_reasonable)).await; @@ -1744,13 +2032,14 @@ fn test_extreme_amount_handling() { } // On-chain address depends on threshold, not msat precision }) + .await; } -#[test] -fn test_wallet_configuration_validation() { - let TestParams { wallet, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_wallet_configuration_validation() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); - rt.block_on(async move { // Test 1: Verify wallet is using expected network // This is more of a sanity check since we can't easily test invalid networks // without creating new wallets @@ -1809,13 +2098,15 @@ fn test_wallet_configuration_validation() { "Address inclusion should be consistent" ); }) + .await; } -#[test] -fn test_edge_case_payment_instruction_parsing() { - let TestParams { wallet, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_edge_case_payment_instruction_parsing() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // Test 1: Empty strings let empty_result = wallet.parse_payment_instructions("").await; assert!( @@ -1863,18 +2154,23 @@ fn test_edge_case_payment_instruction_parsing() { assert!(result.is_ok(), "Failed to parse payment instructions"); } }) + .await; } -#[test] -fn test_lsp_connectivity_fallback() { - let TestParams { wallet, lsp, bitcoind, third_party, rt, .. } = build_test_nodes(); +#[tokio::test(flavor = "multi_thread")] +async fn test_lsp_connectivity_fallback() { + test_utils::run_test(|params| async move { + let wallet = Arc::clone(¶ms.wallet); + let lsp = Arc::clone(¶ms.lsp); + let bitcoind = Arc::clone(¶ms.bitcoind); + let electrsd = Arc::clone(¶ms.electrsd); + let third_party = Arc::clone(¶ms.third_party); - rt.block_on(async move { // open a channel with the LSP open_channel_from_lsp(&wallet, Arc::clone(&third_party)).await; // confirm channel - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; test_utils::wait_for_condition("wallet sync after channel open", || async { wallet.channels().iter().any(|a| a.confirmations.is_some_and(|c| c > 0) && a.is_usable) && third_party @@ -1940,5 +2236,6 @@ fn test_lsp_connectivity_fallback() { "Small amount should still generate a valid invoice even with LSP offline" ); assert!(uri_small.from_trusted); - }); + }) + .await; } diff --git a/orange-sdk/tests/test_utils.rs b/orange-sdk/tests/test_utils.rs index 7293bf3..e516ee4 100644 --- a/orange-sdk/tests/test_utils.rs +++ b/orange-sdk/tests/test_utils.rs @@ -9,11 +9,14 @@ use cdk::types::FeeReserve; use cdk_ldk_node::{BitcoinRpcConfig, GossipSource}; use corepc_node::client::bitcoin::Network; use corepc_node::{Conf, Node as Bitcoind, get_available_port}; +use electrsd::ElectrsD; +use electrsd::electrum_client::ElectrumApi; use ldk_node::bitcoin::hashes::Hash; use ldk_node::lightning::ln::msgs::SocketAddress; use ldk_node::liquidity::LSPS2ServiceConfig; use ldk_node::payment::PaymentStatus; use ldk_node::{Node, bitcoin}; +use orange_sdk::bitcoin::Txid; #[cfg(not(feature = "_cashu-tests"))] use orange_sdk::trusted_wallet::dummy::DummyTrustedWalletExtraConfig; use orange_sdk::{ @@ -28,7 +31,6 @@ use std::net::SocketAddr; use std::str::FromStr; use std::sync::Arc; use std::time::Duration; -use tokio::runtime::Runtime; use uuid::Uuid; /// Waits for an async condition to be true, polling at a specified interval until a timeout. @@ -62,17 +64,26 @@ pub async fn wait_next_event(wallet: &orange_sdk::Wallet) -> orange_sdk::Event { event } -fn create_bitcoind(uuid: Uuid) -> Bitcoind { +async fn create_bitcoind(uuid: Uuid) -> (Arc, Arc) { let mut conf = Conf::default(); - conf.args.push("-txindex"); - conf.args.push("-rpcworkqueue=100"); + // conf.args.push("-txindex"); + conf.args.push("-rest"); + conf.args.push("-rpcworkqueue=200"); + conf.args.push("-fallbackfee=0.00002"); conf.staticdir = Some(temp_dir().join(format!("orange-test-{uuid}/bitcoind"))); let bitcoind = Bitcoind::with_conf(corepc_node::downloaded_exe_path().unwrap(), &conf) - .expect(&format!("Failed to start bitcoind for test {uuid}")); + .unwrap_or_else(|_| panic!("Failed to start bitcoind for test {uuid}")); - // Wait for bitcoind to be ready before returning + // Wait for bitcoind to be ready before electrsd starts wait_for_bitcoind_ready(&bitcoind); + let mut electrsd_conf = electrsd::Conf::default(); + electrsd_conf.http_enabled = true; + electrsd_conf.network = "regtest"; + let electrsd = + ElectrsD::with_conf(electrsd::downloaded_exe_path().unwrap(), &bitcoind, &electrsd_conf) + .unwrap_or_else(|_| panic!("Failed to start electrsd for test {uuid}")); + // mine 101 blocks to get some spendable funds, split it up into multiple calls // to avoid potentially hitting RPC timeouts on slower CI systems let address = bitcoind.client.new_address().unwrap(); @@ -80,7 +91,9 @@ fn create_bitcoind(uuid: Uuid) -> Bitcoind { let _block_hashes = bitcoind.client.generate_to_address(1, &address).unwrap(); } - bitcoind + wait_for_block(&electrsd.client, 101).await; + + (Arc::new(bitcoind), Arc::new(electrsd)) } fn wait_for_bitcoind_ready(bitcoind: &Bitcoind) { @@ -104,15 +117,21 @@ fn wait_for_bitcoind_ready(bitcoind: &Bitcoind) { } } -pub fn generate_blocks(bitcoind: &Bitcoind, num: usize) { +pub async fn generate_blocks(bitcoind: &Bitcoind, electrs: &ElectrsD, num: usize) { + let blockchain_info = + bitcoind.client.get_blockchain_info().expect("failed to get blockchain info"); + let cur_height = blockchain_info.blocks; + let address = bitcoind.client.new_address().unwrap(); let _block_hashes = bitcoind .client .generate_to_address(num, &address) - .expect(&format!("failed to generate {num} blocks")); + .unwrap_or_else(|_| panic!("failed to generate {num} blocks")); + + wait_for_block(&electrs.client, cur_height as usize + num).await; } -fn create_lsp(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc { +fn create_lsp(uuid: Uuid, bitcoind: &Bitcoind) -> Arc { let mut builder = ldk_node::Builder::new(); builder.set_network(Network::Regtest); let mut seed: [u8; 64] = [0; 64]; @@ -143,6 +162,7 @@ fn create_lsp(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc { min_channel_lifetime: 10_000, min_channel_opening_fee_msat: 0, max_client_to_self_delay: 1024, + client_trusts_lsp: true, }; builder.set_liquidity_provider_lsps2(lsps2_service_config); @@ -153,10 +173,10 @@ fn create_lsp(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc { let ldk_node = Arc::new(builder.build().unwrap()); - ldk_node.start_with_runtime(Arc::clone(&rt)).unwrap(); + ldk_node.start().unwrap(); let events_ref = Arc::clone(&ldk_node); - rt.spawn(async move { + tokio::spawn(async move { loop { let event = events_ref.next_event_async().await; println!("LSP: {event:?}"); @@ -169,7 +189,7 @@ fn create_lsp(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc { ldk_node } -fn create_third_party(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc { +fn create_third_party(uuid: Uuid, bitcoind: &Bitcoind) -> Arc { let mut builder = ldk_node::Builder::new(); builder.set_network(Network::Regtest); let mut seed: [u8; 64] = [0; 64]; @@ -197,10 +217,10 @@ fn create_third_party(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc< let ldk_node = Arc::new(builder.build().unwrap()); - ldk_node.start_with_runtime(Arc::clone(&rt)).unwrap(); + ldk_node.start().unwrap(); let events_ref = Arc::clone(&ldk_node); - rt.spawn(async move { + tokio::spawn(async move { loop { let event = events_ref.next_event_async().await; println!("3rd party: {event:?}"); @@ -213,60 +233,88 @@ fn create_third_party(rt: Arc, uuid: Uuid, bitcoind: &Bitcoind) -> Arc< ldk_node } -fn fund_node(node: &Node, bitcoind: &Bitcoind) { +async fn fund_node(node: &Node, bitcoind: &Bitcoind, electrsd: &ElectrsD) { let addr = node.onchain_payment().new_address().unwrap(); - bitcoind.client.send_to_address(&addr, bitcoin::Amount::from_btc(1.0).unwrap()).unwrap(); - generate_blocks(bitcoind, 6); + let res = + bitcoind.client.send_to_address(&addr, bitcoin::Amount::from_btc(1.0).unwrap()).unwrap(); + wait_for_tx(&electrsd.client, res.txid().unwrap()).await; + generate_blocks(bitcoind, electrsd, 6).await; } +#[derive(Clone)] pub struct TestParams { - pub wallet: orange_sdk::Wallet, + pub wallet: Arc, pub lsp: Arc, pub third_party: Arc, pub bitcoind: Arc, - pub rt: Arc, + pub electrsd: Arc, #[cfg(feature = "_cashu-tests")] pub _mint: Arc, } -pub fn build_test_nodes() -> TestParams { - let test_id = Uuid::now_v7(); - let bitcoind = Arc::new(create_bitcoind(test_id)); +impl TestParams { + async fn stop(&self) { + self.wallet.stop().await; + + #[cfg(feature = "_cashu-tests")] + let _ = self._mint.stop().await; + + let _ = self.lsp.stop(); + let _ = self.third_party.stop(); + } +} + +/// Runs a test with automatically managed TestParams lifecycle. +/// The test closure receives TestParams and must return it when done. +/// Cleanup happens automatically after the test completes. +pub async fn run_test(test: F) +where + F: FnOnce(TestParams) -> Fut, + Fut: Future, +{ + let params = build_test_nodes().await; + + println!("=== test start ==="); + + // Run the test and get params back + test(params.clone()).await; + + // Always clean up + params.stop().await; +} - let rt = Arc::new(tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap()); +async fn build_test_nodes() -> TestParams { + let test_id = Uuid::now_v7(); + let (bitcoind, electrsd) = create_bitcoind(test_id).await; - let lsp = create_lsp(Arc::clone(&rt), test_id, &bitcoind); - fund_node(&lsp, &bitcoind); - let third_party = create_third_party(Arc::clone(&rt), test_id, &bitcoind); + let lsp = create_lsp(test_id, &bitcoind); + fund_node(&lsp, &bitcoind, &electrsd).await; + let third_party = create_third_party(test_id, &bitcoind); let start_bal = third_party.list_balances().total_onchain_balance_sats; - fund_node(&third_party, &bitcoind); + fund_node(&third_party, &bitcoind, &electrsd).await; // wait for node to sync (needs blocking wait as we are not in async context here) let third = Arc::clone(&third_party); - rt.block_on(async move { - wait_for_condition("third_party node sync after funding", || { - let res = third.list_balances().total_onchain_balance_sats > start_bal; - async move { res } - }) - .await; - }); + wait_for_condition("third_party node sync after funding", || { + let res = third.list_balances().total_onchain_balance_sats > start_bal; + async move { res } + }) + .await; let lsp_listen = lsp.listening_addresses().unwrap().first().unwrap().clone(); // open a channel from payer to LSP third_party.open_channel(lsp.node_id(), lsp_listen.clone(), 10_000_000, None, None).unwrap(); wait_for_tx_broadcast(&bitcoind); - generate_blocks(&bitcoind, 6); + generate_blocks(&bitcoind, &electrsd, 6).await; // wait for channel ready (needs blocking wait as we are not in async context here) let third_party_clone = Arc::clone(&third_party); - rt.block_on(async move { - wait_for_condition("channel to become usable", || { - let res = third_party_clone.list_channels().first().is_some_and(|c| c.is_usable); - async move { res } - }) - .await; - }); + wait_for_condition("channel to become usable", || { + let res = third_party_clone.list_channels().first().is_some_and(|c| c.is_usable); + async move { res } + }) + .await; // make sure it actually became usable assert!(third_party.list_channels().first().unwrap().is_usable); @@ -276,19 +324,19 @@ pub fn build_test_nodes() -> TestParams { rand::thread_rng().fill_bytes(&mut seed); #[cfg(not(feature = "_cashu-tests"))] - let wallet: orange_sdk::Wallet = { + let wallet = { let dummy_wallet_config = DummyTrustedWalletExtraConfig { uuid: test_id, lsp: Arc::clone(&lsp), bitcoind: Arc::clone(&bitcoind), - rt: Arc::clone(&rt), }; let tmp = temp_dir().join(format!("orange-test-{test_id}/ldk-node")); - let cookie = bitcoind.params.get_cookie_values().unwrap().unwrap(); - let rt_clone = Arc::clone(&rt); - let bitcoind_port = bitcoind.params.rpc_socket.port(); + // take esplora url and just get the port, as we know it's running on localhost + let base_url = electrsd.esplora_url.as_ref().unwrap(); + let port = base_url.split(':').next_back().unwrap(); + let esplora_url = format!("http://localhost:{port}"); let wallet_config = WalletConfig { storage_config: StorageConfig::LocalSQLite(tmp.to_str().unwrap().to_string()), @@ -296,18 +344,14 @@ pub fn build_test_nodes() -> TestParams { scorer_url: None, rgs_url: None, tunables: Tunables::default(), - chain_source: ChainSource::BitcoindRPC { - host: "127.0.0.1".to_string(), - port: bitcoind_port, - user: cookie.user, - password: cookie.password, - }, + chain_source: ChainSource::Esplora { url: esplora_url, username: None, password: None }, lsp: (lsp_listen, lsp_node_id, None), network: Network::Regtest, seed: Seed::Seed64(seed), extra_config: ExtraConfig::Dummy(dummy_wallet_config), }; - rt.block_on(async move { Wallet::new_with_runtime(rt_clone, wallet_config).await.unwrap() }) + + Arc::new(Wallet::new(wallet_config).await.unwrap()) }; #[cfg(feature = "_cashu-tests")] @@ -332,7 +376,7 @@ pub fn build_test_nodes() -> TestParams { tmp.to_str().unwrap().to_string(), FeeReserve { min_fee_reserve: Default::default(), percent_fee_reserve: 0.0 }, vec![cdk_addr.into()], - Some(rt.clone()), + None, ) .unwrap(); let cdk = Arc::new(cdk); @@ -345,7 +389,7 @@ pub fn build_test_nodes() -> TestParams { let bitcoind_clone = Arc::clone(&bitcoind); let lsp_listen_clone = lsp_listen.clone(); - let mint = rt.block_on(async move { + let mint = { // build mint let mem_db = Arc::new(cdk_sqlite::mint::memory::empty().await.unwrap()); let mut mint_seed: [u8; 64] = [0; 64]; @@ -394,7 +438,7 @@ pub fn build_test_nodes() -> TestParams { .client .send_to_address(&addr, bitcoin::Amount::from_btc(1.0).unwrap()) .unwrap(); - generate_blocks(&bitcoind_clone, 6); + generate_blocks(&bitcoind_clone, &electrsd, 6).await; // wait for cdk node to sync wait_for_condition("cdk node sync after funding", || { @@ -403,11 +447,18 @@ pub fn build_test_nodes() -> TestParams { }) .await; + let lsp_listen = lsp_listen_clone.to_string(); cdk.node() - .open_channel(lsp_node_id, lsp_listen_clone, 10_000_000, Some(5_000_000_000), None) + .open_channel( + lsp_node_id, + lsp_listen.parse().unwrap(), + 10_000_000, + Some(5_000_000_000), + None, + ) .unwrap(); wait_for_tx_broadcast(&bitcoind_clone); - generate_blocks(&bitcoind_clone, 6); + generate_blocks(&bitcoind_clone, &electrsd, 6).await; // wait for sync/channel ready wait_for_condition("cdk channel to become usable", || { @@ -417,9 +468,12 @@ pub fn build_test_nodes() -> TestParams { .await; mint - }); + }; - let rt_clone = Arc::clone(&rt); + // take esplora url and just get the port, as we know it's running on localhost + let base_url = electrsd.esplora_url.as_ref().unwrap(); + let port = base_url.split(':').next_back().unwrap(); + let esplora_url = format!("http://localhost:{port}"); let tmp = temp_dir().join(format!("orange-test-{test_id}/wallet")); let wallet_config = WalletConfig { @@ -427,12 +481,7 @@ pub fn build_test_nodes() -> TestParams { logger_type: LoggerType::LogFacade, scorer_url: None, tunables: Tunables::default(), - chain_source: ChainSource::BitcoindRPC { - host: "127.0.0.1".to_string(), - port: bitcoind_port, - user: cookie.user, - password: cookie.password, - }, + chain_source: ChainSource::Esplora { url: esplora_url, username: None, password: None }, lsp: (lsp_listen, lsp_node_id, None), rgs_url: None, network: Network::Regtest, @@ -442,16 +491,13 @@ pub fn build_test_nodes() -> TestParams { unit: orange_sdk::CurrencyUnit::Sat, }), }; - let wallet = - rt.block_on( - async move { Wallet::new_with_runtime(rt_clone, wallet_config).await.unwrap() }, - ); + let wallet = Arc::new(Wallet::new(wallet_config).await.unwrap()); - return TestParams { wallet, lsp, third_party, bitcoind, rt, _mint: mint }; + return TestParams { wallet, lsp, third_party, bitcoind, electrsd, _mint: mint }; }; #[cfg(not(feature = "_cashu-tests"))] - TestParams { wallet, lsp, third_party, bitcoind, rt } + TestParams { wallet, lsp, third_party, bitcoind, electrsd } } pub async fn open_channel_from_lsp(wallet: &orange_sdk::Wallet, payer: Arc) -> Amount { @@ -478,10 +524,10 @@ pub async fn open_channel_from_lsp(wallet: &orange_sdk::Wallet, payer: Arc }) .await; - let event = wait_next_event(&wallet).await; + let event = wait_next_event(wallet).await; assert!(matches!(event, orange_sdk::Event::ChannelOpened { .. })); - let event = wait_next_event(&wallet).await; + let event = wait_next_event(wallet).await; match event { orange_sdk::Event::PaymentReceived { payment_hash, amount_msat, lsp_fee_msats, .. } => { assert!(lsp_fee_msats.is_some()); // we expect a fee to be paid for opening a channel @@ -506,3 +552,59 @@ fn wait_for_tx_broadcast(bitcoind: &Bitcoind) { std::thread::sleep(Duration::from_millis(250)); } } + +pub(crate) async fn wait_for_block(electrs: &E, min_height: usize) { + let mut header = match electrs.block_headers_subscribe() { + Ok(header) => header, + Err(_) => { + // While subscribing should succeed the first time around, we ran into some cases where + // it didn't. Since we can't proceed without subscribing, we try again after a delay + // and panic if it still fails. + tokio::time::sleep(Duration::from_secs(3)).await; + electrs.block_headers_subscribe().expect("failed to subscribe to block headers") + }, + }; + loop { + if header.height >= min_height { + break; + } + header = exponential_backoff_poll(|| { + electrs.ping().expect("failed to ping electrs"); + electrs.block_headers_pop().expect("failed to pop block header") + }) + .await; + } +} + +pub(crate) async fn wait_for_tx(electrs: &E, txid: Txid) { + if electrs.transaction_get(&txid).is_ok() { + return; + } + + exponential_backoff_poll(|| { + electrs.ping().unwrap(); + electrs.transaction_get(&txid).ok() + }) + .await; +} + +pub(crate) async fn exponential_backoff_poll(mut poll: F) -> T +where + F: FnMut() -> Option, +{ + let mut delay = Duration::from_millis(64); + let mut tries = 0; + loop { + match poll() { + Some(data) => break data, + None if delay.as_millis() < 512 => { + delay = delay.mul_f32(2.0); + }, + + None => {}, + } + assert!(tries < 20, "Reached max tries."); + tries += 1; + tokio::time::sleep(delay).await; + } +}