From 5ca3321dbd1be88d3cdb8239fdf660b1eab83564 Mon Sep 17 00:00:00 2001 From: ppb2020 Date: Wed, 1 Jan 2025 15:05:13 -0500 Subject: [PATCH 1/2] Add support for specifying IP version --- .../java/com/trilead/ssh2/Connection.java | 73 ++++++++++++++++--- src/main/java/com/trilead/ssh2/IpVersion.java | 12 +++ .../ssh2/transport/TransportManager.java | 43 +++++++++-- 3 files changed, 111 insertions(+), 17 deletions(-) create mode 100644 src/main/java/com/trilead/ssh2/IpVersion.java diff --git a/src/main/java/com/trilead/ssh2/Connection.java b/src/main/java/com/trilead/ssh2/Connection.java index 9e0c8718..1cacafcd 100644 --- a/src/main/java/com/trilead/ssh2/Connection.java +++ b/src/main/java/com/trilead/ssh2/Connection.java @@ -564,29 +564,75 @@ private void close(Throwable t, boolean hard) /** * Same as - * @return see comments for the - * {@link #connect(ServerHostKeyVerifier, int, int) connect(ServerHostKeyVerifier, int, int)} - * method. - * @throws IOException on error - */ + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(null, 0, 0, IpVersion.IPV4_AND_IPV6)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException on error + */ public synchronized ConnectionInfo connect() throws IOException { - return connect(null, 0, 0); + return connect(null, 0, 0, IpVersion.IPV4_AND_IPV6); } /** * Same as - * {@link #connect(ServerHostKeyVerifier, int, int) connect(verifier, 0, 0)}. + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(null, 0, 0, ipVersion)}. * * @return see comments for the - * {@link #connect(ServerHostKeyVerifier, int, int) connect(ServerHostKeyVerifier, int, int)} + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException + */ + public synchronized ConnectionInfo connect(IpVersion ipVersion) throws IOException + { + return connect(null, 0, 0, ipVersion); + } + + + /** + * Same as + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, 0, 0, IpVersion.IPV4_AND_IPV6)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} * method. * @param verifier the verifier * @throws IOException on error */ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throws IOException { - return connect(verifier, 0, 0); + return connect(verifier, 0, 0, IpVersion.IPV4_AND_IPV6); + } + + /** + * Same as + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, 0, 0, ipVersion)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException + */ + public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, IpVersion ipVersion) throws IOException + { + return connect(verifier, 0, 0, ipVersion); + } + + /** + * Same as + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, connectTimeout, kexTimeout, IpVersion.IPV4_AND_IPV6)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException + */ + public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout) + throws IOException + { + return connect(verifier, connectTimeout, kexTimeout, IpVersion.IPV4_AND_IPV6); } /** @@ -648,6 +694,11 @@ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throw * but it will only have an effect after the * verifier returns. * + * @param ipVersion + * Specify whether the connection should be restricted to one of + * IPv4 or IPv6, with a default of allowing both. See + * {@link IpVersion}. + * * @return A {@link ConnectionInfo} object containing the details of the * established connection. * @@ -671,7 +722,7 @@ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throw * proxy is buggy and does not return a proper HTTP response, * then a normal IOException is thrown instead. */ - public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout) + public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout, IpVersion ipVersion) throws IOException { final class TimeoutState @@ -746,7 +797,7 @@ public void run() try { - tm.initialize(cryptoWishList, verifier, dhgexpara, connectTimeout, getOrCreateSecureRND(), proxyData); + tm.initialize(cryptoWishList, verifier, dhgexpara, connectTimeout, ipVersion, getOrCreateSecureRND(), proxyData); } catch (SocketTimeoutException se) { diff --git a/src/main/java/com/trilead/ssh2/IpVersion.java b/src/main/java/com/trilead/ssh2/IpVersion.java new file mode 100644 index 00000000..b390ff85 --- /dev/null +++ b/src/main/java/com/trilead/ssh2/IpVersion.java @@ -0,0 +1,12 @@ + +package com.trilead.ssh2; + +/** + * Allow the caller to restrict the IP version of the connection to + * be established. + */ +public enum IpVersion { + IPV4_AND_IPV6, + IPV4_ONLY, + IPV6_ONLY +} diff --git a/src/main/java/com/trilead/ssh2/transport/TransportManager.java b/src/main/java/com/trilead/ssh2/transport/TransportManager.java index 5e90cfc4..82e075e4 100644 --- a/src/main/java/com/trilead/ssh2/transport/TransportManager.java +++ b/src/main/java/com/trilead/ssh2/transport/TransportManager.java @@ -4,6 +4,7 @@ import com.trilead.ssh2.ExtensionInfo; import com.trilead.ssh2.packets.PacketExtInfo; import java.io.IOException; +import java.net.Inet6Address; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.Socket; @@ -13,6 +14,7 @@ import com.trilead.ssh2.ConnectionInfo; import com.trilead.ssh2.ConnectionMonitor; import com.trilead.ssh2.DHGexParameters; +import com.trilead.ssh2.IpVersion; import com.trilead.ssh2.ProxyData; import com.trilead.ssh2.ServerHostKeyVerifier; import com.trilead.ssh2.compression.ICompressor; @@ -280,30 +282,59 @@ public void close(Throwable cause, boolean useDisconnectPacket) } } - private void establishConnection(ProxyData proxyData, int connectTimeout) throws IOException + private static InetAddress getIPv4Address(InetAddress[] addresses) { + for (InetAddress address : addresses) { + if (! (address instanceof Inet6Address)) { + return address; + } + } + return null; + } + private static Inet6Address getIPv6Address(InetAddress[] addresses) { + for (InetAddress address : addresses) { + if (address instanceof Inet6Address) { + return (Inet6Address) address; + } + } + return null; + } + + private void establishConnection(ProxyData proxyData, int connectTimeout, IpVersion ipVersion) throws IOException { if (proxyData == null) - sock = connectDirect(hostname, port, connectTimeout); + sock = connectDirect(hostname, port, connectTimeout, ipVersion); else sock = proxyData.openConnection(hostname, port, connectTimeout); } - private static Socket connectDirect(String hostname, int port, int connectTimeout) + private static Socket connectDirect(String hostname, int port, int connectTimeout, IpVersion ipVersion) throws IOException { Socket sock = new Socket(); - InetAddress addr = InetAddress.getByName(hostname); + InetAddress addr; + if (ipVersion == IpVersion.IPV4_ONLY) + { + addr = getIPv4Address(InetAddress.getAllByName(hostname)); + } + else if (ipVersion == IpVersion.IPV6_ONLY) + { + addr = getIPv6Address(InetAddress.getAllByName(hostname)); + } + else // Assume (ipVersion == IpVersion.IPV4_AND_IPV6), the default. + { + addr = InetAddress.getByName(hostname); + } sock.connect(new InetSocketAddress(addr, port), connectTimeout); sock.setSoTimeout(0); return sock; } public void initialize(CryptoWishList cwl, ServerHostKeyVerifier verifier, DHGexParameters dhgex, - int connectTimeout, SecureRandom rnd, ProxyData proxyData) throws IOException + int connectTimeout, IpVersion ipVersion, SecureRandom rnd, ProxyData proxyData) throws IOException { /* First, establish the TCP connection to the SSH-2 server */ - establishConnection(proxyData, connectTimeout); + establishConnection(proxyData, connectTimeout, ipVersion); /* Parse the server line and say hello - important: this information is later needed for the * key exchange (to stop man-in-the-middle attacks) - that is why we wrap it into an object From 2c508a21390f1588497c9cb285dc8957c7dbe785 Mon Sep 17 00:00:00 2001 From: Kenny Root Date: Sat, 24 Jan 2026 16:16:16 -0800 Subject: [PATCH 2/2] Try to connect to IPv4 and IPv6 simultaneously Implement Happy Eyeballs (RFC 8305) connection algorithm to deal better with IPv4 and IPv6 simultaneously. You can still select IPv4 or IPv6 only via API and connecting directly to a single IPv4 or IPv6 address bypasses this algorithm. Additionally providing a proxy will bypass Happy Eyeballs as well. --- .../transport/HappyEyeballsConnector.java | 296 ++++++++++++++++++ .../ssh2/transport/TransportManager.java | 37 +-- .../transport/HappyEyeballsConnectorTest.java | 272 ++++++++++++++++ 3 files changed, 569 insertions(+), 36 deletions(-) create mode 100644 src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java create mode 100644 src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java diff --git a/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java b/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java new file mode 100644 index 00000000..5961d8e9 --- /dev/null +++ b/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java @@ -0,0 +1,296 @@ +package com.trilead.ssh2.transport; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.trilead.ssh2.IpVersion; + +/** + * Implements Happy Eyeballs (RFC 8305) connection algorithm. + * + * This algorithm improves connection times when both IPv4 and IPv6 + * addresses are available by: + *
    + *
  1. Resolving all addresses (A and AAAA records)
  2. + *
  3. Starting IPv6 connection attempts first
  4. + *
  5. After a short delay, starting IPv4 attempts in parallel
  6. + *
  7. Using whichever connection succeeds first
  8. + *
  9. Cancelling/closing remaining attempts
  10. + *
+ */ +class HappyEyeballsConnector { + + static final int CONNECTION_ATTEMPT_DELAY_MS = 250; + + private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(r -> { + Thread t = new Thread(r, "HappyEyeballs-Connector"); + t.setDaemon(true); + return t; + }); + + @FunctionalInterface + interface DnsResolver { + InetAddress[] resolve(String hostname) throws UnknownHostException; + } + + @FunctionalInterface + interface SocketFactory { + Socket createSocket(); + } + + private final DnsResolver dnsResolver; + private final SocketFactory socketFactory; + private final int connectionAttemptDelayMs; + + HappyEyeballsConnector() { + this(InetAddress::getAllByName, Socket::new, CONNECTION_ATTEMPT_DELAY_MS); + } + + HappyEyeballsConnector(DnsResolver dnsResolver, SocketFactory socketFactory, int connectionAttemptDelayMs) { + this.dnsResolver = dnsResolver; + this.socketFactory = socketFactory; + this.connectionAttemptDelayMs = connectionAttemptDelayMs; + } + + /** + * Connect to the given hostname and port using Happy Eyeballs algorithm. + * + * @param hostname the hostname to connect to + * @param port the port to connect to + * @param connectTimeout the connection timeout in milliseconds (0 for infinite) + * @param ipVersion controls which IP versions to use + * @return a connected socket + * @throws IOException if connection fails + */ + Socket connect(String hostname, int port, int connectTimeout, IpVersion ipVersion) + throws IOException { + + List addresses = resolveAddresses(hostname, ipVersion); + + if (addresses.isEmpty()) { + throw new UnknownHostException("No addresses found for: " + hostname); + } + + if (addresses.size() == 1) { + return connectSimple(addresses.get(0), port, connectTimeout); + } + + List sortedAddresses = interleaveByFamily(addresses); + return connectWithRacing(sortedAddresses, port, connectTimeout); + } + + private List resolveAddresses(String hostname, IpVersion ipVersion) + throws UnknownHostException { + InetAddress[] allAddresses = dnsResolver.resolve(hostname); + return filterByIpVersion(allAddresses, ipVersion); + } + + static List filterByIpVersion(InetAddress[] addresses, IpVersion ipVersion) { + List filtered = new ArrayList<>(); + + for (InetAddress addr : addresses) { + boolean isIPv6 = addr instanceof Inet6Address; + + if (ipVersion == IpVersion.IPV4_ONLY && isIPv6) { + continue; + } + if (ipVersion == IpVersion.IPV6_ONLY && !isIPv6) { + continue; + } + filtered.add(addr); + } + + return filtered; + } + + static List interleaveByFamily(List addresses) { + List ipv6 = new ArrayList<>(); + List ipv4 = new ArrayList<>(); + + for (InetAddress addr : addresses) { + if (addr instanceof Inet6Address) { + ipv6.add(addr); + } else { + ipv4.add(addr); + } + } + + List result = new ArrayList<>(); + int maxSize = Math.max(ipv6.size(), ipv4.size()); + + for (int i = 0; i < maxSize; i++) { + if (i < ipv6.size()) + result.add(ipv6.get(i)); + if (i < ipv4.size()) + result.add(ipv4.get(i)); + } + + return result; + } + + private Socket connectWithRacing(List addresses, int port, int connectTimeout) + throws IOException { + + AtomicBoolean winnerFound = new AtomicBoolean(false); + List> futures = new ArrayList<>(); + List socketsToClose = new ArrayList<>(); + + try { + for (int i = 0; i < addresses.size(); i++) { + InetAddress address = addresses.get(i); + int delay = i * connectionAttemptDelayMs; + + Callable task = createConnectionTask( + address, port, connectTimeout, delay, winnerFound, socketsToClose); + futures.add(EXECUTOR.submit(task)); + } + + return waitForFirstSuccess(futures); + + } finally { + for (Future future : futures) { + future.cancel(true); + } + + synchronized (socketsToClose) { + for (Socket socket : socketsToClose) { + closeQuietly(socket); + } + } + } + } + + private Callable createConnectionTask( + InetAddress address, + int port, + int connectTimeout, + int delay, + AtomicBoolean winnerFound, + List socketsToClose) { + + return () -> { + if (delay > 0) { + Thread.sleep(delay); + } + + if (winnerFound.get()) { + throw new CancellationException("Another connection won"); + } + + Socket socket = socketFactory.createSocket(); + synchronized (socketsToClose) { + socketsToClose.add(socket); + } + + try { + socket.connect(new InetSocketAddress(address, port), connectTimeout); + socket.setSoTimeout(0); + + if (winnerFound.compareAndSet(false, true)) { + synchronized (socketsToClose) { + socketsToClose.remove(socket); + } + return socket; + } else { + closeQuietly(socket); + throw new CancellationException("Another connection won"); + } + } catch (IOException e) { + closeQuietly(socket); + synchronized (socketsToClose) { + socketsToClose.remove(socket); + } + throw e; + } + }; + } + + private Socket waitForFirstSuccess(List> futures) throws IOException { + IOException lastException = null; + List> pending = new ArrayList<>(futures); + + while (!pending.isEmpty()) { + Future completed = null; + + for (Future future : pending) { + if (future.isDone()) { + completed = future; + break; + } + } + + if (completed == null) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Connection interrupted", e); + } + continue; + } + + pending.remove(completed); + + try { + Socket socket = completed.get(); + if (socket != null && socket.isConnected()) { + return socket; + } + } catch (CancellationException e) { + // Task was cancelled, try next + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + lastException = (IOException) cause; + } else if (cause instanceof InterruptedException) { + Thread.currentThread().interrupt(); + throw new IOException("Connection interrupted", cause); + } else { + lastException = new IOException("Connection failed", cause); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Connection interrupted", e); + } + } + + if (lastException != null) { + throw lastException; + } + throw new IOException("All connection attempts failed"); + } + + private Socket connectSimple(InetAddress address, int port, int timeout) throws IOException { + Socket socket = socketFactory.createSocket(); + try { + socket.connect(new InetSocketAddress(address, port), timeout); + socket.setSoTimeout(0); + return socket; + } catch (IOException e) { + closeQuietly(socket); + throw e; + } + } + + private static void closeQuietly(Socket socket) { + if (socket != null) { + try { + socket.close(); + } catch (IOException ignored) { + } + } + } +} diff --git a/src/main/java/com/trilead/ssh2/transport/TransportManager.java b/src/main/java/com/trilead/ssh2/transport/TransportManager.java index 82e075e4..f7376b58 100644 --- a/src/main/java/com/trilead/ssh2/transport/TransportManager.java +++ b/src/main/java/com/trilead/ssh2/transport/TransportManager.java @@ -4,9 +4,6 @@ import com.trilead.ssh2.ExtensionInfo; import com.trilead.ssh2.packets.PacketExtInfo; import java.io.IOException; -import java.net.Inet6Address; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.Socket; import java.security.SecureRandom; import java.util.Vector; @@ -282,22 +279,6 @@ public void close(Throwable cause, boolean useDisconnectPacket) } } - private static InetAddress getIPv4Address(InetAddress[] addresses) { - for (InetAddress address : addresses) { - if (! (address instanceof Inet6Address)) { - return address; - } - } - return null; - } - private static Inet6Address getIPv6Address(InetAddress[] addresses) { - for (InetAddress address : addresses) { - if (address instanceof Inet6Address) { - return (Inet6Address) address; - } - } - return null; - } private void establishConnection(ProxyData proxyData, int connectTimeout, IpVersion ipVersion) throws IOException { @@ -310,23 +291,7 @@ private void establishConnection(ProxyData proxyData, int connectTimeout, IpVers private static Socket connectDirect(String hostname, int port, int connectTimeout, IpVersion ipVersion) throws IOException { - Socket sock = new Socket(); - InetAddress addr; - if (ipVersion == IpVersion.IPV4_ONLY) - { - addr = getIPv4Address(InetAddress.getAllByName(hostname)); - } - else if (ipVersion == IpVersion.IPV6_ONLY) - { - addr = getIPv6Address(InetAddress.getAllByName(hostname)); - } - else // Assume (ipVersion == IpVersion.IPV4_AND_IPV6), the default. - { - addr = InetAddress.getByName(hostname); - } - sock.connect(new InetSocketAddress(addr, port), connectTimeout); - sock.setSoTimeout(0); - return sock; + return new HappyEyeballsConnector().connect(hostname, port, connectTimeout, ipVersion); } public void initialize(CryptoWishList cwl, ServerHostKeyVerifier verifier, DHGexParameters dhgex, diff --git a/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java b/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java new file mode 100644 index 00000000..e1807cd8 --- /dev/null +++ b/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java @@ -0,0 +1,272 @@ +package com.trilead.ssh2.transport; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; + +import com.trilead.ssh2.IpVersion; + +class HappyEyeballsConnectorTest { + + @Test + void filterByIpVersion_withIpv4Only_returnsOnlyIpv4() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv4, ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_ONLY); + + assertEquals(1, result.size()); + assertFalse(result.get(0) instanceof Inet6Address); + } + + @Test + void filterByIpVersion_withIpv6Only_returnsOnlyIpv6() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv4, ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV6_ONLY); + + assertEquals(1, result.size()); + assertTrue(result.get(0) instanceof Inet6Address); + } + + @Test + void filterByIpVersion_withBoth_returnsAll() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv4, ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_AND_IPV6); + + assertEquals(2, result.size()); + } + + @Test + void filterByIpVersion_withIpv4Only_andNoIpv4Addresses_returnsEmpty() throws Exception { + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_ONLY); + + assertTrue(result.isEmpty()); + } + + @Test + void interleaveByFamily_withMixedAddresses_interleavesCorrectly() throws Exception { + InetAddress ipv4a = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv4b = Inet4Address.getByName("127.0.0.2"); + InetAddress ipv6a = Inet6Address.getByName("::1"); + InetAddress ipv6b = Inet6Address.getByName("::2"); + + List input = Arrays.asList(ipv4a, ipv4b, ipv6a, ipv6b); + List result = HappyEyeballsConnector.interleaveByFamily(input); + + assertEquals(4, result.size()); + // Should be: ipv6a, ipv4a, ipv6b, ipv4b (IPv6 first per RFC 8305) + assertTrue(result.get(0) instanceof Inet6Address); + assertFalse(result.get(1) instanceof Inet6Address); + assertTrue(result.get(2) instanceof Inet6Address); + assertFalse(result.get(3) instanceof Inet6Address); + } + + @Test + void interleaveByFamily_withOnlyIpv4_returnsAllIpv4() throws Exception { + InetAddress ipv4a = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv4b = Inet4Address.getByName("127.0.0.2"); + + List input = Arrays.asList(ipv4a, ipv4b); + List result = HappyEyeballsConnector.interleaveByFamily(input); + + assertEquals(2, result.size()); + assertFalse(result.get(0) instanceof Inet6Address); + assertFalse(result.get(1) instanceof Inet6Address); + } + + @Test + void interleaveByFamily_withUnequalCounts_handlesCorrectly() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6a = Inet6Address.getByName("::1"); + InetAddress ipv6b = Inet6Address.getByName("::2"); + InetAddress ipv6c = Inet6Address.getByName("::3"); + + List input = Arrays.asList(ipv4, ipv6a, ipv6b, ipv6c); + List result = HappyEyeballsConnector.interleaveByFamily(input); + + assertEquals(4, result.size()); + // Should be: ipv6a, ipv4, ipv6b, ipv6c + assertTrue(result.get(0) instanceof Inet6Address); + assertFalse(result.get(1) instanceof Inet6Address); + assertTrue(result.get(2) instanceof Inet6Address); + assertTrue(result.get(3) instanceof Inet6Address); + } + + @Test + void connect_withSingleAddress_connectsDirectly() throws Exception { + try (ServerSocket server = new ServerSocket(0)) { + int port = server.getLocalPort(); + InetAddress addr = InetAddress.getByName("127.0.0.1"); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { addr }, + Socket::new, + 250); + + Thread acceptThread = new Thread(() -> { + try { + server.accept().close(); + } catch (IOException ignored) { + } + }); + acceptThread.start(); + + Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_AND_IPV6); + + assertTrue(socket.isConnected()); + socket.close(); + acceptThread.join(1000); + } + } + + @Test + void connect_withNoAddresses_throwsUnknownHostException() { + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] {}, + Socket::new, + 250); + + assertThrows(UnknownHostException.class, + () -> connector.connect("test.example.com", 22, 5000, IpVersion.IPV4_AND_IPV6)); + } + + @Test + void connect_withDnsFailure_throwsUnknownHostException() { + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> { + throw new UnknownHostException("DNS failed"); + }, + Socket::new, + 250); + + assertThrows(UnknownHostException.class, + () -> connector.connect("test.example.com", 22, 5000, IpVersion.IPV4_AND_IPV6)); + } + + @Test + void connect_withIpv4Only_filtersToIpv4() throws Exception { + try (ServerSocket server = new ServerSocket(0)) { + int port = server.getLocalPort(); + InetAddress ipv4 = InetAddress.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + + AtomicInteger socketCount = new AtomicInteger(0); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { ipv6, ipv4 }, + () -> { + socketCount.incrementAndGet(); + return new Socket(); + }, + 250); + + Thread acceptThread = new Thread(() -> { + try { + server.accept().close(); + } catch (IOException ignored) { + } + }); + acceptThread.start(); + + Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_ONLY); + + assertTrue(socket.isConnected()); + assertEquals(1, socketCount.get(), "Should only create one socket for single filtered address"); + socket.close(); + acceptThread.join(1000); + } + } + + @Test + void connect_withMultipleAddresses_racesConnections() throws Exception { + try (ServerSocket server = new ServerSocket(0)) { + int port = server.getLocalPort(); + InetAddress addr1 = InetAddress.getByName("127.0.0.1"); + InetAddress addr2 = InetAddress.getByName("127.0.0.1"); + + List createdSockets = new ArrayList<>(); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { addr1, addr2 }, + () -> { + Socket s = new Socket(); + synchronized (createdSockets) { + createdSockets.add(s); + } + return s; + }, + 50 // Short delay for faster test + ); + + Thread acceptThread = new Thread(() -> { + try { + server.accept().close(); + } catch (IOException ignored) { + } + }); + acceptThread.start(); + + Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_AND_IPV6); + + assertTrue(socket.isConnected()); + socket.close(); + acceptThread.join(1000); + + // Give time for cleanup + Thread.sleep(100); + + // All non-winning sockets should be closed + synchronized (createdSockets) { + for (Socket s : createdSockets) { + if (s != socket) { + assertTrue(s.isClosed(), "Losing sockets should be closed"); + } + } + } + } + } + + @Test + void connect_withAllConnectionsFailing_throwsIOException() throws Exception { + InetAddress addr1 = InetAddress.getByName("127.0.0.1"); + InetAddress addr2 = InetAddress.getByName("127.0.0.1"); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { addr1, addr2 }, + Socket::new, + 10); + + // Connect to a port that's not listening + assertThrows(IOException.class, () -> connector.connect("test.example.com", 1, 100, IpVersion.IPV4_AND_IPV6)); + } + + @Test + void connectionAttemptDelay_isConfigurable() { + assertEquals(250, HappyEyeballsConnector.CONNECTION_ATTEMPT_DELAY_MS); + } +}