diff --git a/src/main/java/com/trilead/ssh2/Connection.java b/src/main/java/com/trilead/ssh2/Connection.java
index 9e0c8718..1cacafcd 100644
--- a/src/main/java/com/trilead/ssh2/Connection.java
+++ b/src/main/java/com/trilead/ssh2/Connection.java
@@ -564,29 +564,75 @@ private void close(Throwable t, boolean hard)
/**
* Same as
- * @return see comments for the
- * {@link #connect(ServerHostKeyVerifier, int, int) connect(ServerHostKeyVerifier, int, int)}
- * method.
- * @throws IOException on error
- */
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(null, 0, 0, IpVersion.IPV4_AND_IPV6)}.
+ *
+ * @return see comments for the
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)}
+ * method.
+ * @throws IOException on error
+ */
public synchronized ConnectionInfo connect() throws IOException
{
- return connect(null, 0, 0);
+ return connect(null, 0, 0, IpVersion.IPV4_AND_IPV6);
}
/**
* Same as
- * {@link #connect(ServerHostKeyVerifier, int, int) connect(verifier, 0, 0)}.
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(null, 0, 0, ipVersion)}.
*
* @return see comments for the
- * {@link #connect(ServerHostKeyVerifier, int, int) connect(ServerHostKeyVerifier, int, int)}
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)}
+ * method.
+ * @throws IOException
+ */
+ public synchronized ConnectionInfo connect(IpVersion ipVersion) throws IOException
+ {
+ return connect(null, 0, 0, ipVersion);
+ }
+
+
+ /**
+ * Same as
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, 0, 0, IpVersion.IPV4_AND_IPV6)}.
+ *
+ * @return see comments for the
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)}
* method.
* @param verifier the verifier
* @throws IOException on error
*/
public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throws IOException
{
- return connect(verifier, 0, 0);
+ return connect(verifier, 0, 0, IpVersion.IPV4_AND_IPV6);
+ }
+
+ /**
+ * Same as
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, 0, 0, ipVersion)}.
+ *
+ * @return see comments for the
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)}
+ * method.
+ * @throws IOException
+ */
+ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, IpVersion ipVersion) throws IOException
+ {
+ return connect(verifier, 0, 0, ipVersion);
+ }
+
+ /**
+ * Same as
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, connectTimeout, kexTimeout, IpVersion.IPV4_AND_IPV6)}.
+ *
+ * @return see comments for the
+ * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)}
+ * method.
+ * @throws IOException
+ */
+ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout)
+ throws IOException
+ {
+ return connect(verifier, connectTimeout, kexTimeout, IpVersion.IPV4_AND_IPV6);
}
/**
@@ -648,6 +694,11 @@ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throw
* but it will only have an effect after the
* verifier returns.
*
+ * @param ipVersion
+ * Specify whether the connection should be restricted to one of
+ * IPv4 or IPv6, with a default of allowing both. See
+ * {@link IpVersion}.
+ *
* @return A {@link ConnectionInfo} object containing the details of the
* established connection.
*
@@ -671,7 +722,7 @@ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throw
* proxy is buggy and does not return a proper HTTP response,
* then a normal IOException is thrown instead.
*/
- public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout)
+ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout, IpVersion ipVersion)
throws IOException
{
final class TimeoutState
@@ -746,7 +797,7 @@ public void run()
try
{
- tm.initialize(cryptoWishList, verifier, dhgexpara, connectTimeout, getOrCreateSecureRND(), proxyData);
+ tm.initialize(cryptoWishList, verifier, dhgexpara, connectTimeout, ipVersion, getOrCreateSecureRND(), proxyData);
}
catch (SocketTimeoutException se)
{
diff --git a/src/main/java/com/trilead/ssh2/IpVersion.java b/src/main/java/com/trilead/ssh2/IpVersion.java
new file mode 100644
index 00000000..b390ff85
--- /dev/null
+++ b/src/main/java/com/trilead/ssh2/IpVersion.java
@@ -0,0 +1,12 @@
+
+package com.trilead.ssh2;
+
+/**
+ * Allow the caller to restrict the IP version of the connection to
+ * be established.
+ */
+public enum IpVersion {
+ IPV4_AND_IPV6,
+ IPV4_ONLY,
+ IPV6_ONLY
+}
diff --git a/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java b/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java
new file mode 100644
index 00000000..5961d8e9
--- /dev/null
+++ b/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java
@@ -0,0 +1,296 @@
+package com.trilead.ssh2.transport;
+
+import java.io.IOException;
+import java.net.Inet6Address;
+import java.net.InetAddress;
+import java.net.InetSocketAddress;
+import java.net.Socket;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CancellationException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import com.trilead.ssh2.IpVersion;
+
+/**
+ * Implements Happy Eyeballs (RFC 8305) connection algorithm.
+ *
+ * This algorithm improves connection times when both IPv4 and IPv6
+ * addresses are available by:
+ *
+ * - Resolving all addresses (A and AAAA records)
+ * - Starting IPv6 connection attempts first
+ * - After a short delay, starting IPv4 attempts in parallel
+ * - Using whichever connection succeeds first
+ * - Cancelling/closing remaining attempts
+ *
+ */
+class HappyEyeballsConnector {
+
+ static final int CONNECTION_ATTEMPT_DELAY_MS = 250;
+
+ private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(r -> {
+ Thread t = new Thread(r, "HappyEyeballs-Connector");
+ t.setDaemon(true);
+ return t;
+ });
+
+ @FunctionalInterface
+ interface DnsResolver {
+ InetAddress[] resolve(String hostname) throws UnknownHostException;
+ }
+
+ @FunctionalInterface
+ interface SocketFactory {
+ Socket createSocket();
+ }
+
+ private final DnsResolver dnsResolver;
+ private final SocketFactory socketFactory;
+ private final int connectionAttemptDelayMs;
+
+ HappyEyeballsConnector() {
+ this(InetAddress::getAllByName, Socket::new, CONNECTION_ATTEMPT_DELAY_MS);
+ }
+
+ HappyEyeballsConnector(DnsResolver dnsResolver, SocketFactory socketFactory, int connectionAttemptDelayMs) {
+ this.dnsResolver = dnsResolver;
+ this.socketFactory = socketFactory;
+ this.connectionAttemptDelayMs = connectionAttemptDelayMs;
+ }
+
+ /**
+ * Connect to the given hostname and port using Happy Eyeballs algorithm.
+ *
+ * @param hostname the hostname to connect to
+ * @param port the port to connect to
+ * @param connectTimeout the connection timeout in milliseconds (0 for infinite)
+ * @param ipVersion controls which IP versions to use
+ * @return a connected socket
+ * @throws IOException if connection fails
+ */
+ Socket connect(String hostname, int port, int connectTimeout, IpVersion ipVersion)
+ throws IOException {
+
+ List addresses = resolveAddresses(hostname, ipVersion);
+
+ if (addresses.isEmpty()) {
+ throw new UnknownHostException("No addresses found for: " + hostname);
+ }
+
+ if (addresses.size() == 1) {
+ return connectSimple(addresses.get(0), port, connectTimeout);
+ }
+
+ List sortedAddresses = interleaveByFamily(addresses);
+ return connectWithRacing(sortedAddresses, port, connectTimeout);
+ }
+
+ private List resolveAddresses(String hostname, IpVersion ipVersion)
+ throws UnknownHostException {
+ InetAddress[] allAddresses = dnsResolver.resolve(hostname);
+ return filterByIpVersion(allAddresses, ipVersion);
+ }
+
+ static List filterByIpVersion(InetAddress[] addresses, IpVersion ipVersion) {
+ List filtered = new ArrayList<>();
+
+ for (InetAddress addr : addresses) {
+ boolean isIPv6 = addr instanceof Inet6Address;
+
+ if (ipVersion == IpVersion.IPV4_ONLY && isIPv6) {
+ continue;
+ }
+ if (ipVersion == IpVersion.IPV6_ONLY && !isIPv6) {
+ continue;
+ }
+ filtered.add(addr);
+ }
+
+ return filtered;
+ }
+
+ static List interleaveByFamily(List addresses) {
+ List ipv6 = new ArrayList<>();
+ List ipv4 = new ArrayList<>();
+
+ for (InetAddress addr : addresses) {
+ if (addr instanceof Inet6Address) {
+ ipv6.add(addr);
+ } else {
+ ipv4.add(addr);
+ }
+ }
+
+ List result = new ArrayList<>();
+ int maxSize = Math.max(ipv6.size(), ipv4.size());
+
+ for (int i = 0; i < maxSize; i++) {
+ if (i < ipv6.size())
+ result.add(ipv6.get(i));
+ if (i < ipv4.size())
+ result.add(ipv4.get(i));
+ }
+
+ return result;
+ }
+
+ private Socket connectWithRacing(List addresses, int port, int connectTimeout)
+ throws IOException {
+
+ AtomicBoolean winnerFound = new AtomicBoolean(false);
+ List> futures = new ArrayList<>();
+ List socketsToClose = new ArrayList<>();
+
+ try {
+ for (int i = 0; i < addresses.size(); i++) {
+ InetAddress address = addresses.get(i);
+ int delay = i * connectionAttemptDelayMs;
+
+ Callable task = createConnectionTask(
+ address, port, connectTimeout, delay, winnerFound, socketsToClose);
+ futures.add(EXECUTOR.submit(task));
+ }
+
+ return waitForFirstSuccess(futures);
+
+ } finally {
+ for (Future future : futures) {
+ future.cancel(true);
+ }
+
+ synchronized (socketsToClose) {
+ for (Socket socket : socketsToClose) {
+ closeQuietly(socket);
+ }
+ }
+ }
+ }
+
+ private Callable createConnectionTask(
+ InetAddress address,
+ int port,
+ int connectTimeout,
+ int delay,
+ AtomicBoolean winnerFound,
+ List socketsToClose) {
+
+ return () -> {
+ if (delay > 0) {
+ Thread.sleep(delay);
+ }
+
+ if (winnerFound.get()) {
+ throw new CancellationException("Another connection won");
+ }
+
+ Socket socket = socketFactory.createSocket();
+ synchronized (socketsToClose) {
+ socketsToClose.add(socket);
+ }
+
+ try {
+ socket.connect(new InetSocketAddress(address, port), connectTimeout);
+ socket.setSoTimeout(0);
+
+ if (winnerFound.compareAndSet(false, true)) {
+ synchronized (socketsToClose) {
+ socketsToClose.remove(socket);
+ }
+ return socket;
+ } else {
+ closeQuietly(socket);
+ throw new CancellationException("Another connection won");
+ }
+ } catch (IOException e) {
+ closeQuietly(socket);
+ synchronized (socketsToClose) {
+ socketsToClose.remove(socket);
+ }
+ throw e;
+ }
+ };
+ }
+
+ private Socket waitForFirstSuccess(List> futures) throws IOException {
+ IOException lastException = null;
+ List> pending = new ArrayList<>(futures);
+
+ while (!pending.isEmpty()) {
+ Future completed = null;
+
+ for (Future future : pending) {
+ if (future.isDone()) {
+ completed = future;
+ break;
+ }
+ }
+
+ if (completed == null) {
+ try {
+ Thread.sleep(10);
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IOException("Connection interrupted", e);
+ }
+ continue;
+ }
+
+ pending.remove(completed);
+
+ try {
+ Socket socket = completed.get();
+ if (socket != null && socket.isConnected()) {
+ return socket;
+ }
+ } catch (CancellationException e) {
+ // Task was cancelled, try next
+ } catch (ExecutionException e) {
+ Throwable cause = e.getCause();
+ if (cause instanceof IOException) {
+ lastException = (IOException) cause;
+ } else if (cause instanceof InterruptedException) {
+ Thread.currentThread().interrupt();
+ throw new IOException("Connection interrupted", cause);
+ } else {
+ lastException = new IOException("Connection failed", cause);
+ }
+ } catch (InterruptedException e) {
+ Thread.currentThread().interrupt();
+ throw new IOException("Connection interrupted", e);
+ }
+ }
+
+ if (lastException != null) {
+ throw lastException;
+ }
+ throw new IOException("All connection attempts failed");
+ }
+
+ private Socket connectSimple(InetAddress address, int port, int timeout) throws IOException {
+ Socket socket = socketFactory.createSocket();
+ try {
+ socket.connect(new InetSocketAddress(address, port), timeout);
+ socket.setSoTimeout(0);
+ return socket;
+ } catch (IOException e) {
+ closeQuietly(socket);
+ throw e;
+ }
+ }
+
+ private static void closeQuietly(Socket socket) {
+ if (socket != null) {
+ try {
+ socket.close();
+ } catch (IOException ignored) {
+ }
+ }
+ }
+}
diff --git a/src/main/java/com/trilead/ssh2/transport/TransportManager.java b/src/main/java/com/trilead/ssh2/transport/TransportManager.java
index 5e90cfc4..f7376b58 100644
--- a/src/main/java/com/trilead/ssh2/transport/TransportManager.java
+++ b/src/main/java/com/trilead/ssh2/transport/TransportManager.java
@@ -4,8 +4,6 @@
import com.trilead.ssh2.ExtensionInfo;
import com.trilead.ssh2.packets.PacketExtInfo;
import java.io.IOException;
-import java.net.InetAddress;
-import java.net.InetSocketAddress;
import java.net.Socket;
import java.security.SecureRandom;
import java.util.Vector;
@@ -13,6 +11,7 @@
import com.trilead.ssh2.ConnectionInfo;
import com.trilead.ssh2.ConnectionMonitor;
import com.trilead.ssh2.DHGexParameters;
+import com.trilead.ssh2.IpVersion;
import com.trilead.ssh2.ProxyData;
import com.trilead.ssh2.ServerHostKeyVerifier;
import com.trilead.ssh2.compression.ICompressor;
@@ -280,30 +279,27 @@ public void close(Throwable cause, boolean useDisconnectPacket)
}
}
- private void establishConnection(ProxyData proxyData, int connectTimeout) throws IOException
+
+ private void establishConnection(ProxyData proxyData, int connectTimeout, IpVersion ipVersion) throws IOException
{
if (proxyData == null)
- sock = connectDirect(hostname, port, connectTimeout);
+ sock = connectDirect(hostname, port, connectTimeout, ipVersion);
else
sock = proxyData.openConnection(hostname, port, connectTimeout);
}
- private static Socket connectDirect(String hostname, int port, int connectTimeout)
+ private static Socket connectDirect(String hostname, int port, int connectTimeout, IpVersion ipVersion)
throws IOException
{
- Socket sock = new Socket();
- InetAddress addr = InetAddress.getByName(hostname);
- sock.connect(new InetSocketAddress(addr, port), connectTimeout);
- sock.setSoTimeout(0);
- return sock;
+ return new HappyEyeballsConnector().connect(hostname, port, connectTimeout, ipVersion);
}
public void initialize(CryptoWishList cwl, ServerHostKeyVerifier verifier, DHGexParameters dhgex,
- int connectTimeout, SecureRandom rnd, ProxyData proxyData) throws IOException
+ int connectTimeout, IpVersion ipVersion, SecureRandom rnd, ProxyData proxyData) throws IOException
{
/* First, establish the TCP connection to the SSH-2 server */
- establishConnection(proxyData, connectTimeout);
+ establishConnection(proxyData, connectTimeout, ipVersion);
/* Parse the server line and say hello - important: this information is later needed for the
* key exchange (to stop man-in-the-middle attacks) - that is why we wrap it into an object
diff --git a/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java b/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java
new file mode 100644
index 00000000..e1807cd8
--- /dev/null
+++ b/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java
@@ -0,0 +1,272 @@
+package com.trilead.ssh2.transport;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import java.io.IOException;
+import java.net.Inet4Address;
+import java.net.Inet6Address;
+import java.net.InetAddress;
+import java.net.ServerSocket;
+import java.net.Socket;
+import java.net.UnknownHostException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.jupiter.api.Test;
+
+import com.trilead.ssh2.IpVersion;
+
+class HappyEyeballsConnectorTest {
+
+ @Test
+ void filterByIpVersion_withIpv4Only_returnsOnlyIpv4() throws Exception {
+ InetAddress ipv4 = Inet4Address.getByName("127.0.0.1");
+ InetAddress ipv6 = Inet6Address.getByName("::1");
+ InetAddress[] addresses = { ipv4, ipv6 };
+
+ List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_ONLY);
+
+ assertEquals(1, result.size());
+ assertFalse(result.get(0) instanceof Inet6Address);
+ }
+
+ @Test
+ void filterByIpVersion_withIpv6Only_returnsOnlyIpv6() throws Exception {
+ InetAddress ipv4 = Inet4Address.getByName("127.0.0.1");
+ InetAddress ipv6 = Inet6Address.getByName("::1");
+ InetAddress[] addresses = { ipv4, ipv6 };
+
+ List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV6_ONLY);
+
+ assertEquals(1, result.size());
+ assertTrue(result.get(0) instanceof Inet6Address);
+ }
+
+ @Test
+ void filterByIpVersion_withBoth_returnsAll() throws Exception {
+ InetAddress ipv4 = Inet4Address.getByName("127.0.0.1");
+ InetAddress ipv6 = Inet6Address.getByName("::1");
+ InetAddress[] addresses = { ipv4, ipv6 };
+
+ List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_AND_IPV6);
+
+ assertEquals(2, result.size());
+ }
+
+ @Test
+ void filterByIpVersion_withIpv4Only_andNoIpv4Addresses_returnsEmpty() throws Exception {
+ InetAddress ipv6 = Inet6Address.getByName("::1");
+ InetAddress[] addresses = { ipv6 };
+
+ List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_ONLY);
+
+ assertTrue(result.isEmpty());
+ }
+
+ @Test
+ void interleaveByFamily_withMixedAddresses_interleavesCorrectly() throws Exception {
+ InetAddress ipv4a = Inet4Address.getByName("127.0.0.1");
+ InetAddress ipv4b = Inet4Address.getByName("127.0.0.2");
+ InetAddress ipv6a = Inet6Address.getByName("::1");
+ InetAddress ipv6b = Inet6Address.getByName("::2");
+
+ List input = Arrays.asList(ipv4a, ipv4b, ipv6a, ipv6b);
+ List result = HappyEyeballsConnector.interleaveByFamily(input);
+
+ assertEquals(4, result.size());
+ // Should be: ipv6a, ipv4a, ipv6b, ipv4b (IPv6 first per RFC 8305)
+ assertTrue(result.get(0) instanceof Inet6Address);
+ assertFalse(result.get(1) instanceof Inet6Address);
+ assertTrue(result.get(2) instanceof Inet6Address);
+ assertFalse(result.get(3) instanceof Inet6Address);
+ }
+
+ @Test
+ void interleaveByFamily_withOnlyIpv4_returnsAllIpv4() throws Exception {
+ InetAddress ipv4a = Inet4Address.getByName("127.0.0.1");
+ InetAddress ipv4b = Inet4Address.getByName("127.0.0.2");
+
+ List input = Arrays.asList(ipv4a, ipv4b);
+ List result = HappyEyeballsConnector.interleaveByFamily(input);
+
+ assertEquals(2, result.size());
+ assertFalse(result.get(0) instanceof Inet6Address);
+ assertFalse(result.get(1) instanceof Inet6Address);
+ }
+
+ @Test
+ void interleaveByFamily_withUnequalCounts_handlesCorrectly() throws Exception {
+ InetAddress ipv4 = Inet4Address.getByName("127.0.0.1");
+ InetAddress ipv6a = Inet6Address.getByName("::1");
+ InetAddress ipv6b = Inet6Address.getByName("::2");
+ InetAddress ipv6c = Inet6Address.getByName("::3");
+
+ List input = Arrays.asList(ipv4, ipv6a, ipv6b, ipv6c);
+ List result = HappyEyeballsConnector.interleaveByFamily(input);
+
+ assertEquals(4, result.size());
+ // Should be: ipv6a, ipv4, ipv6b, ipv6c
+ assertTrue(result.get(0) instanceof Inet6Address);
+ assertFalse(result.get(1) instanceof Inet6Address);
+ assertTrue(result.get(2) instanceof Inet6Address);
+ assertTrue(result.get(3) instanceof Inet6Address);
+ }
+
+ @Test
+ void connect_withSingleAddress_connectsDirectly() throws Exception {
+ try (ServerSocket server = new ServerSocket(0)) {
+ int port = server.getLocalPort();
+ InetAddress addr = InetAddress.getByName("127.0.0.1");
+
+ HappyEyeballsConnector connector = new HappyEyeballsConnector(
+ hostname -> new InetAddress[] { addr },
+ Socket::new,
+ 250);
+
+ Thread acceptThread = new Thread(() -> {
+ try {
+ server.accept().close();
+ } catch (IOException ignored) {
+ }
+ });
+ acceptThread.start();
+
+ Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_AND_IPV6);
+
+ assertTrue(socket.isConnected());
+ socket.close();
+ acceptThread.join(1000);
+ }
+ }
+
+ @Test
+ void connect_withNoAddresses_throwsUnknownHostException() {
+ HappyEyeballsConnector connector = new HappyEyeballsConnector(
+ hostname -> new InetAddress[] {},
+ Socket::new,
+ 250);
+
+ assertThrows(UnknownHostException.class,
+ () -> connector.connect("test.example.com", 22, 5000, IpVersion.IPV4_AND_IPV6));
+ }
+
+ @Test
+ void connect_withDnsFailure_throwsUnknownHostException() {
+ HappyEyeballsConnector connector = new HappyEyeballsConnector(
+ hostname -> {
+ throw new UnknownHostException("DNS failed");
+ },
+ Socket::new,
+ 250);
+
+ assertThrows(UnknownHostException.class,
+ () -> connector.connect("test.example.com", 22, 5000, IpVersion.IPV4_AND_IPV6));
+ }
+
+ @Test
+ void connect_withIpv4Only_filtersToIpv4() throws Exception {
+ try (ServerSocket server = new ServerSocket(0)) {
+ int port = server.getLocalPort();
+ InetAddress ipv4 = InetAddress.getByName("127.0.0.1");
+ InetAddress ipv6 = Inet6Address.getByName("::1");
+
+ AtomicInteger socketCount = new AtomicInteger(0);
+
+ HappyEyeballsConnector connector = new HappyEyeballsConnector(
+ hostname -> new InetAddress[] { ipv6, ipv4 },
+ () -> {
+ socketCount.incrementAndGet();
+ return new Socket();
+ },
+ 250);
+
+ Thread acceptThread = new Thread(() -> {
+ try {
+ server.accept().close();
+ } catch (IOException ignored) {
+ }
+ });
+ acceptThread.start();
+
+ Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_ONLY);
+
+ assertTrue(socket.isConnected());
+ assertEquals(1, socketCount.get(), "Should only create one socket for single filtered address");
+ socket.close();
+ acceptThread.join(1000);
+ }
+ }
+
+ @Test
+ void connect_withMultipleAddresses_racesConnections() throws Exception {
+ try (ServerSocket server = new ServerSocket(0)) {
+ int port = server.getLocalPort();
+ InetAddress addr1 = InetAddress.getByName("127.0.0.1");
+ InetAddress addr2 = InetAddress.getByName("127.0.0.1");
+
+ List createdSockets = new ArrayList<>();
+
+ HappyEyeballsConnector connector = new HappyEyeballsConnector(
+ hostname -> new InetAddress[] { addr1, addr2 },
+ () -> {
+ Socket s = new Socket();
+ synchronized (createdSockets) {
+ createdSockets.add(s);
+ }
+ return s;
+ },
+ 50 // Short delay for faster test
+ );
+
+ Thread acceptThread = new Thread(() -> {
+ try {
+ server.accept().close();
+ } catch (IOException ignored) {
+ }
+ });
+ acceptThread.start();
+
+ Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_AND_IPV6);
+
+ assertTrue(socket.isConnected());
+ socket.close();
+ acceptThread.join(1000);
+
+ // Give time for cleanup
+ Thread.sleep(100);
+
+ // All non-winning sockets should be closed
+ synchronized (createdSockets) {
+ for (Socket s : createdSockets) {
+ if (s != socket) {
+ assertTrue(s.isClosed(), "Losing sockets should be closed");
+ }
+ }
+ }
+ }
+ }
+
+ @Test
+ void connect_withAllConnectionsFailing_throwsIOException() throws Exception {
+ InetAddress addr1 = InetAddress.getByName("127.0.0.1");
+ InetAddress addr2 = InetAddress.getByName("127.0.0.1");
+
+ HappyEyeballsConnector connector = new HappyEyeballsConnector(
+ hostname -> new InetAddress[] { addr1, addr2 },
+ Socket::new,
+ 10);
+
+ // Connect to a port that's not listening
+ assertThrows(IOException.class, () -> connector.connect("test.example.com", 1, 100, IpVersion.IPV4_AND_IPV6));
+ }
+
+ @Test
+ void connectionAttemptDelay_isConfigurable() {
+ assertEquals(250, HappyEyeballsConnector.CONNECTION_ATTEMPT_DELAY_MS);
+ }
+}