diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index bf8a864902c..8ed6c4ad0e4 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -55,6 +55,10 @@ public final class EquivalentAddressGroup { */ public static final Attributes.Key ATTR_LOCALITY_NAME = Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY"); + /** Name associated with individual address, if available (e.g., DNS name). */ + @Attr + public static final Attributes.Key ATTR_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); private final List addrs; private final Attributes attrs; diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java index 91d17de93cb..e0171e8d357 100644 --- a/core/src/main/java/io/grpc/internal/CertificateUtils.java +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -26,6 +26,7 @@ import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.util.Collection; +import java.util.List; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.security.auth.x500.X500Principal; @@ -34,6 +35,16 @@ * Contains certificate/key PEM file utility method(s) for internal usage. */ public final class CertificateUtils { + private static Class x509ExtendedTrustManagerClass; + + static { + try { + x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + } catch (ClassNotFoundException e) { + // Will disallow per-rpc authority override via call option. + } + } + /** * Creates X509TrustManagers using the provided CA certs. */ @@ -71,6 +82,17 @@ public static TrustManager[] createTrustManager(InputStream rootCerts) return trustManagerFactory.getTrustManagers(); } + public static TrustManager getX509ExtendedTrustManager(List trustManagers) { + if (x509ExtendedTrustManagerClass != null) { + for (TrustManager trustManager : trustManagers) { + if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { + return trustManager; + } + } + } + return null; + } + private static X509Certificate[] getX509Certificates(InputStream inputStream) throws CertificateException { CertificateFactory factory = CertificateFactory.getInstance("X.509"); diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 039ea6c4f24..ba4fde8d15c 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -26,6 +26,8 @@ import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; import java.util.concurrent.Executor; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; /** * Internal accessor for {@link ProtocolNegotiators}. @@ -38,13 +40,18 @@ private InternalProtocolNegotiators() {} * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. - * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks + * + * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, ObjectPool executorPool, - Optional handshakeCompleteRunnable) { + Optional handshakeCompleteRunnable, + TrustManager extendedX509TrustManager, + String sni, + boolean isXdsTarget) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, - executorPool, handshakeCompleteRunnable, null); + executorPool, handshakeCompleteRunnable, (X509TrustManager) extendedX509TrustManager, sni, + isXdsTarget); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -62,17 +69,19 @@ public void close() { negotiator.close(); } } - + return new TlsNegotiator(); } - + /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null, Optional.absent()); + public static InternalProtocolNegotiator.ProtocolNegotiator tls( + SslContext sslContext, String sni, boolean isXdsTarget, + TrustManager extendedX509TrustManager) { + return tls(sslContext, null, Optional.absent(), extendedX509TrustManager, sni, isXdsTarget); } /** diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 46566eaca1a..29bb7be51bc 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -652,7 +652,8 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null); + return ProtocolNegotiators.tls( + sslContext, executorPool, Optional.absent(), null, null, false); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 77308c76ace..216e45f8bf1 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; import io.grpc.CallCredentials; @@ -102,15 +103,6 @@ final class ProtocolNegotiators { private static final EnumSet understoodServerTlsFeatures = EnumSet.of( TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); - private static Class x509ExtendedTrustManagerClass; - - static { - try { - x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); - } catch (ClassNotFoundException e) { - // Will disallow per-rpc authority override via call option. - } - } private ProtocolNegotiators() { } @@ -139,7 +131,7 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { trustManagers = tlsCreds.getTrustManagers(); } else if (tlsCreds.getRootCertificates() != null) { trustManagers = Arrays.asList(CertificateUtils.createTrustManager( - new ByteArrayInputStream(tlsCreds.getRootCertificates()))); + new ByteArrayInputStream(tlsCreds.getRootCertificates()))); } else { // else use system default TrustManagerFactory tmf = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm()); @@ -147,17 +139,10 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { trustManagers = Arrays.asList(tmf.getTrustManagers()); } builder.trustManager(new FixedTrustManagerFactory(trustManagers)); - TrustManager x509ExtendedTrustManager = null; - if (x509ExtendedTrustManagerClass != null) { - for (TrustManager trustManager : trustManagers) { - if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { - x509ExtendedTrustManager = trustManager; - break; - } - } - } + TrustManager x509ExtendedTrustManager = + CertificateUtils.getX509ExtendedTrustManager(trustManagers); return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(), - (X509TrustManager) x509ExtendedTrustManager)); + (X509TrustManager) x509ExtendedTrustManager)); } catch (SSLException | GeneralSecurityException ex) { log.log(Level.FINE, "Exception building SslContext", ex); return FromChannelCredentialsResult.error( @@ -473,7 +458,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); if (!sslContext.applicationProtocolNegotiator().protocols().contains( - sslHandler.applicationProtocol())) { + sslHandler.applicationProtocol())) { logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null); ctx.fireExceptionCaught(unavailableException( "Failed protocol negotiation: Unable to find compatible protocol")); @@ -579,7 +564,7 @@ static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { @@ -587,12 +572,19 @@ public ClientTlsProtocolNegotiator(SslContext sslContext, } this.handshakeCompleteRunnable = handshakeCompleteRunnable; this.x509ExtendedTrustManager = x509ExtendedTrustManager; + this.sni = sni; + this.isXdsTarget = isXdsTarget; } private final SslContext sslContext; private final ObjectPool executorPool; private final Optional handshakeCompleteRunnable; private final X509TrustManager x509ExtendedTrustManager; + private final String sni; + // For xds targets there may be no SNI determined, and no SNI may be sent in that case. + // Non xds-targets will always use channel authority for SNI. This field is used to handle + // the two cases differently. + private final boolean isXdsTarget; private Executor executor; @Override @@ -604,9 +596,10 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); - ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger, handshakeCompleteRunnable, this, - x509ExtendedTrustManager); + ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, + isXdsTarget ? sni : grpcHandler.getAuthority(), + this.executor, negotiationLogger, handshakeCompleteRunnable, null, + x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -633,16 +626,21 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final X509TrustManager x509ExtendedTrustManager; private SSLEngine sslEngine; - ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, + ClientTlsHandler(ChannelHandler next, SslContext sslContext, String sniHostPort, Executor executor, ChannelLogger negotiationLogger, Optional handshakeCompleteRunnable, ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); - HostPort hostPort = parseAuthority(authority); - this.host = hostPort.host; - this.port = hostPort.port; + if (!Strings.isNullOrEmpty(sniHostPort)) { + HostPort hostPort = parseAuthority(sniHostPort); + this.host = hostPort.host; + this.port = hostPort.port; + } else { + this.host = null; + this.port = 0; + } this.executor = executor; this.handshakeCompleteRunnable = handshakeCompleteRunnable; this.x509ExtendedTrustManager = x509ExtendedTrustManager; @@ -651,7 +649,11 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { @Override @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + if (host != null) { + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + } else { + sslEngine = sslContext.newEngine(ctx.alloc()); + } SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -746,13 +748,14 @@ static HostPort parseAuthority(String authority) { * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, - x509ExtendedTrustManager); + x509ExtendedTrustManager, sni, isXdsTarget); } /** @@ -762,7 +765,7 @@ public static ProtocolNegotiator tls(SslContext sslContext, */ public static ProtocolNegotiator tls(SslContext sslContext, X509TrustManager x509ExtendedTrustManager) { - return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager); + return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null, false); } public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, @@ -1060,8 +1063,8 @@ static final class PlaintextHandler extends ProtocolNegotiationHandler { protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); Attributes attrs = existingPne.getAttributes().toBuilder() - .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) - .build(); + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) + .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs)); fireProtocolNegotiationEvent(ctx); } @@ -1221,4 +1224,4 @@ public String getPeerHost() { return peerHost; } } -} +} \ No newline at end of file diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 55abe29e93a..10cf7480868 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -877,7 +877,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .keyManager(clientCert, clientKey) .build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, - Optional.absent(), null); + Optional.absent(), null, null, false); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 638fe960a32..1e1fa07c228 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -918,7 +918,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + null, null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -957,7 +957,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + null, null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -982,7 +982,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + null, null); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -1011,7 +1011,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", null, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + null, null); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -1023,12 +1023,6 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception { .isEqualTo(Status.Code.UNAVAILABLE); } - private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException { - return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager( - TlsTesting.loadCert("ca.pem")).build(), - null, Optional.absent(), null); - } - @Test public void engineLog() { ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); @@ -1277,7 +1271,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception { } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, - null, Optional.absent(), null); + null, Optional.absent(), null, null, false); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 03976cc7d7b..e1fe04d26d3 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -38,7 +38,6 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler; -import io.grpc.s2a.internal.handshaker.S2AIdentity; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -259,7 +258,8 @@ public void onSuccess(SslContext sslContext) { public void run() { s2aStub.close(); } - })) + }), + null, null, false) .newHandler(grpcHandler); // Delegate the rest of the handshake to the TLS handler. and remove the diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index fba66e2e8d7..a33e97b3317 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -152,8 +152,8 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { childSwitchLb.handleResolvedAddresses( resolvedAddresses.toBuilder() .setAttributes(attributes.toBuilder() - .set(NameResolver.ATTR_BACKEND_SERVICE, cluster) - .build()) + .set(NameResolver.ATTR_BACKEND_SERVICE, cluster) + .build()) .setLoadBalancingPolicyConfig(config.childConfig) .build()); return Status.OK; @@ -241,9 +241,9 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference); if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) { String hostname = args.getAddresses().get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME); + .get(EquivalentAddressGroup.ATTR_ADDRESS_NAME); if (hostname != null) { - attrsBuilder.set(XdsAttributes.ATTR_ADDRESS_NAME, hostname); + attrsBuilder.set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, hostname); } } args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); @@ -320,7 +320,7 @@ private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAt (lrsServerInfo == null) ? null : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, - edsServiceName, locality); + edsServiceName, locality); return new ClusterLocality(localityStats, localityName); } @@ -362,7 +362,7 @@ private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsCo sslContextProviderSupplier = tlsContext != null ? new SslContextProviderSupplier(tlsContext, - (TlsContextManager) xdsClient.getSecurityConfig()) + (TlsContextManager) xdsClient.getSecurityConfig()) : null; } @@ -377,8 +377,9 @@ private class RequestLimitingSubchannelPicker extends SubchannelPicker { private final Map filterMetadata; private RequestLimitingSubchannelPicker(SubchannelPicker delegate, - List dropPolicies, long maxConcurrentRequests, - Map filterMetadata) { + List dropPolicies, + long maxConcurrentRequests, + Map filterMetadata) { this.delegate = delegate; this.dropPolicies = dropPolicies; this.maxConcurrentRequests = maxConcurrentRequests; @@ -438,7 +439,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { result = PickResult.withSubchannel(result.getSubchannel(), result.getStreamTracerFactory(), result.getSubchannel().getAttributes().get( - XdsAttributes.ATTR_ADDRESS_NAME)); + EquivalentAddressGroup.ATTR_ADDRESS_NAME)); } } return result; @@ -542,4 +543,4 @@ void release() { } } } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index 06fafbb6cf1..50b8097fca9 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -194,7 +194,7 @@ StatusOr edsUpdateToResult( .set(XdsAttributes.ATTR_LOCALITY_WEIGHT, localityLbInfo.localityWeight()) .set(XdsAttributes.ATTR_SERVER_WEIGHT, weight) - .set(XdsAttributes.ATTR_ADDRESS_NAME, endpoint.hostname()) + .set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, endpoint.hostname()) .build(); EquivalentAddressGroup eag; if (config.isHttp11ProxyAvailable()) { diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 9c2ee641423..e6ef72c15b8 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -73,20 +73,55 @@ public int hashCode() { public static final class UpstreamTlsContext extends BaseTlsContext { + private final String sni; + private final boolean autoHostSni; + private final boolean autoSniSanValidation; + @VisibleForTesting public UpstreamTlsContext(CommonTlsContext commonTlsContext) { super(commonTlsContext); + this.sni = null; + this.autoHostSni = false; + this.autoSniSanValidation = false; + } + + @VisibleForTesting + public UpstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + upstreamTlsContext) { + super(upstreamTlsContext.getCommonTlsContext()); + this.sni = upstreamTlsContext.getSni(); + this.autoHostSni = upstreamTlsContext.getAutoHostSni(); + this.autoSniSanValidation = upstreamTlsContext.getAutoSniSanValidation(); } public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { - return new UpstreamTlsContext(upstreamTlsContext.getCommonTlsContext()); + UpstreamTlsContext o = new UpstreamTlsContext(upstreamTlsContext); + return o; + } + + public String getSni() { + return sni; + } + + public boolean getAutoHostSni() { + return autoHostSni; + } + + public boolean getAutoSniSanValidation() { + return autoSniSanValidation; } @Override public String toString() { - return "UpstreamTlsContext{" + "commonTlsContext=" + commonTlsContext + '}'; + return "UpstreamTlsContext{" + + "commonTlsContext=" + commonTlsContext + + "sni=" + sni + + "\nauto_host_sni=" + autoHostSni + + "\nauto_sni_san_validation=" + autoSniSanValidation + + "}"; } } diff --git a/xds/src/main/java/io/grpc/xds/TlsContextManager.java b/xds/src/main/java/io/grpc/xds/TlsContextManager.java index 772a6cff102..4e9470d91e6 100644 --- a/xds/src/main/java/io/grpc/xds/TlsContextManager.java +++ b/xds/src/main/java/io/grpc/xds/TlsContextManager.java @@ -30,7 +30,7 @@ SslContextProvider findOrCreateServerSslContextProvider( /** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */ SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext); + UpstreamTlsContext upstreamTlsContext, String sni); /** * Releases an instance of the given client-side {@link SslContextProvider}. @@ -41,7 +41,8 @@ SslContextProvider findOrCreateClientSslContextProvider( *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. */ - SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider); + SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider, + String sni); /** * Releases an instance of the given server-side {@link SslContextProvider}. diff --git a/xds/src/main/java/io/grpc/xds/XdsAttributes.java b/xds/src/main/java/io/grpc/xds/XdsAttributes.java index 2e165201e5f..0e770173219 100644 --- a/xds/src/main/java/io/grpc/xds/XdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/XdsAttributes.java @@ -88,11 +88,6 @@ final class XdsAttributes { static final Attributes.Key ATTR_SERVER_WEIGHT = Attributes.Key.create("io.grpc.xds.XdsAttributes.serverWeight"); - /** Name associated with individual address, if available (e.g., DNS name). */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_ADDRESS_NAME = - Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); - /** * Filter chain match for network filters. */ diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index c3bc7c2e326..20001b6558d 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -525,7 +525,7 @@ public void onClose(Status status, Metadata trailers) { Result.newBuilder() .setConfig(config) .setInterceptor(combineInterceptors( - ImmutableList.of(filters, new ClusterSelectionInterceptor()))) + ImmutableList.of(new ClusterSelectionInterceptor(), filters))) .build(); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java index 37d289c1c47..a66644786db 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java @@ -20,10 +20,12 @@ import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; +import java.util.AbstractMap; /** Factory to create client-side SslContextProvider from UpstreamTlsContext. */ final class ClientSslContextProviderFactory - implements ValueFactory { + implements ValueFactory, + SslContextProvider> { private BootstrapInfo bootstrapInfo; private final CertProviderClientSslContextProviderFactory @@ -41,9 +43,10 @@ final class ClientSslContextProviderFactory /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override - public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { + public SslContextProvider create( + AbstractMap.SimpleImmutableEntry key) { return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, + key.getKey(), key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 6bf66d022ff..f2d40696c9f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -30,9 +30,11 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** Base class for dynamic {@link SslContextProvider}s. */ @Internal @@ -40,7 +42,8 @@ public abstract class DynamicSslContextProvider extends SslContextProvider { protected final List pendingCallbacks = new ArrayList<>(); @Nullable protected final CertificateValidationContext staticCertificateValidationContext; - @Nullable protected SslContext sslContext; + @Nullable protected AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX509TrustManager; protected DynamicSslContextProvider( BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) { @@ -49,15 +52,19 @@ protected DynamicSslContextProvider( } @Nullable - public SslContext getSslContext() { - return sslContext; + public AbstractMap.SimpleImmutableEntry + getSslContextAndExtendedX509TrustManager() { + return sslContextAndExtendedX509TrustManager; } protected abstract CertificateValidationContext generateCertificateValidationContext(); - /** Gets a server or client side SslContextBuilder. */ - protected abstract SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContext) + /** + * Gets a server or client side SslContextBuilder. + */ + protected abstract AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndExtendedX509TrustManager( + CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException; // this gets called only when requested secrets are ready... @@ -65,7 +72,8 @@ protected final void updateSslContext() { try { CertificateValidationContext localCertValidationContext = generateCertificateValidationContext(); - SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); + AbstractMap.SimpleImmutableEntry sslContextBuilderAndTm = + getSslContextBuilderAndExtendedX509TrustManager(localCertValidationContext); CommonTlsContext commonTlsContext = getCommonTlsContext(); if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { List alpnList = commonTlsContext.getAlpnProtocolsList(); @@ -75,16 +83,18 @@ protected final void updateSslContext() { ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, alpnList); - sslContextBuilder.applicationProtocolConfig(apn); + sslContextBuilderAndTm.getKey().applicationProtocolConfig(apn); } List pendingCallbacksCopy; - SslContext sslContextCopy; + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX09TrustManagerCopy; synchronized (pendingCallbacks) { - sslContext = sslContextBuilder.build(); - sslContextCopy = sslContext; + sslContextAndExtendedX509TrustManager = new AbstractMap.SimpleImmutableEntry<>( + sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); + sslContextAndExtendedX09TrustManagerCopy = sslContextAndExtendedX509TrustManager; pendingCallbacksCopy = clonePendingCallbacksAndClear(); } - makePendingCallbacks(sslContextCopy, pendingCallbacksCopy); + makePendingCallbacks(sslContextAndExtendedX09TrustManagerCopy, pendingCallbacksCopy); } catch (Exception e) { onError(Status.fromThrowable(e)); throw new RuntimeException(e); @@ -92,12 +102,13 @@ protected final void updateSslContext() { } protected final void callPerformCallback( - Callback callback, final SslContext sslContextCopy) { + Callback callback, + final AbstractMap.SimpleImmutableEntry sslContextAndTmCopy) { performCallback( new SslContextGetter() { @Override - public SslContext get() { - return sslContextCopy; + public AbstractMap.SimpleImmutableEntry get() { + return sslContextAndTmCopy; } }, callback @@ -108,10 +119,10 @@ public SslContext get() { public final void addCallback(Callback callback) { checkNotNull(callback, "callback"); // if there is a computed sslContext just send it - SslContext sslContextCopy = null; + AbstractMap.SimpleImmutableEntry sslContextCopy = null; synchronized (pendingCallbacks) { - if (sslContext != null) { - sslContextCopy = sslContext; + if (sslContextAndExtendedX509TrustManager != null) { + sslContextCopy = sslContextAndExtendedX509TrustManager; } else { pendingCallbacks.add(callback); } @@ -122,9 +133,11 @@ public final void addCallback(Callback callback) { } private final void makePendingCallbacks( - SslContext sslContextCopy, List pendingCallbacksCopy) { + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX509TrustManagerCopy, + List pendingCallbacksCopy) { for (Callback callback : pendingCallbacksCopy) { - callPerformCallback(callback, sslContextCopy); + callPerformCallback(callback, sslContextAndExtendedX509TrustManagerCopy); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index c34fab74032..86d28a0554b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -19,7 +19,9 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -29,6 +31,9 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; +import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -36,12 +41,14 @@ import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** * Provides client and server side gRPC {@link ProtocolNegotiator}s to provide the SSL @@ -60,14 +67,14 @@ private SecurityProtocolNegotiators() { private static final AsciiString SCHEME = AsciiString.of("http"); public static final Attributes.Key - ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); + ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); /** Attribute key for SslContextProviderSupplier (used from client) for a subchannel. */ @Grpc.TransportAttr public static final Attributes.Key ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); + Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); /** * Returns a {@link InternalProtocolNegotiator.ClientFactory}. @@ -142,7 +149,8 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { fallbackProtocolNegotiator, "No TLS config and no fallbackProtocolNegotiator!"); return fallbackProtocolNegotiator.newHandler(grpcHandler); } - return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier); + return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier, + grpcHandler.getEagAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)); } @Override @@ -185,10 +193,12 @@ static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; + private final String sni; ClientSecurityHandler( GrpcHttp2ConnectionHandler grpcHandler, - SslContextProviderSupplier sslContextProviderSupplier) { + SslContextProviderSupplier sslContextProviderSupplier, + String endpointHostname) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -202,6 +212,19 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; this.sslContextProviderSupplier = sslContextProviderSupplier; + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); + if (CertificateUtils.isXdsSniEnabled) { + sni = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) + ? endpointHostname : upstreamTlsContext.getSni(); + } else { + sni = grpcHandler.getAuthority(); + } + } + + @VisibleForTesting + String getSni() { + return sni; } @Override @@ -213,7 +236,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { if (ctx.isRemoved()) { return; } @@ -222,7 +246,9 @@ public void updateSslContext(SslContext sslContext) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls( + sslContextAndTm.getKey(), sni, true, sslContextAndTm.getValue()) + .newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -234,8 +260,8 @@ public void updateSslContext(SslContext sslContext) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + sni); } @Override @@ -325,13 +351,13 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc @VisibleForTesting static final class ServerSecurityHandler - extends InternalProtocolNegotiators.ProtocolNegotiationHandler { + extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; ServerSecurityHandler( - GrpcHttp2ConnectionHandler grpcHandler, - SslContextProviderSupplier sslContextProviderSupplier) { + GrpcHttp2ConnectionHandler grpcHandler, + SslContextProviderSupplier sslContextProviderSupplier) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -356,9 +382,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { - ChannelHandler handler = - InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + ChannelHandler handler = InternalProtocolNegotiators.serverTls( + sslContextAndTm.getKey()).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler if (!ctx.isRemoved()) { @@ -372,8 +399,7 @@ public void updateSslContext(SslContext sslContext) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, null); } } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index a0c4ed37dfb..49cbab2ca03 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -32,7 +32,9 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.TrustManager; /** * A SslContextProvider is a "container" or provider of SslContext. This is used by gRPC-xds to @@ -57,7 +59,8 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -111,7 +114,7 @@ public UpstreamTlsContext getUpstreamTlsContext() { public abstract void addCallback(Callback callback); protected final void performCallback( - final SslContextGetter sslContextGetter, final Callback callback) { + final SslContextGetter sslContextGetter, final Callback callback) { checkNotNull(sslContextGetter, "sslContextGetter"); checkNotNull(callback, "callback"); callback.executor.execute( @@ -119,8 +122,9 @@ protected final void performCallback( @Override public void run() { try { - SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext); + AbstractMap.SimpleImmutableEntry sslContextAndTm = + sslContextGetter.get(); + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); } catch (Throwable e) { callback.onException(e); } @@ -130,6 +134,6 @@ public void run() { /** Allows implementations to compute or get SslContext. */ protected interface SslContextGetter { - SslContext get() throws Exception; + AbstractMap.SimpleImmutableEntry get() throws Exception; } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 5f629273179..8819874ad00 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -25,7 +25,11 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.AbstractMap; +import java.util.HashSet; import java.util.Objects; +import java.util.Set; +import javax.net.ssl.TrustManager; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -38,6 +42,7 @@ public final class SslContextProviderSupplier implements Closeable { private final BaseTlsContext tlsContext; private final TlsContextManager tlsContextManager; + private final Set snisSentByClients = new HashSet<>(); private SslContextProvider sslContextProvider; private boolean shutdown; @@ -52,31 +57,33 @@ public BaseTlsContext getTlsContext() { } /** Updates SslContext via the passed callback. */ - public synchronized void updateSslContext(final SslContextProvider.Callback callback) { + public synchronized void updateSslContext( + final SslContextProvider.Callback callback, String sni) { checkNotNull(callback, "callback"); try { if (!shutdown) { if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + sslContextProvider = getSslContextProvider(sni); } } // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); + final SslContextProvider toRelease = getSslContextProvider(sni); toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { - @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); - releaseSslContextProvider(toRelease); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease); - } - }); + @Override + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); + releaseSslContextProvider(toRelease, sni); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease, sni); + } + }); } catch (final Throwable throwable) { callback.getExecutor().execute(new Runnable() { @Override @@ -87,18 +94,23 @@ public void run() { } } - private void releaseSslContextProvider(SslContextProvider toRelease) { + private void releaseSslContextProvider(SslContextProvider toRelease, String sni) { if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(toRelease); + tlsContextManager.releaseClientSslContextProvider(toRelease, sni); + snisSentByClients.remove(sni); } else { tlsContextManager.releaseServerSslContextProvider(toRelease); } } - private SslContextProvider getSslContextProvider() { - return tlsContext instanceof UpstreamTlsContext - ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) - : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); + private SslContextProvider getSslContextProvider(String sni) { + if (tlsContext instanceof UpstreamTlsContext) { + snisSentByClients.add(sni); + return tlsContextManager.findOrCreateClientSslContextProvider( + (UpstreamTlsContext) tlsContext, sni); + } + return tlsContextManager.findOrCreateServerSslContextProvider( + (DownstreamTlsContext) tlsContext); } @VisibleForTesting public boolean isShutdown() { @@ -110,7 +122,9 @@ private SslContextProvider getSslContextProvider() { public synchronized void close() { if (sslContextProvider != null) { if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(sslContextProvider); + for (String sni: snisSentByClients) { + tlsContextManager.releaseClientSslContextProvider(sslContextProvider, sni); + } } else { tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java index 34a8863c52b..a36eeb7a727 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java @@ -25,6 +25,7 @@ import io.grpc.xds.TlsContextManager; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; +import java.util.AbstractMap; /** * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by @@ -34,7 +35,8 @@ */ public final class TlsContextManagerImpl implements TlsContextManager { - private final ReferenceCountingMap mapForClients; + private final ReferenceCountingMap, + SslContextProvider> mapForClients; private final ReferenceCountingMap mapForServers; /** @@ -48,7 +50,8 @@ public final class TlsContextManagerImpl implements TlsContextManager { @VisibleForTesting TlsContextManagerImpl( - ValueFactory clientFactory, + ValueFactory, + SslContextProvider> clientFactory, ValueFactory serverFactory) { checkNotNull(clientFactory, "clientFactory"); checkNotNull(serverFactory, "serverFactory"); @@ -69,18 +72,18 @@ public SslContextProvider findOrCreateServerSslContextProvider( @Override public SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { + UpstreamTlsContext upstreamTlsContext, String sni) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); - upstreamTlsContext = new UpstreamTlsContext(builder.build()); - return mapForClients.get(upstreamTlsContext); + return mapForClients.get(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, sni)); } @Override public SslContextProvider releaseClientSslContextProvider( - SslContextProvider clientSslContextProvider) { + SslContextProvider clientSslContextProvider, String sni) { checkNotNull(clientSslContextProvider, "clientSslContextProvider"); - return mapForClients.release(clientSslContextProvider.getUpstreamTlsContext(), + return mapForClients.release( + new AbstractMap.SimpleImmutableEntry<>( + clientSslContextProvider.getUpstreamTlsContext(), sni), clientSslContextProvider); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d7c2267c48f..cff29389cf3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -26,8 +26,11 @@ import io.netty.handler.ssl.SslContextBuilder; import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; +import java.util.Arrays; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @@ -39,6 +42,7 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP CommonTlsContext.CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, UpstreamTlsContext upstreamTlsContext, + String sniForSanMatching, CertificateProviderStore certificateProviderStore) { super( node, @@ -47,32 +51,47 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP rootCertInstance, staticCertValidationContext, upstreamTlsContext, - certificateProviderStore); + certificateProviderStore, + upstreamTlsContext.getAutoSniSanValidation() ? sniForSanMatching : null); } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException { + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndExtendedX509TrustManager( + CertificateValidationContext certificateValidationContext) + throws CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); - // Null rootCertInstance implies hasSystemRootCerts because of the check in - // CertProviderClientSslContextProviderFactory. - if (rootCertInstance != null) { - if (savedSpiffeTrustMap != null) { - sslContextBuilder = sslContextBuilder.trustManager( + if (savedSpiffeTrustMap != null) { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, sniForSanMatching)); + } else if (savedTrustedRoots != null) { + sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( - savedSpiffeTrustMap, - certificateValidationContextdationContext)); - } else { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); - } + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, sniForSanMatching)); + } else { + // Should be impossible because of the check in CertProviderClientSslContextProviderFactory + throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); + } + XdsTrustManagerFactory trustManagerFactory; + if (savedSpiffeTrustMap != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); + } else { + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); } if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } - return sslContextBuilder; + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, + io.grpc.internal.CertificateUtils.getX509ExtendedTrustManager( + Arrays.asList(trustManagerFactory.getTrustManagers()))); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java index 6205c1c3a63..c65fcc59a45 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java @@ -55,7 +55,7 @@ public static CertProviderClientSslContextProviderFactory getInstance() { */ public SslContextProvider getProvider( UpstreamTlsContext upstreamTlsContext, - Node node, + String sni, Node node, @Nullable Map certProviders) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); @@ -74,6 +74,7 @@ public SslContextProvider getProvider( rootCertInstance, staticCertValidationContext, upstreamTlsContext, + sni, certificateProviderStore); } throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index ef65bbfb6f9..a6140133f38 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -30,8 +30,10 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { @@ -51,27 +53,30 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP rootCertInstance, staticCertValidationContext, downstreamTlsContext, - certificateProviderStore); + certificateProviderStore, + null); } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException, CertificateException, IOException { + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndExtendedX509TrustManager( + CertificateValidationContext certificateValidationContextdationContext) + throws CertStoreException, CertificateException, IOException { SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(savedKey, savedCertChain); XdsTrustManagerFactory trustManagerFactory = null; if (isMtls() && savedSpiffeTrustMap != null) { trustManagerFactory = new XdsTrustManagerFactory( savedSpiffeTrustMap, - certificateValidationContextdationContext); + certificateValidationContextdationContext, null); } else if (isMtls()) { trustManagerFactory = new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext); + certificateValidationContextdationContext, null); } setClientAuthValues(sslContextBuilder, trustManagerFactory); sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); - return sslContextBuilder; + // TrustManager in the below return value is not used on the server side, so setting it to null + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, null); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 801dabeecb7..883fda4cd4e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -24,6 +24,7 @@ import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.DynamicSslContextProvider; +import java.io.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; @@ -34,8 +35,8 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements CertificateProvider.Watcher { - @Nullable private final CertificateProviderStore.Handle certHandle; - @Nullable private final CertificateProviderStore.Handle rootCertHandle; + @Nullable private final NoExceptionCloseable certHandle; + @Nullable private final NoExceptionCloseable rootCertHandle; @Nullable private final CertificateProviderInstance certInstance; @Nullable protected final CertificateProviderInstance rootCertInstance; @Nullable protected PrivateKey savedKey; @@ -43,6 +44,13 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider @Nullable protected List savedTrustedRoots; @Nullable protected Map> savedSpiffeTrustMap; private final boolean isUsingSystemRootCerts; + // Logically this field belongs in the client cert provider subclass but it had to be kept here + // because of the whole lot of things that happen in this class's constructor when the client + // cert provider is created, and + // CertProviderClientSslContextProvider::getSslContextBuilderAndExtendedX509TrustManager gets + // called before the constructor is complete, and this field needs to be set in order to be + // passed on to the TrustManager constructor. + protected final String sniForSanMatching; protected CertProviderSslContextProvider( Node node, @@ -51,28 +59,38 @@ protected CertProviderSslContextProvider( CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, BaseTlsContext tlsContext, - CertificateProviderStore certificateProviderStore) { + CertificateProviderStore certificateProviderStore, String sniForSanMatching) { super(tlsContext, staticCertValidationContext); this.certInstance = certInstance; this.rootCertInstance = rootCertInstance; - String certInstanceName = null; - if (certInstance != null && certInstance.isInitialized()) { - certInstanceName = certInstance.getInstanceName(); + this.sniForSanMatching = sniForSanMatching; + this.isUsingSystemRootCerts = rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); + boolean createCertInstance = certInstance != null && certInstance.isInitialized(); + boolean createRootCertInstance = rootCertInstance != null && rootCertInstance.isInitialized(); + boolean sharedCertInstance = createCertInstance && createRootCertInstance + && rootCertInstance.getInstanceName().equals(certInstance.getInstanceName()); + if (createCertInstance) { CertificateProviderInfo certProviderInstanceConfig = - getCertProviderConfig(certProviders, certInstanceName); + getCertProviderConfig(certProviders, certInstance.getInstanceName()); + CertificateProvider.Watcher watcher = this; + if (!sharedCertInstance && !isUsingSystemRootCerts) { + watcher = new IgnoreUpdatesWatcher(watcher, /* ignoreRootCertUpdates= */ true); + } + // TODO: Previously we'd hang if certProviderInstanceConfig were null or + // certInstance.isInitialized() == false. Now we'll proceed. Those should be errors, or are + // they impossible and should be assertions? certHandle = certProviderInstanceConfig == null ? null : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + watcher, + true)::close; } else { certHandle = null; } - if (rootCertInstance != null - && rootCertInstance.isInitialized() - && !rootCertInstance.getInstanceName().equals(certInstanceName)) { + if (createRootCertInstance && !sharedCertInstance) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); rootCertHandle = certProviderInstanceConfig == null ? null @@ -80,13 +98,16 @@ protected CertProviderSslContextProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + new IgnoreUpdatesWatcher(this, /* ignoreRootCertUpdates= */ false), + false)::close; + } else if (rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { + SystemRootCertificateProvider systemRootProvider = new SystemRootCertificateProvider(this); + systemRootProvider.start(); + rootCertHandle = systemRootProvider::close; } else { rootCertHandle = null; } - this.isUsingSystemRootCerts = rootCertInstance == null - && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); } private static CertificateProviderInfo getCertProviderConfig( @@ -150,17 +171,16 @@ public final void updateSpiffeTrustMap(Map> spiffe private void updateSslContextWhenReady() { if (isMtls()) { - if (savedKey != null - && (savedTrustedRoots != null || isUsingSystemRootCerts || savedSpiffeTrustMap != null)) { + if (savedKey != null && (savedTrustedRoots != null || savedSpiffeTrustMap != null)) { updateSslContext(); clearKeysAndCerts(); } - } else if (isClientSideTls()) { + } else if (isRegularTlsAndClientSide()) { if (savedTrustedRoots != null || savedSpiffeTrustMap != null) { updateSslContext(); clearKeysAndCerts(); } - } else if (isServerSideTls()) { + } else if (isRegularTlsAndServerSide()) { if (savedKey != null) { updateSslContext(); clearKeysAndCerts(); @@ -170,7 +190,9 @@ private void updateSslContextWhenReady() { private void clearKeysAndCerts() { savedKey = null; - savedTrustedRoots = null; + if (!isUsingSystemRootCerts) { + savedTrustedRoots = null; + } savedSpiffeTrustMap = null; savedCertChain = null; } @@ -179,11 +201,11 @@ protected final boolean isMtls() { return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } - protected final boolean isClientSideTls() { - return rootCertInstance != null && certInstance == null; + protected final boolean isRegularTlsAndClientSide() { + return (rootCertInstance != null || isUsingSystemRootCerts) && certInstance == null; } - protected final boolean isServerSideTls() { + protected final boolean isRegularTlsAndServerSide() { return certInstance != null && rootCertInstance == null; } @@ -201,4 +223,9 @@ public final void close() { rootCertHandle.close(); } } + + interface NoExceptionCloseable extends Closeable { + @Override + void close(); + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java new file mode 100644 index 00000000000..cd9d88be41b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Status; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; + +public final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { + private final CertificateProvider.Watcher delegate; + private final boolean ignoreRootCertUpdates; + + public IgnoreUpdatesWatcher( + CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { + this.delegate = requireNonNull(delegate, "delegate"); + this.ignoreRootCertUpdates = ignoreRootCertUpdates; + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + if (ignoreRootCertUpdates) { + delegate.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + if (!ignoreRootCertUpdates) { + delegate.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + if (!ignoreRootCertUpdates) { + delegate.updateSpiffeTrustMap(spiffeTrustMap); + } + } + + @Override + public void onError(Status errorStatus) { + delegate.onError(errorStatus); + } + + @VisibleForTesting + public CertificateProvider.Watcher getDelegate() { + return delegate; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java new file mode 100644 index 00000000000..7c60f714e71 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import io.grpc.Status; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** + * An non-registered provider for CertProviderSslContextProvider to use the same code path for + * system root certs as provider-obtained certs. + */ +final class SystemRootCertificateProvider extends CertificateProvider { + public SystemRootCertificateProvider(CertificateProvider.Watcher watcher) { + super(new DistributorWatcher(), false); + getWatcher().addWatcher(watcher); + } + + @Override + public void start() { + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + + List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); + List rootCerts = trustManagers.stream() + .filter(X509TrustManager.class::isInstance) + .map(X509TrustManager.class::cast) + .map(trustManager -> Arrays.asList(trustManager.getAcceptedIssuers())) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + getWatcher().updateTrustedRoots(rootCerts); + } catch (KeyStoreException | NoSuchAlgorithmException ex) { + getWatcher().onError(Status.UNAVAILABLE + .withDescription("Could not load system root certs") + .withCause(ex)); + } + } + + @Override + public void close() { + // Unnecessary because there's no more callbacks, but do it for good measure + for (Watcher watcher : getWatcher().getDownstreamWatchers()) { + getWatcher().removeWatcher(watcher); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java index 86b6dd95c3e..89b4abd3029 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java @@ -16,6 +16,7 @@ package io.grpc.xds.internal.security.trust; +import io.grpc.internal.GrpcUtil; import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; @@ -29,6 +30,8 @@ * Contains certificate utility method(s). */ public final class CertificateUtils { + public static boolean isXdsSniEnabled = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SNI", false); + /** * Generates X509Certificate array from a file on disk. * diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index 8cb44117065..eea47e1ef9f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -58,43 +58,50 @@ public XdsTrustManagerFactory(CertificateValidationContext certificateValidation this( getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext, - false); + false, + null); } public XdsTrustManagerFactory( - X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext) - throws CertStoreException { - this(certs, staticCertificateValidationContext, true); + X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext, + String sniForSanMatching) throws CertStoreException { + this(certs, staticCertificateValidationContext, true, sniForSanMatching); } public XdsTrustManagerFactory(Map> spiffeTrustMap, - CertificateValidationContext staticCertificateValidationContext) throws CertStoreException { - this(spiffeTrustMap, staticCertificateValidationContext, true); + CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) + throws CertStoreException { + this(spiffeTrustMap, staticCertificateValidationContext, true, sniForSanMatching); } private XdsTrustManagerFactory( X509Certificate[] certs, CertificateValidationContext certificateValidationContext, - boolean validationContextIsStatic) + boolean validationContextIsStatic, + String sniForSanMatching) throws CertStoreException { if (validationContextIsStatic) { checkArgument( - certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), + certificateValidationContext == null || !certificateValidationContext.hasTrustedCa() + || certificateValidationContext.hasSystemRootCerts(), "only static certificateValidationContext expected"); } - xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext); + xdsX509TrustManager = createX509TrustManager( + certs, certificateValidationContext, sniForSanMatching); } private XdsTrustManagerFactory( Map> spiffeTrustMap, CertificateValidationContext certificateValidationContext, - boolean validationContextIsStatic) + boolean validationContextIsStatic, + String sniForSanMatching) throws CertStoreException { if (validationContextIsStatic) { checkArgument( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); - xdsX509TrustManager = createX509TrustManager(spiffeTrustMap, certificateValidationContext); + xdsX509TrustManager = createX509TrustManager( + spiffeTrustMap, certificateValidationContext, sniForSanMatching); } } @@ -121,21 +128,23 @@ private static X509Certificate[] getTrustedCaFromCertContext( @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException { - return new XdsX509TrustManager(certContext, createTrustManager(certs)); + X509Certificate[] certs, CertificateValidationContext certContext, String sniForSanMatching) + throws CertStoreException { + return new XdsX509TrustManager(certContext, createTrustManager(certs), sniForSanMatching); } @VisibleForTesting static XdsX509TrustManager createX509TrustManager( Map> spiffeTrustMapFile, - CertificateValidationContext certContext) throws CertStoreException { + CertificateValidationContext certContext, String sniForSanMatching) + throws CertStoreException { checkNotNull(spiffeTrustMapFile, "spiffeTrustMapFile"); Map delegates = new HashMap<>(); for (Map.Entry> entry:spiffeTrustMapFile.entrySet()) { delegates.put(entry.getKey(), createTrustManager( entry.getValue().toArray(new X509Certificate[0]))); } - return new XdsX509TrustManager(certContext, delegates); + return new XdsX509TrustManager(certContext, delegates, sniForSanMatching); } private static X509ExtendedTrustManager createTrustManager(X509Certificate[] certs) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 1ecfe378d29..371aca62414 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.re2j.Pattern; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -60,21 +61,30 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 private final X509ExtendedTrustManager delegate; private final Map spiffeTrustMapDelegates; private final CertificateValidationContext certContext; + private final String sniForSanMatching; XdsX509TrustManager(@Nullable CertificateValidationContext certContext, X509ExtendedTrustManager delegate) { + this(certContext, delegate, null); + } + + XdsX509TrustManager(@Nullable CertificateValidationContext certContext, + X509ExtendedTrustManager delegate, @Nullable String sniForSanMatching) { checkNotNull(delegate, "delegate"); this.certContext = certContext; this.delegate = delegate; this.spiffeTrustMapDelegates = null; + this.sniForSanMatching = sniForSanMatching; } XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - Map spiffeTrustMapDelegates) { + Map spiffeTrustMapDelegates, + @Nullable String sniForSanMatching) { checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); this.certContext = certContext; this.delegate = null; + this.sniForSanMatching = sniForSanMatching; } private static boolean verifyDnsNameInPattern( @@ -208,7 +218,10 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi return; } @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names - List verifyList = certContext.getMatchSubjectAltNamesList(); + List verifyList = + CertificateUtils.isXdsSniEnabled && !Strings.isNullOrEmpty(sniForSanMatching) + ? ImmutableList.of(StringMatcher.newBuilder().setExact(sniForSanMatching).build()) + : certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { return; } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 2059022c203..0d2e198c494 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -117,7 +117,7 @@ public class CdsLoadBalancer2Test { private static final String EDS_SERVICE_NAME = "backend-service-1.googleapis.com"; private static final String NODE_ID = "node-id"; private final io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-name", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-name", true, null, false); private static final Cluster EDS_CLUSTER = Cluster.newBuilder() .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) @@ -158,7 +158,7 @@ public void setUp() throws Exception { lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental", new LeastRequestLoadBalancerProvider())); lbRegistry.register(new FakeLoadBalancerProvider("wrr_locality_experimental", - new WrrLocalityLoadBalancerProvider())); + new WrrLocalityLoadBalancerProvider())); CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); lbRegistry.register(cdsLoadBalancerProvider); loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); @@ -269,24 +269,24 @@ public void discoverTopLevelLogicalDnsCluster() { .setName(CLUSTER) .setType(Cluster.DiscoveryType.LOGICAL_DNS) .setLoadAssignment(ClusterLoadAssignment.newBuilder() - .addEndpoints(LocalityLbEndpoints.newBuilder() - .addLbEndpoints(LbEndpoint.newBuilder() - .setEndpoint(Endpoint.newBuilder() - .setAddress(Address.newBuilder() - .setSocketAddress(SocketAddress.newBuilder() - .setAddress("dns.example.com") - .setPortValue(1111))))))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .setLbPolicy(Cluster.LbPolicy.LEAST_REQUEST) .setLrsServer(ConfigSource.newBuilder() - .setSelf(SelfConfigSource.getDefaultInstance())) + .setSelf(SelfConfigSource.getDefaultInstance())) .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .setTransportSocket(TransportSocket.newBuilder() .setName("envoy.transport_sockets.tls") .setTypedConfig(Any.pack(UpstreamTlsContext.newBuilder() @@ -303,8 +303,8 @@ public void discoverTopLevelLogicalDnsCluster() { ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( DiscoveryMechanism.forLogicalDns( - CLUSTER, "dns.example.com:1111", lrsServerInfo, 100L, upstreamTlsContext, - Collections.emptyMap())); + CLUSTER, "dns.example.com:1111", lrsServerInfo, 100L, upstreamTlsContext, + Collections.emptyMap())); assertThat( GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) .isEqualTo("wrr_locality_experimental"); @@ -326,8 +326,8 @@ public void nonAggregateCluster_resourceUpdate() { Cluster cluster = EDS_CLUSTER.toBuilder() .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .build(); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); startXdsDepManager(); @@ -337,14 +337,14 @@ public void nonAggregateCluster_resourceUpdate() { FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( - DiscoveryMechanism.forEds( + DiscoveryMechanism.forEds( CLUSTER, EDS_SERVICE_NAME, null, 100L, null, Collections.emptyMap(), null)); cluster = EDS_CLUSTER.toBuilder() .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(200)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(200)))) .build(); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); @@ -352,7 +352,7 @@ public void nonAggregateCluster_resourceUpdate() { childBalancer = Iterables.getOnlyElement(childBalancers); childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( - DiscoveryMechanism.forEds( + DiscoveryMechanism.forEds( CLUSTER, EDS_SERVICE_NAME, null, 200L, null, Collections.emptyMap(), null)); } @@ -366,7 +366,7 @@ public void nonAggregateCluster_resourceRevoked() { FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( - DiscoveryMechanism.forEds( + DiscoveryMechanism.forEds( CLUSTER, EDS_SERVICE_NAME, null, null, null, Collections.emptyMap(), null)); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); @@ -397,7 +397,7 @@ public void dynamicCluster() { FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( - DiscoveryMechanism.forEds( + DiscoveryMechanism.forEds( clusterName, EDS_SERVICE_NAME, null, null, null, Collections.emptyMap(), null)); assertThat(this.lastXdsConfig.getClusters()).containsKey(clusterName); @@ -419,44 +419,44 @@ public void discoverAggregateCluster_createsPriorityLbPolicy() { controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS), cluster3 (EDS)] CLUSTER, Cluster.newBuilder() - .setName(CLUSTER) - .setClusterType(Cluster.CustomClusterType.newBuilder() - .setName("envoy.clusters.aggregate") - .setTypedConfig(Any.pack(ClusterConfig.newBuilder() - .addClusters(cluster1) - .addClusters(cluster2) - .addClusters(cluster3) - .build()))) - .setLbPolicy(Cluster.LbPolicy.RING_HASH) - .build(), + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .addClusters(cluster3) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build(), // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] cluster1, Cluster.newBuilder() - .setName(cluster1) - .setClusterType(Cluster.CustomClusterType.newBuilder() - .setName("envoy.clusters.aggregate") - .setTypedConfig(Any.pack(ClusterConfig.newBuilder() - .addClusters(cluster3) - .addClusters(cluster4) - .build()))) - .build(), + .setName(cluster1) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster3) + .addClusters(cluster4) + .build()))) + .build(), cluster2, Cluster.newBuilder() - .setName(cluster2) - .setType(Cluster.DiscoveryType.LOGICAL_DNS) - .setLoadAssignment(ClusterLoadAssignment.newBuilder() - .addEndpoints(LocalityLbEndpoints.newBuilder() - .addLbEndpoints(LbEndpoint.newBuilder() - .setEndpoint(Endpoint.newBuilder() - .setAddress(Address.newBuilder() - .setSocketAddress(SocketAddress.newBuilder() - .setAddress("dns.example.com") - .setPortValue(1111))))))) - .build(), + .setName(cluster2) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(), cluster3, EDS_CLUSTER.toBuilder() .setName(cluster3) .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .build(), cluster4, EDS_CLUSTER.toBuilder().setName(cluster4).build())); startXdsDepManager(); @@ -466,14 +466,14 @@ public void discoverAggregateCluster_createsPriorityLbPolicy() { FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); PriorityLoadBalancerProvider.PriorityLbConfig childLbConfig = - (PriorityLoadBalancerProvider.PriorityLbConfig) childBalancer.config; + (PriorityLoadBalancerProvider.PriorityLbConfig) childBalancer.config; assertThat(childLbConfig.priorities).hasSize(3); assertThat(childLbConfig.priorities.get(0)).isEqualTo(cluster3); assertThat(childLbConfig.priorities.get(1)).isEqualTo(cluster4); assertThat(childLbConfig.priorities.get(2)).isEqualTo(cluster2); assertThat(childLbConfig.childConfigs).hasSize(3); PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig3 = - childLbConfig.childConfigs.get(cluster3); + childLbConfig.childConfigs.get(cluster3); assertThat( GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig3.childConfig) .getPolicyName()) @@ -485,7 +485,7 @@ public void discoverAggregateCluster_createsPriorityLbPolicy() { .getPolicyName()) .isEqualTo("cds_experimental"); PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig2 = - childLbConfig.childConfigs.get(cluster2); + childLbConfig.childConfigs.get(cluster2); assertThat( GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig2.childConfig) .getPolicyName()) @@ -542,12 +542,12 @@ public void aggregateCluster_noChildren() { controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( // CLUSTER (aggr.) -> [] CLUSTER, Cluster.newBuilder() - .setName(CLUSTER) - .setClusterType(Cluster.CustomClusterType.newBuilder() - .setName("envoy.clusters.aggregate") - .setTypedConfig(Any.pack(ClusterConfig.newBuilder() - .build()))) - .build())); + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .build()))) + .build())); startXdsDepManager(); verify(helper) @@ -571,14 +571,14 @@ public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( // CLUSTER (aggr.) -> [cluster1 (missing)] CLUSTER, Cluster.newBuilder() - .setName(CLUSTER) - .setClusterType(Cluster.CustomClusterType.newBuilder() - .setName("envoy.clusters.aggregate") - .setTypedConfig(Any.pack(ClusterConfig.newBuilder() - .addClusters(cluster1) - .build()))) - .setLbPolicy(Cluster.LbPolicy.RING_HASH) - .build())); + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build())); startXdsDepManager(); verify(helper).updateBalancingState( @@ -604,9 +604,9 @@ public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThroug .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .build(); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); startXdsDepManager(); @@ -627,16 +627,16 @@ public void unknownLbProvider() { .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(Policy.newBuilder() - .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() - .setTypedConfig(Any.pack(TypedStruct.newBuilder() - .setTypeUrl("type.googleapis.com/unknownLb") - .setValue(Struct.getDefaultInstance()) - .build()))))) + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/unknownLb") + .setValue(Struct.getDefaultInstance()) + .build()))))) .build(); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); startXdsDepManager(); @@ -654,17 +654,17 @@ public void invalidLbConfig() { .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() .addPolicies(Policy.newBuilder() - .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() - .setTypedConfig(Any.pack(TypedStruct.newBuilder() - .setTypeUrl("type.googleapis.com/ring_hash_experimental") - .setValue(Struct.newBuilder() - .putFields("minRingSize", Value.newBuilder().setNumberValue(-1).build())) - .build()))))) + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/ring_hash_experimental") + .setValue(Struct.newBuilder() + .putFields("minRingSize", Value.newBuilder().setNumberValue(-1).build())) + .build()))))) .build(); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); startXdsDepManager(); @@ -693,9 +693,9 @@ private void startXdsDepManager(final CdsConfig cdsConfig) { loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder() .setAddresses(Collections.emptyList()) .setAttributes(Attributes.newBuilder() - .set(XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) - .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) - .build()) + .set(XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) + .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) + .build()) .setLoadBalancingPolicyConfig(cdsConfig) .build()); }); @@ -782,4 +782,4 @@ public void shutdown() { childBalancers.remove(this); } } -} +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index c5e3f80f170..7b3d355b79a 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -811,10 +811,10 @@ public void endpointAddressesAttachedWithClusterName() { new FixedResultPicker(PickResult.withSubchannel(subchannel))); } }); - assertThat(subchannel.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo( + assertThat(subchannel.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isEqualTo( "authority-host-name"); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)) + assertThat(eag.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)) .isEqualTo("authority-host-name"); } @@ -863,9 +863,9 @@ public void endpointAddressesAttachedWithClusterName() { } }); // Sub Channel wrapper args won't have the address name although addresses will. - assertThat(subchannel.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)).isNull(); + assertThat(subchannel.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isNull(); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)) + assertThat(eag.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)) .isEqualTo("authority-host-name"); } @@ -881,7 +881,7 @@ public void endpointAddressesAttachedWithClusterName() { @Test public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); WeightedTargetConfig weightedTargetConfig = buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); @@ -926,7 +926,7 @@ public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { // Config with a new UpstreamTlsContext. upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true, null, false); config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( @@ -1019,7 +1019,7 @@ public String toString() { // Unique but arbitrary string .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, locality.toString()); if (authorityHostname != null) { - attributes.set(XdsAttributes.ATTR_ADDRESS_NAME, authorityHostname); + attributes.set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, authorityHostname); } EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), attributes.build()); @@ -1259,7 +1259,7 @@ public Map getServerLrsClientMap() { private static final class FakeTlsContextManager implements TlsContextManager { @Override public SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { + UpstreamTlsContext upstreamTlsContext, String sni) { SslContextProvider sslContextProvider = mock(SslContextProvider.class); when(sslContextProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext); return sslContextProvider; @@ -1267,7 +1267,7 @@ public SslContextProvider findOrCreateClientSslContextProvider( @Override public SslContextProvider releaseClientSslContextProvider( - SslContextProvider sslContextProvider) { + SslContextProvider sslContextProvider, String sni) { // no-op return null; } diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index be68018792b..8b2b43956f8 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -408,7 +408,7 @@ public void edsClustersEndpointHostname_addedToAddressAttribute() { assertThat( childBalancer.addresses.get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); + .get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); } @Test @@ -897,7 +897,7 @@ public void onlyLogicalDnsCluster_endpointsResolved() { newInetSocketAddress("127.0.2.1", 9000), newInetSocketAddress("127.0.2.2", 9000)))), childBalancer.addresses); assertThat(childBalancer.addresses.get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); + .get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); } @Test @@ -995,7 +995,8 @@ public void config_equalsTester() { ServerInfo lrsServerInfo = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); UpstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe", true, null, false); DiscoveryMechanism edsDiscoveryMechanism1 = DiscoveryMechanism.forEds(CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, tlsContext, Collections.emptyMap(), null); diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index ff97afe6916..1a491ae6e14 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -301,7 +301,7 @@ public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { private void callUpdateSslContext(SslContextProviderSupplier sslContextProviderSupplier) { assertThat(sslContextProviderSupplier).isNotNull(); SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); - sslContextProviderSupplier.updateSslContext(callback); + sslContextProviderSupplier.updateSslContext(callback, null); } private void sendListenerUpdate( diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 23068d665bf..25fdfb3665f 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -37,6 +37,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -75,6 +76,7 @@ import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.internal.security.TlsContextManagerImpl; import io.grpc.xds.internal.security.certprovider.FileWatcherCertificateProviderProvider; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.handler.ssl.NotSslRecordException; import java.io.File; import java.io.FileOutputStream; @@ -83,7 +85,6 @@ import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.URI; -import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; import java.security.KeyStore; @@ -117,6 +118,8 @@ @RunWith(Parameterized.class) public class XdsSecurityClientServerTest { + private static final String SNI_IN_UTC = "waterzooi.test.google.be"; + @Parameter public Boolean enableSpiffe; private Boolean originalEnableSpiffe; @@ -206,7 +209,8 @@ public void tlsClientServer_noClientAuthentication() throws Exception { * Uses common_tls_context.combined_validation_context in upstream_tls_context. */ @Test - public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() throws Exception { + public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationContext() + throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -217,7 +221,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -233,7 +237,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th * Uses common_tls_context.validation_context in upstream_tls_context. */ @Test - public void tlsClientServer_useSystemRootCerts_validationContext() throws Exception { + public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -244,7 +248,7 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_PEM_FILE, false, SNI_IN_UTC, false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -255,9 +259,157 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except } } + @Test + public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SNI_IN_UTC, true, null, false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, "server1.test.google.in", false, null, false, false); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + fail("Expected handshake failure exception"); + } catch (StatusRuntimeException e) { + assertThat(e.getCause()).isInstanceOf(SSLHandshakeException.class); + assertThat(e.getCause().getCause()).isInstanceOf(CertificateException.class); + assertThat(e.getCause().getCause().getMessage()).isEqualTo( + "Peer certificate SAN check failed"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_autoSniValidation_sniInUtc() + throws Exception { + CertificateUtils.isXdsSniEnabled = true; + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + SNI_IN_UTC, + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void tlsClientServer_autoSniValidation_sniFromHostname() + throws Exception { + CertificateUtils.isXdsSniEnabled = true; + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + "", + true, true); + + // TODO: Change this to foo.test.gooogle.fr that needs wildcard matching after + // https://github.com/grpc/grpc-java/pull/12345 is done + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY, + "waterzooi.test.google.be"); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void tlsClientServer_autoSniValidation_noSNIApplicable_usesMatcherFromCmnVdnCtx() + throws Exception { + CertificateUtils.isXdsSniEnabled = true; + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // This is what will get used for the SAN validation since no SNI was used + "waterzooi.test.google.be", + false, + "", + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + CertificateUtils.isXdsSniEnabled = false; + } + } + /** * Use system root ca cert for TLS channel - mTLS. - * Uses common_tls_context.combined_validation_context in upstream_tls_context. */ @Test public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Exception { @@ -266,12 +418,12 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, - null, false, false); + null, false, true); buildServerWithTlsContext(downstreamTlsContext); UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -546,24 +698,33 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, spiffeFile); return CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert, null, false); } + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, - boolean useCombinedValidationContext) { + boolean useCombinedValidationContext, + String sanToMatch, + boolean isMtls, + String sniInUpstreamTlsContext, + boolean autoHostSni, boolean autoSniSanValidation) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); if (useCombinedValidationContext) { return CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - "google_cloud_private_spiffe-client", "ROOT", null, + isMtls ? "google_cloud_private_spiffe-client" : null, + isMtls ? "ROOT" : null, null, null, null, CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build()); + .addMatchSubjectAltNames( + StringMatcher.newBuilder() + .setExact(sanToMatch)) + .build(), sniInUpstreamTlsContext, autoHostSni, autoSniSanValidation); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, @@ -648,8 +809,18 @@ static EnvoyServerProtoData.Listener buildListener( } private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( - final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) - throws URISyntaxException { + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) { + return getBlockingStub(upstreamTlsContext, overrideAuthority, overrideAuthority); + } + + // Two separate parameters for overrideAuthority and addrAttribute is for the SAN SNI validation test + // tlsClientServer_useSystemRootCerts_sni_san_validation_from_hostname that uses hostname passed for SNI. + // foo.test.google.fr is used for virtual host matching via authority but it can't be used + // for SNI in this testcase because foo.test.google.fr needs wildcard matching to match against *.test.google.fr + // in the certificate SNI, which isn't implemented yet (https://github.com/grpc/grpc-java/pull/12345 implements it) + // so use an exact match SAN such as waterzooi.test.google.be for SNI for this testcase. + private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority, String addrNameAttribute) { ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder( "sectest://localhost:" + port, @@ -661,14 +832,16 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( InetSocketAddress socketAddress = new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); tlsContextManagerForClient = new TlsContextManagerImpl(bootstrapInfoForClient); - sslContextAttributes = - (upstreamTlsContext != null) - ? Attributes.newBuilder() - .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier( - upstreamTlsContext, tlsContextManagerForClient)) - .build() - : Attributes.EMPTY; + Attributes.Builder sslContextAttributesBuilder = (upstreamTlsContext != null) + ? Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier( + upstreamTlsContext, tlsContextManagerForClient)) + : Attributes.newBuilder(); + if (addrNameAttribute != null) { + sslContextAttributesBuilder.set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, addrNameAttribute); + } + sslContextAttributes = sslContextAttributesBuilder.build(); fakeNameResolverFactory.setServers( ImmutableList.of(new EquivalentAddressGroup(socketAddress, sslContextAttributes))); return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 397fe01e0f5..53b7d1df83a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -17,6 +17,7 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.addCertificateValidationContext; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -24,6 +25,7 @@ import com.google.common.collect.ImmutableSet; import io.envoyproxy.envoy.config.core.v3.DataSource; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; @@ -37,7 +39,9 @@ import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.IgnoreUpdatesWatcher; import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; +import java.util.AbstractMap; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -49,6 +53,8 @@ @RunWith(JUnit4.class) public class ClientSslContextProviderFactoryTest { + static final String SNI = "sni"; + CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory; @@ -74,20 +80,20 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); } @@ -98,28 +104,26 @@ public void bothPresent_expectCertProviderClientSslContextProvider() final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - "gcp_id", - "cert-default", - "gcp_id", - "root-default", - /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); - - CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName("gcp_id") + .setCertificateName("cert-default")); + builder = + addCertificateValidationContext( + builder, "gcp_id", "root-default", null); builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem"); - upstreamTlsContext = new UpstreamTlsContext(builder.build()); - + UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(builder.build()); + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -135,17 +139,17 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -169,17 +173,17 @@ public void createCertProviderClientSslContextProvider_withStaticContext() "gcp_id", "root-default", /* alpnProtocols= */ null, - staticCertValidationContext); + staticCertValidationContext, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = new ClientSslContextProviderFactory(bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -199,18 +203,18 @@ public void createCertProviderClientSslContextProvider_2providers() "file_provider", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -243,11 +247,11 @@ public void createNewCertProviderClientSslContextProvider_withSans() { new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -277,10 +281,10 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } static void createAndRegisterProviderProvider( @@ -310,11 +314,18 @@ public CertificateProvider answer(InvocationOnMock invocation) throws Throwable } static void verifyWatcher( - SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { + SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor, + boolean usesDelegateWatcher) { assertThat(watcherCaptor).isNotNull(); assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); - assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) - .isSameInstanceAs(sslContextProvider); + if (usesDelegateWatcher) { + assertThat(((IgnoreUpdatesWatcher) watcherCaptor.getDownstreamWatchers().iterator().next()) + .getDelegate()) + .isSameInstanceAs(sslContextProvider); + } else { + assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) + .isSameInstanceAs(sslContextProvider); + } } static CommonTlsContext.Builder addFilenames( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 48814dece1d..567ce01ecf6 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -36,10 +36,12 @@ import java.io.InputStream; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executor; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** Utility class for client and server ssl provider tests. */ public class CommonTlsContextTestsUtil { @@ -149,23 +151,31 @@ public static String getTempFileNameForResourcesFile(String resFile) throws IOEx * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - CommonTlsContext commonTlsContext) { - UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni, boolean autoSniSanValidation) { + UpstreamTlsContext.Builder upstreamTlsContext = + UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setAutoHostSni(autoHostSni) + .setAutoSniSanValidation(autoSniSanValidation); + if (sni != null) { + upstreamTlsContext.setSni(sni); + } return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( - upstreamTlsContext); + upstreamTlsContext.build()); } /** Helper method to build UpstreamTlsContext for multiple test classes. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - String commonInstanceName, boolean hasIdentityCert) { + String commonInstanceName, boolean hasIdentityCert, String sni, boolean autoHostSni) { return buildUpstreamTlsContextForCertProviderInstance( hasIdentityCert ? commonInstanceName : null, hasIdentityCert ? "default" : null, commonInstanceName, "ROOT", null, - null); + null, + sni, + autoHostSni, false); } /** Gets a cert from contents of a resource. */ @@ -224,7 +234,7 @@ private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( return builder.build(); } - private static CommonTlsContext.Builder addCertificateValidationContext( + public static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, @@ -271,12 +281,15 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( /** Helper method to build UpstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextForCertProviderInstance( - @Nullable String certInstanceName, - @Nullable String certName, - @Nullable String rootInstanceName, - @Nullable String rootCertName, - Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + String sni, + boolean autoHostSni, + boolean autoSniSanValidation) { return buildUpstreamTlsContext( buildCommonTlsContextForCertProviderInstance( certInstanceName, @@ -284,7 +297,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), + sni, autoHostSni, autoSniSanValidation); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -303,7 +317,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), null, false, false); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ @@ -347,14 +361,15 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( } /** Perform some simple checks on sslContext. */ - public static void doChecksOnSslContext(boolean server, SslContext sslContext, + public static void doChecksOnSslContext(boolean server, + AbstractMap.SimpleImmutableEntry sslContextAndTm, List expectedApnProtos) { if (server) { - assertThat(sslContext.isServer()).isTrue(); + assertThat(sslContextAndTm.getKey().isServer()).isTrue(); } else { - assertThat(sslContext.isClient()).isTrue(); + assertThat(sslContextAndTm.getKey().isClient()).isTrue(); } - List apnProtos = sslContext.applicationProtocolNegotiator().protocols(); + List apnProtos = sslContextAndTm.getKey().applicationProtocolNegotiator().protocols(); assertThat(apnProtos).isNotNull(); if (expectedApnProtos != null) { assertThat(apnProtos).isEqualTo(expectedApnProtos); @@ -380,7 +395,7 @@ public static TestCallback getValueThruCallback(SslContextProvider provider, Exe public static class TestCallback extends SslContextProvider.Callback { - public SslContext updatedSslContext; + public AbstractMap.SimpleImmutableEntry updatedSslContext; public Throwable updatedThrowable; public TestCallback(Executor executor) { @@ -388,7 +403,7 @@ public TestCallback(Executor executor) { } @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index a0139618f9f..a8bf66f0fc4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -28,9 +28,7 @@ import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -38,6 +36,7 @@ import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.EquivalentAddressGroup; import io.grpc.internal.FakeClock; import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -53,6 +52,7 @@ import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityHandler; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityProtocolNegotiator; import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -73,6 +73,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -82,10 +83,16 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import javax.net.ssl.TrustManager; + /** Unit tests for {@link SecurityProtocolNegotiators}. */ @RunWith(JUnit4.class) public class SecurityProtocolNegotiatorsTest { + private static final String HOSTNAME = "hostname"; + private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; + private static final String FAKE_AUTHORITY = "authority"; + private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -122,7 +129,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_noFallback_expectExceptio @Test public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build()); + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, false, false); ClientSecurityProtocolNegotiator pn = new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); @@ -141,6 +148,35 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); } + @Test + public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPassedToClientSecurityHandler() { + CertificateUtils.isXdsSniEnabled = true; + try { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true, false); + ClientSecurityProtocolNegotiator pn = + new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + ChannelLogger logger = mock(ChannelLogger.class); + doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); + when(mockHandler.getNegotiationLogger()).thenReturn(logger); + TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); + when(mockHandler.getEagAttributes()) + .thenReturn( + Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) + .build()); + ChannelHandler newHandler = pn.newHandler(mockHandler); + assertThat(newHandler).isNotNull(); + assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } + } + @Test public void clientSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -151,13 +187,13 @@ public void clientSecurityHandler_addLast() CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertNotNull(channelHandlerCtx); @@ -168,19 +204,20 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, FAKE_AUTHORITY); assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); @@ -194,6 +231,117 @@ protected void onException(Throwable throwable) { CommonCertProviderTestUtils.register0(); } + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { + CertificateUtils.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsEmpty_usesSniFromUpstreamTlsContext() { + CertificateUtils.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, ""); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsNull_usesSniFromUpstreamTlsContext() { + CertificateUtils.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, null); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTlsContext() { + CertificateUtils.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void sniFeatureNotEnabled_usesChannelAuthorityForSni() { + CertificateUtils.isXdsSniEnabled = false; + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, "", false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(FAKE_AUTHORITY); + } + @Test public void serverSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -245,19 +393,20 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, null); channel.runPendingTasks(); // need this for tasks to execute on eventLoop assertThat(executor.runDueTasks()).isEqualTo(1); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSecurityHandler.class); assertThat(channelHandlerCtx).isNull(); @@ -356,53 +505,59 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { @Test public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() throws InterruptedException, TimeoutException, ExecutionException { - FakeClock executor = new FakeClock(); - CommonCertProviderTestUtils.register(executor); - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); - - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); - - pipeline.addLast(clientSecurityHandler); - channelHandlerCtx = pipeline.context(clientSecurityHandler); - assertNotNull(channelHandlerCtx); // non-null since we just added it - - // kick off protocol negotiation. - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); - final SettableFuture future = SettableFuture.create(); - sslContextProviderSupplier - .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); - } - - @Override - protected void onException(Throwable throwable) { - future.set(throwable); - } - }); - executor.runDueTasks(); - channel.runPendingTasks(); // need this for tasks to execute on eventLoop - Object fromFuture = future.get(5, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); - channel.runPendingTasks(); - channelHandlerCtx = pipeline.context(clientSecurityHandler); - assertThat(channelHandlerCtx).isNull(); - Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; - - pipeline.fireUserEventTriggered(sslEvent); - channel.runPendingTasks(); // need this for tasks to execute on eventLoop - assertTrue(channel.isOpen()); - CommonCertProviderTestUtils.register0(); + CertificateUtils.isXdsSniEnabled = true; + try { + FakeClock executor = new FakeClock(); + CommonCertProviderTestUtils.register(executor); + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); + + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + pipeline.addLast(clientSecurityHandler); + channelHandlerCtx = pipeline.context(clientSecurityHandler); + assertNotNull(channelHandlerCtx); // non-null since we just added it + + // kick off protocol negotiation. + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + final SettableFuture future = SettableFuture.create(); + sslContextProviderSupplier + .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); + } + + @Override + protected void onException(Throwable throwable) { + future.set(throwable); + } + }, ""); + executor.runDueTasks(); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + Object fromFuture = future.get(5, TimeUnit.SECONDS); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); + channel.runPendingTasks(); + channelHandlerCtx = pipeline.context(clientSecurityHandler); + assertThat(channelHandlerCtx).isNull(); + Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; + + pipeline.fireUserEventTriggered(sslEvent); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + assertTrue(channel.isOpen()); + CommonCertProviderTestUtils.register0(); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } } @Test @@ -414,13 +569,13 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, true); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); @@ -458,7 +613,7 @@ static FakeGrpcHttp2ConnectionHandler newHandler() { @Override public String getAuthority() { - return "authority"; + return FAKE_AUTHORITY; } } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java index cf86b511f1f..7a5a6c00639 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java @@ -78,7 +78,7 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); @@ -117,7 +117,7 @@ public void bothPresent_expectCertProviderServerSslContextProvider() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -144,7 +144,7 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +179,7 @@ public void createCertProviderServerSslContextProvider_withStaticContext() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); } @Test @@ -210,8 +210,8 @@ public void createCertProviderServerSslContextProvider_2providers() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -249,7 +249,7 @@ public void createNewCertProviderServerSslContextProvider_withSans() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f476818297d..c8d5c2ef519 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -17,8 +17,9 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; @@ -26,10 +27,13 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.TrustManager; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -46,93 +50,167 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String SNI = "sni"; + @Mock private TlsContextManager mockTlsContextManager; + @Mock private Executor mockExecutor; private SslContextProviderSupplier supplier; private SslContextProvider mockSslContextProvider; private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; private SslContextProvider.Callback mockCallback; - private void prepareSupplier() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + private void prepareSupplier(boolean createUpstreamTlsContext) { + if (createUpstreamTlsContext) { + upstreamTlsContext = + buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); + } mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); - Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, SNI); } @Test public void get_updateSecret() { - prepareSupplier(); + prepareSupplier(true); callUpdateSslContext(); verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + (AbstractMap.SimpleImmutableEntry) + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)).updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, SNI); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); + } + + @Test + public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { + prepareSupplier(true); + callUpdateSslContext(); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + (AbstractMap.SimpleImmutableEntry) + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, SNI); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } @Test public void get_onException() { - prepareSupplier(); + prepareSupplier(true); callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + verify(mockSslContextProvider, times(1)) + .addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); Exception exception = new Exception("test"); capturedCallback.onException(exception); verify(mockCallback, times(1)).onException(eq(exception)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); + } + + @Test + public void systemRootCertsWithMtls_callbackExecutedFromProvider() { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + null, + "root-default", + null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build()); + prepareSupplier(false); + + callUpdateSslContext(); + + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + (AbstractMap.SimpleImmutableEntry) + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, SNI); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } @Test public void testClose() { - prepareSupplier(); + prepareSupplier(true); callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); - supplier.updateSslContext(mockCallback); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); + supplier.updateSslContext(mockCallback, SNI); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(any(SslContextProvider.class)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); } @Test public void testClose_nullSslContextProvider() { - prepareSupplier(); + prepareSupplier(true); doThrow(new NullPointerException()).when(mockTlsContextManager) - .releaseClientSslContextProvider(null); + .releaseClientSslContextProvider(null, SNI); supplier.close(); verify(mockTlsContextManager, never()) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); callUpdateSslContext(); verify(mockTlsContextManager, times(1)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 035096a3528..3db482690d2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -35,6 +35,7 @@ import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; +import java.util.AbstractMap; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -49,7 +50,9 @@ public class TlsContextManagerTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); - @Mock ValueFactory mockClientFactory; + private static final String SNI = "sni"; + + @Mock ValueFactory, SslContextProvider> mockClientFactory; @Mock ValueFactory mockServerFactory; @@ -79,15 +82,15 @@ public void createClientSslContextProvider() { CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false, null, false); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient); SslContextProvider clientSecretProvider = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider).isNotNull(); SslContextProvider clientSecretProvider1 = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider); } @@ -123,18 +126,18 @@ public void createClientSslContextProvider_differentInstance() { CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false, null, false); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient); SslContextProvider clientSecretProvider = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider).isNotNull(); UpstreamTlsContext upstreamTlsContext1 = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-2", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-2", true, null, false); SslContextProvider clientSecretProvider1 = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1, SNI); assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider); } @@ -161,18 +164,19 @@ public void createServerSslContextProvider_releaseInstance() { public void createClientSslContextProvider_releaseInstance() { UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); SslContextProvider mockProvider = mock(SslContextProvider.class); - when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider); + when(mockClientFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI))) + .thenReturn(mockProvider); SslContextProvider clientSecretProvider = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider); verify(mockProvider, never()).close(); when(mockProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext); - tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider); + tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider, SNI); verify(mockProvider, times(1)).close(); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index b0800458d66..78afee8ae86 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -51,6 +51,7 @@ /** Unit tests for {@link CertProviderClientSslContextProvider}. */ @RunWith(JUnit4.class) public class CertProviderClientSslContextProviderTest { + private static final String SNI = "sni"; private static final Logger logger = Logger.getLogger(CertProviderClientSslContextProviderTest.class.getName()); @@ -72,19 +73,33 @@ private CertProviderClientSslContextProvider getSslContextProvider( String rootInstanceName, Bootstrapper.BootstrapInfo bootstrapInfo, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { - EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - certInstanceName, - "cert-default", - rootInstanceName, - "root-default", - alpnProtocols, - staticCertValidationContext); + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + if (useSystemRootCerts) { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } else { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext, + null, false, false); + } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), + SNI, bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } @@ -106,7 +121,7 @@ private CertProviderClientSslContextProvider getNewSslContextProvider( return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), + SNI, bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } @@ -122,12 +137,12 @@ public void testProviderForClient_mtls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +150,11 @@ public void testProviderForClient_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,7 +183,88 @@ public void testProviderForClient_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + + @Test + public void testProviderForClient_systemRootCerts_regularTls() { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + null, + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContextAndExtendedX509TrustManager()); + + assertThat(watcherCaptor[0]).isNull(); + } + + @Test + public void testProviderForClient_systemRootCerts_mtls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); + + // now generate cert update + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e. different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -190,7 +286,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -198,11 +294,11 @@ public void testProviderForClient_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -231,7 +327,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -248,7 +344,7 @@ public void testProviderForClient_queueExecutor() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); QueuedExecutor queuedExecutor = new QueuedExecutor(); TestCallback testCallback = @@ -281,16 +377,16 @@ public void testProviderForClient_tls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -318,7 +414,7 @@ public void testProviderForClient_sslContextException_onError() throws Exception "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */null, - staticCertValidationContext); + staticCertValidationContext, false); TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor()); provider.addCallback(testCallback); @@ -350,7 +446,7 @@ public void testProviderForClient_rootInstanceNull_and_notUsingSystemRootCerts_e /* rootInstanceName= */ null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); fail("exception expected"); } catch (UnsupportedOperationException expected) { assertThat(expected).hasMessageThat().contains("Unsupported configurations in " @@ -373,7 +469,7 @@ public void testProviderForClient_rootInstanceNull_but_isUsingSystemRootCerts_va CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build()); + .build(), false); } static class QueuedExecutor implements Executor { diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java index 423829ff5af..6515a97f526 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java @@ -127,7 +127,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +135,11 @@ public void testProviderForServer_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,7 +168,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -196,7 +196,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -204,11 +204,11 @@ public void testProviderForServer_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -237,7 +237,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -294,14 +294,14 @@ public void testProviderForServer_tls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java index db83961cfc3..5d156e29c08 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java @@ -91,7 +91,7 @@ public void constructor_fromRootCert() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -115,7 +115,7 @@ public void constructor_fromSpiffeTrustMap() "san2"); // Single domain and single cert XdsTrustManagerFactory factory = new XdsTrustManagerFactory(ImmutableMap - .of("example.com", ImmutableList.of(x509Cert)), staticValidationContext); + .of("example.com", ImmutableList.of(x509Cert)), staticValidationContext, null); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -131,7 +131,7 @@ public void constructor_fromSpiffeTrustMap() X509Certificate anotherCert = TestUtils.loadX509Cert(CLIENT_PEM_FILE); factory = new XdsTrustManagerFactory(ImmutableMap .of("example.com", ImmutableList.of(x509Cert), - "google.com", ImmutableList.of(x509Cert, anotherCert)), staticValidationContext); + "google.com", ImmutableList.of(x509Cert, anotherCert)), staticValidationContext, null); assertThat(factory).isNotNull(); tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -154,7 +154,7 @@ public void constructorRootCert_checkServerTrusted() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "waterzooi.test.google.be"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -167,7 +167,7 @@ public void constructorRootCert_nonStaticContext_throwsException() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); try { new XdsTrustManagerFactory( - new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE)); + new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE), null); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected) @@ -176,6 +176,19 @@ public void constructorRootCert_nonStaticContext_throwsException() } } + @Test + public void constructorRootCert_nonStaticContext_systemRootCerts_valid() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext certValidationContext = CertificateValidationContext.newBuilder() + .setTrustedCa( + DataSource.newBuilder().setFilename(TestUtils.loadCert(CA_PEM_FILE).getAbsolutePath())) + .setSystemRootCerts(CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(); + new XdsTrustManagerFactory( + new X509Certificate[] {x509Cert}, certValidationContext, null); + } + @Test public void constructorRootCert_checkServerTrusted_throwsException() throws CertificateException, IOException, CertStoreException { @@ -183,7 +196,7 @@ public void constructorRootCert_checkServerTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -204,7 +217,7 @@ public void constructorRootCert_checkClientTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index 6fa3d2e7d24..09f5b1028cc 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -42,6 +42,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import javax.net.ssl.SSLEngine; @@ -53,6 +54,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -60,7 +62,7 @@ /** * Unit tests for {@link XdsX509TrustManager}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsX509TrustManagerTest { @Rule @@ -73,6 +75,18 @@ public class XdsX509TrustManagerTest { private SSLSession mockSession; private XdsX509TrustManager trustManager; + private boolean useSniForSanMatching; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + { true }, { false } + }); + } + + public XdsX509TrustManagerTest(boolean useSniForSanMatching) { + this.useSniForSanMatching = useSniForSanMatching; + } @Test public void nullCertContextTest() throws CertificateException, IOException { @@ -93,41 +107,67 @@ public void emptySanListContextTest() throws CertificateException, IOException { @Test public void missingPeerCerts() { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } try { trustManager.verifySubjectAltNameInChain(null); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @Test public void emptyArrayPeerCerts() { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } try { trustManager.verifySubjectAltNameInChain(new X509Certificate[0]); fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @Test public void noSansInPeerCerts() throws CertificateException, IOException { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CLIENT_PEM_FILE)); try { @@ -135,22 +175,75 @@ public void noSansInPeerCerts() throws CertificateException, IOException { fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @Test public void oneSanInPeerCertsVerifies() throws CertificateException, IOException { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "waterzooi.test.google.be"); + } else { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact("waterzooi.test.google.be") + .setIgnoreCase(false) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } + try { + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } + } + } + + @Test + public void autoSanSniValidation_overrides_subAltNamesToMatch() throws CertificateException, IOException { + CertificateUtils.isXdsSniEnabled = true; + try { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact("notgonnabeused.test.google.be") + .setIgnoreCase(false) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, "waterzooi.test.google.be"); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } + } + + @Test + public void emptySni_noAutoSanSniValidation() throws CertificateException, IOException { StringMatcher stringMatcher = - StringMatcher.newBuilder() - .setExact("waterzooi.test.google.be") - .setIgnoreCase(false) - .build(); + StringMatcher.newBuilder() + .setExact("waterzooi.test.google.be") + .setIgnoreCase(false) + .build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, ""); X509Certificate[] certs = - CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); } @@ -420,11 +513,17 @@ public void oneSanInPeerCertsVerifiesMultipleVerifySans() @Test public void oneSanInPeerCertsNotFoundException() throws CertificateException, IOException { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, + "x.foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { @@ -432,6 +531,10 @@ public void oneSanInPeerCertsNotFoundException() fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @@ -477,12 +580,18 @@ public void wildcardSanInPeerCertsSubdomainMismatch() // 2. Asterisk (*) cannot match across domain name labels. // For example, *.example.com matches test.example.com but does not match // sub.test.example.com. - StringMatcher stringMatcher = - StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), + mockDelegate, "sub.abc.test.youtube.com"); + } else { + StringMatcher stringMatcher = + StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { @@ -490,6 +599,10 @@ public void wildcardSanInPeerCertsSubdomainMismatch() fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @@ -509,6 +622,15 @@ public void oneIpAddressInPeerCertsVerifies() throws CertificateException, IOExc trustManager.verifySubjectAltNameInChain(certs); } + @Test + public void oneIpAddressInPeerCertsVerifies_autoSniSanValidation() throws CertificateException, IOException { + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, + "192.168.1.3"); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } + @Test public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); @@ -530,6 +652,15 @@ public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOExc } } + @Test + public void oneIpAddressInPeerCertsMismatch_autoSniSanValidation() throws CertificateException, IOException { + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, + "192.168.1.3"); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } + @Test public void checkServerTrustedSslEngine() throws CertificateException, IOException, CertStoreException { @@ -550,7 +681,7 @@ public void checkServerTrustedSslEngineSpiffeTrustMap() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("example.com", caCerts), null); + ImmutableMap.of("example.com", caCerts), null, null); trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); verify(sslEngine, times(1)).getHandshakeSession(); assertThat(sslEngine.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); @@ -565,7 +696,7 @@ public void checkServerTrustedSslEngineSpiffeTrustMap_missing_spiffe_id() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("example.com", caCerts), null); + ImmutableMap.of("example.com", caCerts), null, null); try { trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); fail("exception expected"); @@ -584,7 +715,7 @@ public void checkServerTrustedSpiffeSslEngineTrustMap_missing_trust_domain() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("unknown.com", caCerts), null); + ImmutableMap.of("unknown.com", caCerts), null, null); try { trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); fail("exception expected"); @@ -602,7 +733,7 @@ public void checkClientTrustedSpiffeTrustMap() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("foo.bar.com", caCerts), null); + ImmutableMap.of("foo.bar.com", caCerts), null, null); trustManager.checkClientTrusted(clientCerts, "RSA"); } @@ -643,7 +774,7 @@ public void checkServerTrustedSslSocketSpiffeTrustMap() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("example.com", caCerts), null); + ImmutableMap.of("example.com", caCerts), null, null); trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslSocket); verify(sslSocket, times(1)).isConnected(); verify(sslSocket, times(1)).getHandshakeSession(); @@ -717,7 +848,7 @@ private SSLParameters buildTrustManagerAndGetSslParameters() X509Certificate[] caCerts = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); trustManager = XdsTrustManagerFactory.createX509TrustManager(caCerts, - null); + null, null); when(mockSession.getProtocol()).thenReturn("TLSv1.2"); when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock"); SSLParameters sslParams = new SSLParameters();