diff --git a/src/main/java/com/trilead/ssh2/Connection.java b/src/main/java/com/trilead/ssh2/Connection.java index 9e0c8718..1cacafcd 100644 --- a/src/main/java/com/trilead/ssh2/Connection.java +++ b/src/main/java/com/trilead/ssh2/Connection.java @@ -564,29 +564,75 @@ private void close(Throwable t, boolean hard) /** * Same as - * @return see comments for the - * {@link #connect(ServerHostKeyVerifier, int, int) connect(ServerHostKeyVerifier, int, int)} - * method. - * @throws IOException on error - */ + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(null, 0, 0, IpVersion.IPV4_AND_IPV6)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException on error + */ public synchronized ConnectionInfo connect() throws IOException { - return connect(null, 0, 0); + return connect(null, 0, 0, IpVersion.IPV4_AND_IPV6); } /** * Same as - * {@link #connect(ServerHostKeyVerifier, int, int) connect(verifier, 0, 0)}. + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(null, 0, 0, ipVersion)}. * * @return see comments for the - * {@link #connect(ServerHostKeyVerifier, int, int) connect(ServerHostKeyVerifier, int, int)} + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException + */ + public synchronized ConnectionInfo connect(IpVersion ipVersion) throws IOException + { + return connect(null, 0, 0, ipVersion); + } + + + /** + * Same as + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, 0, 0, IpVersion.IPV4_AND_IPV6)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} * method. * @param verifier the verifier * @throws IOException on error */ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throws IOException { - return connect(verifier, 0, 0); + return connect(verifier, 0, 0, IpVersion.IPV4_AND_IPV6); + } + + /** + * Same as + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, 0, 0, ipVersion)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException + */ + public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, IpVersion ipVersion) throws IOException + { + return connect(verifier, 0, 0, ipVersion); + } + + /** + * Same as + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(verifier, connectTimeout, kexTimeout, IpVersion.IPV4_AND_IPV6)}. + * + * @return see comments for the + * {@link #connect(ServerHostKeyVerifier, int, int, IpVersion) connect(ServerHostKeyVerifier, int, int, IpVersion)} + * method. + * @throws IOException + */ + public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout) + throws IOException + { + return connect(verifier, connectTimeout, kexTimeout, IpVersion.IPV4_AND_IPV6); } /** @@ -648,6 +694,11 @@ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throw * but it will only have an effect after the * verifier returns. * + * @param ipVersion + * Specify whether the connection should be restricted to one of + * IPv4 or IPv6, with a default of allowing both. See + * {@link IpVersion}. + * * @return A {@link ConnectionInfo} object containing the details of the * established connection. * @@ -671,7 +722,7 @@ public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier) throw * proxy is buggy and does not return a proper HTTP response, * then a normal IOException is thrown instead. */ - public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout) + public synchronized ConnectionInfo connect(ServerHostKeyVerifier verifier, int connectTimeout, int kexTimeout, IpVersion ipVersion) throws IOException { final class TimeoutState @@ -746,7 +797,7 @@ public void run() try { - tm.initialize(cryptoWishList, verifier, dhgexpara, connectTimeout, getOrCreateSecureRND(), proxyData); + tm.initialize(cryptoWishList, verifier, dhgexpara, connectTimeout, ipVersion, getOrCreateSecureRND(), proxyData); } catch (SocketTimeoutException se) { diff --git a/src/main/java/com/trilead/ssh2/IpVersion.java b/src/main/java/com/trilead/ssh2/IpVersion.java new file mode 100644 index 00000000..b390ff85 --- /dev/null +++ b/src/main/java/com/trilead/ssh2/IpVersion.java @@ -0,0 +1,12 @@ + +package com.trilead.ssh2; + +/** + * Allow the caller to restrict the IP version of the connection to + * be established. + */ +public enum IpVersion { + IPV4_AND_IPV6, + IPV4_ONLY, + IPV6_ONLY +} diff --git a/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java b/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java new file mode 100644 index 00000000..5961d8e9 --- /dev/null +++ b/src/main/java/com/trilead/ssh2/transport/HappyEyeballsConnector.java @@ -0,0 +1,296 @@ +package com.trilead.ssh2.transport; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicBoolean; + +import com.trilead.ssh2.IpVersion; + +/** + * Implements Happy Eyeballs (RFC 8305) connection algorithm. + * + * This algorithm improves connection times when both IPv4 and IPv6 + * addresses are available by: + *
    + *
  1. Resolving all addresses (A and AAAA records)
  2. + *
  3. Starting IPv6 connection attempts first
  4. + *
  5. After a short delay, starting IPv4 attempts in parallel
  6. + *
  7. Using whichever connection succeeds first
  8. + *
  9. Cancelling/closing remaining attempts
  10. + *
+ */ +class HappyEyeballsConnector { + + static final int CONNECTION_ATTEMPT_DELAY_MS = 250; + + private static final ExecutorService EXECUTOR = Executors.newCachedThreadPool(r -> { + Thread t = new Thread(r, "HappyEyeballs-Connector"); + t.setDaemon(true); + return t; + }); + + @FunctionalInterface + interface DnsResolver { + InetAddress[] resolve(String hostname) throws UnknownHostException; + } + + @FunctionalInterface + interface SocketFactory { + Socket createSocket(); + } + + private final DnsResolver dnsResolver; + private final SocketFactory socketFactory; + private final int connectionAttemptDelayMs; + + HappyEyeballsConnector() { + this(InetAddress::getAllByName, Socket::new, CONNECTION_ATTEMPT_DELAY_MS); + } + + HappyEyeballsConnector(DnsResolver dnsResolver, SocketFactory socketFactory, int connectionAttemptDelayMs) { + this.dnsResolver = dnsResolver; + this.socketFactory = socketFactory; + this.connectionAttemptDelayMs = connectionAttemptDelayMs; + } + + /** + * Connect to the given hostname and port using Happy Eyeballs algorithm. + * + * @param hostname the hostname to connect to + * @param port the port to connect to + * @param connectTimeout the connection timeout in milliseconds (0 for infinite) + * @param ipVersion controls which IP versions to use + * @return a connected socket + * @throws IOException if connection fails + */ + Socket connect(String hostname, int port, int connectTimeout, IpVersion ipVersion) + throws IOException { + + List addresses = resolveAddresses(hostname, ipVersion); + + if (addresses.isEmpty()) { + throw new UnknownHostException("No addresses found for: " + hostname); + } + + if (addresses.size() == 1) { + return connectSimple(addresses.get(0), port, connectTimeout); + } + + List sortedAddresses = interleaveByFamily(addresses); + return connectWithRacing(sortedAddresses, port, connectTimeout); + } + + private List resolveAddresses(String hostname, IpVersion ipVersion) + throws UnknownHostException { + InetAddress[] allAddresses = dnsResolver.resolve(hostname); + return filterByIpVersion(allAddresses, ipVersion); + } + + static List filterByIpVersion(InetAddress[] addresses, IpVersion ipVersion) { + List filtered = new ArrayList<>(); + + for (InetAddress addr : addresses) { + boolean isIPv6 = addr instanceof Inet6Address; + + if (ipVersion == IpVersion.IPV4_ONLY && isIPv6) { + continue; + } + if (ipVersion == IpVersion.IPV6_ONLY && !isIPv6) { + continue; + } + filtered.add(addr); + } + + return filtered; + } + + static List interleaveByFamily(List addresses) { + List ipv6 = new ArrayList<>(); + List ipv4 = new ArrayList<>(); + + for (InetAddress addr : addresses) { + if (addr instanceof Inet6Address) { + ipv6.add(addr); + } else { + ipv4.add(addr); + } + } + + List result = new ArrayList<>(); + int maxSize = Math.max(ipv6.size(), ipv4.size()); + + for (int i = 0; i < maxSize; i++) { + if (i < ipv6.size()) + result.add(ipv6.get(i)); + if (i < ipv4.size()) + result.add(ipv4.get(i)); + } + + return result; + } + + private Socket connectWithRacing(List addresses, int port, int connectTimeout) + throws IOException { + + AtomicBoolean winnerFound = new AtomicBoolean(false); + List> futures = new ArrayList<>(); + List socketsToClose = new ArrayList<>(); + + try { + for (int i = 0; i < addresses.size(); i++) { + InetAddress address = addresses.get(i); + int delay = i * connectionAttemptDelayMs; + + Callable task = createConnectionTask( + address, port, connectTimeout, delay, winnerFound, socketsToClose); + futures.add(EXECUTOR.submit(task)); + } + + return waitForFirstSuccess(futures); + + } finally { + for (Future future : futures) { + future.cancel(true); + } + + synchronized (socketsToClose) { + for (Socket socket : socketsToClose) { + closeQuietly(socket); + } + } + } + } + + private Callable createConnectionTask( + InetAddress address, + int port, + int connectTimeout, + int delay, + AtomicBoolean winnerFound, + List socketsToClose) { + + return () -> { + if (delay > 0) { + Thread.sleep(delay); + } + + if (winnerFound.get()) { + throw new CancellationException("Another connection won"); + } + + Socket socket = socketFactory.createSocket(); + synchronized (socketsToClose) { + socketsToClose.add(socket); + } + + try { + socket.connect(new InetSocketAddress(address, port), connectTimeout); + socket.setSoTimeout(0); + + if (winnerFound.compareAndSet(false, true)) { + synchronized (socketsToClose) { + socketsToClose.remove(socket); + } + return socket; + } else { + closeQuietly(socket); + throw new CancellationException("Another connection won"); + } + } catch (IOException e) { + closeQuietly(socket); + synchronized (socketsToClose) { + socketsToClose.remove(socket); + } + throw e; + } + }; + } + + private Socket waitForFirstSuccess(List> futures) throws IOException { + IOException lastException = null; + List> pending = new ArrayList<>(futures); + + while (!pending.isEmpty()) { + Future completed = null; + + for (Future future : pending) { + if (future.isDone()) { + completed = future; + break; + } + } + + if (completed == null) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Connection interrupted", e); + } + continue; + } + + pending.remove(completed); + + try { + Socket socket = completed.get(); + if (socket != null && socket.isConnected()) { + return socket; + } + } catch (CancellationException e) { + // Task was cancelled, try next + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof IOException) { + lastException = (IOException) cause; + } else if (cause instanceof InterruptedException) { + Thread.currentThread().interrupt(); + throw new IOException("Connection interrupted", cause); + } else { + lastException = new IOException("Connection failed", cause); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new IOException("Connection interrupted", e); + } + } + + if (lastException != null) { + throw lastException; + } + throw new IOException("All connection attempts failed"); + } + + private Socket connectSimple(InetAddress address, int port, int timeout) throws IOException { + Socket socket = socketFactory.createSocket(); + try { + socket.connect(new InetSocketAddress(address, port), timeout); + socket.setSoTimeout(0); + return socket; + } catch (IOException e) { + closeQuietly(socket); + throw e; + } + } + + private static void closeQuietly(Socket socket) { + if (socket != null) { + try { + socket.close(); + } catch (IOException ignored) { + } + } + } +} diff --git a/src/main/java/com/trilead/ssh2/transport/TransportManager.java b/src/main/java/com/trilead/ssh2/transport/TransportManager.java index 5e90cfc4..f7376b58 100644 --- a/src/main/java/com/trilead/ssh2/transport/TransportManager.java +++ b/src/main/java/com/trilead/ssh2/transport/TransportManager.java @@ -4,8 +4,6 @@ import com.trilead.ssh2.ExtensionInfo; import com.trilead.ssh2.packets.PacketExtInfo; import java.io.IOException; -import java.net.InetAddress; -import java.net.InetSocketAddress; import java.net.Socket; import java.security.SecureRandom; import java.util.Vector; @@ -13,6 +11,7 @@ import com.trilead.ssh2.ConnectionInfo; import com.trilead.ssh2.ConnectionMonitor; import com.trilead.ssh2.DHGexParameters; +import com.trilead.ssh2.IpVersion; import com.trilead.ssh2.ProxyData; import com.trilead.ssh2.ServerHostKeyVerifier; import com.trilead.ssh2.compression.ICompressor; @@ -280,30 +279,27 @@ public void close(Throwable cause, boolean useDisconnectPacket) } } - private void establishConnection(ProxyData proxyData, int connectTimeout) throws IOException + + private void establishConnection(ProxyData proxyData, int connectTimeout, IpVersion ipVersion) throws IOException { if (proxyData == null) - sock = connectDirect(hostname, port, connectTimeout); + sock = connectDirect(hostname, port, connectTimeout, ipVersion); else sock = proxyData.openConnection(hostname, port, connectTimeout); } - private static Socket connectDirect(String hostname, int port, int connectTimeout) + private static Socket connectDirect(String hostname, int port, int connectTimeout, IpVersion ipVersion) throws IOException { - Socket sock = new Socket(); - InetAddress addr = InetAddress.getByName(hostname); - sock.connect(new InetSocketAddress(addr, port), connectTimeout); - sock.setSoTimeout(0); - return sock; + return new HappyEyeballsConnector().connect(hostname, port, connectTimeout, ipVersion); } public void initialize(CryptoWishList cwl, ServerHostKeyVerifier verifier, DHGexParameters dhgex, - int connectTimeout, SecureRandom rnd, ProxyData proxyData) throws IOException + int connectTimeout, IpVersion ipVersion, SecureRandom rnd, ProxyData proxyData) throws IOException { /* First, establish the TCP connection to the SSH-2 server */ - establishConnection(proxyData, connectTimeout); + establishConnection(proxyData, connectTimeout, ipVersion); /* Parse the server line and say hello - important: this information is later needed for the * key exchange (to stop man-in-the-middle attacks) - that is why we wrap it into an object diff --git a/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java b/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java new file mode 100644 index 00000000..e1807cd8 --- /dev/null +++ b/src/test/java/com/trilead/ssh2/transport/HappyEyeballsConnectorTest.java @@ -0,0 +1,272 @@ +package com.trilead.ssh2.transport; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; + +import org.junit.jupiter.api.Test; + +import com.trilead.ssh2.IpVersion; + +class HappyEyeballsConnectorTest { + + @Test + void filterByIpVersion_withIpv4Only_returnsOnlyIpv4() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv4, ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_ONLY); + + assertEquals(1, result.size()); + assertFalse(result.get(0) instanceof Inet6Address); + } + + @Test + void filterByIpVersion_withIpv6Only_returnsOnlyIpv6() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv4, ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV6_ONLY); + + assertEquals(1, result.size()); + assertTrue(result.get(0) instanceof Inet6Address); + } + + @Test + void filterByIpVersion_withBoth_returnsAll() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv4, ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_AND_IPV6); + + assertEquals(2, result.size()); + } + + @Test + void filterByIpVersion_withIpv4Only_andNoIpv4Addresses_returnsEmpty() throws Exception { + InetAddress ipv6 = Inet6Address.getByName("::1"); + InetAddress[] addresses = { ipv6 }; + + List result = HappyEyeballsConnector.filterByIpVersion(addresses, IpVersion.IPV4_ONLY); + + assertTrue(result.isEmpty()); + } + + @Test + void interleaveByFamily_withMixedAddresses_interleavesCorrectly() throws Exception { + InetAddress ipv4a = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv4b = Inet4Address.getByName("127.0.0.2"); + InetAddress ipv6a = Inet6Address.getByName("::1"); + InetAddress ipv6b = Inet6Address.getByName("::2"); + + List input = Arrays.asList(ipv4a, ipv4b, ipv6a, ipv6b); + List result = HappyEyeballsConnector.interleaveByFamily(input); + + assertEquals(4, result.size()); + // Should be: ipv6a, ipv4a, ipv6b, ipv4b (IPv6 first per RFC 8305) + assertTrue(result.get(0) instanceof Inet6Address); + assertFalse(result.get(1) instanceof Inet6Address); + assertTrue(result.get(2) instanceof Inet6Address); + assertFalse(result.get(3) instanceof Inet6Address); + } + + @Test + void interleaveByFamily_withOnlyIpv4_returnsAllIpv4() throws Exception { + InetAddress ipv4a = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv4b = Inet4Address.getByName("127.0.0.2"); + + List input = Arrays.asList(ipv4a, ipv4b); + List result = HappyEyeballsConnector.interleaveByFamily(input); + + assertEquals(2, result.size()); + assertFalse(result.get(0) instanceof Inet6Address); + assertFalse(result.get(1) instanceof Inet6Address); + } + + @Test + void interleaveByFamily_withUnequalCounts_handlesCorrectly() throws Exception { + InetAddress ipv4 = Inet4Address.getByName("127.0.0.1"); + InetAddress ipv6a = Inet6Address.getByName("::1"); + InetAddress ipv6b = Inet6Address.getByName("::2"); + InetAddress ipv6c = Inet6Address.getByName("::3"); + + List input = Arrays.asList(ipv4, ipv6a, ipv6b, ipv6c); + List result = HappyEyeballsConnector.interleaveByFamily(input); + + assertEquals(4, result.size()); + // Should be: ipv6a, ipv4, ipv6b, ipv6c + assertTrue(result.get(0) instanceof Inet6Address); + assertFalse(result.get(1) instanceof Inet6Address); + assertTrue(result.get(2) instanceof Inet6Address); + assertTrue(result.get(3) instanceof Inet6Address); + } + + @Test + void connect_withSingleAddress_connectsDirectly() throws Exception { + try (ServerSocket server = new ServerSocket(0)) { + int port = server.getLocalPort(); + InetAddress addr = InetAddress.getByName("127.0.0.1"); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { addr }, + Socket::new, + 250); + + Thread acceptThread = new Thread(() -> { + try { + server.accept().close(); + } catch (IOException ignored) { + } + }); + acceptThread.start(); + + Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_AND_IPV6); + + assertTrue(socket.isConnected()); + socket.close(); + acceptThread.join(1000); + } + } + + @Test + void connect_withNoAddresses_throwsUnknownHostException() { + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] {}, + Socket::new, + 250); + + assertThrows(UnknownHostException.class, + () -> connector.connect("test.example.com", 22, 5000, IpVersion.IPV4_AND_IPV6)); + } + + @Test + void connect_withDnsFailure_throwsUnknownHostException() { + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> { + throw new UnknownHostException("DNS failed"); + }, + Socket::new, + 250); + + assertThrows(UnknownHostException.class, + () -> connector.connect("test.example.com", 22, 5000, IpVersion.IPV4_AND_IPV6)); + } + + @Test + void connect_withIpv4Only_filtersToIpv4() throws Exception { + try (ServerSocket server = new ServerSocket(0)) { + int port = server.getLocalPort(); + InetAddress ipv4 = InetAddress.getByName("127.0.0.1"); + InetAddress ipv6 = Inet6Address.getByName("::1"); + + AtomicInteger socketCount = new AtomicInteger(0); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { ipv6, ipv4 }, + () -> { + socketCount.incrementAndGet(); + return new Socket(); + }, + 250); + + Thread acceptThread = new Thread(() -> { + try { + server.accept().close(); + } catch (IOException ignored) { + } + }); + acceptThread.start(); + + Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_ONLY); + + assertTrue(socket.isConnected()); + assertEquals(1, socketCount.get(), "Should only create one socket for single filtered address"); + socket.close(); + acceptThread.join(1000); + } + } + + @Test + void connect_withMultipleAddresses_racesConnections() throws Exception { + try (ServerSocket server = new ServerSocket(0)) { + int port = server.getLocalPort(); + InetAddress addr1 = InetAddress.getByName("127.0.0.1"); + InetAddress addr2 = InetAddress.getByName("127.0.0.1"); + + List createdSockets = new ArrayList<>(); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { addr1, addr2 }, + () -> { + Socket s = new Socket(); + synchronized (createdSockets) { + createdSockets.add(s); + } + return s; + }, + 50 // Short delay for faster test + ); + + Thread acceptThread = new Thread(() -> { + try { + server.accept().close(); + } catch (IOException ignored) { + } + }); + acceptThread.start(); + + Socket socket = connector.connect("test.example.com", port, 5000, IpVersion.IPV4_AND_IPV6); + + assertTrue(socket.isConnected()); + socket.close(); + acceptThread.join(1000); + + // Give time for cleanup + Thread.sleep(100); + + // All non-winning sockets should be closed + synchronized (createdSockets) { + for (Socket s : createdSockets) { + if (s != socket) { + assertTrue(s.isClosed(), "Losing sockets should be closed"); + } + } + } + } + } + + @Test + void connect_withAllConnectionsFailing_throwsIOException() throws Exception { + InetAddress addr1 = InetAddress.getByName("127.0.0.1"); + InetAddress addr2 = InetAddress.getByName("127.0.0.1"); + + HappyEyeballsConnector connector = new HappyEyeballsConnector( + hostname -> new InetAddress[] { addr1, addr2 }, + Socket::new, + 10); + + // Connect to a port that's not listening + assertThrows(IOException.class, () -> connector.connect("test.example.com", 1, 100, IpVersion.IPV4_AND_IPV6)); + } + + @Test + void connectionAttemptDelay_isConfigurable() { + assertEquals(250, HappyEyeballsConnector.CONNECTION_ATTEMPT_DELAY_MS); + } +}