From c85f9c1286ad94809b9639dc49f2532bfc2872e1 Mon Sep 17 00:00:00 2001 From: Andrzej Haczewski Date: Thu, 18 Jun 2026 17:55:00 +0200 Subject: [PATCH] lore-client: fix QUIC fallback across resolved addresses Signed-off-by: Andrzej Haczewski --- lore-transport/src/quic/client.rs | 362 +++++++++++++++++++++++++++--- 1 file changed, 333 insertions(+), 29 deletions(-) diff --git a/lore-transport/src/quic/client.rs b/lore-transport/src/quic/client.rs index 3a74bc4..36f8679 100644 --- a/lore-transport/src/quic/client.rs +++ b/lore-transport/src/quic/client.rs @@ -16,7 +16,9 @@ use std::time::Instant; use async_trait::async_trait; use bytes::Bytes; +use futures::StreamExt; use futures::TryFutureExt; +use futures::stream::FuturesUnordered; use lore_base::error::Disconnected; use lore_base::error::NotAuthorized; use lore_base::lore_debug; @@ -70,6 +72,8 @@ pub struct EndpointConfig { const IDLE_TIMEOUT_MS: u32 = 30000; const KEEP_ALIVE_MS: u64 = 500; +const HAPPY_EYEBALLS_DELAY_MS: u64 = 250; +const HAPPY_EYEBALLS_MAX_IN_FLIGHT: usize = 10; pub const DEFAULT_EXPECTED_RTT_MS: u64 = 100; #[derive(Clone, Debug)] @@ -683,12 +687,14 @@ pub async fn connect( let remote_url = config.remote_url.as_str(); let url = Url::parse(remote_url).internal_with(|| format!("remote {remote_url} is invalid"))?; let host = url.host_str().unwrap_or_default().to_string(); - let remote_addrs = ( + let remote_addrs: Vec<_> = ( strip_ipv6_brackets(host.as_str()), url.port().unwrap_or(config.default_port), ) .to_socket_addrs() - .internal_with(|| format!("remote {remote_url} is invalid"))?; + .internal_with(|| format!("remote {remote_url} is invalid"))? + .collect(); + let remote_addrs = interleave_socket_addrs(remote_addrs); let server_name = config.sni_override.as_deref().unwrap_or(host.as_str()); let validate_certificate = url.scheme().ends_with("s"); @@ -746,40 +752,151 @@ pub async fn connect( client_config.transport_config(Arc::new(transport_config)); - for remote_addr in remote_addrs { - lore_debug!("QUIC connecting to {host} at {remote_addr}"); - let bind = if remote_addr.is_ipv6() { - SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + let connection = connect_happy_eyeballs( + remote_addrs, + Duration::from_millis(HAPPY_EYEBALLS_DELAY_MS), + |remote_addr| { + connect_to_addr( + client_config.clone(), + host.clone(), + remote_addr, + server_name.to_string(), + ) + }, + ) + .await; + if let Some(connection) = connection { + return Ok(connection); + } + + // Silent propagation of connection errors + lore_debug!("QUIC connect failed {remote_url}"); + Err(ProtocolError::internal(format!("connect: {remote_url}"))) +} + +fn interleave_socket_addrs(remote_addrs: Vec) -> Vec { + let Some(first) = remote_addrs.first() else { + return remote_addrs; + }; + let prefer_ipv6 = first.is_ipv6(); + let (preferred, fallback): (Vec<_>, Vec<_>) = remote_addrs + .into_iter() + .partition(|addr| addr.is_ipv6() == prefer_ipv6); + let mut preferred = preferred.into_iter(); + let mut fallback = fallback.into_iter(); + let mut interleaved = Vec::with_capacity(preferred.len() + fallback.len()); + + loop { + if let Some(addr) = preferred.next() { + interleaved.push(addr); } else { - SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) - }; - match quinn::Endpoint::client(bind) { - Ok(mut endpoint) => { - endpoint.set_default_client_config(client_config.clone()); - match endpoint.connect(remote_addr, server_name) { - Ok(connecting) => match connecting.await { - Ok(connection) => { - lore_debug!("Success QUIC connecting to {remote_addr}"); - return Ok(connection); - } - Err(err) => { - lore_debug!("Failed QUIC connecting to {remote_addr}: {err}"); - } - }, - Err(err) => { - lore_debug!("Failed QUIC connect to {remote_addr}: {err}"); - } + interleaved.extend(fallback); + break; + } + if let Some(addr) = fallback.next() { + interleaved.push(addr); + } else { + interleaved.extend(preferred); + break; + } + } + + interleaved +} + +async fn connect_happy_eyeballs( + remote_addrs: Vec, + attempt_delay: Duration, + mut connect: F, +) -> Option +where + F: FnMut(SocketAddr) -> Fut, + Fut: Future>, +{ + let mut remote_addrs = remote_addrs.into_iter(); + let mut attempts = FuturesUnordered::new(); + attempts.push(connect(remote_addrs.next()?)); + + let mut next_addr = remote_addrs.next(); + let delay = tokio::time::sleep(attempt_delay); + tokio::pin!(delay); + + loop { + if next_addr.is_none() { + while let Some(result) = attempts.next().await { + if result.is_some() { + return result; } } - Err(err) => { - lore_debug!("QUIC failed binding socket to {bind} for {remote_addr}: {err}"); + return None; + } + + tokio::select! { + result = attempts.next(), if !attempts.is_empty() => { + if let Some(Some(connection)) = result { + // Dropping `attempts` cancels the losing Quinn handshakes because each + // production future owns its `Connecting` and `Endpoint`. + return Some(connection); + } + if attempts.is_empty() { + let Some(addr) = next_addr.take() else { + continue; + }; + attempts.push(connect(addr)); + next_addr = remote_addrs.next(); + delay.as_mut().reset(tokio::time::Instant::now() + attempt_delay); + } + } + _ = &mut delay, if attempts.len() < HAPPY_EYEBALLS_MAX_IN_FLIGHT => { + let Some(addr) = next_addr.take() else { + continue; + }; + attempts.push(connect(addr)); + next_addr = remote_addrs.next(); + delay.as_mut().reset(tokio::time::Instant::now() + attempt_delay); } } } +} - // Silent propagation of connection errors - lore_debug!("QUIC connect failed {remote_url}"); - Err(ProtocolError::internal(format!("connect: {remote_url}"))) +async fn connect_to_addr( + client_config: quinn::ClientConfig, + host: String, + remote_addr: SocketAddr, + server_name: String, +) -> Option { + lore_debug!("QUIC connecting to {host} at {remote_addr}"); + let bind = if remote_addr.is_ipv6() { + SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0) + } else { + SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0) + }; + let mut endpoint = match quinn::Endpoint::client(bind) { + Ok(endpoint) => endpoint, + Err(err) => { + lore_debug!("QUIC failed binding socket to {bind} for {remote_addr}: {err}"); + return None; + } + }; + endpoint.set_default_client_config(client_config); + + let connecting = match endpoint.connect(remote_addr, server_name.as_str()) { + Ok(connecting) => connecting, + Err(err) => { + lore_debug!("Failed QUIC connect to {remote_addr}: {err}"); + return None; + } + }; + match connecting.await { + Ok(connection) => { + lore_debug!("Success QUIC connecting to {remote_addr}"); + Some(connection) + } + Err(err) => { + lore_debug!("Failed QUIC connecting to {remote_addr}: {err}"); + None + } + } } pub async fn reconnect( @@ -1152,3 +1269,190 @@ pub async fn send_command( QuicClientError::Read })? } + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::net::SocketAddr; + use std::sync::Arc; + use std::time::Duration; + + use parking_lot::Mutex; + use tokio::sync::Semaphore; + use tokio::sync::mpsc; + + use super::HAPPY_EYEBALLS_MAX_IN_FLIGHT; + use super::connect_happy_eyeballs; + use super::interleave_socket_addrs; + + fn ipv6_addr() -> SocketAddr { + "[::1]:41337".parse().unwrap() + } + + fn ipv4_addr() -> SocketAddr { + "127.0.0.1:41337".parse().unwrap() + } + + #[test] + fn happy_eyeballs_interleaves_ipv6_first_addresses() { + let ipv6_second = "[::2]:41337".parse().unwrap(); + let ipv6_third = "[::3]:41337".parse().unwrap(); + let ipv4_second = "127.0.0.2:41337".parse().unwrap(); + + assert_eq!( + interleave_socket_addrs(vec![ + ipv6_addr(), + ipv6_second, + ipv6_third, + ipv4_addr(), + ipv4_second, + ]), + vec![ + ipv6_addr(), + ipv4_addr(), + ipv6_second, + ipv4_second, + ipv6_third, + ] + ); + } + + #[test] + fn happy_eyeballs_interleaves_ipv4_first_addresses() { + let ipv4_second = "127.0.0.2:41337".parse().unwrap(); + let ipv6_second = "[::2]:41337".parse().unwrap(); + + assert_eq!( + interleave_socket_addrs(vec![ipv4_addr(), ipv4_second, ipv6_addr(), ipv6_second,]), + vec![ipv4_addr(), ipv6_addr(), ipv4_second, ipv6_second] + ); + } + + #[tokio::test] + async fn happy_eyeballs_starts_fallback_while_first_attempt_is_stalled() { + let attempts = Arc::new(Mutex::new(Vec::new())); + let attempt_log = attempts.clone(); + + let result = tokio::time::timeout( + Duration::from_secs(2), + connect_happy_eyeballs( + vec![ipv6_addr(), ipv4_addr()], + Duration::from_millis(10), + move |addr| { + let attempt_log = attempt_log.clone(); + async move { + attempt_log.lock().push(addr); + if addr.is_ipv6() { + std::future::pending().await + } else { + Some(addr) + } + } + }, + ), + ) + .await + .expect("fallback should not wait for the stalled first attempt"); + + assert_eq!(result, Some(ipv4_addr())); + assert_eq!(*attempts.lock(), vec![ipv6_addr(), ipv4_addr()]); + } + + #[tokio::test] + async fn happy_eyeballs_advances_immediately_after_failure() { + let started = std::time::Instant::now(); + + let result = connect_happy_eyeballs( + vec![ipv6_addr(), ipv4_addr()], + Duration::from_secs(1), + |addr| async move { if addr.is_ipv6() { None } else { Some(addr) } }, + ) + .await; + + assert_eq!(result, Some(ipv4_addr())); + assert!(started.elapsed() < Duration::from_millis(750)); + } + + #[tokio::test] + async fn happy_eyeballs_does_not_start_fallback_after_first_success() { + let attempts = Arc::new(Mutex::new(HashMap::new())); + let attempt_counts = attempts.clone(); + + let result = connect_happy_eyeballs( + vec![ipv6_addr(), ipv4_addr()], + Duration::from_millis(10), + move |addr| { + let attempt_counts = attempt_counts.clone(); + async move { + *attempt_counts.lock().entry(addr).or_insert(0) += 1; + Some(addr) + } + }, + ) + .await; + + assert_eq!(result, Some(ipv6_addr())); + assert_eq!(attempts.lock().get(&ipv6_addr()), Some(&1)); + assert_eq!(attempts.lock().get(&ipv4_addr()), None); + } + + #[tokio::test] + async fn happy_eyeballs_returns_none_when_all_attempts_fail() { + let result = connect_happy_eyeballs( + vec![ipv6_addr(), ipv4_addr()], + Duration::from_millis(10), + |_| async { None:: }, + ) + .await; + + assert_eq!(result, None); + } + + #[tokio::test] + async fn happy_eyeballs_bounds_in_flight_attempts() { + let remote_addrs: Vec<_> = (1..=HAPPY_EYEBALLS_MAX_IN_FLIGHT + 1) + .map(|port| SocketAddr::new(ipv6_addr().ip(), port as u16)) + .collect(); + let release = Arc::new(Semaphore::new(0)); + let attempt_release = release.clone(); + let (started_tx, mut started_rx) = mpsc::unbounded_channel(); + + let task = lore_base::lore_spawn!(connect_happy_eyeballs( + remote_addrs.clone(), + Duration::from_millis(1), + move |addr| { + started_tx.send(addr).unwrap(); + let attempt_release = attempt_release.clone(); + async move { + attempt_release.acquire().await.unwrap().forget(); + None:: + } + }, + )); + + for expected in remote_addrs.iter().take(HAPPY_EYEBALLS_MAX_IN_FLIGHT) { + assert_eq!( + tokio::time::timeout(Duration::from_secs(2), started_rx.recv()) + .await + .expect("attempt should start"), + Some(*expected) + ); + } + assert!( + tokio::time::timeout(Duration::from_millis(50), started_rx.recv()) + .await + .is_err(), + "attempts above the in-flight limit should remain queued" + ); + + release.add_permits(1); + assert_eq!( + tokio::time::timeout(Duration::from_secs(2), started_rx.recv()) + .await + .expect("queued attempt should start when a slot opens"), + Some(remote_addrs[HAPPY_EYEBALLS_MAX_IN_FLIGHT]) + ); + + task.abort(); + } +}