diff --git a/buildpack/spring-boot-buildpack-platform/src/main/java/org/springframework/boot/buildpack/platform/io/InspectedContent.java b/buildpack/spring-boot-buildpack-platform/src/main/java/org/springframework/boot/buildpack/platform/io/InspectedContent.java index 0d135a6aae14..56fd60d247ab 100644 --- a/buildpack/spring-boot-buildpack-platform/src/main/java/org/springframework/boot/buildpack/platform/io/InspectedContent.java +++ b/buildpack/spring-boot-buildpack-platform/src/main/java/org/springframework/boot/buildpack/platform/io/InspectedContent.java @@ -148,6 +148,11 @@ private InspectingOutputStream(Inspector[] inspectors) { this.delegate = new ByteArrayOutputStream(); } + @Override + public void close() throws IOException { + this.delegate.close(); + } + @Override public void write(int b) throws IOException { this.singleByteBuffer[0] = (byte) (b & 0xFF); diff --git a/module/spring-boot-rsocket/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java b/module/spring-boot-rsocket/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java index 4b55599397ad..5c300fa7819d 100644 --- a/module/spring-boot-rsocket/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java +++ b/module/spring-boot-rsocket/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java @@ -269,13 +269,13 @@ private TcpServer apply(TcpServer server) { } - private static final class HttpServerSslCustomizer extends SslCustomizer { + static final class HttpServerSslCustomizer extends SslCustomizer { private final SslProvider sslProvider; private final Map serverNameSslProviders; - private HttpServerSslCustomizer(Ssl.@Nullable ClientAuth clientAuth, SslBundle sslBundle, + HttpServerSslCustomizer(Ssl.@Nullable ClientAuth clientAuth, SslBundle sslBundle, Map serverNameSslBundles) { super(Ssl.ClientAuth.map(clientAuth, ClientAuth.NONE, ClientAuth.OPTIONAL, ClientAuth.REQUIRE)); this.sslProvider = createSslProvider(sslBundle); @@ -287,11 +287,13 @@ private HttpServer apply(HttpServer server) { } private void applySecurity(SslContextSpec spec) { - spec.sslContext(this.sslProvider.getSslContext()).setSniAsyncMappings((serverName, promise) -> { - SslProvider provider = (serverName != null) ? this.serverNameSslProviders.get(serverName) - : this.sslProvider; - return promise.setSuccess(provider); - }); + spec.sslContext(this.sslProvider.getSslContext()) + .setSniAsyncMappings((serverName, promise) -> promise.setSuccess(getSslProvider(serverName))); + } + + SslProvider getSslProvider(@Nullable String serverName) { + return (serverName != null) ? this.serverNameSslProviders.getOrDefault(serverName, this.sslProvider) + : this.sslProvider; } private Map createServerNameSslProviders(Map serverNameSslBundles) { diff --git a/module/spring-boot-rsocket/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java b/module/spring-boot-rsocket/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java index 1a9232e2c6bb..7f102a592418 100644 --- a/module/spring-boot-rsocket/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java +++ b/module/spring-boot-rsocket/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java @@ -20,6 +20,8 @@ import java.nio.channels.ClosedChannelException; import java.time.Duration; import java.util.Arrays; +import java.util.Collections; +import java.util.Map; import java.util.concurrent.Callable; import io.netty.buffer.PooledByteBufAllocator; @@ -267,6 +269,26 @@ void websocketTransportBasicSslCertificateFromFileSystemWithBundle(@ResourcePath testBasicSslWithPemCertificateFromBundle(testCert, testKey, testCert, Transport.WEBSOCKET); } + @Test + @WithPackageResources({ "test-cert.pem", "test-key.pem" }) + void websocketTransportSslProviderFallsBackToDefaultWhenServerNameIsUnmapped() { + SslBundle defaultBundle = createBundle("test-cert.pem", "test-key.pem"); + SslBundle mappedBundle = createBundle("test-cert.pem", "test-key.pem"); + NettyRSocketServerFactory.HttpServerSslCustomizer customizer = new NettyRSocketServerFactory.HttpServerSslCustomizer( + Ssl.ClientAuth.NONE, defaultBundle, Map.of("mapped.example", mappedBundle)); + assertThat(customizer.getSslProvider("unmapped.example")).isSameAs(customizer.getSslProvider(null)); + } + + @Test + @WithPackageResources({ "test-cert.pem", "test-key.pem" }) + @SuppressWarnings("NullAway") // Test null check + void websocketTransportSslProviderReturnsDefaultWhenServerNameIsNull() { + SslBundle defaultBundle = createBundle("test-cert.pem", "test-key.pem"); + NettyRSocketServerFactory.HttpServerSslCustomizer customizer = new NettyRSocketServerFactory.HttpServerSslCustomizer( + Ssl.ClientAuth.NONE, defaultBundle, Collections.emptyMap()); + assertThat(customizer.getSslProvider(null)).isNotNull(); + } + private void checkEchoRequest() { String payload = "test payload"; assertThat(this.requester).isNotNull(); @@ -338,6 +360,12 @@ private void testBasicSslWithPemCertificateFromBundle(String certificate, String checkEchoRequest(); } + private static SslBundle createBundle(String certificate, String certificatePrivateKey) { + PemSslStoreDetails keyStoreDetails = PemSslStoreDetails.forCertificate("classpath:" + certificate) + .withPrivateKey("classpath:" + certificatePrivateKey); + return SslBundle.of(new PemSslStoreBundle(keyStoreDetails, null)); + } + @Test void tcpTransportSslRejectsInsecureClient() { NettyRSocketServerFactory factory = getFactory();