Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 183 additions & 40 deletions java/src/main/java/org/wildfly/openssl/OpenSSLContextSPI.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@

import static org.wildfly.openssl.OpenSSLEngine.isTLS13Supported;

import javax.net.ssl.KeyManager;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIMatcher;
import javax.net.ssl.SSLContextSpi;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509KeyManager;
import javax.net.ssl.X509TrustManager;

import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
Expand All @@ -34,23 +48,20 @@
import java.security.SecureRandom;
import java.security.cert.CertificateException;
import java.security.cert.CertificateFactory;
import java.security.cert.CertificateParsingException;
import java.security.cert.X509Certificate;
import java.util.Base64;
import java.util.Collection;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLContextSpi;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLServerSocketFactory;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509KeyManager;
import javax.net.ssl.X509TrustManager;
import java.util.regex.Pattern;

public abstract class OpenSSLContextSPI extends SSLContextSpi {

Expand Down Expand Up @@ -135,10 +146,16 @@ public static String[] getAvailableCipherSuites() {
OpenSSLContextSPI(final int value) throws SSLException {
this.supportedCiphers = value;
SSL.init();
ctx = makeSSLContext();
}

private long makeSSLContext() throws RuntimeException {
final long sslCtx;

try {
// Create SSL Context
try {
ctx = SSL.getInstance().makeSSLContext(value, SSL.SSL_MODE_COMBINED);
sslCtx = SSL.getInstance().makeSSLContext(this.supportedCiphers, SSL.SSL_MODE_COMBINED);
} catch (Exception e) {
// If the sslEngine is disabled on the AprLifecycleListener
// there will be an Exception here but there is no way to check
Expand All @@ -147,19 +164,19 @@ public static String[] getAvailableCipherSuites() {
}
try {
//disable unsafe renegotiation
SSL.getInstance().clearSSLContextOptions(ctx, SSL.SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION);
SSL.getInstance().clearSSLContextOptions(sslCtx, SSL.SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION);
} catch (UnsatisfiedLinkError e) {
// Ignore
}
// Disable compression
SSL.getInstance().setSSLContextOptions(ctx, SSL.SSL_OP_NO_COMPRESSION);
SSL.getInstance().setSSLContextOptions(sslCtx, SSL.SSL_OP_NO_COMPRESSION);

// Disable TLS Session Tickets (RFC4507) to protect perfect forward secrecy
SSL.getInstance().setSSLContextOptions(ctx, SSL.SSL_OP_NO_TICKET);
SSL.getInstance().setSSLContextOptions(sslCtx, SSL.SSL_OP_NO_TICKET);
} catch (Exception e) {
throw new RuntimeException(Messages.MESSAGES.failedToInitializeSslContext(), e);
}

return sslCtx;
}

/**
Expand All @@ -175,17 +192,31 @@ private synchronized void init(KeyManager[] kms, TrustManager[] tms) throws KeyM
return;
}

// a single subject can have multiple certificates for different algorithms, as
// aliases are required to be unique, the subject is the next best thing to establish
// some form of grouping, as a single context can have multiple certificates
// for different algorithms
final Map<String, Long> subjectToSSLContextMap = new LinkedHashMap<>();

// this simple map is used later on during certificate selection in the SNICallback,
// as a single ssl ctx can have multiple certificate, and SNI uses a requested
// hostname to allow the server to choose the certificate, we flatten everything
final Map<SNIMatcher, Long> x509CertificateToSSLContextMap = new LinkedHashMap<>();

try {
// Load Server key and certificate
X509KeyManager keyManager = chooseKeyManager(kms);
if (keyManager != null) {
for (String algorithm : ALGORITHMS) {

int counter = 0;

boolean rsa = algorithm.equals("RSA");
final String[] aliases = keyManager.getServerAliases(algorithm, null);
if (aliases != null && aliases.length != 0) {
for(String alias: aliases) {

counter++;
X509Certificate[] certificateChain = keyManager.getCertificateChain(alias);
PrivateKey key = keyManager.getPrivateKey(alias);
if(key == null || certificateChain == null || key.getEncoded() == null) {
Expand All @@ -207,12 +238,76 @@ private synchronized void init(KeyManager[] kms, TrustManager[] tms) throws KeyM
encodedIntermediaries[i - 1] = certificateChain[i].getEncoded();
}
X509Certificate certificate = certificateChain[0];
SSL.getInstance().setCertificate(ctx, certificate.getEncoded(), encodedIntermediaries, sb.toString().getBytes(StandardCharsets.US_ASCII), rsa ? SSL.SSL_AIDX_RSA : SSL.SSL_AIDX_DSA);
break;

// for a single subject multiple certificates with different algorithms can exist, if
// we already have a context for a specific subject, use it, otherwise generate a new context
// to be used with SNI
Long sslCtx = subjectToSSLContextMap.get(certificate.getSubjectX500Principal().getName());

// if no existing context could be found, and this is the first round, establish the
// "default" context
if (sslCtx == null) {
if (counter == 1) {
sslCtx = ctx;
} else {
sslCtx = makeSSLContext();
}

subjectToSSLContextMap.put(certificate.getSubjectX500Principal().getName(), sslCtx);
}

// set the certifcates to use for this context
SSL.getInstance().setCertificate(sslCtx, certificate.getEncoded(), encodedIntermediaries, sb.toString().getBytes(StandardCharsets.US_ASCII), rsa ? SSL.SSL_AIDX_RSA : SSL.SSL_AIDX_DSA);
x509CertificateToSSLContextMap.put(getHostnamesSNIMatcher(certificate), sslCtx);
}
}
}
}

if (x509CertificateToSSLContextMap.size() > 1) {
SSL.registerDefault(ctx, new SSL.SNICallBack() {

@Override
public long getSslContext(String sniHostName) {
if (sniHostName == null || sniHostName.isEmpty()) {
return ctx;
}

final String lowerSniHostname = sniHostName.toLowerCase(Locale.ENGLISH);
final SNIHostName lowerSniHostnameForMatcher = new SNIHostName(lowerSniHostname.getBytes(StandardCharsets.UTF_8));

String sniHostnameAsWildcard = null;
SNIHostName sniHostnameAsWildcardForMatcher = null;

final int idx = lowerSniHostname.indexOf('.');

if (idx > 0) {
sniHostnameAsWildcard = "*" + lowerSniHostname.substring(idx);
sniHostnameAsWildcardForMatcher = new SNIHostName(sniHostnameAsWildcard.getBytes(StandardCharsets.UTF_8));
}

long wildcardSSLContext = 0L;

for (final SNIMatcher hostnameMatcher: x509CertificateToSSLContextMap.keySet()) {
// find a ssl ctx by hostname, return if its a perfect match
if (hostnameMatcher.matches(lowerSniHostnameForMatcher)) {
return x509CertificateToSSLContextMap.get(hostnameMatcher);
}

// check if context might be good with as a wildcard cert, but
// there might be another ctx with a better match, so don't
// return it yet, let's wait until we checked all ctx avail
if (sniHostnameAsWildcard != null && hostnameMatcher.matches(sniHostnameAsWildcardForMatcher)) {
wildcardSSLContext = x509CertificateToSSLContextMap.get(hostnameMatcher);
}
}

// if we have a ssl ctx with a matching wildcard cert, prefer it
return wildcardSSLContext != 0L ? wildcardSSLContext : ctx;
}
});
}

/*
// Support Client Certificates
SSL.getInstance().setCACertificate(ctx,
Expand All @@ -227,29 +322,15 @@ private synchronized void init(KeyManager[] kms, TrustManager[] tms) throws KeyM
*/
// Client certificate verification

SSL.getInstance().setSessionCacheSize(ctx, DEFAULT_SESSION_CACHE_SIZE);
final X509TrustManager manager = chooseTrustManager(tms);
if(manager != null) {
SSL.getInstance().setCertVerifyCallback(ctx, (ssl, chain, cipherNo, server) -> {
X509Certificate[] peerCerts = certificates(chain);
Cipher cipher = Cipher.valueOf(cipherNo);
String auth = cipher == null ? "RSA" : cipher.getAu().toString();
try {
if(server) {
manager.checkClientTrusted(peerCerts, auth);
} else {
manager.checkServerTrusted(peerCerts, auth);
}
return true;
} catch (Exception e) {
if (LOG.isLoggable(Level.FINE)) {
LOG.log(Level.FINE, "Certificate verification failed", e);
}
}
return false;
});
}
final Set<Long> sslContexts = new HashSet<>(x509CertificateToSSLContextMap.values());

if (sslContexts.isEmpty()) {
configureSSLContext(tms, ctx);
} else {
for (long sslCtx : sslContexts) {
configureSSLContext(tms, sslCtx);
}
}

serverSessionContext = new OpenSSLServerSessionContext(ctx);
serverSessionContext.setSessionIdContext("test".getBytes(StandardCharsets.US_ASCII));
Expand All @@ -264,6 +345,31 @@ private synchronized void init(KeyManager[] kms, TrustManager[] tms) throws KeyM
}
}

private void configureSSLContext(final TrustManager[] tms, final long sslCtx) {
SSL.getInstance().setSessionCacheSize(sslCtx, DEFAULT_SESSION_CACHE_SIZE);
final X509TrustManager manager = chooseTrustManager(tms);
if(manager != null) {
SSL.getInstance().setCertVerifyCallback(sslCtx, (ssl, chain, cipherNo, server) -> {
X509Certificate[] peerCerts = certificates(chain);
Cipher cipher = Cipher.valueOf(cipherNo);
String auth = cipher == null ? "RSA" : cipher.getAu().toString();
try {
if(server) {
manager.checkClientTrusted(peerCerts, auth);
} else {
manager.checkServerTrusted(peerCerts, auth);
}
return true;
} catch (Exception e) {
if (LOG.isLoggable(Level.FINE)) {
LOG.log(Level.FINE, "Certificate verification failed", e);
}
}
return false;
});
}
}

private X509KeyManager chooseKeyManager(KeyManager[] tms) {
if(tms == null) {
return null;
Expand Down Expand Up @@ -442,6 +548,43 @@ public void sessionRemoved(byte[] session) {
serverSessionContext.remove(session);
}

private SNIMatcher getHostnamesSNIMatcher(final X509Certificate cert) {
if (cert == null) {
return SNIHostName.createSNIMatcher("");
}

final StringBuilder builder = new StringBuilder();

// extract all the valid "hostnames" from the SANs
try {
final Collection<List<?>> sansList = cert.getSubjectAlternativeNames();

if (sansList != null && !sansList.isEmpty()) {
for (final List<?> san : sansList) {
if ((Integer) san.get(0) == 2) { // DNS
final Object sanData = san.get(1);

if (sanData instanceof String) {
builder.append(Pattern.quote((String) sanData));
builder.append("|");
}
}
}
}
} catch (final CertificateParsingException ex) {
final String msg = String.format("Unable to parse SANS of own certificate [%s].", cert.getSubjectX500Principal().getName());
LOG.log(Level.WARNING, msg, ex);
}

final int len = builder.length();

if (len > 0 && builder.charAt(len - 1) == '|') {
builder.deleteCharAt(len - 1);
}

return SNIHostName.createSNIMatcher(builder.toString());
}

public static final class OpenSSLTLSContextSpi extends OpenSSLContextSPI {

public OpenSSLTLSContextSpi() throws SSLException {
Expand Down
65 changes: 32 additions & 33 deletions java/src/main/java/org/wildfly/openssl/OpenSSLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -1077,44 +1077,43 @@ private void beginHandshakeImplicitly() throws SSLException {
}
}

protected String alpnCallback(final String[] data) {
String version = SSL.getInstance().getVersion(ssl);
if((protocolSelector == null && applicationProtocols == null) || version == null || ! (version.equals("TLSv1.2") || version.equals("TLSv1.3"))) {
//only offer ALPN on TLS 1.2+, try and force http/1.1 if it is offered, otherwise fail the connection
//it seems wrong to hard code protocols in the SSL impl, but openssl does not really allow alpn to be enabled
//on a per engine basis
for(String i : data) {
if(i.equals("http/1.1")) {
return i;
}
}
selectedApplicationProtocol = "";
return null;
}
if (protocolSelector != null) {
selectedApplicationProtocol = protocolSelector.apply(OpenSSLEngine.this, Arrays.asList(data));
return selectedApplicationProtocol;
}

for (String proto : applicationProtocols) {
for (String clientProto : data) {
if (clientProto.equals(proto)) {
selectedApplicationProtocol = proto;
return proto;
}
}
}
selectedApplicationProtocol = "";
return null;
}

private void handshake() throws SSLException {
initSsl();
if (!alpnRegistered) {
alpnRegistered = true;
if (!isClientMode()) {
SSL.getInstance().setServerALPNCallback(ssl, new ServerALPNCallback() {
@Override
public String select(String[] data) {
String version = SSL.getInstance().getVersion(ssl);
if((protocolSelector == null && applicationProtocols == null) || version == null || ! (version.equals("TLSv1.2") || version.equals("TLSv1.3"))) {
//only offer ALPN on TLS 1.2+, try and force http/1.1 if it is offered, otherwise fail the connection
//it seems wrong to hard code protocols in the SSL impl, but openssl does not really allow alpn to be enabled
//on a per engine basis
for(String i : data) {
if(i.equals("http/1.1")) {
return i;
}
}
selectedApplicationProtocol = "";
return null;
}
if (protocolSelector != null) {
selectedApplicationProtocol = protocolSelector.apply(OpenSSLEngine.this, Arrays.asList(data));
return selectedApplicationProtocol;
}

for (String proto : applicationProtocols) {
for (String clientProto : data) {
if (clientProto.equals(proto)) {
selectedApplicationProtocol = proto;
return proto;
}
}
}
selectedApplicationProtocol = "";
return null;
}
});
SSL.getInstance().setServerALPNCallback(ssl, new OpenSSLEngineServerALPNCallback(this));
} else if(applicationProtocols != null){
SSL.getInstance().setAlpnProtos(ssl, applicationProtocols);
}
Expand Down
Loading