From 8e4d787565b9bc3929c9e24c82accd6c8b7233e7 Mon Sep 17 00:00:00 2001 From: Will Sargent Date: Sun, 6 Apr 2014 17:52:28 -0700 Subject: [PATCH] Move the hostname verification to after the SSL handshake has completed. --- .../asynchttpclient/async/BasicHttpsTest.java | 9 +- .../asynchttpclient/async/util/TestUtils.java | 96 ++++++++++++------- .../providers/grizzly/ConnectionManager.java | 52 +++++----- .../grizzly/GrizzlyAsyncHttpProvider.java | 2 +- .../netty/request/NettyConnectListener.java | 40 ++++++-- 5 files changed, 130 insertions(+), 69 deletions(-) diff --git a/api/src/test/java/org/asynchttpclient/async/BasicHttpsTest.java b/api/src/test/java/org/asynchttpclient/async/BasicHttpsTest.java index 7bc7233783..24675057e7 100644 --- a/api/src/test/java/org/asynchttpclient/async/BasicHttpsTest.java +++ b/api/src/test/java/org/asynchttpclient/async/BasicHttpsTest.java @@ -30,6 +30,7 @@ import javax.net.ssl.SSLHandshakeException; import javax.servlet.http.HttpServletResponse; +import java.io.IOException; import java.net.ConnectException; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -100,17 +101,19 @@ public void reconnectsAfterFailedCertificationPath() throws Exception { String body = "hello there"; // first request fails because server certificate is rejected + Throwable cause = null; try { c.preparePost(getTargetUrl()).setBody(body).setHeader("Content-Type", "text/html").execute().get(TIMEOUT, TimeUnit.SECONDS); } catch (final ExecutionException e) { - Throwable cause = e.getCause(); + cause = e.getCause(); if (cause instanceof ConnectException) { - assertNotNull(cause.getCause()); + //assertNotNull(cause.getCause()); assertTrue(cause.getCause() instanceof SSLHandshakeException, "Expected an SSLHandshakeException, got a " + cause.getCause()); } else { - assertTrue(cause instanceof SSLHandshakeException, "Expected an SSLHandshakeException, got a " + cause); + assertTrue(cause instanceof IOException, "Expected an IOException, got a " + cause); } } + assertNotNull(cause); trusted.set(true); diff --git a/api/src/test/java/org/asynchttpclient/async/util/TestUtils.java b/api/src/test/java/org/asynchttpclient/async/util/TestUtils.java index 75aab93fe0..7c9b46c50b 100644 --- a/api/src/test/java/org/asynchttpclient/async/util/TestUtils.java +++ b/api/src/test/java/org/asynchttpclient/async/util/TestUtils.java @@ -22,11 +22,7 @@ import org.eclipse.jetty.util.security.Constraint; import org.eclipse.jetty.util.ssl.SslContextFactory; -import javax.net.ssl.KeyManager; -import javax.net.ssl.KeyManagerFactory; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManager; -import javax.net.ssl.X509TrustManager; +import javax.net.ssl.*; import java.io.File; import java.io.FileNotFoundException; @@ -38,8 +34,7 @@ import java.net.URISyntaxException; import java.net.URL; import java.nio.charset.Charset; -import java.security.KeyStore; -import java.security.SecureRandom; +import java.security.*; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.ArrayList; @@ -149,7 +144,6 @@ public static void addHttpsConnector(Server server, int port) throws URISyntaxEx ServerConnector connector = new ServerConnector(server, new SslConnectionFactory(sslContextFactory, "http/1.1"), new HttpConnectionFactory(httpsConfig)); connector.setPort(port); - server.addConnector(connector); server.addConnector(connector); } @@ -191,21 +185,38 @@ private static void addAuthHandler(Server server, String auth, LoginAuthenticato server.setHandler(security); } + private static KeyManager[] createKeyManagers() throws GeneralSecurityException, IOException { + InputStream keyStoreStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("ssltest-cacerts.jks"); + char[] keyStorePassword = "changeit".toCharArray(); + KeyStore ks = KeyStore.getInstance("JKS"); + ks.load(keyStoreStream, keyStorePassword); + assert(ks.size() > 0); + + // Set up key manager factory to use our key store + char[] certificatePassword = "changeit".toCharArray(); + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(ks, certificatePassword); + + // Initialize the SSLContext to work with our key managers. + return kmf.getKeyManagers(); + } + + private static TrustManager[] createTrustManagers() throws GeneralSecurityException, IOException { + InputStream keyStoreStream = Thread.currentThread().getContextClassLoader().getResourceAsStream("ssltest-keystore.jks"); + char[] keyStorePassword = "changeit".toCharArray(); + KeyStore ks = KeyStore.getInstance("JKS"); + ks.load(keyStoreStream, keyStorePassword); + assert(ks.size() > 0); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + return tmf.getTrustManagers(); + } + public static SSLContext createSSLContext(AtomicBoolean trust) { try { - InputStream keyStoreStream = HostnameVerifierTest.class.getResourceAsStream("ssltest-cacerts.jks"); - char[] keyStorePassword = "changeit".toCharArray(); - KeyStore ks = KeyStore.getInstance("JKS"); - ks.load(keyStoreStream, keyStorePassword); - - // Set up key manager factory to use our key store - char[] certificatePassword = "changeit".toCharArray(); - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, certificatePassword); - - // Initialize the SSLContext to work with our key managers. - KeyManager[] keyManagers = kmf.getKeyManagers(); - TrustManager[] trustManagers = new TrustManager[] { dummyTrustManager(trust) }; + KeyManager[] keyManagers = createKeyManagers(); + TrustManager[] trustManagers = new TrustManager[] { dummyTrustManager(trust, (X509TrustManager) createTrustManagers()[0]) }; SecureRandom secureRandom = new SecureRandom(); SSLContext sslContext = SSLContext.getInstance("TLS"); @@ -217,21 +228,40 @@ public static SSLContext createSSLContext(AtomicBoolean trust) { } } - private static final TrustManager dummyTrustManager(final AtomicBoolean trust) { - return new X509TrustManager() { - public X509Certificate[] getAcceptedIssuers() { - return new X509Certificate[0]; - } + public static class DummyTrustManager implements X509TrustManager { - public void checkClientTrusted(X509Certificate[] chain, String authType) throws CertificateException { - } + private final X509TrustManager tm; + private final AtomicBoolean trust; - public void checkServerTrusted(X509Certificate[] chain, String authType) throws CertificateException { - if (!trust.get()) { - throw new CertificateException("Server certificate not trusted."); - } + public DummyTrustManager(final AtomicBoolean trust, final X509TrustManager tm) { + this.trust = trust; + this.tm = tm; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + tm.checkClientTrusted(chain, authType); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + if (!trust.get()) { + throw new CertificateException("Server certificate not trusted."); } - }; + tm.checkServerTrusted(chain, authType); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return tm.getAcceptedIssuers(); + } + } + + private static TrustManager dummyTrustManager(final AtomicBoolean trust, final X509TrustManager tm) { + return new DummyTrustManager(trust, tm); + } public static File getClasspathFile(String file) throws FileNotFoundException { diff --git a/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/ConnectionManager.java b/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/ConnectionManager.java index 41984d449f..e90f418664 100644 --- a/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/ConnectionManager.java +++ b/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/ConnectionManager.java @@ -20,18 +20,22 @@ import org.asynchttpclient.ConnectionPoolKeyStrategy; import org.asynchttpclient.ProxyServer; import org.asynchttpclient.Request; +import org.asynchttpclient.util.Base64; import org.glassfish.grizzly.CompletionHandler; import org.glassfish.grizzly.Connection; -import org.glassfish.grizzly.EmptyCompletionHandler; import org.glassfish.grizzly.Grizzly; import org.glassfish.grizzly.GrizzlyFuture; import org.glassfish.grizzly.attributes.Attribute; import org.glassfish.grizzly.connectionpool.EndpointKey; import org.glassfish.grizzly.filterchain.FilterChainBuilder; import org.glassfish.grizzly.impl.FutureImpl; +import org.glassfish.grizzly.ssl.SSLBaseFilter; +import org.glassfish.grizzly.ssl.SSLFilter; import org.glassfish.grizzly.ssl.SSLUtils; import org.glassfish.grizzly.utils.Futures; import org.glassfish.grizzly.utils.IdleTimeoutFilter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLSession; @@ -51,6 +55,8 @@ public class ConnectionManager { + private final static Logger LOGGER = LoggerFactory.getLogger(ConnectionManager.class); + private static final Attribute DO_NOT_CACHE = Grizzly.DEFAULT_ATTRIBUTE_BUILDER.createAttribute(ConnectionManager.class .getName()); private final ConnectionPool connectionPool; @@ -60,13 +66,15 @@ public class ConnectionManager { private final FilterChainBuilder secureBuilder; private final FilterChainBuilder nonSecureBuilder; private final boolean asyncConnect; + private final SSLFilter sslFilter; // ------------------------------------------------------------ Constructors ConnectionManager(final GrizzlyAsyncHttpProvider provider,// final ConnectionPool connectionPool,// final FilterChainBuilder secureBuilder,// - final FilterChainBuilder nonSecureBuilder) { + final FilterChainBuilder nonSecureBuilder,// + final SSLFilter sslFilter) { this.provider = provider; final AsyncHttpClientConfig config = provider.getClientConfig(); @@ -87,6 +95,7 @@ public class ConnectionManager { AsyncHttpProviderConfig providerConfig = config.getAsyncHttpProviderConfig(); asyncConnect = providerConfig instanceof GrizzlyAsyncHttpProviderConfig ? GrizzlyAsyncHttpProviderConfig.class.cast(providerConfig) .isAsyncConnectMode() : false; + this.sslFilter = sslFilter; } // ---------------------------------------------------------- Public Methods @@ -95,7 +104,7 @@ public void doTrackedConnection(final Request request,// final GrizzlyResponseFuture requestFuture,// final CompletionHandler connectHandler) throws IOException { final EndpointKey key = getEndPointKey(request, requestFuture.getProxyServer()); - CompletionHandler handler = wrapHandler(request, getVerifier(), connectHandler); + CompletionHandler handler = wrapHandler(request, getVerifier(), connectHandler, sslFilter); if (asyncConnect) { connectionPool.take(key, handler); } else { @@ -136,37 +145,32 @@ public Connection obtainConnection(final Request request, final GrizzlyResponseF // --------------------------------------------------Package Private Methods static CompletionHandler wrapHandler(final Request request, final HostnameVerifier verifier, - final CompletionHandler delegate) { + final CompletionHandler delegate, final SSLFilter sslFilter) { final URI uri = request.getURI(); if (Utils.isSecure(uri) && verifier != null) { - return new EmptyCompletionHandler() { + SSLBaseFilter.HandshakeListener handshakeListener = new SSLBaseFilter.HandshakeListener() { @Override - public void completed(Connection result) { - final String host = uri.getHost(); - final SSLSession session = SSLUtils.getSSLEngine(result).getSession(); - if (!verifier.verify(host, session)) { - failed(new ConnectException("Host name verification failed for host " + host)); - } else { - delegate.completed(result); - } - + public void onStart(Connection connection) { + // do nothing + LOGGER.debug("SSL Handshake onStart: "); } @Override - public void cancelled() { - delegate.cancelled(); - } + public void onComplete(Connection connection) { + sslFilter.removeHandshakeListener(this); - @Override - public void failed(Throwable throwable) { - delegate.failed(throwable); - } + final String host = uri.getHost(); + final SSLSession session = SSLUtils.getSSLEngine(connection).getSession(); + LOGGER.debug("SSL Handshake onComplete: session = {}, id = {}, isValid = {}, host = {}", session.toString(), Base64.encode(session.getId()), session.isValid(), host); - @Override - public void updated(Connection result) { - delegate.updated(result); + if (!verifier.verify(host, session)) { + connection.close(); // XXX what's the correct way to kill a connection? + IOException e = new ConnectException("Host name verification failed for host " + host); + delegate.failed(e); + } } }; + sslFilter.addHandshakeListener(handshakeListener); } return delegate; } diff --git a/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/GrizzlyAsyncHttpProvider.java b/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/GrizzlyAsyncHttpProvider.java index a7d221093f..722a1b7e88 100644 --- a/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/GrizzlyAsyncHttpProvider.java +++ b/providers/grizzly/src/main/java/org/asynchttpclient/providers/grizzly/GrizzlyAsyncHttpProvider.java @@ -334,7 +334,7 @@ public void onTimeout(Connection connection) { } else { pool = null; } - connectionManager = new ConnectionManager(this, pool, secure, nonSecure); + connectionManager = new ConnectionManager(this, pool, secure, nonSecure, filter); } diff --git a/providers/netty/src/main/java/org/asynchttpclient/providers/netty/request/NettyConnectListener.java b/providers/netty/src/main/java/org/asynchttpclient/providers/netty/request/NettyConnectListener.java index 7bfc6d8717..c023406e44 100644 --- a/providers/netty/src/main/java/org/asynchttpclient/providers/netty/request/NettyConnectListener.java +++ b/providers/netty/src/main/java/org/asynchttpclient/providers/netty/request/NettyConnectListener.java @@ -16,10 +16,13 @@ */ package org.asynchttpclient.providers.netty.request; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; import org.asynchttpclient.AsyncHttpClientConfig; import org.asynchttpclient.providers.netty.channel.Channels; import org.asynchttpclient.providers.netty.future.NettyResponseFuture; import org.asynchttpclient.providers.netty.future.StackTraceInspector; +import org.asynchttpclient.util.Base64; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -28,6 +31,10 @@ import io.netty.channel.ChannelFutureListener; import io.netty.handler.ssl.SslHandler; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLSession; import java.net.ConnectException; import java.nio.channels.ClosedChannelException; @@ -54,15 +61,32 @@ public NettyResponseFuture future() { public void onFutureSuccess(final Channel channel) throws ConnectException { Channels.setDefaultAttribute(channel, future); - SslHandler sslHandler = Channels.getSslHandler(channel); - - if (sslHandler != null && !config.getHostnameVerifier().verify(future.getURI().getHost(), sslHandler.engine().getSession())) { - ConnectException exception = new ConnectException("HostnameVerifier exception"); - future.abort(exception); - throw exception; + final HostnameVerifier hostnameVerifier = config.getHostnameVerifier(); + final SslHandler sslHandler = Channels.getSslHandler(channel); + if (hostnameVerifier != null && sslHandler != null) { + final String host = future.getURI().getHost(); + sslHandler.handshakeFuture().addListener(new GenericFutureListener>() { + @Override + public void operationComplete(Future handshakeFuture) throws Exception { + if (handshakeFuture.isSuccess()) { + Channel channel = (Channel) handshakeFuture.getNow(); + SSLEngine engine = sslHandler.engine(); + SSLSession session = engine.getSession(); + + LOGGER.debug("onFutureSuccess: session = {}, id = {}, isValid = {}, host = {}", session.toString(), Base64.encode(session.getId()), session.isValid(), host); + if (!hostnameVerifier.verify(host, session)) { + ConnectException exception = new ConnectException("HostnameVerifier exception"); + future.abort(exception); + throw exception; + } else { + requestSender.writeRequest(future, channel); + } + } + } + }); + } else { + requestSender.writeRequest(future, channel); } - - requestSender.writeRequest(future, channel); } public void onFutureFailure(Channel channel, Throwable cause) {