From ed5072e1481c2f355e4f8b2e3eab92f1a7a7b7cd Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 20 May 2025 07:19:41 +0000 Subject: [PATCH 01/51] Add Docker fiels for xds example server and client. --- examples/example-xds/xds-client.Dockerfile | 47 ++++++++++++++++++++++ examples/example-xds/xds-server.Dockerfile | 47 ++++++++++++++++++++++ 2 files changed, 94 insertions(+) create mode 100644 examples/example-xds/xds-client.Dockerfile create mode 100644 examples/example-xds/xds-server.Dockerfile diff --git a/examples/example-xds/xds-client.Dockerfile b/examples/example-xds/xds-client.Dockerfile new file mode 100644 index 00000000000..0f34d219177 --- /dev/null +++ b/examples/example-xds/xds-client.Dockerfile @@ -0,0 +1,47 @@ +# Copyright 2024 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Stage 1: Build XDS client +# + +FROM eclipse-temurin:11-jdk AS build + +WORKDIR /grpc-java/examples +COPY . . + +RUN cd example-xds && ../gradlew installDist -PskipCodegen=true -PskipAndroid=true + +# +# Stage 2: +# +# - Copy only the necessary files to reduce Docker image size. +# - Have an ENTRYPOINT script which will launch the XDS client +# with the given parameters. +# + +FROM eclipse-temurin:11-jre + +WORKDIR /grpc-java/ +COPY --from=build /grpc-java/examples/example-xds/build/install/example-xds/. . + +# Intentionally after the COPY to force the update on each build. +# Update Ubuntu system packages: +RUN apt-get update \ + && apt-get -y upgrade \ + && apt-get -y autoremove \ + && rm -rf /var/lib/apt/lists/* + +# Client +ENTRYPOINT ["bin/xds-hello-world-client"] diff --git a/examples/example-xds/xds-server.Dockerfile b/examples/example-xds/xds-server.Dockerfile new file mode 100644 index 00000000000..542fb0263af --- /dev/null +++ b/examples/example-xds/xds-server.Dockerfile @@ -0,0 +1,47 @@ +# Copyright 2024 gRPC authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# +# Stage 1: Build XDS server +# + +FROM eclipse-temurin:11-jdk AS build + +WORKDIR /grpc-java/examples +COPY . . + +RUN cd example-xds && ../gradlew installDist -PskipCodegen=true -PskipAndroid=true + +# +# Stage 2: +# +# - Copy only the necessary files to reduce Docker image size. +# - Have an ENTRYPOINT script which will launch the XDS server +# with the given parameters. +# + +FROM eclipse-temurin:11-jre + +WORKDIR /grpc-java/ +COPY --from=build /grpc-java/examples/example-xds/build/install/example-xds/. . + +# Intentionally after the COPY to force the update on each build. +# Update Ubuntu system packages: +RUN apt-get update \ + && apt-get -y upgrade \ + && apt-get -y autoremove \ + && rm -rf /var/lib/apt/lists/* + +# Server +ENTRYPOINT ["bin/xds-hello-world-server"] From 6263ccec5df096b63bfbc47475c11f30dce67de1 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 1 Jul 2025 05:07:31 +0000 Subject: [PATCH 02/51] Changes needed for System root certs to work. Commented out the change for SNI in ProtocolNegotiators.java --- .../testing/integration/XdsTestClient.java | 2 + .../io/grpc/netty/ProtocolNegotiators.java | 6 ++- .../io/grpc/xds/GcpAuthenticationFilter.java | 7 +++ .../java/io/grpc/xds/XdsNameResolver.java | 2 +- .../security/SslContextProviderSupplier.java | 45 ++++++++++++------- .../CertProviderClientSslContextProvider.java | 2 +- .../CertProviderSslContextProvider.java | 4 ++ 7 files changed, 50 insertions(+), 18 deletions(-) diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java index 89519041a79..f341836e71d 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java @@ -452,12 +452,14 @@ public void onNext(SimpleResponse response) { private void handleRpcCompleted(long requestId, RpcType rpcType, String hostname, Set watchers) { + logger.info("RPC completed"); statsAccumulator.recordRpcFinished(rpcType, Status.OK); notifyWatchers(watchers, rpcType, requestId, hostname); } private void handleRpcError(long requestId, RpcType rpcType, Status status, Set watchers) { + logger.info("RPC error with status " + status); statsAccumulator.recordRpcFinished(rpcType, status); notifyWatchers(watchers, rpcType, requestId, null); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 77308c76ace..5c0c0437efe 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -651,7 +651,11 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { @Override @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + /*if (host.equals("psm-grpc-server")) { + sslEngine = sslContext.newEngine(ctx.alloc(), "kannanj-psm-server-20250604-1226-8bkw5-830293263384.us-east7.run.app", 443); + } else {*/ + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + // } SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index dc133eaaf1a..6eba753f716 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -37,6 +37,7 @@ import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.CompositeCallCredentials; +import io.grpc.InternalLogId; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -45,10 +46,13 @@ import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.client.XdsLogger; +import io.grpc.xds.client.XdsLogger.XdsLogLevel; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; import java.util.function.Function; import javax.annotation.Nullable; @@ -61,6 +65,7 @@ final class GcpAuthenticationFilter implements Filter { static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; private final LruCache callCredentialsCache; + private final XdsLogger logger = XdsLogger.withLogId(InternalLogId.allocate("bootstrapper", null)); final String filterInstanceName; GcpAuthenticationFilter(String name, int cacheSize) { @@ -193,6 +198,8 @@ public ClientCall interceptCall( } else { callOptions = callOptions.withCallCredentials(newCallCredentials); } + logger.log(XdsLogLevel.INFO, "Time to expiry of the auth token=" + callOptions.getDeadline().timeRemaining( + TimeUnit.SECONDS)); return next.newCall(method, callOptions); } }; diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index a14abf95f41..8c4423923fb 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -515,7 +515,7 @@ public void onClose(Status status, Metadata trailers) { Result.newBuilder() .setConfig(config) .setInterceptor(combineInterceptors( - ImmutableList.of(filters, new ClusterSelectionInterceptor()))) + ImmutableList.of(new ClusterSelectionInterceptor(), filters))) .build(); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 5f629273179..9be306eca21 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -20,12 +20,15 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; +import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProvider; import io.netty.handler.ssl.SslContext; import java.util.Objects; +import javax.net.ssl.SSLException; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -62,21 +65,33 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(); - toRelease.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); - releaseSslContextProvider(toRelease); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease); - } - }); + if (toRelease instanceof CertProviderClientSslContextProvider + && ((CertProviderClientSslContextProvider) toRelease).isUsingSystemRootCerts()) { + callback.getExecutor().execute(() -> { + try { + callback.updateSslContext(GrpcSslContexts.forClient().build()); + releaseSslContextProvider(toRelease); + } catch (SSLException e) { + callback.onException(e); + } + }); + } else { + toRelease.addCallback( + new SslContextProvider.Callback(callback.getExecutor()) { + + @Override + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); + releaseSslContextProvider(toRelease); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease); + } + }); + }; } catch (final Throwable throwable) { callback.getExecutor().execute(new Runnable() { @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d7c2267c48f..d1eeefa61ea 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -30,7 +30,7 @@ import javax.annotation.Nullable; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ -final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { +public final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { CertProviderClientSslContextProvider( Node node, diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 801dabeecb7..cfdafff7466 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -89,6 +89,10 @@ protected CertProviderSslContextProvider( && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); } + public boolean isUsingSystemRootCerts() { + return this.isUsingSystemRootCerts; + } + private static CertificateProviderInfo getCertProviderConfig( @Nullable Map certProviders, String pluginInstanceName) { return certProviders != null ? certProviders.get(pluginInstanceName) : null; From 5e794bf72bda69f24483528d88791edd5a9f7932 Mon Sep 17 00:00:00 2001 From: deadEternally Date: Tue, 29 Jul 2025 13:21:10 +0530 Subject: [PATCH 03/51] In-progress changes. --- .../netty/InternalProtocolNegotiators.java | 11 ++-- .../io/grpc/netty/NettyChannelBuilder.java | 2 +- .../io/grpc/netty/ProtocolNegotiators.java | 50 ++++++++++--------- .../grpc/netty/NettyClientTransportTest.java | 2 +- .../grpc/netty/ProtocolNegotiatorsTest.java | 12 ++--- .../io/grpc/xds/ClusterImplLoadBalancer.java | 29 ++++------- .../io/grpc/xds/EnvoyServerProtoData.java | 19 +++++-- .../security/SecurityProtocolNegotiators.java | 6 +-- .../internal/security/SslContextProvider.java | 4 +- .../security/SslContextProviderSupplier.java | 15 ++++-- .../security/CommonTlsContextTestsUtil.java | 2 +- .../SecurityProtocolNegotiatorsTest.java | 6 +-- .../SslContextProviderSupplierTest.java | 4 +- 13 files changed, 90 insertions(+), 72 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 039ea6c4f24..a9ecefc1c39 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -42,9 +42,10 @@ private InternalProtocolNegotiators() {} */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, ObjectPool executorPool, - Optional handshakeCompleteRunnable) { + Optional handshakeCompleteRunnable, + String sni) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, - executorPool, handshakeCompleteRunnable, null); + executorPool, handshakeCompleteRunnable, null, sni); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -71,8 +72,8 @@ public void close() { * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null, Optional.absent()); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, String sni) { + return tls(sslContext, null, Optional.absent(), sni); } /** @@ -170,7 +171,7 @@ public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, - Optional.absent(), null, null); + Optional.absent(), null, null, sni); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 46566eaca1a..2db5ab20a91 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -652,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null); + return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 5c0c0437efe..23537412128 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -46,6 +46,7 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.NoopSslSession; import io.grpc.internal.ObjectPool; +import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; @@ -60,6 +61,7 @@ import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; import io.netty.handler.proxy.HttpProxyHandler; import io.netty.handler.proxy.ProxyConnectionEvent; +import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslEngine; import io.netty.handler.ssl.SslContext; @@ -223,15 +225,15 @@ public static FromServerCredentialsResult from(ServerCredentials creds) { } // else use system default switch (tlsCreds.getClientAuth()) { case OPTIONAL: - builder.clientAuth(io.netty.handler.ssl.ClientAuth.OPTIONAL); + builder.clientAuth(ClientAuth.OPTIONAL); break; case REQUIRE: - builder.clientAuth(io.netty.handler.ssl.ClientAuth.REQUIRE); + builder.clientAuth(ClientAuth.REQUIRE); break; case NONE: - builder.clientAuth(io.netty.handler.ssl.ClientAuth.NONE); + builder.clientAuth(ClientAuth.NONE); break; default: @@ -578,8 +580,8 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { @@ -587,12 +589,14 @@ public ClientTlsProtocolNegotiator(SslContext sslContext, } this.handshakeCompleteRunnable = handshakeCompleteRunnable; this.x509ExtendedTrustManager = x509ExtendedTrustManager; + this.sni = sni; } private final SslContext sslContext; private final ObjectPool executorPool; private final Optional handshakeCompleteRunnable; private final X509TrustManager x509ExtendedTrustManager; + private final String sni; private Executor executor; @Override @@ -606,7 +610,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), this.executor, negotiationLogger, handshakeCompleteRunnable, this, - x509ExtendedTrustManager); + x509ExtendedTrustManager, sni); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -631,15 +635,17 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private Executor executor; private final Optional handshakeCompleteRunnable; private final X509TrustManager x509ExtendedTrustManager; + private final String sni; private SSLEngine sslEngine; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, - Executor executor, ChannelLogger negotiationLogger, - Optional handshakeCompleteRunnable, - ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, - X509TrustManager x509ExtendedTrustManager) { + Executor executor, ChannelLogger negotiationLogger, + Optional handshakeCompleteRunnable, + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509ExtendedTrustManager, String sni) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); + this.sni = sni; HostPort hostPort = parseAuthority(authority); this.host = hostPort.host; this.port = hostPort.port; @@ -651,11 +657,7 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { @Override @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - /*if (host.equals("psm-grpc-server")) { - sslEngine = sslContext.newEngine(ctx.alloc(), "kannanj-psm-server-20250604-1226-8bkw5-830293263384.us-east7.run.app", 443); - } else {*/ - sslEngine = sslContext.newEngine(ctx.alloc(), host, port); - // } + sslEngine = sslContext.newEngine(ctx.alloc(), sni != null? sni : host, port); SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -748,25 +750,27 @@ static HostPort parseAuthority(String authority) { /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will - * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} + * be negotiated, the {@code handler} is added and writes to the {@link Channel} * may happen immediately, even before the TLS Handshake is complete. + * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks + * @param sni */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, - x509ExtendedTrustManager); + x509ExtendedTrustManager, sni); } /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will - * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} + * be negotiated, the {@code handler} is added and writes to the {@link Channel} * may happen immediately, even before the TLS Handshake is complete. */ public static ProtocolNegotiator tls(SslContext sslContext, X509TrustManager x509ExtendedTrustManager) { - return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager); + return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null); } public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, @@ -908,8 +912,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } /** - * Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is - * added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, + * Returns a {@link ChannelHandler} that ensures that the {@code handler} is + * added to the pipeline writes to the {@link Channel} may happen immediately, * even before it is active. */ public static ProtocolNegotiator plaintext() { diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 65683dd8396..f04acb3a42b 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -836,7 +836,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .keyManager(clientCert, clientKey) .build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, - Optional.absent(), null); + Optional.absent(), null, sni); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 638fe960a32..249a346183f 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -918,7 +918,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + getClientTlsProtocolNegotiator(), null, sni); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -957,7 +957,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + getClientTlsProtocolNegotiator(), null, sni); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -982,7 +982,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + getClientTlsProtocolNegotiator(), null, sni); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -1011,7 +1011,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", null, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null); + getClientTlsProtocolNegotiator(), null, sni); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -1026,7 +1026,7 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception { private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException { return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager( TlsTesting.loadCert("ca.pem")).build(), - null, Optional.absent(), null); + null, Optional.absent(), null, sni); } @Test @@ -1277,7 +1277,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception { } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, - null, Optional.absent(), null); + null, Optional.absent(), null, sni); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 034cdee0815..71928a12077 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -146,7 +146,7 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { childLbHelper.updateDropPolicies(config.dropCategories); childLbHelper.updateMaxConcurrentRequests(config.maxConcurrentRequests); - childLbHelper.updateSslContextProviderSupplier(config.tlsContext); + childLbHelper.updateSslContext(config.tlsContext); childLbHelper.updateFilterMetadata(config.filterMetadata); childSwitchLb.handleResolvedAddresses( @@ -184,7 +184,7 @@ public void shutdown() { if (childSwitchLb != null) { childSwitchLb.shutdown(); if (childLbHelper != null) { - childLbHelper.updateSslContextProviderSupplier(null); + childLbHelper.updateSslContext(null); childLbHelper = null; } } @@ -204,7 +204,7 @@ private final class ClusterImplLbHelper extends ForwardingLoadBalancerHelper { private List dropPolicies = Collections.emptyList(); private long maxConcurrentRequests = DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; @Nullable - private SslContextProviderSupplier sslContextProviderSupplier; + private UpstreamTlsContext tlsContext; private Map filterMetadata = ImmutableMap.of(); @Nullable private final ServerInfo lrsServerInfo; @@ -293,10 +293,12 @@ private List withAdditionalAttributes( for (EquivalentAddressGroup eag : addresses) { Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set( XdsAttributes.ATTR_CLUSTER_NAME, cluster); - if (sslContextProviderSupplier != null) { + if (tlsContext != null) { attrBuilder.set( SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - sslContextProviderSupplier); + new SslContextProviderSupplier(tlsContext, + (TlsContextManager) xdsClient.getSecurityConfig(), + eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME))); } newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build())); } @@ -348,22 +350,11 @@ private void updateMaxConcurrentRequests(@Nullable Long maxConcurrentRequests) { updateBalancingState(currentState, currentPicker); } - private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsContext) { - UpstreamTlsContext currentTlsContext = - sslContextProviderSupplier != null - ? (UpstreamTlsContext)sslContextProviderSupplier.getTlsContext() - : null; - if (Objects.equals(currentTlsContext, tlsContext)) { + private void updateSslContext(@Nullable UpstreamTlsContext tlsContext) { + if (Objects.equals(this.tlsContext, tlsContext)) { return; } - if (sslContextProviderSupplier != null) { - sslContextProviderSupplier.close(); - } - sslContextProviderSupplier = - tlsContext != null - ? new SslContextProviderSupplier(tlsContext, - (TlsContextManager) xdsClient.getSecurityConfig()) - : null; + this.tlsContext = tlsContext; } private void updateFilterMetadata(Map filterMetadata) { diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index fd2a1d2a069..461f31066f0 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -73,15 +73,28 @@ public int hashCode() { public static final class UpstreamTlsContext extends BaseTlsContext { + private final String sni; + private final boolean auto_host_sni; + @VisibleForTesting - public UpstreamTlsContext(CommonTlsContext commonTlsContext) { - super(commonTlsContext); + public UpstreamTlsContext(io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { + super(upstreamTlsContext.getCommonTlsContext()); + this.sni = upstreamTlsContext.getSni(); + this.auto_host_sni = upstreamTlsContext.getAutoHostSni(); } public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { - return new UpstreamTlsContext(upstreamTlsContext.getCommonTlsContext()); + return new UpstreamTlsContext(upstreamTlsContext); + } + + public String getSni() { + return sni; + } + + public boolean getAutoHostSni() { + return auto_host_sni; } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index c34fab74032..be0e81cc672 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -213,7 +213,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContext(SslContext sslContext, String sni) { if (ctx.isRemoved()) { return; } @@ -222,7 +222,7 @@ public void updateSslContext(SslContext sslContext) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls(sslContext, sni).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -356,7 +356,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContext(SslContext sslContext, String sni) { ChannelHandler handler = InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index a0c4ed37dfb..d9be5748792 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -57,7 +57,7 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContext(SslContext sslContext, String sni); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -120,7 +120,7 @@ protected final void performCallback( public void run() { try { SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext); + callback.updateSslContext(sslContext, sni); } catch (Throwable e) { callback.onException(e); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 9be306eca21..69486dbd546 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -41,13 +41,20 @@ public final class SslContextProviderSupplier implements Closeable { private final BaseTlsContext tlsContext; private final TlsContextManager tlsContextManager; + private final String hostname; private SslContextProvider sslContextProvider; private boolean shutdown; public SslContextProviderSupplier( BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { + this(tlsContext, tlsContextManager, null); + } + + public SslContextProviderSupplier( + BaseTlsContext tlsContext, TlsContextManager tlsContextManager, String hostname) { this.tlsContext = checkNotNull(tlsContext, "tlsContext"); this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); + this.hostname = hostname; } public BaseTlsContext getTlsContext() { @@ -63,13 +70,15 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call sslContextProvider = getSslContextProvider(); } } + UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); + String sni = upstreamTlsContext.getAutoHostSni() ? hostname : upstreamTlsContext.getSni(); // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(); if (toRelease instanceof CertProviderClientSslContextProvider && ((CertProviderClientSslContextProvider) toRelease).isUsingSystemRootCerts()) { callback.getExecutor().execute(() -> { try { - callback.updateSslContext(GrpcSslContexts.forClient().build()); + callback.updateSslContext(GrpcSslContexts.forClient().build(), sni); releaseSslContextProvider(toRelease); } catch (SSLException e) { callback.onException(e); @@ -80,8 +89,8 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); + public void updateSslContext(SslContext sslContext, String sni) { + callback.updateSslContext(sslContext, sni); releaseSslContextProvider(toRelease); } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 48814dece1d..a991e96b6db 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -388,7 +388,7 @@ public TestCallback(Executor executor) { } @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContext(SslContext sslContext, String sni) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index a0139618f9f..c588c5ea1ba 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -168,7 +168,7 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContext(SslContext sslContext, String sni) { future.set(sslContext); } @@ -245,7 +245,7 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContext(SslContext sslContext, String sni) { future.set(sslContext); } @@ -381,7 +381,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContext(SslContext sslContext, String sni) { future.set(sslContext); } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f476818297d..a14a3eac695 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -83,8 +83,8 @@ public void get_updateSecret() { SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + capturedCallback.updateSslContext(mockSslContext, sni); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext), sni); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); From 30ffa7b23ecae6a8015ce8852a3b3cbe946fb00d Mon Sep 17 00:00:00 2001 From: deadEternally Date: Sun, 3 Aug 2025 12:54:43 +0530 Subject: [PATCH 04/51] Save changes. --- .../grpc/netty/InternalProtocolNegotiators.java | 2 +- .../java/io/grpc/netty/ProtocolNegotiators.java | 5 ++--- .../io/grpc/netty/ProtocolNegotiatorsTest.java | 16 +++++----------- .../io/grpc/xds/ClusterImplLoadBalancer.java | 13 ++++++++++--- .../ClientSslContextProviderFactory.java | 7 ++++--- .../internal/security/SslContextProvider.java | 1 + .../CertProviderClientSslContextProvider.java | 2 +- ...tProviderClientSslContextProviderFactory.java | 3 ++- ...CertProviderClientSslContextProviderTest.java | 4 ++-- 9 files changed, 28 insertions(+), 25 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index a9ecefc1c39..605ef7cdd6e 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -171,7 +171,7 @@ public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, - Optional.absent(), null, null, sni); + Optional.absent(), null, null); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 23537412128..fb33da78d0a 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -609,7 +609,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger, handshakeCompleteRunnable, this, + this.executor, negotiationLogger, handshakeCompleteRunnable, x509ExtendedTrustManager, sni); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -641,7 +641,6 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, Executor executor, ChannelLogger negotiationLogger, Optional handshakeCompleteRunnable, - ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, X509TrustManager x509ExtendedTrustManager, String sni) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); @@ -754,7 +753,7 @@ static HostPort parseAuthority(String authority) { * may happen immediately, even before the TLS Handshake is complete. * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks - * @param sni + * @param sni the SNI value to use in the Tls handshake */ public static ProtocolNegotiator tls(SslContext sslContext, ObjectPool executorPool, Optional handshakeCompleteRunnable, diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index 249a346183f..b62b3e57a7e 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -918,7 +918,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null, sni); + null, null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -957,7 +957,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null, sni); + null, null); pipeline.addLast(handler); pipeline.replace(SslHandler.class, null, goodSslHandler); pipeline.fireUserEventTriggered(ProtocolNegotiationEvent.DEFAULT); @@ -982,7 +982,7 @@ public String applicationProtocol() { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", elg, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null, sni); + null, null); pipeline.addLast(handler); final AtomicReference error = new AtomicReference<>(); @@ -1011,7 +1011,7 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { public void clientTlsHandler_closeDuringNegotiation() throws Exception { ClientTlsHandler handler = new ClientTlsHandler(grpcHandler, sslContext, "authority", null, noopLogger, Optional.absent(), - getClientTlsProtocolNegotiator(), null, sni); + null, null); pipeline.addLast(new WriteBufferingAndExceptionHandler(handler)); ChannelFuture pendingWrite = channel.writeAndFlush(NettyClientHandler.NOOP_MESSAGE); @@ -1023,12 +1023,6 @@ public void clientTlsHandler_closeDuringNegotiation() throws Exception { .isEqualTo(Status.Code.UNAVAILABLE); } - private ClientTlsProtocolNegotiator getClientTlsProtocolNegotiator() throws SSLException { - return new ClientTlsProtocolNegotiator(GrpcSslContexts.forClient().trustManager( - TlsTesting.loadCert("ca.pem")).build(), - null, Optional.absent(), null, sni); - } - @Test public void engineLog() { ChannelHandler handler = new ServerTlsHandler(grpcHandler, sslContext, null); @@ -1277,7 +1271,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception { } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, - null, Optional.absent(), null, sni); + null, Optional.absent(), null, null); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 71928a12077..ce74f9f33ef 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -59,6 +59,7 @@ import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -208,6 +209,7 @@ private final class ClusterImplLbHelper extends ForwardingLoadBalancerHelper { private Map filterMetadata = ImmutableMap.of(); @Nullable private final ServerInfo lrsServerInfo; + private final Map sslContextProviderSupplierMap = new HashMap<>(); private ClusterImplLbHelper(AtomicLong inFlights, @Nullable ServerInfo lrsServerInfo) { this.inFlights = checkNotNull(inFlights, "inFlights"); @@ -294,11 +296,16 @@ private List withAdditionalAttributes( Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set( XdsAttributes.ATTR_CLUSTER_NAME, cluster); if (tlsContext != null) { + String addressNameAttr = eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME); + if (!sslContextProviderSupplierMap.containsKey(addressNameAttr)) { + sslContextProviderSupplierMap.put(addressNameAttr, + new SslContextProviderSupplier(tlsContext, + (TlsContextManager) xdsClient.getSecurityConfig(), + eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME))); + } attrBuilder.set( SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier(tlsContext, - (TlsContextManager) xdsClient.getSecurityConfig(), - eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME))); + sslContextProviderSupplierMap.get(addressNameAttr)); } newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build())); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java index 37d289c1c47..4df2a896ef8 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java @@ -20,10 +20,11 @@ import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProviderFactory; +import java.util.AbstractMap; /** Factory to create client-side SslContextProvider from UpstreamTlsContext. */ final class ClientSslContextProviderFactory - implements ValueFactory { + implements ValueFactory, SslContextProvider> { private BootstrapInfo bootstrapInfo; private final CertProviderClientSslContextProviderFactory @@ -41,9 +42,9 @@ final class ClientSslContextProviderFactory /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override - public SslContextProvider create(UpstreamTlsContext upstreamTlsContext) { + public SslContextProvider create(AbstractMap.SimpleImmutableEntry key) { return certProviderClientSslContextProviderFactory.getProvider( - upstreamTlsContext, + key.getKey(), key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index d9be5748792..214fddce633 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -44,6 +44,7 @@ public abstract class SslContextProvider implements Closeable { protected final BaseTlsContext tlsContext; + private String sni; @VisibleForTesting public abstract static class Callback { private final Executor executor; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d1eeefa61ea..b25fddce65c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -39,7 +39,7 @@ public final class CertProviderClientSslContextProvider extends CertProviderSslC CommonTlsContext.CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, UpstreamTlsContext upstreamTlsContext, - CertificateProviderStore certificateProviderStore) { + String sni, CertificateProviderStore certificateProviderStore) { super( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java index 6205c1c3a63..c65fcc59a45 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderFactory.java @@ -55,7 +55,7 @@ public static CertProviderClientSslContextProviderFactory getInstance() { */ public SslContextProvider getProvider( UpstreamTlsContext upstreamTlsContext, - Node node, + String sni, Node node, @Nullable Map certProviders) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext(); @@ -74,6 +74,7 @@ public SslContextProvider getProvider( rootCertInstance, staticCertValidationContext, upstreamTlsContext, + sni, certificateProviderStore); } throw new UnsupportedOperationException("Unsupported configurations in UpstreamTlsContext!"); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index b0800458d66..e9521a5a308 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -84,7 +84,7 @@ private CertProviderClientSslContextProvider getSslContextProvider( return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), + key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } @@ -106,7 +106,7 @@ private CertProviderClientSslContextProvider getNewSslContextProvider( return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, - bootstrapInfo.node().toEnvoyProtoNode(), + key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } From 42c9df03c1ceea9b9d65ac368952eb6b9d30a614 Mon Sep 17 00:00:00 2001 From: kannanjgithub Date: Mon, 11 Aug 2025 10:37:47 +0530 Subject: [PATCH 05/51] Save changes. --- .../netty/InternalProtocolNegotiators.java | 22 +++--- .../io/grpc/netty/ProtocolNegotiators.java | 72 +++++++++---------- .../grpc/netty/NettyClientTransportTest.java | 3 +- .../io/grpc/xds/ClusterImplLoadBalancer.java | 48 +++++++------ .../io/grpc/xds/EnvoyServerProtoData.java | 19 +++-- .../java/io/grpc/xds/TlsContextManager.java | 5 +- .../security/SecurityProtocolNegotiators.java | 20 +++--- .../internal/security/SslContextProvider.java | 26 +++++-- .../security/SslContextProviderSupplier.java | 48 +++++++------ .../security/TlsContextManagerImpl.java | 16 ++--- .../grpc/xds/ClusterImplLoadBalancerTest.java | 4 +- .../ClientSslContextProviderFactoryTest.java | 44 ++++++------ .../security/CommonTlsContextTestsUtil.java | 2 +- .../SslContextProviderSupplierTest.java | 33 +++++---- .../security/TlsContextManagerTest.java | 20 +++--- ...tProviderClientSslContextProviderTest.java | 5 +- 16 files changed, 216 insertions(+), 171 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 605ef7cdd6e..d46098afeb6 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -41,9 +41,9 @@ private InternalProtocolNegotiators() {} * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, - Optional handshakeCompleteRunnable, - String sni) { + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + String sni) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, executorPool, handshakeCompleteRunnable, null, sni); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @@ -63,17 +63,17 @@ public void close() { negotiator.close(); } } - + return new TlsNegotiator(); } - + /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, String sni) { - return tls(sslContext, null, Optional.absent(), sni); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { + return tls(sslContext, null, Optional.absent(), null); } /** @@ -156,7 +156,7 @@ public void close() { * Internal version of {@link WaitUntilActiveHandler}. */ public static ChannelHandler waitUntilActiveHandler(ChannelHandler next, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { return new WaitUntilActiveHandler(next, negotiationLogger); } @@ -171,14 +171,14 @@ public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, - Optional.absent(), null, null); + Optional.absent(), null); } public static class ProtocolNegotiationHandler extends ProtocolNegotiators.ProtocolNegotiationHandler { protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { super(next, negotiatorName, negotiationLogger); } @@ -186,4 +186,4 @@ protected ProtocolNegotiationHandler(ChannelHandler next, ChannelLogger negotiat super(next, negotiationLogger); } } -} +} \ No newline at end of file diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index fb33da78d0a..e31a5785beb 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -46,7 +46,6 @@ import io.grpc.internal.GrpcUtil; import io.grpc.internal.NoopSslSession; import io.grpc.internal.ObjectPool; -import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelHandler; @@ -61,7 +60,6 @@ import io.netty.handler.codec.http2.Http2ClientUpgradeCodec; import io.netty.handler.proxy.HttpProxyHandler; import io.netty.handler.proxy.ProxyConnectionEvent; -import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.OpenSsl; import io.netty.handler.ssl.OpenSslEngine; import io.netty.handler.ssl.SslContext; @@ -141,7 +139,7 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { trustManagers = tlsCreds.getTrustManagers(); } else if (tlsCreds.getRootCertificates() != null) { trustManagers = Arrays.asList(CertificateUtils.createTrustManager( - new ByteArrayInputStream(tlsCreds.getRootCertificates()))); + new ByteArrayInputStream(tlsCreds.getRootCertificates()))); } else { // else use system default TrustManagerFactory tmf = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm()); @@ -159,7 +157,7 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { } } return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(), - (X509TrustManager) x509ExtendedTrustManager)); + (X509TrustManager) x509ExtendedTrustManager)); } catch (SSLException | GeneralSecurityException ex) { log.log(Level.FINE, "Exception building SslContext", ex); return FromChannelCredentialsResult.error( @@ -225,15 +223,15 @@ public static FromServerCredentialsResult from(ServerCredentials creds) { } // else use system default switch (tlsCreds.getClientAuth()) { case OPTIONAL: - builder.clientAuth(ClientAuth.OPTIONAL); + builder.clientAuth(io.netty.handler.ssl.ClientAuth.OPTIONAL); break; case REQUIRE: - builder.clientAuth(ClientAuth.REQUIRE); + builder.clientAuth(io.netty.handler.ssl.ClientAuth.REQUIRE); break; case NONE: - builder.clientAuth(ClientAuth.NONE); + builder.clientAuth(io.netty.handler.ssl.ClientAuth.NONE); break; default: @@ -281,7 +279,7 @@ public static final class FromChannelCredentialsResult { public final String error; private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator, - CallCredentials creds, String error) { + CallCredentials creds, String error) { this.negotiator = negotiator; this.callCredentials = creds; this.error = error; @@ -395,7 +393,7 @@ public ProtocolNegotiator newNegotiator(ObjectPool offloadEx * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator serverTls(final SslContext sslContext, - final ObjectPool executorPool) { + final ObjectPool executorPool) { Preconditions.checkNotNull(sslContext, "sslContext"); final Executor executor; if (executorPool != null) { @@ -444,8 +442,8 @@ static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; ServerTlsHandler(ChannelHandler next, - SslContext sslContext, - final ObjectPool executorPool) { + SslContext sslContext, + final ObjectPool executorPool) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.next = Preconditions.checkNotNull(next, "next"); if (executorPool != null) { @@ -475,7 +473,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); if (!sslContext.applicationProtocolNegotiator().protocols().contains( - sslHandler.applicationProtocol())) { + sslHandler.applicationProtocol())) { logSslEngineDetails(Level.FINE, ctx, "TLS negotiation failed for new client.", null); ctx.fireExceptionCaught(unavailableException( "Failed protocol negotiation: Unable to find compatible protocol")); @@ -502,8 +500,8 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, final @Nullable String proxyPassword, - final ProtocolNegotiator negotiator) { + final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final ProtocolNegotiator negotiator) { Preconditions.checkNotNull(negotiator, "negotiator"); Preconditions.checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); @@ -580,8 +578,10 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni) { + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, + String sni) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { @@ -608,9 +608,9 @@ public AsciiString scheme() { public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); - ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, grpcHandler.getAuthority(), - this.executor, negotiationLogger, handshakeCompleteRunnable, - x509ExtendedTrustManager, sni); + ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, + sni != null? sni : grpcHandler.getAuthority(), + this.executor, negotiationLogger, handshakeCompleteRunnable, x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -635,17 +635,15 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private Executor executor; private final Optional handshakeCompleteRunnable; private final X509TrustManager x509ExtendedTrustManager; - private final String sni; private SSLEngine sslEngine; - ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, + ClientTlsHandler(ChannelHandler next, SslContext sslContext, String sni, Executor executor, ChannelLogger negotiationLogger, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni) { + X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); - this.sni = sni; - HostPort hostPort = parseAuthority(authority); + HostPort hostPort = parseAuthority(sni); this.host = hostPort.host; this.port = hostPort.port; this.executor = executor; @@ -656,7 +654,7 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { @Override @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - sslEngine = sslContext.newEngine(ctx.alloc(), sni != null? sni : host, port); + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); @@ -749,11 +747,9 @@ static HostPort parseAuthority(String authority) { /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will - * be negotiated, the {@code handler} is added and writes to the {@link Channel} + * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. - * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks - * @param sni the SNI value to use in the Tls handshake */ public static ProtocolNegotiator tls(SslContext sslContext, ObjectPool executorPool, Optional handshakeCompleteRunnable, @@ -764,16 +760,16 @@ public static ProtocolNegotiator tls(SslContext sslContext, /** * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will - * be negotiated, the {@code handler} is added and writes to the {@link Channel} + * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ public static ProtocolNegotiator tls(SslContext sslContext, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager) { return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null); } public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager) { return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager); } @@ -911,8 +907,8 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc } /** - * Returns a {@link ChannelHandler} that ensures that the {@code handler} is - * added to the pipeline writes to the {@link Channel} may happen immediately, + * Returns a {@link io.netty.channel.ChannelHandler} that ensures that the {@code handler} is + * added to the pipeline writes to the {@link io.netty.channel.Channel} may happen immediately, * even before it is active. */ public static ProtocolNegotiator plaintext() { @@ -941,7 +937,7 @@ private static RuntimeException unavailableException(String msg) { @VisibleForTesting static void logSslEngineDetails(Level level, ChannelHandlerContext ctx, String msg, - @Nullable Throwable t) { + @Nullable Throwable t) { if (!log.isLoggable(level)) { return; } @@ -1067,8 +1063,8 @@ static final class PlaintextHandler extends ProtocolNegotiationHandler { protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); Attributes attrs = existingPne.getAttributes().toBuilder() - .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) - .build(); + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) + .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs)); fireProtocolNegotiationEvent(ctx); } @@ -1130,7 +1126,7 @@ static class ProtocolNegotiationHandler extends ChannelDuplexHandler { private final ChannelLogger negotiationLogger; protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = negotiatorName; this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); @@ -1228,4 +1224,4 @@ public String getPeerHost() { return peerHost; } } -} +} \ No newline at end of file diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index f04acb3a42b..1d1c5990818 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -149,6 +149,7 @@ public class NettyClientTransportTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String SNI = "sni"; private static final SslContext SSL_CONTEXT = createSslContext(); @Mock @@ -836,7 +837,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .keyManager(clientCert, clientKey) .build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, - Optional.absent(), null, sni); + Optional.absent(), null, SNI); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index ce74f9f33ef..027c8479952 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -59,7 +59,6 @@ import io.grpc.xds.orca.OrcaPerRequestUtil.OrcaPerRequestReportListener; import java.util.ArrayList; import java.util.Collections; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -147,14 +146,14 @@ public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { childLbHelper.updateDropPolicies(config.dropCategories); childLbHelper.updateMaxConcurrentRequests(config.maxConcurrentRequests); - childLbHelper.updateSslContext(config.tlsContext); + childLbHelper.updateSslContextProviderSupplier(config.tlsContext); childLbHelper.updateFilterMetadata(config.filterMetadata); childSwitchLb.handleResolvedAddresses( resolvedAddresses.toBuilder() .setAttributes(attributes.toBuilder() - .set(NameResolver.ATTR_BACKEND_SERVICE, cluster) - .build()) + .set(NameResolver.ATTR_BACKEND_SERVICE, cluster) + .build()) .setLoadBalancingPolicyConfig(config.childConfig) .build()); return Status.OK; @@ -185,7 +184,7 @@ public void shutdown() { if (childSwitchLb != null) { childSwitchLb.shutdown(); if (childLbHelper != null) { - childLbHelper.updateSslContext(null); + childLbHelper.updateSslContextProviderSupplier(null); childLbHelper = null; } } @@ -205,11 +204,10 @@ private final class ClusterImplLbHelper extends ForwardingLoadBalancerHelper { private List dropPolicies = Collections.emptyList(); private long maxConcurrentRequests = DEFAULT_PER_CLUSTER_MAX_CONCURRENT_REQUESTS; @Nullable - private UpstreamTlsContext tlsContext; + private SslContextProviderSupplier sslContextProviderSupplier; private Map filterMetadata = ImmutableMap.of(); @Nullable private final ServerInfo lrsServerInfo; - private final Map sslContextProviderSupplierMap = new HashMap<>(); private ClusterImplLbHelper(AtomicLong inFlights, @Nullable ServerInfo lrsServerInfo) { this.inFlights = checkNotNull(inFlights, "inFlights"); @@ -295,17 +293,10 @@ private List withAdditionalAttributes( for (EquivalentAddressGroup eag : addresses) { Attributes.Builder attrBuilder = eag.getAttributes().toBuilder().set( XdsAttributes.ATTR_CLUSTER_NAME, cluster); - if (tlsContext != null) { - String addressNameAttr = eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME); - if (!sslContextProviderSupplierMap.containsKey(addressNameAttr)) { - sslContextProviderSupplierMap.put(addressNameAttr, - new SslContextProviderSupplier(tlsContext, - (TlsContextManager) xdsClient.getSecurityConfig(), - eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME))); - } + if (sslContextProviderSupplier != null) { attrBuilder.set( SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - sslContextProviderSupplierMap.get(addressNameAttr)); + sslContextProviderSupplier); } newAddresses.add(new EquivalentAddressGroup(eag.getAddresses(), attrBuilder.build())); } @@ -329,7 +320,7 @@ private ClusterLocality createClusterLocalityFromAttributes(Attributes addressAt (lrsServerInfo == null) ? null : xdsClient.addClusterLocalityStats(lrsServerInfo, cluster, - edsServiceName, locality); + edsServiceName, locality); return new ClusterLocality(localityStats, localityName); } @@ -357,11 +348,22 @@ private void updateMaxConcurrentRequests(@Nullable Long maxConcurrentRequests) { updateBalancingState(currentState, currentPicker); } - private void updateSslContext(@Nullable UpstreamTlsContext tlsContext) { - if (Objects.equals(this.tlsContext, tlsContext)) { + private void updateSslContextProviderSupplier(@Nullable UpstreamTlsContext tlsContext) { + UpstreamTlsContext currentTlsContext = + sslContextProviderSupplier != null + ? (UpstreamTlsContext)sslContextProviderSupplier.getTlsContext() + : null; + if (Objects.equals(currentTlsContext, tlsContext)) { return; } - this.tlsContext = tlsContext; + if (sslContextProviderSupplier != null) { + sslContextProviderSupplier.close(); + } + sslContextProviderSupplier = + tlsContext != null + ? new SslContextProviderSupplier(tlsContext, + (TlsContextManager) xdsClient.getSecurityConfig()) + : null; } private void updateFilterMetadata(Map filterMetadata) { @@ -375,8 +377,8 @@ private class RequestLimitingSubchannelPicker extends SubchannelPicker { private final Map filterMetadata; private RequestLimitingSubchannelPicker(SubchannelPicker delegate, - List dropPolicies, long maxConcurrentRequests, - Map filterMetadata) { + List dropPolicies, long maxConcurrentRequests, + Map filterMetadata) { this.delegate = delegate; this.dropPolicies = dropPolicies; this.maxConcurrentRequests = maxConcurrentRequests; @@ -540,4 +542,4 @@ void release() { } } } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 461f31066f0..ed2b841e983 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -76,6 +76,13 @@ public static final class UpstreamTlsContext extends BaseTlsContext { private final String sni; private final boolean auto_host_sni; + @VisibleForTesting + public UpstreamTlsContext(CommonTlsContext commonTlsContext) { + super(commonTlsContext); + this.sni = null; + this.auto_host_sni = false; + } + @VisibleForTesting public UpstreamTlsContext(io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { super(upstreamTlsContext.getCommonTlsContext()); @@ -118,7 +125,7 @@ public static DownstreamTlsContext fromEnvoyProtoDownstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext downstreamTlsContext) { return new DownstreamTlsContext(downstreamTlsContext.getCommonTlsContext(), - downstreamTlsContext.hasRequireClientCertificate()); + downstreamTlsContext.hasRequireClientCertificate()); } public boolean isRequireClientCertificate() { @@ -204,10 +211,10 @@ abstract static class FilterChainMatch { abstract String transportProtocol(); public static FilterChainMatch create(int destinationPort, - ImmutableList prefixRanges, - ImmutableList applicationProtocols, ImmutableList sourcePrefixRanges, - ConnectionSourceType connectionSourceType, ImmutableList sourcePorts, - ImmutableList serverNames, String transportProtocol) { + ImmutableList prefixRanges, + ImmutableList applicationProtocols, ImmutableList sourcePrefixRanges, + ConnectionSourceType connectionSourceType, ImmutableList sourcePorts, + ImmutableList serverNames, String transportProtocol) { return new AutoValue_EnvoyServerProtoData_FilterChainMatch( destinationPort, prefixRanges, applicationProtocols, sourcePrefixRanges, connectionSourceType, sourcePorts, serverNames, transportProtocol); @@ -419,4 +426,4 @@ static FailurePercentageEjection create( enforcementPercentage, minimumHosts, requestVolume); } } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/TlsContextManager.java b/xds/src/main/java/io/grpc/xds/TlsContextManager.java index 772a6cff102..4e9470d91e6 100644 --- a/xds/src/main/java/io/grpc/xds/TlsContextManager.java +++ b/xds/src/main/java/io/grpc/xds/TlsContextManager.java @@ -30,7 +30,7 @@ SslContextProvider findOrCreateServerSslContextProvider( /** Creates a SslContextProvider. Used for retrieving a client-side SslContext. */ SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext); + UpstreamTlsContext upstreamTlsContext, String sni); /** * Releases an instance of the given client-side {@link SslContextProvider}. @@ -41,7 +41,8 @@ SslContextProvider findOrCreateClientSslContextProvider( *

Caller must not release a reference more than once. It's advised that you clear the * reference to the instance with the null returned by this method. */ - SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider); + SslContextProvider releaseClientSslContextProvider(SslContextProvider sslContextProvider, + String sni); /** * Releases an instance of the given server-side {@link SslContextProvider}. diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index be0e81cc672..8362c7a8cb0 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -60,14 +60,14 @@ private SecurityProtocolNegotiators() { private static final AsciiString SCHEME = AsciiString.of("http"); public static final Attributes.Key - ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); + ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER = + Attributes.Key.create("io.grpc.xds.internal.security.server.sslContextProviderSupplier"); /** Attribute key for SslContextProviderSupplier (used from client) for a subchannel. */ @Grpc.TransportAttr public static final Attributes.Key ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER = - Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); + Attributes.Key.create("io.grpc.xds.internal.security.SslContextProviderSupplier"); /** * Returns a {@link InternalProtocolNegotiator.ClientFactory}. @@ -222,7 +222,7 @@ public void updateSslContext(SslContext sslContext, String sni) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext, sni).newHandler(grpcHandler); + InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -289,10 +289,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc if (evt instanceof ProtocolNegotiationEvent) { ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent)evt; SslContextProviderSupplier sslContextProviderSupplier = InternalProtocolNegotiationEvent - .getAttributes(pne).get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER); + .getAttributes(pne).get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER); if (sslContextProviderSupplier == null) { logger.log(Level.FINE, "No sslContextProviderSupplier found in filterChainMatch " - + "for connection from {0} to {1}", + + "for connection from {0} to {1}", new Object[]{ctx.channel().remoteAddress(), ctx.channel().localAddress()}); if (fallbackProtocolNegotiator == null) { ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); @@ -325,13 +325,13 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc @VisibleForTesting static final class ServerSecurityHandler - extends InternalProtocolNegotiators.ProtocolNegotiationHandler { + extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; ServerSecurityHandler( - GrpcHttp2ConnectionHandler grpcHandler, - SslContextProviderSupplier sslContextProviderSupplier) { + GrpcHttp2ConnectionHandler grpcHandler, + SslContextProviderSupplier sslContextProviderSupplier) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -376,4 +376,4 @@ public void onException(Throwable throwable) { ); } } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index 214fddce633..b61ad1cd815 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -44,19 +44,37 @@ public abstract class SslContextProvider implements Closeable { protected final BaseTlsContext tlsContext; - private String sni; @VisibleForTesting public abstract static class Callback { private final Executor executor; + private final String hostname; + private final boolean isClientSide; protected Callback(Executor executor) { this.executor = executor; + this.hostname = null; + this.isClientSide = false; + } + + // Only for client SslContextProvider. + protected Callback(Executor executor, String hostname) { + this.executor = executor; + this.hostname = hostname; + this.isClientSide = true; } @VisibleForTesting public Executor getExecutor() { return executor; } + protected String getHostname() { + return hostname; + } + + public boolean isClientSide() { + return isClientSide; + } + /** Informs callee of new/updated SslContext. */ @VisibleForTesting public abstract void updateSslContext(SslContext sslContext, String sni); @@ -112,7 +130,7 @@ public UpstreamTlsContext getUpstreamTlsContext() { public abstract void addCallback(Callback callback); protected final void performCallback( - final SslContextGetter sslContextGetter, final Callback callback) { + final SslContextGetter sslContextGetter, final Callback callback) { checkNotNull(sslContextGetter, "sslContextGetter"); checkNotNull(callback, "callback"); callback.executor.execute( @@ -121,7 +139,7 @@ protected final void performCallback( public void run() { try { SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext, sni); + callback.updateSslContext(sslContext, callback.getHostname()); } catch (Throwable e) { callback.onException(e); } @@ -133,4 +151,4 @@ public void run() { protected interface SslContextGetter { SslContext get() throws Exception; } -} +} \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 69486dbd546..e92dc3e0c8d 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -27,7 +27,9 @@ import io.grpc.xds.TlsContextManager; import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProvider; import io.netty.handler.ssl.SslContext; +import java.util.HashSet; import java.util.Objects; +import java.util.Set; import javax.net.ssl.SSLException; /** @@ -41,20 +43,14 @@ public final class SslContextProviderSupplier implements Closeable { private final BaseTlsContext tlsContext; private final TlsContextManager tlsContextManager; - private final String hostname; + private final Set snisSentByClients = new HashSet<>(); private SslContextProvider sslContextProvider; private boolean shutdown; public SslContextProviderSupplier( BaseTlsContext tlsContext, TlsContextManager tlsContextManager) { - this(tlsContext, tlsContextManager, null); - } - - public SslContextProviderSupplier( - BaseTlsContext tlsContext, TlsContextManager tlsContextManager, String hostname) { this.tlsContext = checkNotNull(tlsContext, "tlsContext"); this.tlsContextManager = checkNotNull(tlsContextManager, "tlsContextManager"); - this.hostname = hostname; } public BaseTlsContext getTlsContext() { @@ -65,21 +61,26 @@ public BaseTlsContext getTlsContext() { public synchronized void updateSslContext(final SslContextProvider.Callback callback) { checkNotNull(callback, "callback"); try { + String sni; + if (callback.isClientSide()) { + UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); + sni = upstreamTlsContext.getAutoHostSni() ? callback.getHostname() : upstreamTlsContext.getSni(); + } else { + sni = null; + } if (!shutdown) { if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(); + sslContextProvider = getSslContextProvider(sni); } } - UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); - String sni = upstreamTlsContext.getAutoHostSni() ? hostname : upstreamTlsContext.getSni(); // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); + final SslContextProvider toRelease = getSslContextProvider(sni); if (toRelease instanceof CertProviderClientSslContextProvider && ((CertProviderClientSslContextProvider) toRelease).isUsingSystemRootCerts()) { callback.getExecutor().execute(() -> { try { callback.updateSslContext(GrpcSslContexts.forClient().build(), sni); - releaseSslContextProvider(toRelease); + releaseSslContextProvider(toRelease, sni); } catch (SSLException e) { callback.onException(e); } @@ -91,13 +92,13 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call @Override public void updateSslContext(SslContext sslContext, String sni) { callback.updateSslContext(sslContext, sni); - releaseSslContextProvider(toRelease); + releaseSslContextProvider(toRelease, sni); } @Override public void onException(Throwable throwable) { callback.onException(throwable); - releaseSslContextProvider(toRelease); + releaseSslContextProvider(toRelease, sni); } }); }; @@ -111,18 +112,21 @@ public void run() { } } - private void releaseSslContextProvider(SslContextProvider toRelease) { + private void releaseSslContextProvider(SslContextProvider toRelease, String sni) { if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(toRelease); + tlsContextManager.releaseClientSslContextProvider(toRelease, sni); + snisSentByClients.remove(sni); } else { tlsContextManager.releaseServerSslContextProvider(toRelease); } } - private SslContextProvider getSslContextProvider() { - return tlsContext instanceof UpstreamTlsContext - ? tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext) - : tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); + private SslContextProvider getSslContextProvider(String sni) { + if (tlsContext instanceof UpstreamTlsContext) { + snisSentByClients.add(sni); + return tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext, sni); + } + return tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); } @VisibleForTesting public boolean isShutdown() { @@ -134,7 +138,9 @@ private SslContextProvider getSslContextProvider() { public synchronized void close() { if (sslContextProvider != null) { if (tlsContext instanceof UpstreamTlsContext) { - tlsContextManager.releaseClientSslContextProvider(sslContextProvider); + for (String sni: snisSentByClients) { + tlsContextManager.releaseClientSslContextProvider(sslContextProvider, sni); + } } else { tlsContextManager.releaseServerSslContextProvider(sslContextProvider); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java index 34a8863c52b..560c28077da 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java @@ -25,6 +25,7 @@ import io.grpc.xds.TlsContextManager; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; +import java.util.AbstractMap; /** * Class to manage {@link SslContextProvider} objects created from inputs we get from xDS. Used by @@ -34,7 +35,7 @@ */ public final class TlsContextManagerImpl implements TlsContextManager { - private final ReferenceCountingMap mapForClients; + private final ReferenceCountingMap, SslContextProvider> mapForClients; private final ReferenceCountingMap mapForServers; /** @@ -48,7 +49,7 @@ public final class TlsContextManagerImpl implements TlsContextManager { @VisibleForTesting TlsContextManagerImpl( - ValueFactory clientFactory, + ValueFactory, SslContextProvider> clientFactory, ValueFactory serverFactory) { checkNotNull(clientFactory, "clientFactory"); checkNotNull(serverFactory, "serverFactory"); @@ -69,18 +70,17 @@ public SslContextProvider findOrCreateServerSslContextProvider( @Override public SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { + UpstreamTlsContext upstreamTlsContext, String sni) { checkNotNull(upstreamTlsContext, "upstreamTlsContext"); - CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); - upstreamTlsContext = new UpstreamTlsContext(builder.build()); - return mapForClients.get(upstreamTlsContext); + return mapForClients.get(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, sni)); } @Override public SslContextProvider releaseClientSslContextProvider( - SslContextProvider clientSslContextProvider) { + SslContextProvider clientSslContextProvider, String sni) { checkNotNull(clientSslContextProvider, "clientSslContextProvider"); - return mapForClients.release(clientSslContextProvider.getUpstreamTlsContext(), + return mapForClients.release( + new AbstractMap.SimpleImmutableEntry<>(clientSslContextProvider.getUpstreamTlsContext(), sni), clientSslContextProvider); } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 7df0630b779..aaf7b7f9cc8 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -1259,7 +1259,7 @@ public Map getServerLrsClientMap() { private static final class FakeTlsContextManager implements TlsContextManager { @Override public SslContextProvider findOrCreateClientSslContextProvider( - UpstreamTlsContext upstreamTlsContext) { + UpstreamTlsContext upstreamTlsContext, String sni) { SslContextProvider sslContextProvider = mock(SslContextProvider.class); when(sslContextProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext); return sslContextProvider; @@ -1267,7 +1267,7 @@ public SslContextProvider findOrCreateClientSslContextProvider( @Override public SslContextProvider releaseClientSslContextProvider( - SslContextProvider sslContextProvider) { + SslContextProvider sslContextProvider, String sni) { // no-op return null; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 397fe01e0f5..6a3140fdf24 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -17,6 +17,7 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.addCertificateValidationContext; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; @@ -24,10 +25,12 @@ import com.google.common.collect.ImmutableSet; import io.envoyproxy.envoy.config.core.v3.DataSource; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.CommonBootstrapperTestUtils; @@ -38,6 +41,7 @@ import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; +import java.util.AbstractMap; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -49,6 +53,8 @@ @RunWith(JUnit4.class) public class ClientSslContextProviderFactoryTest { + static final String SNI = "sni"; + CertificateProviderRegistry certificateProviderRegistry; CertificateProviderStore certificateProviderStore; CertProviderClientSslContextProviderFactory certProviderClientSslContextProviderFactory; @@ -81,13 +87,13 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); // verify that bootstrapInfo is cached... sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); } @@ -98,25 +104,23 @@ public void bothPresent_expectCertProviderClientSslContextProvider() final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; createAndRegisterProviderProvider(certificateProviderRegistry, watcherCaptor, "testca", 0); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - "gcp_id", - "cert-default", - "gcp_id", - "root-default", - /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); - - CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); + CommonTlsContext.Builder builder = CommonTlsContext.newBuilder() + .setTlsCertificateProviderInstance( + CertificateProviderPluginInstance.newBuilder() + .setInstanceName("gcp_id") + .setCertificateName("cert-default")); + builder = + addCertificateValidationContext( + builder, "gcp_id", "root-default", null); builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem"); - upstreamTlsContext = new UpstreamTlsContext(builder.build()); - + UpstreamTlsContext upstreamTlsContext = new UpstreamTlsContext(builder.build()); + Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); @@ -142,7 +146,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); @@ -176,7 +180,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() new ClientSslContextProviderFactory(bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); @@ -206,7 +210,7 @@ public void createCertProviderClientSslContextProvider_2providers() new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); @@ -243,7 +247,7 @@ public void createNewCertProviderClientSslContextProvider_withSans() { new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); @@ -277,7 +281,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(upstreamTlsContext); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0]); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index a991e96b6db..d25d68d4959 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -224,7 +224,7 @@ private static CommonTlsContext buildNewCommonTlsContextForCertProviderInstance( return builder.build(); } - private static CommonTlsContext.Builder addCertificateValidationContext( + public static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index a14a3eac695..2437023c1d2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -25,6 +25,7 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; @@ -46,6 +47,8 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + private static final String HOSTNAME = "hostname"; + @Mock private TlsContextManager mockTlsContextManager; private SslContextProviderSupplier supplier; private SslContextProvider mockSslContextProvider; @@ -58,12 +61,14 @@ private void prepareSupplier() { mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); + when(mockCallback.isClientSide()).thenReturn(true); + when(mockCallback.getHostname()).thenReturn(HOSTNAME); Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); supplier.updateSslContext(mockCallback); @@ -74,23 +79,23 @@ public void get_updateSecret() { prepareSupplier(); callUpdateSslContext(); verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(HOSTNAME)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext, sni); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext), sni); + capturedCallback.updateSslContext(mockSslContext, HOSTNAME); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext), HOSTNAME); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); supplier.updateSslContext(mockCallback); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); } @Test @@ -106,7 +111,7 @@ public void get_onException() { capturedCallback.onException(exception); verify(mockCallback, times(1)).onException(eq(exception)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); } @Test @@ -115,24 +120,24 @@ public void testClose() { callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); supplier.updateSslContext(mockCallback); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(any(SslContextProvider.class)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(HOSTNAME)); } @Test public void testClose_nullSslContextProvider() { prepareSupplier(); doThrow(new NullPointerException()).when(mockTlsContextManager) - .releaseClientSslContextProvider(null); + .releaseClientSslContextProvider(null, HOSTNAME); supplier.close(); verify(mockTlsContextManager, never()) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); callUpdateSslContext(); verify(mockTlsContextManager, times(1)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 035096a3528..2cc2c940812 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -35,6 +35,7 @@ import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.CommonBootstrapperTestUtils; import io.grpc.xds.internal.security.ReferenceCountingMap.ValueFactory; +import java.util.AbstractMap; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -49,7 +50,9 @@ public class TlsContextManagerTest { @Rule public final MockitoRule mockitoRule = MockitoJUnit.rule(); - @Mock ValueFactory mockClientFactory; + private static final String SNI = "sni"; + + @Mock ValueFactory, SslContextProvider> mockClientFactory; @Mock ValueFactory mockServerFactory; @@ -83,11 +86,11 @@ public void createClientSslContextProvider() { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient); SslContextProvider clientSecretProvider = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider).isNotNull(); SslContextProvider clientSecretProvider1 = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider1).isSameInstanceAs(clientSecretProvider); } @@ -127,14 +130,14 @@ public void createClientSslContextProvider_differentInstance() { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient); SslContextProvider clientSecretProvider = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider).isNotNull(); UpstreamTlsContext upstreamTlsContext1 = CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-2", true); SslContextProvider clientSecretProvider1 = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1, SNI); assertThat(clientSecretProvider1).isNotSameInstanceAs(clientSecretProvider); } @@ -166,13 +169,14 @@ public void createClientSslContextProvider_releaseInstance() { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); SslContextProvider mockProvider = mock(SslContextProvider.class); - when(mockClientFactory.create(upstreamTlsContext)).thenReturn(mockProvider); + when(mockClientFactory.create(new AbstractMap.SimpleImmutableEntry("sni", upstreamTlsContext))) + .thenReturn(mockProvider); SslContextProvider clientSecretProvider = - tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext); + tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); assertThat(clientSecretProvider).isSameInstanceAs(mockProvider); verify(mockProvider, never()).close(); when(mockProvider.getUpstreamTlsContext()).thenReturn(upstreamTlsContext); - tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider); + tlsContextManagerImpl.releaseClientSslContextProvider(mockProvider, SNI); verify(mockProvider, times(1)).close(); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index e9521a5a308..d5102219e51 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -51,6 +51,7 @@ /** Unit tests for {@link CertProviderClientSslContextProvider}. */ @RunWith(JUnit4.class) public class CertProviderClientSslContextProviderTest { + private static final String SNI = "sni"; private static final Logger logger = Logger.getLogger(CertProviderClientSslContextProviderTest.class.getName()); @@ -84,7 +85,7 @@ private CertProviderClientSslContextProvider getSslContextProvider( return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, - key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), + SNI, bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } @@ -106,7 +107,7 @@ private CertProviderClientSslContextProvider getNewSslContextProvider( return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, - key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), + SNI, bootstrapInfo.node().toEnvoyProtoNode(), bootstrapInfo.certProviders()); } From f12bc61a96a1d934212d0eca0a61ad8268d52ec5 Mon Sep 17 00:00:00 2001 From: kannanjgithub Date: Fri, 15 Aug 2025 21:48:51 +0530 Subject: [PATCH 06/51] save changed --- .../io/grpc/xds/ClusterImplLoadBalancer.java | 6 +-- .../grpc/xds/ClusterResolverLoadBalancer.java | 5 ++- .../main/java/io/grpc/xds/XdsAttributes.java | 5 --- .../security/SecurityProtocolNegotiators.java | 16 +++++-- .../internal/security/SslContextProvider.java | 7 --- .../security/SslContextProviderSupplier.java | 44 +++++++------------ .../security/trust/XdsX509TrustManager.java | 18 +++++++- .../grpc/xds/ClusterImplLoadBalancerTest.java | 10 ++--- .../xds/ClusterResolverLoadBalancerTest.java | 7 +-- .../SecurityProtocolNegotiatorsTest.java | 8 ++-- .../SslContextProviderSupplierTest.java | 7 ++- 11 files changed, 70 insertions(+), 63 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 027c8479952..c5491a92bed 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -241,9 +241,9 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference); if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) { String hostname = args.getAddresses().get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME); + .get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME); if (hostname != null) { - attrsBuilder.set(XdsAttributes.ATTR_ADDRESS_NAME, hostname); + attrsBuilder.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, hostname); } } args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); @@ -438,7 +438,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { result = PickResult.withSubchannel(result.getSubchannel(), result.getStreamTracerFactory(), result.getSubchannel().getAttributes().get( - XdsAttributes.ATTR_ADDRESS_NAME)); + SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)); } } return result; diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index 080760303bf..6ef0020287c 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -61,6 +61,7 @@ import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; @@ -432,7 +433,7 @@ public void run() { .set(XdsAttributes.ATTR_LOCALITY_WEIGHT, localityLbInfo.localityWeight()) .set(XdsAttributes.ATTR_SERVER_WEIGHT, weight) - .set(XdsAttributes.ATTR_ADDRESS_NAME, endpoint.hostname()) + .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, endpoint.hostname()) .build(); EquivalentAddressGroup eag; @@ -680,7 +681,7 @@ public Status onResult2(final ResolutionResult resolutionResult) { Attributes attr = eag.getAttributes().toBuilder() .set(XdsAttributes.ATTR_LOCALITY, LOGICAL_DNS_CLUSTER_LOCALITY) .set(XdsAttributes.ATTR_LOCALITY_NAME, localityName) - .set(XdsAttributes.ATTR_ADDRESS_NAME, dnsHostName) + .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, dnsHostName) .build(); eag = new EquivalentAddressGroup(eag.getAddresses(), attr); eag = AddressFilter.setPathFilter(eag, Arrays.asList(priorityName, localityName)); diff --git a/xds/src/main/java/io/grpc/xds/XdsAttributes.java b/xds/src/main/java/io/grpc/xds/XdsAttributes.java index 4a64fdb1453..42f5ad90461 100644 --- a/xds/src/main/java/io/grpc/xds/XdsAttributes.java +++ b/xds/src/main/java/io/grpc/xds/XdsAttributes.java @@ -95,11 +95,6 @@ final class XdsAttributes { static final Attributes.Key ATTR_SERVER_WEIGHT = Attributes.Key.create("io.grpc.xds.XdsAttributes.serverWeight"); - /** Name associated with individual address, if available (e.g., DNS name). */ - @EquivalentAddressGroup.Attr - static final Attributes.Key ATTR_ADDRESS_NAME = - Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); - /** * Filter chain match for network filters. */ diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 8362c7a8cb0..3f604cd5be1 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -20,6 +20,7 @@ import com.google.common.annotations.VisibleForTesting; import io.grpc.Attributes; +import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -50,6 +51,11 @@ @VisibleForTesting public final class SecurityProtocolNegotiators { + /** Name associated with individual address, if available (e.g., DNS name). */ + @EquivalentAddressGroup.Attr + public static final Attributes.Key ATTR_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); + // Prevent instantiation. private SecurityProtocolNegotiators() { } @@ -142,7 +148,8 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { fallbackProtocolNegotiator, "No TLS config and no fallbackProtocolNegotiator!"); return fallbackProtocolNegotiator.newHandler(grpcHandler); } - return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier); + return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier, + grpcHandler.getEagAttributes().get(ATTR_ADDRESS_NAME)); } @Override @@ -185,10 +192,12 @@ static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; + private final String hostname; ClientSecurityHandler( GrpcHttp2ConnectionHandler grpcHandler, - SslContextProviderSupplier sslContextProviderSupplier) { + SslContextProviderSupplier sslContextProviderSupplier, + String hostname) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -202,6 +211,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; this.sslContextProviderSupplier = sslContextProviderSupplier; + this.hostname = hostname; } @Override @@ -210,7 +220,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { ctx.pipeline().addBefore(ctx.name(), null, bufferReads); sslContextProviderSupplier.updateSslContext( - new SslContextProvider.Callback(ctx.executor()) { + new SslContextProvider.Callback(ctx.executor(), hostname) { @Override public void updateSslContext(SslContext sslContext, String sni) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index b61ad1cd815..970551d5611 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -48,19 +48,16 @@ public abstract class SslContextProvider implements Closeable { @VisibleForTesting public abstract static class Callback { private final Executor executor; private final String hostname; - private final boolean isClientSide; protected Callback(Executor executor) { this.executor = executor; this.hostname = null; - this.isClientSide = false; } // Only for client SslContextProvider. protected Callback(Executor executor, String hostname) { this.executor = executor; this.hostname = hostname; - this.isClientSide = true; } @VisibleForTesting public Executor getExecutor() { @@ -71,10 +68,6 @@ protected String getHostname() { return hostname; } - public boolean isClientSide() { - return isClientSide; - } - /** Informs callee of new/updated SslContext. */ @VisibleForTesting public abstract void updateSslContext(SslContext sslContext, String sni); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index e92dc3e0c8d..33c224caa3a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -62,7 +62,7 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call checkNotNull(callback, "callback"); try { String sni; - if (callback.isClientSide()) { + if (tlsContext instanceof UpstreamTlsContext) { UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); sni = upstreamTlsContext.getAutoHostSni() ? callback.getHostname() : upstreamTlsContext.getSni(); } else { @@ -75,33 +75,21 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(sni); - if (toRelease instanceof CertProviderClientSslContextProvider - && ((CertProviderClientSslContextProvider) toRelease).isUsingSystemRootCerts()) { - callback.getExecutor().execute(() -> { - try { - callback.updateSslContext(GrpcSslContexts.forClient().build(), sni); - releaseSslContextProvider(toRelease, sni); - } catch (SSLException e) { - callback.onException(e); - } - }); - } else { - toRelease.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSslContext(SslContext sslContext, String sni) { - callback.updateSslContext(sslContext, sni); - releaseSslContextProvider(toRelease, sni); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease, sni); - } - }); - }; + toRelease.addCallback( + new SslContextProvider.Callback(callback.getExecutor()) { + + @Override + public void updateSslContext(SslContext sslContext, String sni) { + callback.updateSslContext(sslContext, sni); + releaseSslContextProvider(toRelease, sni); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease, sni); + } + }); } catch (final Throwable throwable) { callback.getExecutor().execute(new Runnable() { @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 1ecfe378d29..390c131aa92 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Strings; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.re2j.Pattern; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; @@ -60,21 +61,34 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 private final X509ExtendedTrustManager delegate; private final Map spiffeTrustMapDelegates; private final CertificateValidationContext certContext; + private final String sni; XdsX509TrustManager(@Nullable CertificateValidationContext certContext, X509ExtendedTrustManager delegate) { + this(certContext, delegate, null); + } + + XdsX509TrustManager(@Nullable CertificateValidationContext certContext, + X509ExtendedTrustManager delegate, @Nullable String sni) { checkNotNull(delegate, "delegate"); this.certContext = certContext; this.delegate = delegate; this.spiffeTrustMapDelegates = null; + this.sni = sni; + } + + XdsX509TrustManager(@Nullable CertificateValidationContext certContext, + Map spiffeTrustMapDelegates) { + this(certContext, spiffeTrustMapDelegates, null); } XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - Map spiffeTrustMapDelegates) { + Map spiffeTrustMapDelegates, @Nullable String sni) { checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); this.certContext = certContext; this.delegate = null; + this.sni = sni; } private static boolean verifyDnsNameInPattern( @@ -208,7 +222,7 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi return; } @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names - List verifyList = certContext.getMatchSubjectAltNamesList(); + List verifyList = sni != null? ImmutableList.of(StringMatcher.newBuilder().setExact(sni).build()) : certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { return; } diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index aaf7b7f9cc8..b6600d1890a 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -811,10 +811,10 @@ public void endpointAddressesAttachedWithClusterName() { new FixedResultPicker(PickResult.withSubchannel(subchannel))); } }); - assertThat(subchannel.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo( + assertThat(subchannel.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isEqualTo( "authority-host-name"); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)) + assertThat(eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)) .isEqualTo("authority-host-name"); } @@ -863,9 +863,9 @@ public void endpointAddressesAttachedWithClusterName() { } }); // Sub Channel wrapper args won't have the address name although addresses will. - assertThat(subchannel.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)).isNull(); + assertThat(subchannel.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isNull(); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(XdsAttributes.ATTR_ADDRESS_NAME)) + assertThat(eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)) .isEqualTo("authority-host-name"); } @@ -1019,7 +1019,7 @@ public String toString() { // Unique but arbitrary string .set(XdsAttributes.ATTR_LOCALITY_NAME, locality.toString()); if (authorityHostname != null) { - attributes.set(XdsAttributes.ATTR_ADDRESS_NAME, authorityHostname); + attributes.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, authorityHostname); } EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), attributes.build()); diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index d701f281c01..cc850ae9a08 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -84,6 +84,7 @@ import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsResourceType; import io.grpc.xds.internal.security.CommonTlsContextTestsUtil; +import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.net.URI; @@ -378,7 +379,7 @@ public void edsClustersEndpointHostname_addedToAddressAttribute() { assertThat( childBalancer.addresses.get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); + .get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); } @Test @@ -864,9 +865,9 @@ void do_onlyLogicalDnsCluster_endpointsResolved() { Collections.emptyList(), "pick_first"); assertAddressesEqual(Arrays.asList(endpoint1, endpoint2), childBalancer.addresses); assertThat(childBalancer.addresses.get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME); + .get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME); assertThat(childBalancer.addresses.get(1).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME); + .get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME); } @Test diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index c588c5ea1ba..616ff27f6d3 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -86,6 +86,8 @@ @RunWith(JUnit4.class) public class SecurityProtocolNegotiatorsTest { + private static final String HOSTNAME = "hostname"; + private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -157,7 +159,7 @@ public void clientSecurityHandler_addLast() new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertNotNull(channelHandlerCtx); @@ -369,7 +371,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); @@ -420,7 +422,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier); + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); pipeline.addLast(clientSecurityHandler); channelHandlerCtx = pipeline.context(clientSecurityHandler); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 2437023c1d2..4500e7630c2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -55,7 +55,7 @@ public class SslContextProviderSupplierTest { private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; private SslContextProvider.Callback mockCallback; - private void prepareSupplier() { + private void prepareSupplier(boolean autoHostSni) { upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); mockSslContextProvider = mock(SslContextProvider.class); @@ -65,9 +65,12 @@ private void prepareSupplier() { supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } + private EnvoyServerProtoData.UpstreamTlsContext getUpstreamTlsContext(boolean autoHostSni) { + + } + private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); - when(mockCallback.isClientSide()).thenReturn(true); when(mockCallback.getHostname()).thenReturn(HOSTNAME); Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); From e9c4e3c65b70029bcbf15a87d8e5fa5c03d5148f Mon Sep 17 00:00:00 2001 From: kannanjgithub Date: Tue, 19 Aug 2025 17:00:55 +0530 Subject: [PATCH 07/51] Save changes. --- .../io/grpc/xds/CdsLoadBalancer2Test.java | 2 +- .../grpc/xds/ClusterImplLoadBalancerTest.java | 4 +-- .../xds/ClusterResolverLoadBalancerTest.java | 2 +- .../grpc/xds/XdsSecurityClientServerTest.java | 4 +-- .../ClientSslContextProviderFactoryTest.java | 9 +++-- .../security/CommonTlsContextTestsUtil.java | 34 +++++++++++-------- .../SecurityProtocolNegotiatorsTest.java | 8 ++--- .../SslContextProviderSupplierTest.java | 14 +++----- .../security/TlsContextManagerTest.java | 8 ++--- ...tProviderClientSslContextProviderTest.java | 2 +- 10 files changed, 44 insertions(+), 43 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 479bde76ce5..8a64a17cb65 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -107,7 +107,7 @@ public class CdsLoadBalancer2Test { .node(BOOTSTRAP_NODE) .build(); private final UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); private final OutlierDetection outlierDetection = OutlierDetection.create( null, null, null, null, SuccessRateEjection.create(null, null, null, null), null); diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index b6600d1890a..77b3830fb42 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -881,7 +881,7 @@ public void endpointAddressesAttachedWithClusterName() { @Test public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); LoadBalancerProvider weightedTargetProvider = new WeightedTargetLoadBalancerProvider(); WeightedTargetConfig weightedTargetConfig = buildWeightedTargetConfig(ImmutableMap.of(locality, 10)); @@ -926,7 +926,7 @@ public void endpointAddressesAttachedWithTlsConfig_securityEnabledByDefault() { // Config with a new UpstreamTlsContext. upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe1", true, null, false); config = new ClusterImplConfig(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, null, Collections.emptyList(), GracefulSwitchLoadBalancer.createLoadBalancingPolicyConfig( diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index cc850ae9a08..b99725568ac 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -134,7 +134,7 @@ public class ClusterResolverLoadBalancerTest { private final Locality locality3 = Locality.create("test-region-3", "test-zone-3", "test-subzone-3"); private final UpstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); private final OutlierDetection outlierDetection = OutlierDetection.create( 100L, 100L, 100L, 100, SuccessRateEjection.create(100, 100, 100, 100), FailurePercentageEjection.create(100, 100, 100, 100)); diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 23068d665bf..051bdfeed22 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -546,7 +546,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, spiffeFile); return CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert, null, false); } private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( @@ -563,7 +563,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build()); + .build(), null, false); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 6a3140fdf24..aef1e79061c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -30,7 +30,6 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.TlsCertificate; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper; import io.grpc.xds.client.CommonBootstrapperTestUtils; @@ -80,7 +79,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -139,7 +138,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -173,7 +172,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() "gcp_id", "root-default", /* alpnProtocols= */ null, - staticCertValidationContext); + staticCertValidationContext, null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -203,7 +202,7 @@ public void createCertProviderClientSslContextProvider_2providers() "file_provider", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index d25d68d4959..9d256397bd8 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -149,23 +149,28 @@ public static String getTempFileNameForResourcesFile(String resFile) throws IOEx * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - CommonTlsContext commonTlsContext) { - UpstreamTlsContext upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).build(); + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni) { + UpstreamTlsContext.Builder upstreamTlsContext = + UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).setAutoHostSni(autoHostSni); + if (sni != null) { + upstreamTlsContext.setSni(sni); + } return EnvoyServerProtoData.UpstreamTlsContext.fromEnvoyProtoUpstreamTlsContext( - upstreamTlsContext); + upstreamTlsContext.build()); } /** Helper method to build UpstreamTlsContext for multiple test classes. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - String commonInstanceName, boolean hasIdentityCert) { + String commonInstanceName, boolean hasIdentityCert, String sni, boolean autoHostSni) { return buildUpstreamTlsContextForCertProviderInstance( hasIdentityCert ? commonInstanceName : null, hasIdentityCert ? "default" : null, commonInstanceName, "ROOT", null, - null); + null, + sni, + autoHostSni); } /** Gets a cert from contents of a resource. */ @@ -271,12 +276,12 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( /** Helper method to build UpstreamTlsContext for CertProvider tests. */ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContextForCertProviderInstance( - @Nullable String certInstanceName, - @Nullable String certName, - @Nullable String rootInstanceName, - @Nullable String rootCertName, - Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { + @Nullable String certInstanceName, + @Nullable String certName, + @Nullable String rootInstanceName, + @Nullable String rootCertName, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, String sni, boolean autoHostSni) { return buildUpstreamTlsContext( buildCommonTlsContextForCertProviderInstance( certInstanceName, @@ -284,7 +289,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), + sni, autoHostSni); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -303,7 +309,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext), null, false); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 616ff27f6d3..64f770aee05 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -124,7 +124,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_noFallback_expectExceptio @Test public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build()); + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, false); ClientSecurityProtocolNegotiator pn = new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); @@ -153,7 +153,7 @@ public void clientSecurityHandler_addLast() CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, @@ -365,7 +365,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, @@ -416,7 +416,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 4500e7630c2..ed4e25963aa 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -57,7 +57,7 @@ public class SslContextProviderSupplierTest { private void prepareSupplier(boolean autoHostSni) { upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, autoHostSni); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) @@ -65,10 +65,6 @@ private void prepareSupplier(boolean autoHostSni) { supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } - private EnvoyServerProtoData.UpstreamTlsContext getUpstreamTlsContext(boolean autoHostSni) { - - } - private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); when(mockCallback.getHostname()).thenReturn(HOSTNAME); @@ -79,7 +75,7 @@ private void callUpdateSslContext() { @Test public void get_updateSecret() { - prepareSupplier(); + prepareSupplier(false); callUpdateSslContext(); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); @@ -103,7 +99,7 @@ public void get_updateSecret() { @Test public void get_onException() { - prepareSupplier(); + prepareSupplier(false); callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); @@ -119,7 +115,7 @@ public void get_onException() { @Test public void testClose() { - prepareSupplier(); + prepareSupplier(false); callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) @@ -133,7 +129,7 @@ public void testClose() { @Test public void testClose_nullSslContextProvider() { - prepareSupplier(); + prepareSupplier(false); doThrow(new NullPointerException()).when(mockTlsContextManager) .releaseClientSslContextProvider(null, HOSTNAME); supplier.close(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 2cc2c940812..67817aff71d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -82,7 +82,7 @@ public void createClientSslContextProvider() { CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false, null, false); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient); SslContextProvider clientSecretProvider = @@ -126,7 +126,7 @@ public void createClientSslContextProvider_differentInstance() { CA_PEM_FILE, "cert-instance-2", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", false, null, false); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(bootstrapInfoForClient); SslContextProvider clientSecretProvider = @@ -134,7 +134,7 @@ public void createClientSslContextProvider_differentInstance() { assertThat(clientSecretProvider).isNotNull(); UpstreamTlsContext upstreamTlsContext1 = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-2", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-2", true, null, false); SslContextProvider clientSecretProvider1 = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext1, SNI); @@ -164,7 +164,7 @@ public void createServerSslContextProvider_releaseInstance() { public void createClientSslContextProvider_releaseInstance() { UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index d5102219e51..4c13cf6f10e 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -81,7 +81,7 @@ private CertProviderClientSslContextProvider getSslContextProvider( rootInstanceName, "root-default", alpnProtocols, - staticCertValidationContext); + staticCertValidationContext, null, false); return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, From dd8fa02df210394f0304406a75a2d841b932b92b Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 1 Sep 2025 06:42:59 +0000 Subject: [PATCH 08/51] Save changes. --- .../netty/InternalProtocolNegotiators.java | 6 +- .../io/grpc/netty/ProtocolNegotiators.java | 12 +- .../security/SecurityProtocolNegotiators.java | 33 ++-- .../internal/security/SslContextProvider.java | 4 +- .../security/SslContextProviderSupplier.java | 23 +-- .../XdsClientWrapperForServerSdsTestMisc.java | 2 +- .../security/CommonTlsContextTestsUtil.java | 2 +- .../SecurityProtocolNegotiatorsTest.java | 87 +++++++++- .../SslContextProviderSupplierTest.java | 162 ++++++++++++++---- 9 files changed, 252 insertions(+), 79 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index d46098afeb6..831712dff86 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -72,8 +72,8 @@ public void close() { * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext) { - return tls(sslContext, null, Optional.absent(), null); + public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, String sni) { + return tls(sslContext, null, Optional.absent(), sni); } /** @@ -171,7 +171,7 @@ public static ChannelHandler clientTlsHandler( ChannelHandler next, SslContext sslContext, String authority, ChannelLogger negotiationLogger) { return new ClientTlsHandler(next, sslContext, authority, null, negotiationLogger, - Optional.absent(), null); + Optional.absent(), null, null); } public static class ProtocolNegotiationHandler diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index e31a5785beb..2f63d41dbdc 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; import io.grpc.CallCredentials; @@ -89,6 +90,7 @@ import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; + import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** @@ -609,8 +611,8 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, - sni != null? sni : grpcHandler.getAuthority(), - this.executor, negotiationLogger, handshakeCompleteRunnable, x509ExtendedTrustManager); + !Strings.isNullOrEmpty(sni)? sni : grpcHandler.getAuthority(), + this.executor, negotiationLogger, handshakeCompleteRunnable, null, x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -637,13 +639,13 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final X509TrustManager x509ExtendedTrustManager; private SSLEngine sslEngine; - ClientTlsHandler(ChannelHandler next, SslContext sslContext, String sni, + ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, Executor executor, ChannelLogger negotiationLogger, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager) { + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); - HostPort hostPort = parseAuthority(sni); + HostPort hostPort = parseAuthority(authority); this.host = hostPort.host; this.port = hostPort.port; this.executor = executor; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 3f604cd5be1..db7065807e4 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -19,8 +19,10 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Strings; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.Grpc; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -30,6 +32,7 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; +import io.grpc.xds.EnvoyServerProtoData; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -192,12 +195,12 @@ static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; - private final String hostname; + private final String sni; ClientSecurityHandler( GrpcHttp2ConnectionHandler grpcHandler, SslContextProviderSupplier sslContextProviderSupplier, - String hostname) { + String endpointHostname) { super( // superclass (InternalProtocolNegotiators.ProtocolNegotiationHandler) expects 'next' // handler but we don't have a next handler _yet_. So we "disable" superclass's behavior @@ -211,7 +214,15 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { checkNotNull(grpcHandler, "grpcHandler"); this.grpcHandler = grpcHandler; this.sslContextProviderSupplier = sslContextProviderSupplier; - this.hostname = hostname; + EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); + UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); + sni = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) + ? endpointHostname : upstreamTlsContext.getSni(); + } + + @VisibleForTesting + String getSni() { + return sni; } @Override @@ -220,10 +231,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { ctx.pipeline().addBefore(ctx.name(), null, bufferReads); sslContextProviderSupplier.updateSslContext( - new SslContextProvider.Callback(ctx.executor(), hostname) { + new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext, String sni) { + public void updateSslContext(SslContext sslContext) { if (ctx.isRemoved()) { return; } @@ -232,7 +243,7 @@ public void updateSslContext(SslContext sslContext, String sni) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.tls(sslContext, sni).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -244,8 +255,8 @@ public void updateSslContext(SslContext sslContext, String sni) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + sni); } @Override @@ -366,7 +377,7 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext, String sni) { + public void updateSslContext(SslContext sslContext) { ChannelHandler handler = InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); @@ -382,8 +393,8 @@ public void updateSslContext(SslContext sslContext, String sni) { public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - } - ); + }, + null); } } } \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index 970551d5611..1b28e223b5b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -69,7 +69,7 @@ protected String getHostname() { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSslContext(SslContext sslContext, String sni); + @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -132,7 +132,7 @@ protected final void performCallback( public void run() { try { SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext, callback.getHostname()); + callback.updateSslContext(sslContext); } catch (Throwable e) { callback.onException(e); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 33c224caa3a..b1a8ebb8404 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -20,17 +20,16 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import io.grpc.netty.GrpcSslContexts; +import com.google.common.base.Strings; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; -import io.grpc.xds.internal.security.certprovider.CertProviderClientSslContextProvider; import io.netty.handler.ssl.SslContext; + import java.util.HashSet; import java.util.Objects; import java.util.Set; -import javax.net.ssl.SSLException; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -58,29 +57,17 @@ public BaseTlsContext getTlsContext() { } /** Updates SslContext via the passed callback. */ - public synchronized void updateSslContext(final SslContextProvider.Callback callback) { + public synchronized void updateSslContext(final SslContextProvider.Callback callback, String sni) { checkNotNull(callback, "callback"); try { - String sni; - if (tlsContext instanceof UpstreamTlsContext) { - UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); - sni = upstreamTlsContext.getAutoHostSni() ? callback.getHostname() : upstreamTlsContext.getSni(); - } else { - sni = null; - } - if (!shutdown) { - if (sslContextProvider == null) { - sslContextProvider = getSslContextProvider(sni); - } - } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(sni); toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSslContext(SslContext sslContext, String sni) { - callback.updateSslContext(sslContext, sni); + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); releaseSslContextProvider(toRelease, sni); } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java index ff97afe6916..1a491ae6e14 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientWrapperForServerSdsTestMisc.java @@ -301,7 +301,7 @@ public void releaseOldSupplierOnTemporaryError_noClose() throws Exception { private void callUpdateSslContext(SslContextProviderSupplier sslContextProviderSupplier) { assertThat(sslContextProviderSupplier).isNotNull(); SslContextProvider.Callback callback = mock(SslContextProvider.Callback.class); - sslContextProviderSupplier.updateSslContext(callback); + sslContextProviderSupplier.updateSslContext(callback, null); } private void sendListenerUpdate( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 9d256397bd8..21d5d21f457 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -394,7 +394,7 @@ public TestCallback(Executor executor) { } @Override - public void updateSslContext(SslContext sslContext, String sni) { + public void updateSslContext(SslContext sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 64f770aee05..b835b39ef46 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -87,6 +87,7 @@ public class SecurityProtocolNegotiatorsTest { private static final String HOSTNAME = "hostname"; + private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -170,7 +171,7 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext, String sni) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -178,7 +179,7 @@ public void updateSslContext(SslContext sslContext, String sni) { protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, null); assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); @@ -196,6 +197,78 @@ protected void onException(Throwable throwable) { CommonCertProviderTestUtils.register0(); } + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsEmpty_usesSniFromUpstreamTlsContext() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, ""); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsNull_usesSniFromUpstreamTlsContext() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, null); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + + @Test + public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTlsContext() { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } + @Test public void serverSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -247,7 +320,7 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext, String sni) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -255,7 +328,7 @@ public void updateSslContext(SslContext sslContext, String sni) { protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, null); channel.runPendingTasks(); // need this for tasks to execute on eventLoop assertThat(executor.runDueTasks()).isEqualTo(1); Object fromFuture = future.get(2, TimeUnit.SECONDS); @@ -383,7 +456,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext, String sni) { + public void updateSslContext(SslContext sslContext) { future.set(sslContext); } @@ -391,7 +464,7 @@ public void updateSslContext(SslContext sslContext, String sni) { protected void onException(Throwable throwable) { future.set(throwable); } - }); + }, null); executor.runDueTasks(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop Object fromFuture = future.get(5, TimeUnit.SECONDS); @@ -416,7 +489,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_handleHandlerRemoved() { CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, true); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index ed4e25963aa..dc16cf82298 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -47,7 +47,8 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final String HOSTNAME = "hostname"; + private static final String ENDPOINT_HOSTNAME_FROM_ATTR = "endpoint-hostname-from-attribute"; + private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; @Mock private TlsContextManager mockTlsContextManager; private SslContextProviderSupplier supplier; @@ -55,52 +56,151 @@ public class SslContextProviderSupplierTest { private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; private SslContextProvider.Callback mockCallback; - private void prepareSupplier(boolean autoHostSni) { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, autoHostSni); + private void prepareSupplier(boolean autoHostSni, String sniInUTC, String sniSentByClient) { + upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe", true, sniInUTC, autoHostSni); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(sniSentByClient)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } - private void callUpdateSslContext() { + private void callUpdateSslContext(String endpointHostname) { mockCallback = mock(SslContextProvider.Callback.class); - when(mockCallback.getHostname()).thenReturn(HOSTNAME); + when(mockCallback.getHostname()).thenReturn(endpointHostname); Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, null); } @Test public void get_updateSecret() { - prepareSupplier(false); - callUpdateSslContext(); + prepareSupplier(false, null, ""); + callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq(HOSTNAME)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq("")); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext, HOSTNAME); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext), HOSTNAME); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, null); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); + } + + @Test + public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { + prepareSupplier(false, SNI_IN_UTC, SNI_IN_UTC); + callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + SslContext mockSslContext = mock(SslContext.class); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback, null); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + } + + @Test + public void autoHostSniTrue_usesSniFromEndpointHostname() { + prepareSupplier(true, SNI_IN_UTC, ENDPOINT_HOSTNAME_FROM_ATTR); + callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + SslContext mockSslContext = mock(SslContext.class); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + when(mockCallback.getHostname()).thenReturn(ENDPOINT_HOSTNAME_FROM_ATTR); + supplier.updateSslContext(mockCallback, null); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); + } + + @Test + public void autoHostSniTrue_endpointHostNameIsNull_usesSniFromUpstreamTlsContext() { + prepareSupplier(true, SNI_IN_UTC, SNI_IN_UTC); + callUpdateSslContext(null); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + SslContext mockSslContext = mock(SslContext.class); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + when(mockCallback.getHostname()).thenReturn(null); + supplier.updateSslContext(mockCallback, null); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + } + + @Test + public void autoHostSniTrue_endpointHostNameIsEmpty_usesSniFromUpstreamTlsContext() { + prepareSupplier(true, SNI_IN_UTC, SNI_IN_UTC); + callUpdateSslContext(""); + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + SslContext mockSslContext = mock(SslContext.class); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + when(mockCallback.getHostname()).thenReturn(""); + supplier.updateSslContext(mockCallback, null); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); } @Test public void get_onException() { - prepareSupplier(false); - callUpdateSslContext(); + prepareSupplier(false, null, ""); + callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); @@ -110,33 +210,33 @@ public void get_onException() { capturedCallback.onException(exception); verify(mockCallback, times(1)).onException(eq(exception)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); } @Test public void testClose() { - prepareSupplier(false); - callUpdateSslContext(); + prepareSupplier(false, null, ""); + callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); supplier.close(); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); - supplier.updateSslContext(mockCallback); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); + supplier.updateSslContext(mockCallback, null); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq(HOSTNAME)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq("")); } @Test public void testClose_nullSslContextProvider() { - prepareSupplier(false); + prepareSupplier(false, null, ""); doThrow(new NullPointerException()).when(mockTlsContextManager) - .releaseClientSslContextProvider(null, HOSTNAME); + .releaseClientSslContextProvider(null, ""); supplier.close(); verify(mockTlsContextManager, never()) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(HOSTNAME)); - callUpdateSslContext(); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); + callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); verify(mockTlsContextManager, times(1)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(HOSTNAME)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); } } From a3710655684d0963dba5a704500e6dfe748516fb Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 1 Sep 2025 11:26:18 +0000 Subject: [PATCH 09/51] Save changes. --- .../internal/security/SslContextProvider.java | 10 -- .../security/SslContextProviderSupplier.java | 6 + .../SslContextProviderSupplierTest.java | 146 ++++-------------- 3 files changed, 40 insertions(+), 122 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index 1b28e223b5b..fcd3a899c8d 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -54,20 +54,10 @@ protected Callback(Executor executor) { this.hostname = null; } - // Only for client SslContextProvider. - protected Callback(Executor executor, String hostname) { - this.executor = executor; - this.hostname = hostname; - } - @VisibleForTesting public Executor getExecutor() { return executor; } - protected String getHostname() { - return hostname; - } - /** Informs callee of new/updated SslContext. */ @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index b1a8ebb8404..052e08d4b84 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -60,6 +60,12 @@ public BaseTlsContext getTlsContext() { public synchronized void updateSslContext(final SslContextProvider.Callback callback, String sni) { checkNotNull(callback, "callback"); try { + if (!shutdown) { + if (sslContextProvider == null) { + sslContextProvider = getSslContextProvider(sni); + } + } + // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(sni); toRelease.addCallback( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index dc16cf82298..5ac94648605 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -25,7 +25,6 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; @@ -47,8 +46,7 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final String ENDPOINT_HOSTNAME_FROM_ATTR = "endpoint-hostname-from-attribute"; - private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; + private static final String SNI = "sni"; @Mock private TlsContextManager mockTlsContextManager; private SslContextProviderSupplier supplier; @@ -56,32 +54,31 @@ public class SslContextProviderSupplierTest { private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; private SslContextProvider.Callback mockCallback; - private void prepareSupplier(boolean autoHostSni, String sniInUTC, String sniSentByClient) { + private void prepareSupplier() { upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext( - "google_cloud_private_spiffe", true, sniInUTC, autoHostSni); + "google_cloud_private_spiffe", true, SNI, false); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(sniSentByClient)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } - private void callUpdateSslContext(String endpointHostname) { + private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); - when(mockCallback.getHostname()).thenReturn(endpointHostname); Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback, null); + supplier.updateSslContext(mockCallback, SNI); } @Test public void get_updateSecret() { - prepareSupplier(false, null, ""); - callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + prepareSupplier(); + callUpdateSslContext(); verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq("")); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); @@ -91,21 +88,21 @@ public void get_updateSecret() { capturedCallback.updateSslContext(mockSslContext); verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback, null); + supplier.updateSslContext(mockCallback, SNI); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } @Test public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { - prepareSupplier(false, SNI_IN_UTC, SNI_IN_UTC); - callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + prepareSupplier(); + callUpdateSslContext(); verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); @@ -115,92 +112,17 @@ public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { capturedCallback.updateSslContext(mockSslContext); verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback, null); + supplier.updateSslContext(mockCallback, SNI); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); - } - - @Test - public void autoHostSniTrue_usesSniFromEndpointHostname() { - prepareSupplier(true, SNI_IN_UTC, ENDPOINT_HOSTNAME_FROM_ATTR); - callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); - verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); - verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); - ArgumentCaptor callbackCaptor = - ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); - SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); - assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); - verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - when(mockCallback.getHostname()).thenReturn(ENDPOINT_HOSTNAME_FROM_ATTR); - supplier.updateSslContext(mockCallback, null); - verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(ENDPOINT_HOSTNAME_FROM_ATTR)); - } - - @Test - public void autoHostSniTrue_endpointHostNameIsNull_usesSniFromUpstreamTlsContext() { - prepareSupplier(true, SNI_IN_UTC, SNI_IN_UTC); - callUpdateSslContext(null); - verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); - verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC)); - ArgumentCaptor callbackCaptor = - ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); - SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); - assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); - verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - when(mockCallback.getHostname()).thenReturn(null); - supplier.updateSslContext(mockCallback, null); - verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); - } - - @Test - public void autoHostSniTrue_endpointHostNameIsEmpty_usesSniFromUpstreamTlsContext() { - prepareSupplier(true, SNI_IN_UTC, SNI_IN_UTC); - callUpdateSslContext(""); - verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); - verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI_IN_UTC)); - ArgumentCaptor callbackCaptor = - ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); - SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); - assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); - verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI_IN_UTC)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - when(mockCallback.getHostname()).thenReturn(""); - supplier.updateSslContext(mockCallback, null); - verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI_IN_UTC)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } @Test public void get_onException() { - prepareSupplier(false, null, ""); - callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + prepareSupplier(); + callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); @@ -210,33 +132,33 @@ public void get_onException() { capturedCallback.onException(exception); verify(mockCallback, times(1)).onException(eq(exception)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); } @Test public void testClose() { - prepareSupplier(false, null, ""); - callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + prepareSupplier(); + callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); - supplier.updateSslContext(mockCallback, null); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); + supplier.updateSslContext(mockCallback, SNI); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(any(SslContextProvider.class), eq("")); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); } @Test public void testClose_nullSslContextProvider() { - prepareSupplier(false, null, ""); + prepareSupplier(); doThrow(new NullPointerException()).when(mockTlsContextManager) - .releaseClientSslContextProvider(null, ""); + .releaseClientSslContextProvider(null, SNI); supplier.close(); verify(mockTlsContextManager, never()) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq("")); - callUpdateSslContext(ENDPOINT_HOSTNAME_FROM_ATTR); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); + callUpdateSslContext(); verify(mockTlsContextManager, times(1)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq("")); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } } From a6f1bc9bac35ce4fe3c4c5a2f7d46298fa7da616 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 2 Sep 2025 11:23:05 +0000 Subject: [PATCH 10/51] XdsX509TrustManager changes for auto sni san validation. --- .../io/grpc/xds/EnvoyServerProtoData.java | 7 + .../CertProviderClientSslContextProvider.java | 23 ++- .../CertProviderServerSslContextProvider.java | 4 +- .../trust/XdsTrustManagerFactory.java | 31 ++-- .../security/trust/XdsX509TrustManager.java | 19 +- .../trust/XdsTrustManagerFactoryTest.java | 14 +- .../trust/XdsX509TrustManagerTest.java | 174 ++++++++++++++---- 7 files changed, 188 insertions(+), 84 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index ed2b841e983..b3d161b6bbb 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -75,12 +75,14 @@ public static final class UpstreamTlsContext extends BaseTlsContext { private final String sni; private final boolean auto_host_sni; + private final boolean auto_sni_san_validation; @VisibleForTesting public UpstreamTlsContext(CommonTlsContext commonTlsContext) { super(commonTlsContext); this.sni = null; this.auto_host_sni = false; + this.auto_sni_san_validation = false; } @VisibleForTesting @@ -88,6 +90,7 @@ public UpstreamTlsContext(io.envoyproxy.envoy.extensions.transport_sockets.tls.v super(upstreamTlsContext.getCommonTlsContext()); this.sni = upstreamTlsContext.getSni(); this.auto_host_sni = upstreamTlsContext.getAutoHostSni(); + this.auto_sni_san_validation = upstreamTlsContext.getAutoSniSanValidation(); } public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( @@ -104,6 +107,10 @@ public boolean getAutoHostSni() { return auto_host_sni; } + public boolean getAutoSniSanValidation() { + return auto_sni_san_validation; + } + @Override public String toString() { return "UpstreamTlsContext{" + "commonTlsContext=" + commonTlsContext + '}'; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index b25fddce65c..54d128ec89a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -32,14 +32,16 @@ /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ public final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { + private final String sniForSanMatching; + CertProviderClientSslContextProvider( - Node node, - @Nullable Map certProviders, - CommonTlsContext.CertificateProviderInstance certInstance, - CommonTlsContext.CertificateProviderInstance rootCertInstance, - CertificateValidationContext staticCertValidationContext, - UpstreamTlsContext upstreamTlsContext, - String sni, CertificateProviderStore certificateProviderStore) { + Node node, + @Nullable Map certProviders, + CommonTlsContext.CertificateProviderInstance certInstance, + CommonTlsContext.CertificateProviderInstance rootCertInstance, + CertificateValidationContext staticCertValidationContext, + UpstreamTlsContext upstreamTlsContext, + String sniForSanMatching, CertificateProviderStore certificateProviderStore) { super( node, certProviders, @@ -48,11 +50,12 @@ public final class CertProviderClientSslContextProvider extends CertProviderSslC staticCertValidationContext, upstreamTlsContext, certificateProviderStore); + this.sniForSanMatching = upstreamTlsContext.getAutoSniSanValidation()? sniForSanMatching : null; } @Override protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) + CertificateValidationContext certificateValidationContext) throws CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); // Null rootCertInstance implies hasSystemRootCerts because of the check in @@ -62,12 +65,12 @@ protected final SslContextBuilder getSslContextBuilder( sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( savedSpiffeTrustMap, - certificateValidationContextdationContext)); + certificateValidationContext, sniForSanMatching)); } else { sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); + certificateValidationContext, sniForSanMatching)); } } if (isMtls()) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index ef65bbfb6f9..2488fcb1199 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -63,11 +63,11 @@ protected final SslContextBuilder getSslContextBuilder( if (isMtls() && savedSpiffeTrustMap != null) { trustManagerFactory = new XdsTrustManagerFactory( savedSpiffeTrustMap, - certificateValidationContextdationContext); + certificateValidationContextdationContext, null); } else if (isMtls()) { trustManagerFactory = new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext); + certificateValidationContextdationContext, null); } setClientAuthValues(sslContextBuilder, trustManagerFactory); sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index 8cb44117065..3ba31d8ff2b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -58,43 +58,46 @@ public XdsTrustManagerFactory(CertificateValidationContext certificateValidation this( getTrustedCaFromCertContext(certificateValidationContext), certificateValidationContext, - false); + false, + null); } public XdsTrustManagerFactory( - X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext) + X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) throws CertStoreException { - this(certs, staticCertificateValidationContext, true); + this(certs, staticCertificateValidationContext, true, sniForSanMatching); } public XdsTrustManagerFactory(Map> spiffeTrustMap, - CertificateValidationContext staticCertificateValidationContext) throws CertStoreException { - this(spiffeTrustMap, staticCertificateValidationContext, true); + CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) throws CertStoreException { + this(spiffeTrustMap, staticCertificateValidationContext, true, sniForSanMatching); } private XdsTrustManagerFactory( X509Certificate[] certs, CertificateValidationContext certificateValidationContext, - boolean validationContextIsStatic) + boolean validationContextIsStatic, + String sniForSanMatching) throws CertStoreException { if (validationContextIsStatic) { checkArgument( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); } - xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext); + xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext, sniForSanMatching); } private XdsTrustManagerFactory( Map> spiffeTrustMap, CertificateValidationContext certificateValidationContext, - boolean validationContextIsStatic) + boolean validationContextIsStatic, + String sniForSanMatching) throws CertStoreException { if (validationContextIsStatic) { checkArgument( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); - xdsX509TrustManager = createX509TrustManager(spiffeTrustMap, certificateValidationContext); + xdsX509TrustManager = createX509TrustManager(spiffeTrustMap, certificateValidationContext, sniForSanMatching); } } @@ -121,21 +124,21 @@ private static X509Certificate[] getTrustedCaFromCertContext( @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - X509Certificate[] certs, CertificateValidationContext certContext) throws CertStoreException { - return new XdsX509TrustManager(certContext, createTrustManager(certs)); + X509Certificate[] certs, CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { + return new XdsX509TrustManager(certContext, createTrustManager(certs), sniForSanMatching); } @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - Map> spiffeTrustMapFile, - CertificateValidationContext certContext) throws CertStoreException { + Map> spiffeTrustMapFile, + CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { checkNotNull(spiffeTrustMapFile, "spiffeTrustMapFile"); Map delegates = new HashMap<>(); for (Map.Entry> entry:spiffeTrustMapFile.entrySet()) { delegates.put(entry.getKey(), createTrustManager( entry.getValue().toArray(new X509Certificate[0]))); } - return new XdsX509TrustManager(certContext, delegates); + return new XdsX509TrustManager(certContext, delegates, sniForSanMatching); } private static X509ExtendedTrustManager createTrustManager(X509Certificate[] certs) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 390c131aa92..b32f0821b60 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -61,7 +61,7 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 private final X509ExtendedTrustManager delegate; private final Map spiffeTrustMapDelegates; private final CertificateValidationContext certContext; - private final String sni; + private final String sniForSanMatching; XdsX509TrustManager(@Nullable CertificateValidationContext certContext, X509ExtendedTrustManager delegate) { @@ -69,26 +69,21 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 } XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - X509ExtendedTrustManager delegate, @Nullable String sni) { + X509ExtendedTrustManager delegate, @Nullable String sniForSanMatching) { checkNotNull(delegate, "delegate"); this.certContext = certContext; this.delegate = delegate; this.spiffeTrustMapDelegates = null; - this.sni = sni; + this.sniForSanMatching = sniForSanMatching; } XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - Map spiffeTrustMapDelegates) { - this(certContext, spiffeTrustMapDelegates, null); - } - - XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - Map spiffeTrustMapDelegates, @Nullable String sni) { + Map spiffeTrustMapDelegates, @Nullable String sniForSanMatching) { checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); this.certContext = certContext; this.delegate = null; - this.sni = sni; + this.sniForSanMatching = sniForSanMatching; } private static boolean verifyDnsNameInPattern( @@ -222,7 +217,9 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi return; } @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names - List verifyList = sni != null? ImmutableList.of(StringMatcher.newBuilder().setExact(sni).build()) : certContext.getMatchSubjectAltNamesList(); + List verifyList = !Strings.isNullOrEmpty(sniForSanMatching) + ? ImmutableList.of(StringMatcher.newBuilder().setExact(sniForSanMatching).build()) + : certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { return; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java index db83961cfc3..36e75327419 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java @@ -91,7 +91,7 @@ public void constructor_fromRootCert() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -115,7 +115,7 @@ public void constructor_fromSpiffeTrustMap() "san2"); // Single domain and single cert XdsTrustManagerFactory factory = new XdsTrustManagerFactory(ImmutableMap - .of("example.com", ImmutableList.of(x509Cert)), staticValidationContext); + .of("example.com", ImmutableList.of(x509Cert)), staticValidationContext, null); assertThat(factory).isNotNull(); TrustManager[] tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -131,7 +131,7 @@ public void constructor_fromSpiffeTrustMap() X509Certificate anotherCert = TestUtils.loadX509Cert(CLIENT_PEM_FILE); factory = new XdsTrustManagerFactory(ImmutableMap .of("example.com", ImmutableList.of(x509Cert), - "google.com", ImmutableList.of(x509Cert, anotherCert)), staticValidationContext); + "google.com", ImmutableList.of(x509Cert, anotherCert)), staticValidationContext, null); assertThat(factory).isNotNull(); tms = factory.getTrustManagers(); assertThat(tms).isNotNull(); @@ -154,7 +154,7 @@ public void constructorRootCert_checkServerTrusted() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "waterzooi.test.google.be"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -167,7 +167,7 @@ public void constructorRootCert_nonStaticContext_throwsException() X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); try { new XdsTrustManagerFactory( - new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE)); + new X509Certificate[] {x509Cert}, getCertContextFromPath(CA_PEM_FILE), null); Assert.fail("no exception thrown"); } catch (IllegalArgumentException expected) { assertThat(expected) @@ -183,7 +183,7 @@ public void constructorRootCert_checkServerTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] serverChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); @@ -204,7 +204,7 @@ public void constructorRootCert_checkClientTrusted_throwsException() CertificateValidationContext staticValidationContext = buildStaticValidationContext("san1", "san2"); XdsTrustManagerFactory factory = - new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext); + new XdsTrustManagerFactory(new X509Certificate[]{x509Cert}, staticValidationContext, null); XdsX509TrustManager xdsX509TrustManager = (XdsX509TrustManager) factory.getTrustManagers()[0]; X509Certificate[] clientChain = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index 6fa3d2e7d24..40232ed9425 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -42,6 +42,7 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.Collection; import java.util.Collections; import java.util.List; import javax.net.ssl.SSLEngine; @@ -53,6 +54,7 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import org.junit.runners.Parameterized; import org.mockito.Mock; import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; @@ -60,7 +62,7 @@ /** * Unit tests for {@link XdsX509TrustManager}. */ -@RunWith(JUnit4.class) +@RunWith(Parameterized.class) public class XdsX509TrustManagerTest { @Rule @@ -73,6 +75,18 @@ public class XdsX509TrustManagerTest { private SSLSession mockSession; private XdsX509TrustManager trustManager; + private boolean useSniForSanMatching; + + @Parameterized.Parameters + public static Collection data() { + return Arrays.asList(new Object[][] { + { true }, { false } + }); + } + + public XdsX509TrustManagerTest(boolean useSniForSanMatching) { + this.useSniForSanMatching = useSniForSanMatching; + } @Test public void nullCertContextTest() throws CertificateException, IOException { @@ -93,11 +107,16 @@ public void emptySanListContextTest() throws CertificateException, IOException { @Test public void missingPeerCerts() { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } try { trustManager.verifySubjectAltNameInChain(null); fail("no exception thrown"); @@ -108,11 +127,16 @@ public void missingPeerCerts() { @Test public void emptyArrayPeerCerts() { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } try { trustManager.verifySubjectAltNameInChain(new X509Certificate[0]); fail("no exception thrown"); @@ -123,11 +147,16 @@ public void emptyArrayPeerCerts() { @Test public void noSansInPeerCerts() throws CertificateException, IOException { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CLIENT_PEM_FILE)); try { @@ -140,17 +169,54 @@ public void noSansInPeerCerts() throws CertificateException, IOException { @Test public void oneSanInPeerCertsVerifies() throws CertificateException, IOException { + if (useSniForSanMatching) { + trustManager = new XdsX509TrustManager( + CertificateValidationContext.getDefaultInstance(), mockDelegate, "waterzooi.test.google.be"); + } else { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact("waterzooi.test.google.be") + .setIgnoreCase(false) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } + + @Test + public void autoSanSniValidation_precedes_subAltNamesToMatch() throws CertificateException, IOException { StringMatcher stringMatcher = - StringMatcher.newBuilder() - .setExact("waterzooi.test.google.be") - .setIgnoreCase(false) - .build(); + StringMatcher.newBuilder() + .setExact("notgonnabeused.test.google.be") + .setIgnoreCase(false) + .build(); @SuppressWarnings("deprecation") CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, "waterzooi.test.google.be"); X509Certificate[] certs = - CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } + + @Test + public void emptySni_noAutoSanSniValidation() throws CertificateException, IOException { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact("waterzooi.test.google.be") + .setIgnoreCase(false) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, ""); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); trustManager.verifySubjectAltNameInChain(certs); } @@ -420,11 +486,16 @@ public void oneSanInPeerCertsVerifiesMultipleVerifySans() @Test public void oneSanInPeerCertsNotFoundException() throws CertificateException, IOException { - StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, + "x.foo.com"); + } else { + StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { @@ -477,12 +548,17 @@ public void wildcardSanInPeerCertsSubdomainMismatch() // 2. Asterisk (*) cannot match across domain name labels. // For example, *.example.com matches test.example.com but does not match // sub.test.example.com. - StringMatcher stringMatcher = - StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate); + if (useSniForSanMatching) { + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), + mockDelegate, "sub.abc.test.youtube.com"); + } else { + StringMatcher stringMatcher = + StringMatcher.newBuilder().setExact("sub.abc.test.youtube.com").build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate); + } X509Certificate[] certs = CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); try { @@ -509,6 +585,15 @@ public void oneIpAddressInPeerCertsVerifies() throws CertificateException, IOExc trustManager.verifySubjectAltNameInChain(certs); } + @Test + public void oneIpAddressInPeerCertsVerifies_autoSniSanValidation() throws CertificateException, IOException { + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, + "192.168.1.3"); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } + @Test public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder().setExact("x.foo.com").build(); @@ -530,6 +615,15 @@ public void oneIpAddressInPeerCertsMismatch() throws CertificateException, IOExc } } + @Test + public void oneIpAddressInPeerCertsMismatch_autoSniSanValidation() throws CertificateException, IOException { + trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, + "192.168.1.3"); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } + @Test public void checkServerTrustedSslEngine() throws CertificateException, IOException, CertStoreException { @@ -550,7 +644,7 @@ public void checkServerTrustedSslEngineSpiffeTrustMap() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("example.com", caCerts), null); + ImmutableMap.of("example.com", caCerts), null, null); trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); verify(sslEngine, times(1)).getHandshakeSession(); assertThat(sslEngine.getSSLParameters().getEndpointIdentificationAlgorithm()).isEmpty(); @@ -565,7 +659,7 @@ public void checkServerTrustedSslEngineSpiffeTrustMap_missing_spiffe_id() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("example.com", caCerts), null); + ImmutableMap.of("example.com", caCerts), null, null); try { trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); fail("exception expected"); @@ -584,7 +678,7 @@ public void checkServerTrustedSpiffeSslEngineTrustMap_missing_trust_domain() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("unknown.com", caCerts), null); + ImmutableMap.of("unknown.com", caCerts), null, null); try { trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslEngine); fail("exception expected"); @@ -602,7 +696,7 @@ public void checkClientTrustedSpiffeTrustMap() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("foo.bar.com", caCerts), null); + ImmutableMap.of("foo.bar.com", caCerts), null, null); trustManager.checkClientTrusted(clientCerts, "RSA"); } @@ -643,7 +737,7 @@ public void checkServerTrustedSslSocketSpiffeTrustMap() List caCerts = Arrays.asList(CertificateUtils .toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE))); trustManager = XdsTrustManagerFactory.createX509TrustManager( - ImmutableMap.of("example.com", caCerts), null); + ImmutableMap.of("example.com", caCerts), null, null); trustManager.checkServerTrusted(serverCerts, "ECDHE_ECDSA", sslSocket); verify(sslSocket, times(1)).isConnected(); verify(sslSocket, times(1)).getHandshakeSession(); @@ -717,7 +811,7 @@ private SSLParameters buildTrustManagerAndGetSslParameters() X509Certificate[] caCerts = CertificateUtils.toX509Certificates(TlsTesting.loadCert(CA_PEM_FILE)); trustManager = XdsTrustManagerFactory.createX509TrustManager(caCerts, - null); + null, null); when(mockSession.getProtocol()).thenReturn("TLSv1.2"); when(mockSession.getPeerHost()).thenReturn("peer-host-from-mock"); SSLParameters sslParams = new SSLParameters(); From a576df08f07ef8f6155b23cb399e90b3fd7b5f99 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 2 Sep 2025 11:48:04 +0000 Subject: [PATCH 11/51] Fallback flag when no sni is available to send to specify to use xds channel authority itself. --- .../security/SecurityProtocolNegotiators.java | 9 ++++++- .../SecurityProtocolNegotiatorsTest.java | 26 ++++++++++++++++++- 2 files changed, 33 insertions(+), 2 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index db7065807e4..f9cc329e541 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -54,6 +54,9 @@ @VisibleForTesting public final class SecurityProtocolNegotiators { + static boolean useChannelAuthorityIfNoSniApplicable = + GrpcUtil.getFlag("GRPC_USE_CHANNEL_AUTHORITY_IF_NO_SNI_APPLICABLE", false); + /** Name associated with individual address, if available (e.g., DNS name). */ @EquivalentAddressGroup.Attr public static final Attributes.Key ATTR_ADDRESS_NAME = @@ -216,8 +219,12 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { this.sslContextProviderSupplier = sslContextProviderSupplier; EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); - sni = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) + String sniVal = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) ? endpointHostname : upstreamTlsContext.getSni(); + if (Strings.isNullOrEmpty(sniVal) && useChannelAuthorityIfNoSniApplicable) { + sniVal = grpcHandler.getAuthority(); + } + sni = sniVal; } @VisibleForTesting diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index b835b39ef46..27bbbc9d735 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -88,6 +88,7 @@ public class SecurityProtocolNegotiatorsTest { private static final String HOSTNAME = "hostname"; private static final String SNI_IN_UTC = "sni-in-upstream-tls-context"; + private static final String FAKE_AUTHORITY = "authority"; private final GrpcHttp2ConnectionHandler grpcHandler = FakeGrpcHttp2ConnectionHandler.newHandler(); @@ -269,6 +270,29 @@ public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTls assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); } + @Test + public void emptySni_useChannelAuthorityIfNoSniApplicableIsTrue_usesChannelAuthority() { + SecurityProtocolNegotiators.useChannelAuthorityIfNoSniApplicable = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, "", false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(FAKE_AUTHORITY); + } finally { + SecurityProtocolNegotiators.useChannelAuthorityIfNoSniApplicable = false; + } + } + @Test public void serverSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -533,7 +557,7 @@ static FakeGrpcHttp2ConnectionHandler newHandler() { @Override public String getAuthority() { - return "authority"; + return FAKE_AUTHORITY; } } } From ce1f2d0bc0320225ad62ce3602cd8401406d2f47 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 2 Sep 2025 13:42:50 +0000 Subject: [PATCH 12/51] Save changes --- .../security/SslContextProviderSupplier.java | 5 +- .../SecurityProtocolNegotiatorsTest.java | 53 +++++++++++++++++++ 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 052e08d4b84..458c5835223 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -16,11 +16,8 @@ package io.grpc.xds.internal.security; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import com.google.common.base.Strings; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; @@ -31,6 +28,8 @@ import java.util.Objects; import java.util.Set; +import static com.google.common.base.Preconditions.checkNotNull; + /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} * and communicate it to the consumer i.e. {@link SecurityProtocolNegotiators} diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 27bbbc9d735..0fd72e59457 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -197,6 +197,59 @@ protected void onException(Throwable throwable) { .contains("ProtocolNegotiators.ClientTlsHandler"); CommonCertProviderTestUtils.register0(); } + + @Test + public void clientSecurityHandler_addLast() + throws InterruptedException, TimeoutException, ExecutionException { + FakeClock executor = new FakeClock(); + CommonCertProviderTestUtils.register(executor); + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); + + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + pipeline.addLast(clientSecurityHandler); + channelHandlerCtx = pipeline.context(clientSecurityHandler); + assertNotNull(channelHandlerCtx); + + // kick off protocol negotiation. + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + final SettableFuture future = SettableFuture.create(); + sslContextProviderSupplier + .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSslContext(SslContext sslContext) { + future.set(sslContext); + } + + @Override + protected void onException(Throwable throwable) { + future.set(throwable); + } + }, null); + assertThat(executor.runDueTasks()).isEqualTo(1); + channel.runPendingTasks(); + Object fromFuture = future.get(2, TimeUnit.SECONDS); + assertThat(fromFuture).isInstanceOf(SslContext.class); + channel.runPendingTasks(); + channelHandlerCtx = pipeline.context(clientSecurityHandler); + assertThat(channelHandlerCtx).isNull(); + + // pipeline should have SslHandler and ClientTlsHandler + Iterator> iterator = pipeline.iterator(); + assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class); + // ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name + assertThat(iterator.next().getValue().getClass().getCanonicalName()) + .contains("ProtocolNegotiators.ClientTlsHandler"); + CommonCertProviderTestUtils.register0(); + } @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { From 5a4f758bed9919057661e62248951ad56b83322b Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 3 Sep 2025 13:37:56 +0000 Subject: [PATCH 13/51] Unit test for auto host sni hostname propagation to ClientSecurityHander. --- .../SecurityProtocolNegotiatorsTest.java | 81 ++++++------------- 1 file changed, 25 insertions(+), 56 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 0fd72e59457..43d02a9feaa 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -28,9 +28,7 @@ import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.doNothing; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; +import static org.mockito.Mockito.*; import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.SettableFuture; @@ -145,6 +143,30 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); } + @Test + public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPassedToClientSecurityHandler() { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true); + ClientSecurityProtocolNegotiator pn = + new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + ChannelLogger logger = mock(ChannelLogger.class); + doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); + when(mockHandler.getNegotiationLogger()).thenReturn(logger); + TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); + when(mockHandler.getEagAttributes()) + .thenReturn( + Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) + .build()); + ChannelHandler newHandler = pn.newHandler(mockHandler); + assertThat(newHandler).isNotNull(); + assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); + } + @Test public void clientSecurityHandler_addLast() throws InterruptedException, TimeoutException, ExecutionException { @@ -197,59 +219,6 @@ protected void onException(Throwable throwable) { .contains("ProtocolNegotiators.ClientTlsHandler"); CommonCertProviderTestUtils.register0(); } - - @Test - public void clientSecurityHandler_addLast() - throws InterruptedException, TimeoutException, ExecutionException { - FakeClock executor = new FakeClock(); - CommonCertProviderTestUtils.register(executor); - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); - - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); - pipeline.addLast(clientSecurityHandler); - channelHandlerCtx = pipeline.context(clientSecurityHandler); - assertNotNull(channelHandlerCtx); - - // kick off protocol negotiation. - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); - final SettableFuture future = SettableFuture.create(); - sslContextProviderSupplier - .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); - } - - @Override - protected void onException(Throwable throwable) { - future.set(throwable); - } - }, null); - assertThat(executor.runDueTasks()).isEqualTo(1); - channel.runPendingTasks(); - Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); - channel.runPendingTasks(); - channelHandlerCtx = pipeline.context(clientSecurityHandler); - assertThat(channelHandlerCtx).isNull(); - - // pipeline should have SslHandler and ClientTlsHandler - Iterator> iterator = pipeline.iterator(); - assertThat(iterator.next().getValue()).isInstanceOf(SslHandler.class); - // ProtocolNegotiators.ClientTlsHandler.class not accessible, get canonical name - assertThat(iterator.next().getValue().getClass().getCanonicalName()) - .contains("ProtocolNegotiators.ClientTlsHandler"); - CommonCertProviderTestUtils.register0(); - } @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { From 40769980a0236f097a9613d30e082964a0ca57a2 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 8 Sep 2025 06:14:52 +0000 Subject: [PATCH 14/51] Save changes. --- .../CertProviderSslContextProvider.java | 3 + .../security/CommonTlsContextTestsUtil.java | 7 ++- ...tProviderClientSslContextProviderTest.java | 57 +++++++++++++++++++ 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 801dabeecb7..639fe92eab6 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -138,6 +138,9 @@ public final void updateCertificate(PrivateKey key, List certCh @Override public final void updateTrustedRoots(List trustedRoots) { + if (isUsingSystemRootCerts) { + return; + } savedTrustedRoots = trustedRoots; updateSslContextWhenReady(); } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 48814dece1d..5ad1d757114 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -229,9 +229,6 @@ private static CommonTlsContext.Builder addCertificateValidationContext( String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext) { - if (staticCertValidationContext == null && rootInstanceName == null) { - return builder; - } CertificateValidationContext.Builder contextBuilder; if (staticCertValidationContext == null) { contextBuilder = CertificateValidationContext.newBuilder(); @@ -243,6 +240,10 @@ private static CommonTlsContext.Builder addCertificateValidationContext( .setInstanceName(rootInstanceName) .setCertificateName(rootCertName)); builder.setValidationContext(contextBuilder.build()); + } else { + builder.setValidationContext(contextBuilder.setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build()); } return builder.setCombinedValidationContext(CombinedCertificateValidationContext.newBuilder() .setDefaultValidationContext(contextBuilder)); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index b0800458d66..fe8b1cb9b4c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -173,6 +173,63 @@ public void testProviderForClient_mtls() throws Exception { assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } + @Test + public void testProviderForClient_systemRootCerts() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + /* staticCertValidationContext= */ null); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate cert update, updates SslContext + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + + doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); + TestCallback testCallback1 = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // just do root cert update: trusted roots is not updated (because of system root certs config) + // and sslContext should still be the same + watcherCaptor[0].updateTrustedRoots( + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); + + // now update id cert: sslContext should be updated i.e.different from the previous one + watcherCaptor[0].updateCertificate( + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); + } + @Test public void testProviderForClient_mtls_newXds() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = From 968d56414900c6cebcce92996a3e4e45f14f2c9e Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 8 Sep 2025 11:57:07 +0000 Subject: [PATCH 15/51] Save changes. --- .../grpc/xds/XdsSecurityClientServerTest.java | 2 +- .../ClientSslContextProviderFactoryTest.java | 10 +++---- .../security/CommonTlsContextTestsUtil.java | 22 +++++++++----- ...tProviderClientSslContextProviderTest.java | 29 ++++++++++--------- 4 files changed, 36 insertions(+), 27 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 23068d665bf..a92b2d868cf 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -563,7 +563,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build()); + .build(), false); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 397fe01e0f5..e11efc44b49 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -74,7 +74,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -105,7 +105,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem"); @@ -135,7 +135,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -169,7 +169,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() "gcp_id", "root-default", /* alpnProtocols= */ null, - staticCertValidationContext); + staticCertValidationContext, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -199,7 +199,7 @@ public void createCertProviderClientSslContextProvider_2providers() "file_provider", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 5ad1d757114..3171a43e05f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -165,7 +165,7 @@ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( commonInstanceName, "ROOT", null, - null); + null, false); } /** Gets a cert from contents of a resource. */ @@ -182,7 +182,8 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( String rootInstanceName, String rootCertName, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); if (certInstanceName != null) { builder = @@ -193,7 +194,8 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( } builder = addCertificateValidationContext( - builder, rootInstanceName, rootCertName, staticCertValidationContext); + builder, rootInstanceName, rootCertName, staticCertValidationContext, + useSystemRootCerts); if (alpnProtocols != null) { builder.addAllAlpnProtocols(alpnProtocols); } @@ -228,7 +230,8 @@ private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, - CertificateValidationContext staticCertValidationContext) { + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { CertificateValidationContext.Builder contextBuilder; if (staticCertValidationContext == null) { contextBuilder = CertificateValidationContext.newBuilder(); @@ -240,7 +243,7 @@ private static CommonTlsContext.Builder addCertificateValidationContext( .setInstanceName(rootInstanceName) .setCertificateName(rootCertName)); builder.setValidationContext(contextBuilder.build()); - } else { + } else if (useSystemRootCerts) { builder.setValidationContext(contextBuilder.setSystemRootCerts( CertificateValidationContext.SystemRootCerts.getDefaultInstance()) .build()); @@ -277,7 +280,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { return buildUpstreamTlsContext( buildCommonTlsContextForCertProviderInstance( certInstanceName, @@ -285,7 +289,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext)); + staticCertValidationContext, + useSystemRootCerts)); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -324,7 +329,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext), requireClientCert); + staticCertValidationContext, +false), requireClientCert); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index fe8b1cb9b4c..d3e00f5b59d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -68,11 +68,12 @@ public void setUp() throws Exception { /** Helper method to build CertProviderClientSslContextProvider. */ private CertProviderClientSslContextProvider getSslContextProvider( - String certInstanceName, - String rootInstanceName, - Bootstrapper.BootstrapInfo bootstrapInfo, - Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext) { + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( certInstanceName, @@ -80,7 +81,8 @@ private CertProviderClientSslContextProvider getSslContextProvider( rootInstanceName, "root-default", alpnProtocols, - staticCertValidationContext); + staticCertValidationContext, + useSystemRootCerts); return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, @@ -122,7 +124,7 @@ public void testProviderForClient_mtls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); @@ -185,7 +187,8 @@ public void testProviderForClient_systemRootCerts() throws Exception { null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, + true); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); @@ -305,7 +308,7 @@ public void testProviderForClient_queueExecutor() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); QueuedExecutor queuedExecutor = new QueuedExecutor(); TestCallback testCallback = @@ -338,7 +341,7 @@ public void testProviderForClient_tls() throws Exception { "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); @@ -375,7 +378,7 @@ public void testProviderForClient_sslContextException_onError() throws Exception "gcp_id", CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */null, - staticCertValidationContext); + staticCertValidationContext, false); TestCallback testCallback = new TestCallback(MoreExecutors.directExecutor()); provider.addCallback(testCallback); @@ -407,7 +410,7 @@ public void testProviderForClient_rootInstanceNull_and_notUsingSystemRootCerts_e /* rootInstanceName= */ null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null); + /* staticCertValidationContext= */ null, false); fail("exception expected"); } catch (UnsupportedOperationException expected) { assertThat(expected).hasMessageThat().contains("Unsupported configurations in " @@ -430,7 +433,7 @@ public void testProviderForClient_rootInstanceNull_but_isUsingSystemRootCerts_va CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build()); + .build(), false); } static class QueuedExecutor implements Executor { From 6c1898a7d90a29a33afcba5ff7b4040d7f02d382 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 8 Sep 2025 12:29:14 +0000 Subject: [PATCH 16/51] Save changes. --- .../grpc/xds/XdsSecurityClientServerTest.java | 2 +- .../ClientSslContextProviderFactoryTest.java | 10 +++--- .../security/CommonTlsContextTestsUtil.java | 25 +++++--------- ...tProviderClientSslContextProviderTest.java | 33 +++++++++++++------ 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index a92b2d868cf..23068d665bf 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -563,7 +563,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) - .build(), false); + .build()); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index e11efc44b49..397fe01e0f5 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -74,7 +74,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, false); + /* staticCertValidationContext= */ null); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -105,7 +105,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, false); + /* staticCertValidationContext= */ null); CommonTlsContext.Builder builder = upstreamTlsContext.getCommonTlsContext().toBuilder(); builder = addFilenames(builder, "foo.pem", "foo.key", "root.pem"); @@ -135,7 +135,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, false); + /* staticCertValidationContext= */ null); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -169,7 +169,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() "gcp_id", "root-default", /* alpnProtocols= */ null, - staticCertValidationContext, false); + staticCertValidationContext); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -199,7 +199,7 @@ public void createCertProviderClientSslContextProvider_2providers() "file_provider", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, false); + /* staticCertValidationContext= */ null); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 3171a43e05f..e39aae56a69 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -165,7 +165,7 @@ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( commonInstanceName, "ROOT", null, - null, false); + null); } /** Gets a cert from contents of a resource. */ @@ -182,8 +182,7 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( String rootInstanceName, String rootCertName, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext, - boolean useSystemRootCerts) { + CertificateValidationContext staticCertValidationContext) { CommonTlsContext.Builder builder = CommonTlsContext.newBuilder(); if (certInstanceName != null) { builder = @@ -194,8 +193,7 @@ private static CommonTlsContext buildCommonTlsContextForCertProviderInstance( } builder = addCertificateValidationContext( - builder, rootInstanceName, rootCertName, staticCertValidationContext, - useSystemRootCerts); + builder, rootInstanceName, rootCertName, staticCertValidationContext); if (alpnProtocols != null) { builder.addAllAlpnProtocols(alpnProtocols); } @@ -230,8 +228,7 @@ private static CommonTlsContext.Builder addCertificateValidationContext( CommonTlsContext.Builder builder, String rootInstanceName, String rootCertName, - CertificateValidationContext staticCertValidationContext, - boolean useSystemRootCerts) { + CertificateValidationContext staticCertValidationContext) { CertificateValidationContext.Builder contextBuilder; if (staticCertValidationContext == null) { contextBuilder = CertificateValidationContext.newBuilder(); @@ -243,10 +240,6 @@ private static CommonTlsContext.Builder addCertificateValidationContext( .setInstanceName(rootInstanceName) .setCertificateName(rootCertName)); builder.setValidationContext(contextBuilder.build()); - } else if (useSystemRootCerts) { - builder.setValidationContext(contextBuilder.setSystemRootCerts( - CertificateValidationContext.SystemRootCerts.getDefaultInstance()) - .build()); } return builder.setCombinedValidationContext(CombinedCertificateValidationContext.newBuilder() .setDefaultValidationContext(contextBuilder)); @@ -280,8 +273,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext, - boolean useSystemRootCerts) { + CertificateValidationContext staticCertValidationContext) { return buildUpstreamTlsContext( buildCommonTlsContextForCertProviderInstance( certInstanceName, @@ -289,8 +281,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext, - useSystemRootCerts)); + staticCertValidationContext)); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -329,8 +320,8 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext, -false), requireClientCert); + staticCertValidationContext), + requireClientCert); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index d3e00f5b59d..ab12d2f32d9 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -74,15 +74,26 @@ private CertProviderClientSslContextProvider getSslContextProvider( Iterable alpnProtocols, CertificateValidationContext staticCertValidationContext, boolean useSystemRootCerts) { - EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - certInstanceName, - "cert-default", - rootInstanceName, - "root-default", - alpnProtocols, - staticCertValidationContext, - useSystemRootCerts); + EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; + if (useSystemRootCerts) { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } else { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( + certInstanceName, + "cert-default", + rootInstanceName, + "root-default", + alpnProtocols, + staticCertValidationContext); + } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( upstreamTlsContext, @@ -187,7 +198,9 @@ public void testProviderForClient_systemRootCerts() throws Exception { null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts(CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), true); assertThat(provider.savedKey).isNull(); From 5be2aa2da33b6c341fb3899f215695ff68ec6b0d Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 8 Sep 2025 12:31:36 +0000 Subject: [PATCH 17/51] Save changes. --- .../internal/security/CommonTlsContextTestsUtil.java | 3 +-- .../CertProviderClientSslContextProviderTest.java | 12 ++++++------ 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index e39aae56a69..718695f3b3f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -320,8 +320,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext), - requireClientCert); + staticCertValidationContext), requireClientCert); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index ab12d2f32d9..3d04a1cdf17 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -68,12 +68,12 @@ public void setUp() throws Exception { /** Helper method to build CertProviderClientSslContextProvider. */ private CertProviderClientSslContextProvider getSslContextProvider( - String certInstanceName, - String rootInstanceName, - Bootstrapper.BootstrapInfo bootstrapInfo, - Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext, - boolean useSystemRootCerts) { + String certInstanceName, + String rootInstanceName, + Bootstrapper.BootstrapInfo bootstrapInfo, + Iterable alpnProtocols, + CertificateValidationContext staticCertValidationContext, + boolean useSystemRootCerts) { EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; if (useSystemRootCerts) { upstreamTlsContext = From 90abe55967760f18fe60d285d476eae4c6ea3a0b Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 9 Sep 2025 07:01:16 +0000 Subject: [PATCH 18/51] style --- ...tProviderClientSslContextProviderTest.java | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 3d04a1cdf17..0643295d7e5 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -193,15 +193,16 @@ public void testProviderForClient_systemRootCerts() throws Exception { TestCertificateProvider.createAndRegisterProviderProvider( certificateProviderRegistry, watcherCaptor, "testca", 0); CertProviderClientSslContextProvider provider = - getSslContextProvider( - "gcp_id", - null, - CommonBootstrapperTestUtils.getTestBootstrapInfo(), - /* alpnProtocols= */ null, - CertificateValidationContext.newBuilder() - .setSystemRootCerts(CertificateValidationContext.SystemRootCerts.getDefaultInstance()) - .build(), - true); + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); From 4c44e4c66806abd40daf11b53c3691f0e889fcac Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 9 Sep 2025 09:07:40 +0000 Subject: [PATCH 19/51] Add comment and rename some confusing method names. --- .../certprovider/CertProviderSslContextProvider.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 801dabeecb7..d2bd0d09f93 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -155,12 +155,12 @@ private void updateSslContextWhenReady() { updateSslContext(); clearKeysAndCerts(); } - } else if (isClientSideTls()) { + } else if (isNormalTlsAndClientSide()) { if (savedTrustedRoots != null || savedSpiffeTrustMap != null) { updateSslContext(); clearKeysAndCerts(); } - } else if (isServerSideTls()) { + } else if (isNormalTlsAndServerSide()) { if (savedKey != null) { updateSslContext(); clearKeysAndCerts(); @@ -179,11 +179,13 @@ protected final boolean isMtls() { return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } - protected final boolean isClientSideTls() { + protected final boolean isNormalTlsAndClientSide() { + // We don't do (rootCertInstance != null || isUsingSystemRootCerts) here because of where this method is called + // from. With the rootCertInstance being null when using system root certs, there is nothing to update. return rootCertInstance != null && certInstance == null; } - protected final boolean isServerSideTls() { + protected final boolean isNormalTlsAndServerSide() { return certInstance != null && rootCertInstance == null; } From 199cc69a359bb8108ea1430e185d1ecc139b3171 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 9 Sep 2025 09:24:20 +0000 Subject: [PATCH 20/51] style. --- .../certprovider/CertProviderSslContextProvider.java | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 234d4a06654..b88cfd2b032 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -183,8 +183,9 @@ protected final boolean isMtls() { } protected final boolean isNormalTlsAndClientSide() { - // We don't do (rootCertInstance != null || isUsingSystemRootCerts) here because of where this method is called - // from. With the rootCertInstance being null when using system root certs, there is nothing to update. + // We don't do (rootCertInstance != null || isUsingSystemRootCerts) here because of how this + // method is used. With the rootCertInstance being null when using system root certs, there + // is nothing to update in the SslContext return rootCertInstance != null && certInstance == null; } From 37cd044cfa34b038337e0d4e3b543edd5a7bb7a6 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 10 Sep 2025 08:27:30 +0000 Subject: [PATCH 21/51] Handle Sslcontext updates for System root certs with and without Mtls. --- .../security/SslContextProviderSupplier.java | 47 ++++++---- .../SslContextProviderSupplierTest.java | 88 ++++++++++++++++--- ...tProviderClientSslContextProviderTest.java | 64 +++++++++++--- 3 files changed, 161 insertions(+), 38 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 5f629273179..b8113bd4752 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -20,12 +20,14 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; +import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; import java.util.Objects; +import javax.net.ssl.SSLException; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -62,21 +64,36 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(); - toRelease.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); - releaseSslContextProvider(toRelease); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease); - } - }); + // When using system root certs on client side, SslContext updates via CertificateProvider is + // only required if Mtls is also enabled, i.e. tlsContext has a cert provider instance. + if (tlsContext instanceof UpstreamTlsContext + && !CommonTlsContextUtil.hasCertProviderInstance(tlsContext.getCommonTlsContext()) + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { + callback.getExecutor().execute(() -> { + try { + callback.updateSslContext(GrpcSslContexts.forClient().build()); + releaseSslContextProvider(toRelease); + } catch (SSLException e) { + callback.onException(e); + } + }); + } else { + toRelease.addCallback( + new SslContextProvider.Callback(callback.getExecutor()) { + + @Override + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); + releaseSslContextProvider(toRelease); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease); + } + }); + } } catch (final Throwable throwable) { callback.getExecutor().execute(new Runnable() { @Override diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f476818297d..999b5e2cec5 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -17,15 +17,18 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; -import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; @@ -47,14 +50,17 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private TlsContextManager mockTlsContextManager; + @Mock private Executor mockExecutor; private SslContextProviderSupplier supplier; private SslContextProvider mockSslContextProvider; private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; private SslContextProvider.Callback mockCallback; - private void prepareSupplier() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + private void prepareSupplier(boolean createUpstreamTlsContext) { + if (createUpstreamTlsContext) { + upstreamTlsContext = + buildUpstreamTlsContext("google_cloud_private_spiffe", true); + } mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) @@ -64,14 +70,13 @@ private void prepareSupplier() { private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); - Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); supplier.updateSslContext(mockCallback); } @Test public void get_updateSecret() { - prepareSupplier(); + prepareSupplier(true); callUpdateSslContext(); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); @@ -95,11 +100,12 @@ public void get_updateSecret() { @Test public void get_onException() { - prepareSupplier(); + prepareSupplier(true); callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + verify(mockSslContextProvider, times(1)) + .addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); Exception exception = new Exception("test"); @@ -109,9 +115,71 @@ public void get_onException() { .releaseClientSslContextProvider(eq(mockSslContextProvider)); } + @Test + public void systemRootCertsWithMtls_callbackExecutedFromProvider() { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + "gcp_id", + "cert-default", + null, + "root-default", + null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build()); + prepareSupplier(false); + + callUpdateSslContext(); + + verify(mockTlsContextManager, times(2)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + verify(mockTlsContextManager, times(0)) + .releaseClientSslContextProvider(any(SslContextProvider.class)); + ArgumentCaptor callbackCaptor = + ArgumentCaptor.forClass(SslContextProvider.Callback.class); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); + SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); + assertThat(capturedCallback).isNotNull(); + SslContext mockSslContext = mock(SslContext.class); + capturedCallback.updateSslContext(mockSslContext); + verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); + supplier.updateSslContext(mockCallback); + verify(mockTlsContextManager, times(3)) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + } + + @Test + public void systemRootCertsWithRegularTls_callbackExecutedFromSupplier() { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( + null, + null, + null, + "root-default", + null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build()); + supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); + reset(mockTlsContextManager); + + callUpdateSslContext(); + ArgumentCaptor runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class); + verify(mockExecutor).execute(runnableArgumentCaptor.capture()); + runnableArgumentCaptor.getValue().run(); + verify(mockCallback, times(1)).updateSslContext(any(SslContext.class)); + verify(mockTlsContextManager, times(1)) + .releaseClientSslContextProvider(eq(mockSslContextProvider)); + } + @Test public void testClose() { - prepareSupplier(); + prepareSupplier(true); callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) @@ -125,7 +193,7 @@ public void testClose() { @Test public void testClose_nullSslContextProvider() { - prepareSupplier(); + prepareSupplier(true); doThrow(new NullPointerException()).when(mockTlsContextManager) .releaseClientSslContextProvider(null); supplier.close(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 0643295d7e5..07d422a97d0 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -187,14 +187,19 @@ public void testProviderForClient_mtls() throws Exception { } @Test - public void testProviderForClient_systemRootCerts() throws Exception { + /** + * Note this route will not really be invoked since {@link SslContextProviderSupplier} will + * shortcircuit creating the certificate provider and directly invoke the callback with the + * SslContext in this case. + */ + public void testProviderForClient_systemRootCerts_regularTls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; TestCertificateProvider.createAndRegisterProviderProvider( certificateProviderRegistry, watcherCaptor, "testca", 0); CertProviderClientSslContextProvider provider = getSslContextProvider( - "gcp_id", + null, null, CommonBootstrapperTestUtils.getTestBootstrapInfo(), /* alpnProtocols= */ null, @@ -209,36 +214,69 @@ public void testProviderForClient_systemRootCerts() throws Exception { assertThat(provider.savedTrustedRoots).isNull(); assertThat(provider.getSslContext()).isNull(); - // now generate cert update, updates SslContext + assertThat(watcherCaptor[0]).isNull(); + } + + @Test + public void testProviderForClient_systemRootCerts_mtls() throws Exception { + final CertificateProvider.DistributorWatcher[] watcherCaptor = + new CertificateProvider.DistributorWatcher[1]; + TestCertificateProvider.createAndRegisterProviderProvider( + certificateProviderRegistry, watcherCaptor, "testca", 0); + CertProviderClientSslContextProvider provider = + getSslContextProvider( + "gcp_id", + null, + CommonBootstrapperTestUtils.getTestBootstrapInfo(), + /* alpnProtocols= */ null, + CertificateValidationContext.newBuilder() + .setSystemRootCerts( + CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(), + true); + + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.getSslContext()).isNull(); + + // now generate root cert update, will get ignored because of systemRootCerts config + watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.getSslContext()).isNull(); + assertThat(provider.savedKey).isNull(); + assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNull(); + + // now generate cert update watcherCaptor[0].updateCertificate( - CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), - ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); + CommonCertProviderTestUtils.getPrivateKey(CLIENT_KEY_FILE), + ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.getSslContext()).isNotNull(); TestCallback testCallback = - CommonTlsContextTestsUtil.getValueThruCallback(provider); + CommonTlsContextTestsUtil.getValueThruCallback(provider); doChecksOnSslContext(false, testCallback.updatedSslContext, /* expectedApnProtos= */ null); TestCallback testCallback1 = - CommonTlsContextTestsUtil.getValueThruCallback(provider); + CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); - // just do root cert update: trusted roots is not updated (because of system root certs config) - // and sslContext should still be the same + // just do root cert update: sslContext should still be the same, will get ignored because of + // systemRootCerts config watcherCaptor[0].updateTrustedRoots( - ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); + ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); - // now update id cert: sslContext should be updated i.e.different from the previous one + // now update id cert: sslContext should be updated i.e. different from the previous one watcherCaptor[0].updateCertificate( - CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), - ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); + CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), + ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); From 6958b4e399830397e991047894efcf77e6848f30 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 10 Sep 2025 09:35:00 +0000 Subject: [PATCH 22/51] Merge fixes --- .../S2AProtocolNegotiatorFactory.java | 3 +- .../security/SslContextProviderSupplier.java | 8 +- .../io/grpc/xds/CdsLoadBalancer2Test.java | 875 +++++++----------- .../SslContextProviderSupplierTest.java | 18 +- ...tProviderClientSslContextProviderTest.java | 3 +- 5 files changed, 371 insertions(+), 536 deletions(-) diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 03976cc7d7b..81aca657e40 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -259,7 +259,8 @@ public void onSuccess(SslContext sslContext) { public void run() { s2aStub.close(); } - })) + }), + null) .newHandler(grpcHandler); // Delegate the rest of the handshake to the TLS handler. and remove the diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 66ea98e3948..8baa0e26796 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -68,7 +68,7 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call } // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); + final SslContextProvider toRelease = getSslContextProvider(sni); // When using system root certs on client side, SslContext updates via CertificateProvider is // only required if Mtls is also enabled, i.e. tlsContext has a cert provider instance. if (tlsContext instanceof UpstreamTlsContext @@ -77,7 +77,7 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call callback.getExecutor().execute(() -> { try { callback.updateSslContext(GrpcSslContexts.forClient().build()); - releaseSslContextProvider(toRelease); + releaseSslContextProvider(toRelease, sni); } catch (SSLException e) { callback.onException(e); } @@ -89,13 +89,13 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call @Override public void updateSslContext(SslContext sslContext) { callback.updateSslContext(sslContext); - releaseSslContextProvider(toRelease); + releaseSslContextProvider(toRelease, sni); } @Override public void onException(Throwable throwable) { callback.onException(throwable); - releaseSslContextProvider(toRelease); + releaseSslContextProvider(toRelease, sni); } }); } diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index ea349da6905..771f8c2596d 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -33,6 +33,29 @@ import com.github.xds.type.v3.TypedStruct; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; +import com.google.protobuf.Any; +import com.google.protobuf.Struct; +import com.google.protobuf.UInt32Value; +import com.google.protobuf.Value; +import io.envoyproxy.envoy.config.cluster.v3.CircuitBreakers; +import io.envoyproxy.envoy.config.cluster.v3.Cluster; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy; +import io.envoyproxy.envoy.config.cluster.v3.LoadBalancingPolicy.Policy; +import io.envoyproxy.envoy.config.cluster.v3.OutlierDetection; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.AggregatedConfigSource; +import io.envoyproxy.envoy.config.core.v3.ConfigSource; +import io.envoyproxy.envoy.config.core.v3.RoutingPriority; +import io.envoyproxy.envoy.config.core.v3.SelfConfigSource; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.config.core.v3.TransportSocket; +import io.envoyproxy.envoy.config.core.v3.TypedExtensionConfig; +import io.envoyproxy.envoy.config.endpoint.v3.ClusterLoadAssignment; +import io.envoyproxy.envoy.config.endpoint.v3.Endpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LbEndpoint; +import io.envoyproxy.envoy.config.endpoint.v3.LocalityLbEndpoints; +import io.envoyproxy.envoy.extensions.clusters.aggregate.v3.ClusterConfig; +import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext; import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ConnectivityState; @@ -94,19 +117,15 @@ public class CdsLoadBalancer2Test { private static final String EDS_SERVICE_NAME = "backend-service-1.googleapis.com"; private static final String NODE_ID = "node-id"; private final io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-name", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("cert-instance-name", true, null, false); private static final Cluster EDS_CLUSTER = Cluster.newBuilder() .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() .setServiceName(EDS_SERVICE_NAME) .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setAds(AggregatedConfigSource.newBuilder()))) .build(); - private final UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); - private final OutlierDetection outlierDetection = OutlierDetection.create( - null, null, null, null, SuccessRateEjection.create(null, null, null, null), null); private final FakeClock fakeClock = new FakeClock(); private final LoadBalancerRegistry lbRegistry = new LoadBalancerRegistry(); @@ -116,9 +135,9 @@ public class CdsLoadBalancer2Test { Arrays.asList("control-plane.example.com"), serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( InProcessChannelBuilder - .forName(serverInfo.target()) - .directExecutor() - .build()), + .forName(serverInfo.target()) + .directExecutor() + .build()), fakeClock); private final ServerInfo lrsServerInfo = xdsClient.getBootstrapInfo().servers().get(0); private XdsDependencyManager xdsDepManager; @@ -139,7 +158,7 @@ public void setUp() throws Exception { lbRegistry.register(new FakeLoadBalancerProvider("least_request_experimental", new LeastRequestLoadBalancerProvider())); lbRegistry.register(new FakeLoadBalancerProvider("wrr_locality_experimental", - new WrrLocalityLoadBalancerProvider())); + new WrrLocalityLoadBalancerProvider())); CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); lbRegistry.register(cdsLoadBalancerProvider); loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); @@ -208,16 +227,16 @@ public void discoverTopLevelEdsCluster() { .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) .setLrsServer(ConfigSource.newBuilder() - .setSelf(SelfConfigSource.getDefaultInstance())) + .setSelf(SelfConfigSource.getDefaultInstance())) .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .setTransportSocket(TransportSocket.newBuilder() .setName("envoy.transport_sockets.tls") .setTypedConfig(Any.pack(UpstreamTlsContext.newBuilder() @@ -235,10 +254,10 @@ public void discoverTopLevelEdsCluster() { ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( DiscoveryMechanism.forEds( - CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, upstreamTlsContext, - Collections.emptyMap(), io.grpc.xds.EnvoyServerProtoData.OutlierDetection.create( - null, null, null, null, SuccessRateEjection.create(null, null, null, null), - FailurePercentageEjection.create(null, null, null, null)))); + CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, upstreamTlsContext, + Collections.emptyMap(), io.grpc.xds.EnvoyServerProtoData.OutlierDetection.create( + null, null, null, null, SuccessRateEjection.create(null, null, null, null), + FailurePercentageEjection.create(null, null, null, null)))); assertThat( GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) .isEqualTo("wrr_locality_experimental"); @@ -250,24 +269,24 @@ public void discoverTopLevelLogicalDnsCluster() { .setName(CLUSTER) .setType(Cluster.DiscoveryType.LOGICAL_DNS) .setLoadAssignment(ClusterLoadAssignment.newBuilder() - .addEndpoints(LocalityLbEndpoints.newBuilder() - .addLbEndpoints(LbEndpoint.newBuilder() - .setEndpoint(Endpoint.newBuilder() - .setAddress(Address.newBuilder() - .setSocketAddress(SocketAddress.newBuilder() - .setAddress("dns.example.com") - .setPortValue(1111))))))) + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .setLbPolicy(Cluster.LbPolicy.LEAST_REQUEST) .setLrsServer(ConfigSource.newBuilder() - .setSelf(SelfConfigSource.getDefaultInstance())) + .setSelf(SelfConfigSource.getDefaultInstance())) .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .setTransportSocket(TransportSocket.newBuilder() .setName("envoy.transport_sockets.tls") .setTypedConfig(Any.pack(UpstreamTlsContext.newBuilder() @@ -284,8 +303,8 @@ public void discoverTopLevelLogicalDnsCluster() { ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( DiscoveryMechanism.forLogicalDns( - CLUSTER, "dns.example.com:1111", lrsServerInfo, 100L, upstreamTlsContext, - Collections.emptyMap())); + CLUSTER, "dns.example.com:1111", lrsServerInfo, 100L, upstreamTlsContext, + Collections.emptyMap())); assertThat( GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) .isEqualTo("wrr_locality_experimental"); @@ -307,8 +326,8 @@ public void nonAggregateCluster_resourceUpdate() { Cluster cluster = EDS_CLUSTER.toBuilder() .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .build(); controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); startXdsDepManager(); @@ -317,511 +336,378 @@ public void nonAggregateCluster_resourceUpdate() { assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, null, null, null, - 100L, upstreamTlsContext, outlierDetection); + assertThat(childLbConfig.discoveryMechanism).isEqualTo( + DiscoveryMechanism.forEds( + CLUSTER, EDS_SERVICE_NAME, null, 100L, null, Collections.emptyMap(), null)); - update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, null, - outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + cluster = EDS_CLUSTER.toBuilder() + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(200)))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); + childBalancer = Iterables.getOnlyElement(childBalancers); childLbConfig = (ClusterResolverConfig) childBalancer.config; - instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 200L, null, outlierDetection); + assertThat(childLbConfig.discoveryMechanism).isEqualTo( + DiscoveryMechanism.forEds( + CLUSTER, EDS_SERVICE_NAME, null, 200L, null, Collections.emptyMap(), null)); } @Test public void nonAggregateCluster_resourceRevoked() { - CdsUpdate update = - CdsUpdate.forLogicalDns(CLUSTER, DNS_HOST_NAME, null, 100L, upstreamTlsContext, - false) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, EDS_CLUSTER)); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, CLUSTER, DiscoveryMechanism.Type.LOGICAL_DNS, null, - DNS_HOST_NAME, null, 100L, upstreamTlsContext, null); + assertThat(childLbConfig.discoveryMechanism).isEqualTo( + DiscoveryMechanism.forEds( + CLUSTER, EDS_SERVICE_NAME, null, null, null, Collections.emptyMap(), null)); + + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of()); - xdsClient.deliverResourceNotExist(CLUSTER); assertThat(childBalancer.shutdown).isTrue(); Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER - + " xDS node ID: " + NODE_ID); + "CDS resource " + CLUSTER + " does not exist nodeID: " + NODE_ID); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), unavailable, null); + assertPickerStatus(pickerCaptor.getValue(), unavailable); assertThat(childBalancer.shutdown).isTrue(); assertThat(childBalancers).isEmpty(); } @Test - public void discoverAggregateCluster() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .ringHashLbPolicy(100L, 1000L).build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - assertThat(childBalancers).isEmpty(); - String cluster3 = "cluster-03.googleapis.com"; - String cluster4 = "cluster-04.googleapis.com"; - // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] - CdsUpdate update1 = - CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster4)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - assertThat(xdsClient.watchers.keySet()).containsExactly( - CLUSTER, cluster1, cluster2, cluster3, cluster4); - assertThat(childBalancers).isEmpty(); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - assertThat(childBalancers).isEmpty(); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, null, 100L, null, false) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(childBalancers).isEmpty(); - CdsUpdate update4 = - CdsUpdate.forEds(cluster4, null, LRS_SERVER_INFO, 300L, null, outlierDetection, false) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster4, update4); - assertThat(childBalancers).hasSize(1); // all non-aggregate clusters discovered - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - assertThat(childBalancer.name).isEqualTo(CLUSTER_RESOLVER_POLICY_NAME); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(3); - // Clusters on higher level has higher priority: [cluster2, cluster3, cluster4] - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, null, 100L, null, null); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster3, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(2), cluster4, - DiscoveryMechanism.Type.EDS, null, null, LRS_SERVER_INFO, 300L, null, outlierDetection); - assertThat( - GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) - .isEqualTo("ring_hash_experimental"); // dominated by top-level cluster's config - RingHashConfig ringHashConfig = (RingHashConfig) - GracefulSwitchLoadBalancerAccessor.getChildConfig(childLbConfig.lbConfig); - assertThat(ringHashConfig.minRingSize).isEqualTo(100L); - assertThat(ringHashConfig.maxRingSize).isEqualTo(1000L); - } - - @Test - public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - xdsClient.deliverResourceNotExist(cluster1); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER - + " xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancers).isEmpty(); - } + public void dynamicCluster() { + String clusterName = "cluster2"; + Cluster cluster = EDS_CLUSTER.toBuilder() + .setName(clusterName) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + clusterName, cluster, + CLUSTER, Cluster.newBuilder().setName(CLUSTER).build())); + startXdsDepManager(new CdsConfig(clusterName, /*dynamic=*/ true)); - @Test - public void aggregateCluster_descendantClustersRevoked() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - false) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(2); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - - // Revoke cluster1, should still be able to proceed with cluster2. - xdsClient.deliverResourceNotExist(cluster1); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - assertDiscoveryMechanism(Iterables.getOnlyElement(childLbConfig.discoveryMechanisms), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - verify(helper, never()).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), any(SubchannelPicker.class)); - - // All revoked. - xdsClient.deliverResourceNotExist(cluster2); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER - + " xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); - } + assertThat(childLbConfig.discoveryMechanism).isEqualTo( + DiscoveryMechanism.forEds( + clusterName, EDS_SERVICE_NAME, null, null, null, Collections.emptyMap(), null)); - @Test - public void aggregateCluster_rootClusterRevoked() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Arrays.asList(cluster1, cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - CdsUpdate update1 = CdsUpdate.forEds(cluster1, EDS_SERVICE_NAME, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - CdsUpdate update2 = - CdsUpdate.forLogicalDns(cluster2, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - false) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(2); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(0), cluster1, - DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, null, LRS_SERVER_INFO, 200L, - upstreamTlsContext, outlierDetection); - assertDiscoveryMechanism(childLbConfig.discoveryMechanisms.get(1), cluster2, - DiscoveryMechanism.Type.LOGICAL_DNS, null, DNS_HOST_NAME, LRS_SERVER_INFO, 100L, null, - null); - - xdsClient.deliverResourceNotExist(CLUSTER); - assertThat(xdsClient.watchers.keySet()) - .containsExactly(CLUSTER); // subscription to all descendant clusters cancelled - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER - + " xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); + assertThat(this.lastXdsConfig.getClusters()).containsKey(clusterName); + shutdownLoadBalancer(); + assertThat(this.lastXdsConfig.getClusters()).doesNotContainKey(clusterName); } @Test - public void aggregateCluster_intermediateClusterChanges() { + public void discoverAggregateCluster_createsPriorityLbPolicy() { + lbRegistry.register(new FakeLoadBalancerProvider(PRIORITY_POLICY_NAME)); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); + String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2); - - // cluster2 (aggr.) -> [cluster3 (EDS)] String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Collections.singletonList(cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster2, cluster3); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); + String cluster4 = "cluster-04.googleapis.com"; + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (aggr.), cluster2 (logical DNS), cluster3 (EDS)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .addClusters(cluster3) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build(), + // cluster1 (aggr.) -> [cluster3 (EDS), cluster4 (EDS)] + cluster1, Cluster.newBuilder() + .setName(cluster1) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster3) + .addClusters(cluster4) + .build()))) + .build(), + cluster2, Cluster.newBuilder() + .setName(cluster2) + .setType(Cluster.DiscoveryType.LOGICAL_DNS) + .setLoadAssignment(ClusterLoadAssignment.newBuilder() + .addEndpoints(LocalityLbEndpoints.newBuilder() + .addLbEndpoints(LbEndpoint.newBuilder() + .setEndpoint(Endpoint.newBuilder() + .setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress("dns.example.com") + .setPortValue(1111))))))) + .build(), + cluster3, EDS_CLUSTER.toBuilder() + .setName(cluster3) + .setCircuitBreakers(CircuitBreakers.newBuilder() + .addThresholds(CircuitBreakers.Thresholds.newBuilder() + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .build(), + cluster4, EDS_CLUSTER.toBuilder().setName(cluster4).build())); + startXdsDepManager(); + + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(1); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); - - // cluster2 revoked - xdsClient.deliverResourceNotExist(cluster2); - assertThat(xdsClient.watchers.keySet()) - .containsExactly(CLUSTER, cluster2); // cancelled subscription to cluster3 - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: found 0 leaf (logical DNS or EDS) clusters for root cluster " + CLUSTER - + " xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), unavailable, null); - assertThat(childBalancer.shutdown).isTrue(); - assertThat(childBalancers).isEmpty(); + assertThat(childBalancer.name).isEqualTo(PRIORITY_POLICY_NAME); + PriorityLoadBalancerProvider.PriorityLbConfig childLbConfig = + (PriorityLoadBalancerProvider.PriorityLbConfig) childBalancer.config; + assertThat(childLbConfig.priorities).hasSize(3); + assertThat(childLbConfig.priorities.get(0)).isEqualTo(cluster3); + assertThat(childLbConfig.priorities.get(1)).isEqualTo(cluster4); + assertThat(childLbConfig.priorities.get(2)).isEqualTo(cluster2); + assertThat(childLbConfig.childConfigs).hasSize(3); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig3 = + childLbConfig.childConfigs.get(cluster3); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig3.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig4 = + childLbConfig.childConfigs.get(cluster4); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig4.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); + PriorityLoadBalancerProvider.PriorityLbConfig.PriorityChildConfig childConfig2 = + childLbConfig.childConfigs.get(cluster2); + assertThat( + GracefulSwitchLoadBalancerAccessor.getChildProvider(childConfig2.childConfig) + .getPolicyName()) + .isEqualTo("cds_experimental"); } @Test - public void aggregateCluster_withLoops() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] - String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); - - // cluster2 (aggr.) -> [cluster3 (EDS), cluster1 (parent), cluster2 (self), cluster3 (dup)] - String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); - - reset(helper); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" - + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com], xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), unavailable, null); - } + // Both priorities will get tried using real priority LB policy. + public void discoverAggregateCluster_testChildCdsLbPolicyParsing() { + lbRegistry.register(new PriorityLoadBalancerProvider()); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); - @Test - public void aggregateCluster_withLoops_afterEds() { String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // CLUSTER (aggr.) -> [cluster2 (aggr.)] String cluster2 = "cluster-02.googleapis.com"; - update = - CdsUpdate.forAggregate(cluster1, Collections.singletonList(cluster2)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (EDS), cluster2 (EDS)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .addClusters(cluster2) + .build()))) + .build(), + cluster1, EDS_CLUSTER.toBuilder().setName(cluster1).build(), + cluster2, EDS_CLUSTER.toBuilder().setName(cluster2).build())); + startXdsDepManager(); - String cluster3 = "cluster-03.googleapis.com"; - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - - // cluster2 (aggr.) -> [cluster3 (EDS)] - CdsUpdate update2a = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster1, cluster2, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2a); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1, cluster2, cluster3); - verify(helper).updateBalancingState( - eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status unavailable = Status.UNAVAILABLE.withDescription( - "CDS error: circular aggregate clusters directly under cluster-02.googleapis.com for root" - + " cluster cluster-foo.googleapis.com, named [cluster-01.googleapis.com," - + " cluster-02.googleapis.com], xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), unavailable, null); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); + assertThat(childBalancers).hasSize(2); + ClusterResolverConfig cluster1ResolverConfig = + (ClusterResolverConfig) childBalancers.get(0).config; + assertThat(cluster1ResolverConfig.discoveryMechanism.cluster) + .isEqualTo("cluster-01.googleapis.com"); + assertThat(cluster1ResolverConfig.discoveryMechanism.type) + .isEqualTo(DiscoveryMechanism.Type.EDS); + assertThat(cluster1ResolverConfig.discoveryMechanism.edsServiceName) + .isEqualTo("backend-service-1.googleapis.com"); + ClusterResolverConfig cluster2ResolverConfig = + (ClusterResolverConfig) childBalancers.get(1).config; + assertThat(cluster2ResolverConfig.discoveryMechanism.cluster) + .isEqualTo("cluster-02.googleapis.com"); + assertThat(cluster2ResolverConfig.discoveryMechanism.type) + .isEqualTo(DiscoveryMechanism.Type.EDS); + assertThat(cluster2ResolverConfig.discoveryMechanism.edsServiceName) + .isEqualTo("backend-service-1.googleapis.com"); } @Test - public void aggregateCluster_duplicateChildren() { - String cluster1 = "cluster-01.googleapis.com"; - String cluster2 = "cluster-02.googleapis.com"; - String cluster3 = "cluster-03.googleapis.com"; - String cluster4 = "cluster-04.googleapis.com"; - - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - - // cluster1 (aggr) -> [cluster3 (EDS), cluster2 (aggr), cluster4 (aggr)] - CdsUpdate update1 = - CdsUpdate.forAggregate(cluster1, Arrays.asList(cluster3, cluster2, cluster4, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - assertThat(xdsClient.watchers.keySet()).containsExactly( - cluster3, cluster4, cluster2, cluster1, CLUSTER); - xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); - - // cluster2 (agg) -> [cluster3 (EDS), cluster4 {agg}] with dups - CdsUpdate update2 = - CdsUpdate.forAggregate(cluster2, Arrays.asList(cluster3, cluster4, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster2, update2); - - // Define EDS cluster - CdsUpdate update3 = CdsUpdate.forEds(cluster3, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster3, update3); - - // cluster4 (agg) -> [cluster3 (EDS)] with dups (3 copies) - CdsUpdate update4 = - CdsUpdate.forAggregate(cluster4, Arrays.asList(cluster3, cluster3, cluster3)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster4, update4); - xdsClient.watchers.values().forEach(list -> assertThat(list.size()).isEqualTo(1)); + public void aggregateCluster_noChildren() { + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .build()))) + .build())); + startXdsDepManager(); - FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - DiscoveryMechanism instance = Iterables.getOnlyElement(childLbConfig.discoveryMechanisms); - assertDiscoveryMechanism(instance, cluster3, DiscoveryMechanism.Type.EDS, EDS_SERVICE_NAME, - null, LRS_SERVER_INFO, 100L, upstreamTlsContext, outlierDetection); + verify(helper) + .updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()) + .contains("aggregate ClusterConfig.clusters must not be empty"); + assertThat(childBalancers).isEmpty(); } @Test - public void aggregateCluster_discoveryErrorBeforeChildLbCreated_returnErrorPicker() { + public void aggregateCluster_noNonAggregateClusterExits_returnErrorPicker() { + lbRegistry.register(new PriorityLoadBalancerProvider()); + CdsLoadBalancerProvider cdsLoadBalancerProvider = new CdsLoadBalancerProvider(lbRegistry); + lbRegistry.register(cdsLoadBalancerProvider); + loadBalancer = (CdsLoadBalancer2) cdsLoadBalancerProvider.newLoadBalancer(helper); + String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - assertThat(xdsClient.watchers.keySet()).containsExactly(CLUSTER, cluster1); - Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); - xdsClient.deliverError(error); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of( + // CLUSTER (aggr.) -> [cluster1 (missing)] + CLUSTER, Cluster.newBuilder() + .setName(CLUSTER) + .setClusterType(Cluster.CustomClusterType.newBuilder() + .setName("envoy.clusters.aggregate") + .setTypedConfig(Any.pack(ClusterConfig.newBuilder() + .addClusters(cluster1) + .build()))) + .setLbPolicy(Cluster.LbPolicy.RING_HASH) + .build())); + startXdsDepManager(); + verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - Status expectedError = Status.UNAVAILABLE.withDescription( - "Unable to load CDS cluster-foo.googleapis.com. xDS server returned: " - + "RESOURCE_EXHAUSTED: OOM xDS node ID: " + NODE_ID); - assertPicker(pickerCaptor.getValue(), expectedError, null); + Status status = Status.UNAVAILABLE.withDescription( + "CDS resource " + cluster1 + " does not exist nodeID: " + NODE_ID); + assertPickerStatus(pickerCaptor.getValue(), status); assertThat(childBalancers).isEmpty(); } @Test - public void aggregateCluster_discoveryErrorAfterChildLbCreated_propagateToChildLb() { - String cluster1 = "cluster-01.googleapis.com"; - // CLUSTER (aggr.) -> [cluster1 (logical DNS)] - CdsUpdate update = - CdsUpdate.forAggregate(CLUSTER, Collections.singletonList(cluster1)) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); - CdsUpdate update1 = - CdsUpdate.forLogicalDns(cluster1, DNS_HOST_NAME, LRS_SERVER_INFO, 200L, null, - false) - .roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(cluster1, update1); - FakeLoadBalancer childLb = Iterables.getOnlyElement(childBalancers); - ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childLb.config; - assertThat(childLbConfig.discoveryMechanisms).hasSize(1); - - Status error = Status.RESOURCE_EXHAUSTED.withDescription("OOM"); - xdsClient.deliverError(error); - assertThat(childLb.upstreamError.getCode()).isEqualTo(Status.Code.UNAVAILABLE); - assertThat(childLb.upstreamError.getDescription()).contains("RESOURCE_EXHAUSTED: OOM"); - assertThat(childLb.shutdown).isFalse(); // child LB may choose to keep working - } - - @Test - public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_returnErrorPicker() { - Status upstreamError = Status.UNAVAILABLE.withDescription( - "unreachable xDS node ID: " + NODE_ID); - loadBalancer.handleNameResolutionError(upstreamError); + public void handleNameResolutionErrorFromUpstream_beforeChildLbCreated_failingPicker() { + Status status = Status.UNAVAILABLE.withDescription("unreachable"); + loadBalancer.handleNameResolutionError(status); verify(helper).updateBalancingState( eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); - assertPicker(pickerCaptor.getValue(), upstreamError, null); + assertPickerStatus(pickerCaptor.getValue(), status); } @Test public void handleNameResolutionErrorFromUpstream_afterChildLbCreated_fallThrough() { - CdsUpdate update = CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, - upstreamTlsContext, outlierDetection, false).roundRobinLbPolicy().build(); - xdsClient.deliverCdsUpdate(CLUSTER, update); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper, never()).updateBalancingState(eq(ConnectivityState.TRANSIENT_FAILURE), any()); FakeLoadBalancer childBalancer = Iterables.getOnlyElement(childBalancers); assertThat(childBalancer.shutdown).isFalse(); + loadBalancer.handleNameResolutionError(Status.UNAVAILABLE.withDescription("unreachable")); assertThat(childBalancer.upstreamError.getCode()).isEqualTo(Code.UNAVAILABLE); assertThat(childBalancer.upstreamError.getDescription()).isEqualTo("unreachable"); - verify(helper, never()).updateBalancingState( - any(ConnectivityState.class), any(SubchannelPicker.class)); + verify(helper).updateBalancingState( + eq(ConnectivityState.CONNECTING), any(SubchannelPicker.class)); } @Test public void unknownLbProvider() { - try { - xdsClient.deliverCdsUpdate(CLUSTER, - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection, false) - .lbPolicyConfig(ImmutableMap.of("unknownLb", ImmutableMap.of("foo", "bar"))).build()); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("unknownLb"); - return; - } - fail("Expected the unknown LB to cause an exception"); + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/unknownLb") + .setValue(Struct.getDefaultInstance()) + .build()))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("Invalid LoadBalancingPolicy"); } @Test public void invalidLbConfig() { - try { - xdsClient.deliverCdsUpdate(CLUSTER, - CdsUpdate.forEds(CLUSTER, EDS_SERVICE_NAME, LRS_SERVER_INFO, 100L, upstreamTlsContext, - outlierDetection, false).lbPolicyConfig( - ImmutableMap.of("ring_hash_experimental", ImmutableMap.of("minRingSize", "-1"))) + Cluster cluster = Cluster.newBuilder() + .setName(CLUSTER) + .setType(Cluster.DiscoveryType.EDS) + .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) + .setLoadBalancingPolicy(LoadBalancingPolicy.newBuilder() + .addPolicies(Policy.newBuilder() + .setTypedExtensionConfig(TypedExtensionConfig.newBuilder() + .setTypedConfig(Any.pack(TypedStruct.newBuilder() + .setTypeUrl("type.googleapis.com/ring_hash_experimental") + .setValue(Struct.newBuilder() + .putFields("minRingSize", Value.newBuilder().setNumberValue(-1).build())) + .build()))))) + .build(); + controlPlaneService.setXdsConfig(ADS_TYPE_URL_CDS, ImmutableMap.of(CLUSTER, cluster)); + startXdsDepManager(); + verify(helper).updateBalancingState( + eq(ConnectivityState.TRANSIENT_FAILURE), pickerCaptor.capture()); + PickResult result = pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)); + Status actualStatus = result.getStatus(); + assertThat(actualStatus.getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(actualStatus.getDescription()).contains("Invalid 'minRingSize'"); + } + + private void startXdsDepManager() { + startXdsDepManager(new CdsConfig(CLUSTER)); + } + + private void startXdsDepManager(final CdsConfig cdsConfig) { + xdsDepManager.start( + xdsConfig -> { + if (!xdsConfig.hasValue()) { + throw new AssertionError("" + xdsConfig.getStatus()); + } + this.lastXdsConfig = xdsConfig.getValue(); + if (loadBalancer == null) { + return; + } + loadBalancer.acceptResolvedAddresses(ResolvedAddresses.newBuilder() + .setAddresses(Collections.emptyList()) + .setAttributes(Attributes.newBuilder() + .set(XdsAttributes.XDS_CONFIG, xdsConfig.getValue()) + .set(XdsAttributes.XDS_CLUSTER_SUBSCRIPT_REGISTRY, xdsDepManager) + .build()) + .setLoadBalancingPolicyConfig(cdsConfig) .build()); - } catch (Exception e) { - assertThat(e).hasMessageThat().contains("Unable to parse"); - return; - } - fail("Expected the invalid config to cause an exception"); + }); + // trigger does not exist timer, so broken config is more obvious + fakeClock.forwardTime(10, TimeUnit.MINUTES); } - private static void assertPicker(SubchannelPicker picker, Status expectedStatus, - @Nullable Subchannel expectedSubchannel) { + private static void assertPickerStatus(SubchannelPicker picker, Status expectedStatus) { PickResult result = picker.pickSubchannel(mock(PickSubchannelArgs.class)); Status actualStatus = result.getStatus(); assertThat(actualStatus.getCode()).isEqualTo(expectedStatus.getCode()); assertThat(actualStatus.getDescription()).isEqualTo(expectedStatus.getDescription()); - if (actualStatus.isOk()) { - assertThat(result.getSubchannel()).isSameInstanceAs(expectedSubchannel); - } - } - - private static void assertDiscoveryMechanism(DiscoveryMechanism instance, String name, - DiscoveryMechanism.Type type, @Nullable String edsServiceName, @Nullable String dnsHostName, - @Nullable ServerInfo lrsServerInfo, @Nullable Long maxConcurrentRequests, - @Nullable UpstreamTlsContext tlsContext, @Nullable OutlierDetection outlierDetection) { - assertThat(instance.cluster).isEqualTo(name); - assertThat(instance.type).isEqualTo(type); - assertThat(instance.edsServiceName).isEqualTo(edsServiceName); - assertThat(instance.dnsHostName).isEqualTo(dnsHostName); - assertThat(instance.lrsServerInfo).isEqualTo(lrsServerInfo); - assertThat(instance.maxConcurrentRequests).isEqualTo(maxConcurrentRequests); - assertThat(instance.tlsContext).isEqualTo(tlsContext); - assertThat(instance.outlierDetection).isEqualTo(outlierDetection); } private final class FakeLoadBalancerProvider extends LoadBalancerProvider { @@ -880,8 +766,9 @@ private final class FakeLoadBalancer extends LoadBalancer { } @Override - public void handleResolvedAddresses(ResolvedAddresses resolvedAddresses) { + public Status acceptResolvedAddresses(ResolvedAddresses resolvedAddresses) { config = resolvedAddresses.getLoadBalancingPolicyConfig(); + return Status.OK; } @Override @@ -895,58 +782,4 @@ public void shutdown() { childBalancers.remove(this); } } - - private static final class FakeXdsClient extends XdsClient { - // watchers needs to support any non-cyclic shaped graphs - private final Map>> watchers = new HashMap<>(); - - @Override - @SuppressWarnings("unchecked") - public void watchXdsResource(XdsResourceType type, - String resourceName, - ResourceWatcher watcher, Executor syncContext) { - assertThat(type.typeName()).isEqualTo("CDS"); - watchers.computeIfAbsent(resourceName, k -> new ArrayList<>()) - .add((ResourceWatcher)watcher); - } - - @Override - public void cancelXdsResourceWatch(XdsResourceType type, - String resourceName, - ResourceWatcher watcher) { - assertThat(type.typeName()).isEqualTo("CDS"); - assertThat(watchers).containsKey(resourceName); - List> watcherList = watchers.get(resourceName); - assertThat(watcherList.remove(watcher)).isTrue(); - if (watcherList.isEmpty()) { - watchers.remove(resourceName); - } - } - - @Override - public BootstrapInfo getBootstrapInfo() { - return BOOTSTRAP_INFO; - } - - private void deliverCdsUpdate(String clusterName, CdsUpdate update) { - if (watchers.containsKey(clusterName)) { - List> resourceWatchers = - ImmutableList.copyOf(watchers.get(clusterName)); - resourceWatchers.forEach(w -> w.onChanged(update)); - } - } - - private void deliverResourceNotExist(String clusterName) { - if (watchers.containsKey(clusterName)) { - ImmutableList.copyOf(watchers.get(clusterName)) - .forEach(w -> w.onResourceDoesNotExist(clusterName)); - } - } - - private void deliverError(Status error) { - watchers.values().stream() - .flatMap(List::stream) - .forEach(w -> w.onError(error)); - } - } -} +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f2fd37d4cfe..3b77d17b370 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -61,19 +61,19 @@ public class SslContextProviderSupplierTest { private void prepareSupplier(boolean createUpstreamTlsContext) { if (createUpstreamTlsContext) { upstreamTlsContext = - buildUpstreamTlsContext("google_cloud_private_spiffe", true); + buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); } mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, SNI); } @Test @@ -159,9 +159,9 @@ public void systemRootCertsWithMtls_callbackExecutedFromProvider() { callUpdateSslContext(); verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class)); + .releaseClientSslContextProvider(any(SslContextProvider.class), eq(SNI)); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); @@ -171,11 +171,11 @@ public void systemRootCertsWithMtls_callbackExecutedFromProvider() { capturedCallback.updateSslContext(mockSslContext); verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback); + supplier.updateSslContext(mockCallback, SNI); verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } @Test @@ -200,7 +200,7 @@ public void systemRootCertsWithRegularTls_callbackExecutedFromSupplier() { runnableArgumentCaptor.getValue().run(); verify(mockCallback, times(1)).updateSslContext(any(SslContext.class)); verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); + .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); } @Test diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 00534507311..491221a6168 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -93,7 +93,8 @@ private CertProviderClientSslContextProvider getSslContextProvider( rootInstanceName, "root-default", alpnProtocols, - staticCertValidationContext); + staticCertValidationContext, + null, false); } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( From 139805ec569e2592737d74c436f0c3357b688ea3 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 10 Sep 2025 09:38:14 +0000 Subject: [PATCH 23/51] Style changes. --- .../CertProviderClientSslContextProviderTest.java | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 07d422a97d0..bfd4bd42211 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -187,11 +187,9 @@ public void testProviderForClient_mtls() throws Exception { } @Test - /** - * Note this route will not really be invoked since {@link SslContextProviderSupplier} will - * shortcircuit creating the certificate provider and directly invoke the callback with the - * SslContext in this case. - */ + // Note: This code flow will not really be invoked since {@link SslContextProviderSupplier} will + // shortcircuit creating the certificate provider and directly invoke the callback with the + // SslContext in this case. public void testProviderForClient_systemRootCerts_regularTls() throws Exception { final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; From 4417fcc5b23de4c37b65ee3e59f97c0f706f8652 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 10 Sep 2025 11:43:16 +0000 Subject: [PATCH 24/51] Fix some mistakes in code. --- .../netty/InternalProtocolNegotiators.java | 9 +++++---- .../io/grpc/netty/ProtocolNegotiators.java | 20 ++++++++++--------- .../grpc/netty/NettyClientTransportTest.java | 3 +-- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 831712dff86..c0e22b7483a 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -41,9 +41,9 @@ private InternalProtocolNegotiators() {} * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, - Optional handshakeCompleteRunnable, - String sni) { + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + String sni) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, executorPool, handshakeCompleteRunnable, null, sni); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @@ -72,7 +72,8 @@ public void close() { * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. */ - public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, String sni) { + public static InternalProtocolNegotiator.ProtocolNegotiator tls( + SslContext sslContext, String sni) { return tls(sslContext, null, Optional.absent(), sni); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 2f63d41dbdc..b0ec651735b 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -90,7 +90,6 @@ import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; - import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** @@ -502,7 +501,8 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final @Nullable String proxyUsername, + final @Nullable String proxyPassword, final ProtocolNegotiator negotiator) { Preconditions.checkNotNull(negotiator, "negotiator"); Preconditions.checkNotNull(proxyAddress, "proxyAddress"); @@ -611,8 +611,9 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, - !Strings.isNullOrEmpty(sni)? sni : grpcHandler.getAuthority(), - this.executor, negotiationLogger, handshakeCompleteRunnable, null, x509ExtendedTrustManager); + !Strings.isNullOrEmpty(sni) ? sni : grpcHandler.getAuthority(), + this.executor, negotiationLogger, handshakeCompleteRunnable, null, + x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); } @@ -642,7 +643,8 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, Executor executor, ChannelLogger negotiationLogger, Optional handshakeCompleteRunnable, - ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, X509TrustManager x509ExtendedTrustManager) { + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); HostPort hostPort = parseAuthority(authority); @@ -754,8 +756,8 @@ static HostPort parseAuthority(String authority) { * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni) { return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, x509ExtendedTrustManager, sni); } @@ -766,12 +768,12 @@ public static ProtocolNegotiator tls(SslContext sslContext, * may happen immediately, even before the TLS Handshake is complete. */ public static ProtocolNegotiator tls(SslContext sslContext, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager) { return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null); } public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, - X509TrustManager x509ExtendedTrustManager) { + X509TrustManager x509ExtendedTrustManager) { return new TlsProtocolNegotiatorClientFactory(sslContext, x509ExtendedTrustManager); } diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 1d1c5990818..0cabbc0428f 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -149,7 +149,6 @@ public class NettyClientTransportTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - private static final String SNI = "sni"; private static final SslContext SSL_CONTEXT = createSslContext(); @Mock @@ -837,7 +836,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .keyManager(clientCert, clientKey) .build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, - Optional.absent(), null, SNI); + Optional.absent(), null, null); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); From 7f48afa7782339f2646b48d8f3127c0a876ce035 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 06:28:12 +0000 Subject: [PATCH 25/51] Remove special-casing for System root certs in SslContextProviderSupplier and handle it in --- .../security/SslContextProviderSupplier.java | 47 +++------- .../CertProviderClientSslContextProvider.java | 14 ++- .../CertProviderSslContextProvider.java | 8 +- .../SslContextProviderSupplierTest.java | 94 +++---------------- ...tProviderClientSslContextProviderTest.java | 10 +- 5 files changed, 49 insertions(+), 124 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index b8113bd4752..5f629273179 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -20,14 +20,12 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; import java.util.Objects; -import javax.net.ssl.SSLException; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -64,36 +62,21 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call } // we want to increment the ref-count so call findOrCreate again... final SslContextProvider toRelease = getSslContextProvider(); - // When using system root certs on client side, SslContext updates via CertificateProvider is - // only required if Mtls is also enabled, i.e. tlsContext has a cert provider instance. - if (tlsContext instanceof UpstreamTlsContext - && !CommonTlsContextUtil.hasCertProviderInstance(tlsContext.getCommonTlsContext()) - && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { - callback.getExecutor().execute(() -> { - try { - callback.updateSslContext(GrpcSslContexts.forClient().build()); - releaseSslContextProvider(toRelease); - } catch (SSLException e) { - callback.onException(e); - } - }); - } else { - toRelease.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); - releaseSslContextProvider(toRelease); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease); - } - }); - } + toRelease.addCallback( + new SslContextProvider.Callback(callback.getExecutor()) { + + @Override + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); + releaseSslContextProvider(toRelease); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); + releaseSslContextProvider(toRelease); + } + }); } catch (final Throwable throwable) { callback.getExecutor().execute(new Runnable() { @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d7c2267c48f..79374ede827 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -28,6 +28,7 @@ import java.security.cert.X509Certificate; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.SSLException; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @@ -48,6 +49,17 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP staticCertValidationContext, upstreamTlsContext, certificateProviderStore); + // Null rootCertInstance implies hasSystemRootCerts because of the check in + // CertProviderClientSslContextProviderFactory. + if (rootCertInstance == null && !isMtls()) { + try { + // Instantiate sslContext so that addCallback will immediately update the callback with + // the SslContext. + sslContext = getSslContextBuilder(staticCertificateValidationContext).build(); + } catch (SSLException | CertStoreException e) { + throw new RuntimeException(e); + } + } } @Override @@ -55,8 +67,6 @@ protected final SslContextBuilder getSslContextBuilder( CertificateValidationContext certificateValidationContextdationContext) throws CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); - // Null rootCertInstance implies hasSystemRootCerts because of the check in - // CertProviderClientSslContextProviderFactory. if (rootCertInstance != null) { if (savedSpiffeTrustMap != null) { sslContextBuilder = sslContextBuilder.trustManager( diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index b88cfd2b032..dea60abc35f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -158,12 +158,12 @@ private void updateSslContextWhenReady() { updateSslContext(); clearKeysAndCerts(); } - } else if (isNormalTlsAndClientSide()) { + } else if (isRegularTlsAndClientSide()) { if (savedTrustedRoots != null || savedSpiffeTrustMap != null) { updateSslContext(); clearKeysAndCerts(); } - } else if (isNormalTlsAndServerSide()) { + } else if (isRegularTlsAndServerSide()) { if (savedKey != null) { updateSslContext(); clearKeysAndCerts(); @@ -182,14 +182,14 @@ protected final boolean isMtls() { return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } - protected final boolean isNormalTlsAndClientSide() { + protected final boolean isRegularTlsAndClientSide() { // We don't do (rootCertInstance != null || isUsingSystemRootCerts) here because of how this // method is used. With the rootCertInstance being null when using system root certs, there // is nothing to update in the SslContext return rootCertInstance != null && certInstance == null; } - protected final boolean isNormalTlsAndServerSide() { + protected final boolean isRegularTlsAndServerSide() { return certInstance != null && rootCertInstance == null; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 999b5e2cec5..f5b462b250d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -17,18 +17,15 @@ package io.grpc.xds.internal.security; import static com.google.common.truth.Truth.assertThat; -import static io.grpc.xds.internal.security.CommonTlsContextTestsUtil.buildUpstreamTlsContext; -import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.any; import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; -import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; @@ -50,33 +47,31 @@ public class SslContextProviderSupplierTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private TlsContextManager mockTlsContextManager; - @Mock private Executor mockExecutor; private SslContextProviderSupplier supplier; private SslContextProvider mockSslContextProvider; private EnvoyServerProtoData.UpstreamTlsContext upstreamTlsContext; private SslContextProvider.Callback mockCallback; - private void prepareSupplier(boolean createUpstreamTlsContext) { - if (createUpstreamTlsContext) { - upstreamTlsContext = - buildUpstreamTlsContext("google_cloud_private_spiffe", true); - } + private void prepareSupplier() { + upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) - .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .when(mockTlsContextManager) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } private void callUpdateSslContext() { mockCallback = mock(SslContextProvider.Callback.class); + Executor mockExecutor = mock(Executor.class); doReturn(mockExecutor).when(mockCallback).getExecutor(); supplier.updateSslContext(mockCallback); } @Test public void get_updateSecret() { - prepareSupplier(true); + prepareSupplier(); callUpdateSslContext(); verify(mockTlsContextManager, times(2)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); @@ -100,12 +95,11 @@ public void get_updateSecret() { @Test public void get_onException() { - prepareSupplier(true); + prepareSupplier(); callUpdateSslContext(); ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)) - .addCallback(callbackCaptor.capture()); + verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); Exception exception = new Exception("test"); @@ -115,71 +109,9 @@ public void get_onException() { .releaseClientSslContextProvider(eq(mockSslContextProvider)); } - @Test - public void systemRootCertsWithMtls_callbackExecutedFromProvider() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( - "gcp_id", - "cert-default", - null, - "root-default", - null, - CertificateValidationContext.newBuilder() - .setSystemRootCerts( - CertificateValidationContext.SystemRootCerts.getDefaultInstance()) - .build()); - prepareSupplier(false); - - callUpdateSslContext(); - - verify(mockTlsContextManager, times(2)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); - verify(mockTlsContextManager, times(0)) - .releaseClientSslContextProvider(any(SslContextProvider.class)); - ArgumentCaptor callbackCaptor = - ArgumentCaptor.forClass(SslContextProvider.Callback.class); - verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); - SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); - assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); - verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); - SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); - supplier.updateSslContext(mockCallback); - verify(mockTlsContextManager, times(3)) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); - } - - @Test - public void systemRootCertsWithRegularTls_callbackExecutedFromSupplier() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( - null, - null, - null, - "root-default", - null, - CertificateValidationContext.newBuilder() - .setSystemRootCerts( - CertificateValidationContext.SystemRootCerts.getDefaultInstance()) - .build()); - supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); - reset(mockTlsContextManager); - - callUpdateSslContext(); - ArgumentCaptor runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class); - verify(mockExecutor).execute(runnableArgumentCaptor.capture()); - runnableArgumentCaptor.getValue().run(); - verify(mockCallback, times(1)).updateSslContext(any(SslContext.class)); - verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider)); - } - @Test public void testClose() { - prepareSupplier(true); + prepareSupplier(); callUpdateSslContext(); supplier.close(); verify(mockTlsContextManager, times(1)) @@ -193,7 +125,7 @@ public void testClose() { @Test public void testClose_nullSslContextProvider() { - prepareSupplier(true); + prepareSupplier(); doThrow(new NullPointerException()).when(mockTlsContextManager) .releaseClientSslContextProvider(null); supplier.close(); @@ -203,4 +135,4 @@ public void testClose_nullSslContextProvider() { verify(mockTlsContextManager, times(1)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } -} +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index bfd4bd42211..3b2ca05231e 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -187,10 +187,7 @@ public void testProviderForClient_mtls() throws Exception { } @Test - // Note: This code flow will not really be invoked since {@link SslContextProviderSupplier} will - // shortcircuit creating the certificate provider and directly invoke the callback with the - // SslContext in this case. - public void testProviderForClient_systemRootCerts_regularTls() throws Exception { + public void testProviderForClient_systemRootCerts_regularTls() { final CertificateProvider.DistributorWatcher[] watcherCaptor = new CertificateProvider.DistributorWatcher[1]; TestCertificateProvider.createAndRegisterProviderProvider( @@ -210,7 +207,10 @@ public void testProviderForClient_systemRootCerts_regularTls() throws Exception assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContext()).isNotNull(); + TestCallback testCallback = + CommonTlsContextTestsUtil.getValueThruCallback(provider); + assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContext()); assertThat(watcherCaptor[0]).isNull(); } From d2b722a9038fabe57e8b9bc4d6a9455f3dc7a67a Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 06:34:31 +0000 Subject: [PATCH 26/51] Formatting changes. --- .../xds/internal/security/CommonTlsContextTestsUtil.java | 3 +++ .../internal/security/SslContextProviderSupplierTest.java | 8 ++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 718695f3b3f..48814dece1d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -229,6 +229,9 @@ private static CommonTlsContext.Builder addCertificateValidationContext( String rootInstanceName, String rootCertName, CertificateValidationContext staticCertValidationContext) { + if (staticCertValidationContext == null && rootInstanceName == null) { + return builder; + } CertificateValidationContext.Builder contextBuilder; if (staticCertValidationContext == null) { contextBuilder = CertificateValidationContext.newBuilder(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f5b462b250d..f476818297d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -54,11 +54,11 @@ public class SslContextProviderSupplierTest { private void prepareSupplier() { upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); mockSslContextProvider = mock(SslContextProvider.class); doReturn(mockSslContextProvider) - .when(mockTlsContextManager) - .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); + .when(mockTlsContextManager) + .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); } @@ -135,4 +135,4 @@ public void testClose_nullSslContextProvider() { verify(mockTlsContextManager, times(1)) .findOrCreateClientSslContextProvider(eq(upstreamTlsContext)); } -} \ No newline at end of file +} From acb8fa5d790d86814d24b5adf22c7718529a4b60 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 07:25:08 +0000 Subject: [PATCH 27/51] Merge with changes to not special case system root certs in SslContextProviderSupplier itself but have it be handled only in the ClientCertificateSslContextProviderSupplier. --- .../security/SslContextProviderSupplier.java | 45 ++++++------------- .../CertProviderClientSslContextProvider.java | 2 + .../SslContextProviderSupplierTest.java | 25 ----------- 3 files changed, 15 insertions(+), 57 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 3f48b7f2cd1..26661eddfe2 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -16,11 +16,8 @@ package io.grpc.xds.internal.security; -import static com.google.common.base.Preconditions.checkNotNull; - import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; -import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; @@ -29,9 +26,10 @@ import java.util.HashSet; import java.util.Objects; -import javax.net.ssl.SSLException; import java.util.Set; +import static com.google.common.base.Preconditions.checkNotNull; + /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} * and communicate it to the consumer i.e. {@link SecurityProtocolNegotiators} @@ -67,39 +65,22 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call } } // we want to increment the ref-count so call findOrCreate again... - final SslContextProvider toRelease = getSslContextProvider(); + final SslContextProvider toRelease = getSslContextProvider(sni); toRelease.addCallback( new SslContextProvider.Callback(callback.getExecutor()) { - final SslContextProvider toRelease = getSslContextProvider(sni); - // When using system root certs on client side, SslContext updates via CertificateProvider is - // only required if Mtls is also enabled, i.e. tlsContext has a cert provider instance. - if (tlsContext instanceof UpstreamTlsContext - && !CommonTlsContextUtil.hasCertProviderInstance(tlsContext.getCommonTlsContext()) - && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { - callback.getExecutor().execute(() -> { - try { - callback.updateSslContext(GrpcSslContexts.forClient().build()); + + @Override + public void updateSslContext(SslContext sslContext) { + callback.updateSslContext(sslContext); + releaseSslContextProvider(toRelease, sni); + } + + @Override + public void onException(Throwable throwable) { + callback.onException(throwable); releaseSslContextProvider(toRelease, sni); - } catch (SSLException e) { - callback.onException(e); } }); - } else { - toRelease.addCallback( - new SslContextProvider.Callback(callback.getExecutor()) { - - @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); - releaseSslContextProvider(toRelease, sni); - } - - @Override - public void onException(Throwable throwable) { - callback.onException(throwable); - releaseSslContextProvider(toRelease, sni); - } - }); } catch (final Throwable throwable) { callback.getExecutor().execute(new Runnable() { @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index be5c5d1c8c2..48ce4f6d541 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -33,6 +33,8 @@ /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { + private final String sniForSanMatching; + CertProviderClientSslContextProvider( Node node, @Nullable Map certProviders, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 3b77d17b370..f79d1299060 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -178,31 +178,6 @@ public void systemRootCertsWithMtls_callbackExecutedFromProvider() { .findOrCreateClientSslContextProvider(eq(upstreamTlsContext), eq(SNI)); } - @Test - public void systemRootCertsWithRegularTls_callbackExecutedFromSupplier() { - upstreamTlsContext = - CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( - null, - null, - null, - "root-default", - null, - CertificateValidationContext.newBuilder() - .setSystemRootCerts( - CertificateValidationContext.SystemRootCerts.getDefaultInstance()) - .build()); - supplier = new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager); - reset(mockTlsContextManager); - - callUpdateSslContext(); - ArgumentCaptor runnableArgumentCaptor = ArgumentCaptor.forClass(Runnable.class); - verify(mockExecutor).execute(runnableArgumentCaptor.capture()); - runnableArgumentCaptor.getValue().run(); - verify(mockCallback, times(1)).updateSslContext(any(SslContext.class)); - verify(mockTlsContextManager, times(1)) - .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); - } - @Test public void testClose() { prepareSupplier(true); From 13200faee8bd8446fe4dd390e9f8395d92afa0c4 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 08:13:42 +0000 Subject: [PATCH 28/51] nit --- .../xds/internal/security/SslContextProviderSupplierTest.java | 1 - 1 file changed, 1 deletion(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index f79d1299060..bac85686688 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -24,7 +24,6 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; -import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; From e95725d9334f23c98882fee954048ebe8427b9d7 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 12:00:50 +0000 Subject: [PATCH 29/51] Trust manager handling for system root certs. --- .../CertProviderClientSslContextProvider.java | 40 +++++++++++++-- .../grpc/xds/XdsSecurityClientServerTest.java | 51 +++++++++++++++++-- 2 files changed, 83 insertions(+), 8 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index 79374ede827..7e6de5f1dd0 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -24,11 +24,22 @@ import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; + +import java.io.IOException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; import java.security.cert.CertStoreException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; import java.security.cert.X509Certificate; -import java.util.Map; +import java.util.*; +import java.util.stream.Collectors; import javax.annotation.Nullable; import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @@ -56,7 +67,7 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP // Instantiate sslContext so that addCallback will immediately update the callback with // the SslContext. sslContext = getSslContextBuilder(staticCertificateValidationContext).build(); - } catch (SSLException | CertStoreException e) { + } catch (CertStoreException | CertificateException | IOException e) { throw new RuntimeException(e); } } @@ -65,7 +76,7 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP @Override protected final SslContextBuilder getSslContextBuilder( CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException { + throws CertificateException, IOException, CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); if (rootCertInstance != null) { if (savedSpiffeTrustMap != null) { @@ -79,10 +90,33 @@ protected final SslContextBuilder getSslContextBuilder( savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContextdationContext)); } + } else { + try { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + getX509CertificatesFromSystemTrustStore(), + certificateValidationContextdationContext)); + } catch (KeyStoreException | NoSuchAlgorithmException e) { + throw new CertStoreException(e); + } } if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } return sslContextBuilder; } + + private X509Certificate[] getX509CertificatesFromSystemTrustStore() throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + + List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); + List rootCerts = trustManagers.stream() + .filter(X509TrustManager.class::isInstance) + .map(X509TrustManager.class::cast) + .map(trustManager -> Arrays.asList(trustManager.getAcceptedIssuers())) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + return rootCerts.toArray(new X509Certificate[rootCerts.size()]); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 23068d665bf..08557a508e2 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -37,6 +37,7 @@ import com.google.common.util.concurrent.SettableFuture; import io.envoyproxy.envoy.config.core.v3.SocketAddress.Protocol; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; import io.grpc.Grpc; @@ -117,6 +118,10 @@ @RunWith(Parameterized.class) public class XdsSecurityClientServerTest { + // TODO: Change this is a specific domain after + // https://github.com/grpc/grpc-java/issues/12326 is fixed + private static final String SAN_TO_MATCH = "*.test.google.fr"; + @Parameter public Boolean enableSpiffe; private Boolean originalEnableSpiffe; @@ -217,7 +222,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_PEM_FILE, true, SAN_TO_MATCH); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -244,7 +249,7 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false); + CLIENT_PEM_FILE, false, SAN_TO_MATCH); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -255,6 +260,39 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except } } + /** + * Use system root ca cert for TLS channel - no mTLS. + * Subj Alt Names to match are specified in the validaton context. + */ + @Test + public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, "server1.test.google.in"); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + fail("Expected handshake failure exception"); + } catch (StatusRuntimeException e) { + assertThat(e.getCause()).isInstanceOf(SSLHandshakeException.class); + assertThat(e.getCause().getCause()).isInstanceOf(CertificateException.class); + assertThat(e.getCause().getCause().getMessage()).isEqualTo( + "Peer certificate SAN check failed"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + /** * Use system root ca cert for TLS channel - mTLS. * Uses common_tls_context.combined_validation_context in upstream_tls_context. @@ -266,12 +304,12 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); DownstreamTlsContext downstreamTlsContext = setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, - null, false, false); + null, false, true); buildServerWithTlsContext(downstreamTlsContext); UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true); + CLIENT_PEM_FILE, true, SAN_TO_MATCH); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -552,7 +590,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, - boolean useCombinedValidationContext) { + boolean useCombinedValidationContext, String sanToMatch) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); @@ -563,6 +601,9 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys CertificateValidationContext.newBuilder() .setSystemRootCerts( CertificateValidationContext.SystemRootCerts.newBuilder().build()) + .addMatchSubjectAltNames( + StringMatcher.newBuilder() + .setExact(sanToMatch)) .build()); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( From 180f373f84299005285131dddab557f469507010 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 12:12:45 +0000 Subject: [PATCH 30/51] Fix style --- .../CertProviderClientSslContextProvider.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index 7e6de5f1dd0..b080417bd53 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -24,19 +24,19 @@ import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; - import java.io.IOException; import java.security.KeyStore; import java.security.KeyStoreException; import java.security.NoSuchAlgorithmException; import java.security.cert.CertStoreException; -import java.security.cert.Certificate; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; -import java.util.*; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import javax.annotation.Nullable; -import javax.net.ssl.SSLException; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; @@ -106,8 +106,10 @@ protected final SslContextBuilder getSslContextBuilder( return sslContextBuilder; } - private X509Certificate[] getX509CertificatesFromSystemTrustStore() throws KeyStoreException, CertificateException, IOException, NoSuchAlgorithmException { - TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + private X509Certificate[] getX509CertificatesFromSystemTrustStore() + throws KeyStoreException, NoSuchAlgorithmException { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init((KeyStore) null); List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); From 381beb2c53de5702d69241a5396748cd4c32da6e Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 13:09:56 +0000 Subject: [PATCH 31/51] Fixes. --- .../CertProviderClientSslContextProvider.java | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index b080417bd53..1acdcad65e9 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -22,6 +22,7 @@ import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; import java.io.IOException; @@ -60,9 +61,9 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP staticCertValidationContext, upstreamTlsContext, certificateProviderStore); - // Null rootCertInstance implies hasSystemRootCerts because of the check in - // CertProviderClientSslContextProviderFactory. - if (rootCertInstance == null && !isMtls()) { + if (rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()) + && !isMtls()) { try { // Instantiate sslContext so that addCallback will immediately update the callback with // the SslContext. @@ -75,7 +76,7 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP @Override protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) + CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); if (rootCertInstance != null) { @@ -83,19 +84,19 @@ protected final SslContextBuilder getSslContextBuilder( sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( savedSpiffeTrustMap, - certificateValidationContextdationContext)); + certificateValidationContext)); } else { sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); + certificateValidationContext)); } } else { try { sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( getX509CertificatesFromSystemTrustStore(), - certificateValidationContextdationContext)); + certificateValidationContext)); } catch (KeyStoreException | NoSuchAlgorithmException e) { throw new CertStoreException(e); } From 18f5d5ab0dd7a9abe8f51415818faec8446b098c Mon Sep 17 00:00:00 2001 From: Kannan J Date: Thu, 11 Sep 2025 13:50:52 +0000 Subject: [PATCH 32/51] Fix unit tests to cover both mtls and non-mtls for system root certs. --- .../grpc/xds/XdsSecurityClientServerTest.java | 41 +++++++++++++++---- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 08557a508e2..6b6c57ba61b 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -211,7 +211,8 @@ public void tlsClientServer_noClientAuthentication() throws Exception { * Uses common_tls_context.combined_validation_context in upstream_tls_context. */ @Test - public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() throws Exception { + public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationContext() + throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -222,7 +223,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -238,7 +239,7 @@ public void tlsClientServer_useSystemRootCerts_useCombinedValidationContext() th * Uses common_tls_context.validation_context in upstream_tls_context. */ @Test - public void tlsClientServer_useSystemRootCerts_validationContext() throws Exception { + public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa().toAbsolutePath(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -249,7 +250,7 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false, SAN_TO_MATCH); + CLIENT_PEM_FILE, false, SAN_TO_MATCH, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -260,6 +261,29 @@ public void tlsClientServer_useSystemRootCerts_validationContext() throws Except } } + @Test + public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, true); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, SAN_TO_MATCH, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + assertThat(unaryRpc(/* requestMessage= */ "buddy", blockingStub)).isEqualTo("Hello buddy"); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + /** * Use system root ca cert for TLS channel - no mTLS. * Subj Alt Names to match are specified in the validaton context. @@ -276,7 +300,7 @@ public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() thro UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, "server1.test.google.in"); + CLIENT_PEM_FILE, true, "server1.test.google.in", false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -309,7 +333,7 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -590,13 +614,14 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, - boolean useCombinedValidationContext, String sanToMatch) { + boolean useCombinedValidationContext, String sanToMatch, boolean isMtls) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); if (useCombinedValidationContext) { return CommonTlsContextTestsUtil.buildUpstreamTlsContextForCertProviderInstance( - "google_cloud_private_spiffe-client", "ROOT", null, + isMtls ? "google_cloud_private_spiffe-client" : null, + isMtls ? "ROOT" : null, null, null, null, CertificateValidationContext.newBuilder() .setSystemRootCerts( From e18d6cdddc9cfca6f5280a38b05a324a9714f105 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Fri, 12 Sep 2025 12:07:56 +0000 Subject: [PATCH 33/51] Suppress warning. --- xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java | 1 + 1 file changed, 1 insertion(+) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 6b6c57ba61b..9d2f552db1c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -611,6 +611,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli .buildUpstreamTlsContext("google_cloud_private_spiffe-client", hasIdentityCert); } + @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, From 2ecbdb95e80ba296ff05f791839c6bbeb2921042 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Fri, 12 Sep 2025 13:05:47 +0000 Subject: [PATCH 34/51] Save changes. --- .../io/grpc/xds/GcpAuthenticationFilter.java | 2 - .../CertProviderClientSslContextProvider.java | 27 ++++++----- .../grpc/xds/XdsSecurityClientServerTest.java | 48 +++++++++++++++---- .../security/CommonTlsContextTestsUtil.java | 11 +++-- .../SecurityProtocolNegotiatorsTest.java | 4 +- 5 files changed, 63 insertions(+), 29 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index 6eba753f716..edd6cd6a190 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -198,8 +198,6 @@ public ClientCall interceptCall( } else { callOptions = callOptions.withCallCredentials(newCallCredentials); } - logger.log(XdsLogLevel.INFO, "Time to expiry of the auth token=" + callOptions.getDeadline().timeRemaining( - TimeUnit.SECONDS)); return next.newCall(method, callOptions); } }; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index fc27630ac3a..8345352ecf4 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -85,18 +85,23 @@ protected final SslContextBuilder getSslContextBuilder( if (rootCertInstance != null) { if (savedSpiffeTrustMap != null) { sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedSpiffeTrustMap, - certificateValidationContext, sniForSanMatching)); + new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, sniForSanMatching)); } else { - try { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - getX509CertificatesFromSystemTrustStore(), - certificateValidationContext)); - } catch (KeyStoreException | NoSuchAlgorithmException e) { - throw new CertStoreException(e); - } + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, sniForSanMatching)); + } + } else { + try { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + getX509CertificatesFromSystemTrustStore(), + certificateValidationContext, sniForSanMatching)); + } catch (KeyStoreException | NoSuchAlgorithmException e) { + throw new CertStoreException(e); } } if (isMtls()) { diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 9657794c55d..df760c9fc67 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -118,7 +118,7 @@ @RunWith(Parameterized.class) public class XdsSecurityClientServerTest { - // TODO: Change this is a specific domain after + // TODO: Change this to a specific domain after // https://github.com/grpc/grpc-java/issues/12326 is fixed private static final String SAN_TO_MATCH = "*.test.google.fr"; @@ -223,7 +223,7 @@ public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationConte UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, false); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -250,7 +250,7 @@ public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false, SAN_TO_MATCH, false); + CLIENT_PEM_FILE, false, SAN_TO_MATCH, false, null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -273,7 +273,7 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, true); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, true, null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -286,10 +286,11 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { /** * Use system root ca cert for TLS channel - no mTLS. - * Subj Alt Names to match are specified in the validaton context. + * Subj Alt Names to match are specified in the validation context. */ @Test - public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() throws Exception { + public void tlsClientServer_useSystemRootCerts_noAutoSniValidation_failureToMatchSubjAltNames() + throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -300,7 +301,7 @@ public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() thro UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, "server1.test.google.in", false); + CLIENT_PEM_FILE, true, "server1.test.google.in", false, null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -317,6 +318,33 @@ public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() thro } } + @Test + public void tlsClientServer_useSystemRootCerts_autoSniValidation() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // won't be used + "server1.test.google.in", + false, SAN_TO_MATCH); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + /** * Use system root ca cert for TLS channel - mTLS. * Uses common_tls_context.combined_validation_context in upstream_tls_context. @@ -333,7 +361,7 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, false); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -615,7 +643,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, - boolean useCombinedValidationContext, String sanToMatch, boolean isMtls) { + boolean useCombinedValidationContext, String sanToMatch, boolean isMtls, String sni) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); @@ -630,7 +658,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys .addMatchSubjectAltNames( StringMatcher.newBuilder() .setExact(sanToMatch)) - .build()); + .build(), sni, false); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 21d5d21f457..b71e4b2a866 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -149,9 +149,12 @@ public static String getTempFileNameForResourcesFile(String resFile) throws IOEx * Helper method to build UpstreamTlsContext for above tests. Called from other classes as well. */ static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( - CommonTlsContext commonTlsContext, String sni, boolean autoHostSni) { + CommonTlsContext commonTlsContext, String sni, boolean autoHostSni, boolean autoSniSanValidation) { UpstreamTlsContext.Builder upstreamTlsContext = - UpstreamTlsContext.newBuilder().setCommonTlsContext(commonTlsContext).setAutoHostSni(autoHostSni); + UpstreamTlsContext.newBuilder() + .setCommonTlsContext(commonTlsContext) + .setAutoHostSni(autoHostSni) + .setAutoSniSanValidation(autoSniSanValidation); if (sni != null) { upstreamTlsContext.setSni(sni); } @@ -290,7 +293,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootCertName, alpnProtocols, staticCertValidationContext), - sni, autoHostSni); + sni, autoHostSni, false); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ @@ -309,7 +312,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootInstanceName, rootCertName, alpnProtocols, - staticCertValidationContext), null, false); + staticCertValidationContext), null, false, false); } /** Helper method to build DownstreamTlsContext for CertProvider tests. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 43d02a9feaa..a27871b9915 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -124,7 +124,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_noFallback_expectExceptio @Test public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, false); + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, false, false); ClientSecurityProtocolNegotiator pn = new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); @@ -146,7 +146,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() @Test public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPassedToClientSecurityHandler() { UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true, false); ClientSecurityProtocolNegotiator pn = new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); From 3845e16e9d31d2d87e3eaab53de2760ee7afdd26 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Fri, 12 Sep 2025 13:13:24 +0000 Subject: [PATCH 35/51] Use non wildcard SAN in the SAN matchers in validation context. --- .../test/java/io/grpc/xds/XdsSecurityClientServerTest.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 9d2f552db1c..03f1651f2c7 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -118,9 +118,7 @@ @RunWith(Parameterized.class) public class XdsSecurityClientServerTest { - // TODO: Change this is a specific domain after - // https://github.com/grpc/grpc-java/issues/12326 is fixed - private static final String SAN_TO_MATCH = "*.test.google.fr"; + private static final String SAN_TO_MATCH = "waterzooi.test.google.be"; @Parameter public Boolean enableSpiffe; From 0ca4f8b02790bbdbaeb4179dfc2c3c07c78619ae Mon Sep 17 00:00:00 2001 From: Kannan J Date: Mon, 15 Sep 2025 04:26:36 +0000 Subject: [PATCH 36/51] Save changes. --- .../io/grpc/xds/EnvoyServerProtoData.java | 10 ++++++-- .../CertProviderClientSslContextProvider.java | 2 +- .../grpc/xds/XdsSecurityClientServerTest.java | 25 ++++++++++++------- .../ClientSslContextProviderFactoryTest.java | 8 +++--- .../security/CommonTlsContextTestsUtil.java | 9 ++++--- ...tProviderClientSslContextProviderTest.java | 2 +- 6 files changed, 36 insertions(+), 20 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index b3d161b6bbb..3a30cf0aad2 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -96,7 +96,8 @@ public UpstreamTlsContext(io.envoyproxy.envoy.extensions.transport_sockets.tls.v public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { - return new UpstreamTlsContext(upstreamTlsContext); + UpstreamTlsContext o = new UpstreamTlsContext(upstreamTlsContext); + return o; } public String getSni() { @@ -113,7 +114,12 @@ public boolean getAutoSniSanValidation() { @Override public String toString() { - return "UpstreamTlsContext{" + "commonTlsContext=" + commonTlsContext + '}'; + return "UpstreamTlsContext{" + + "commonTlsContext=" + commonTlsContext + + "sni=" + sni + + "\nauto_host_sni=" + auto_host_sni + + "\nauto_sni_san_validation=" + auto_sni_san_validation + + "}"; } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index 8345352ecf4..49e3048f3c9 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -63,6 +63,7 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP staticCertValidationContext, upstreamTlsContext, certificateProviderStore); + this.sniForSanMatching = upstreamTlsContext.getAutoSniSanValidation()? sniForSanMatching : null; if (rootCertInstance == null && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()) && !isMtls()) { @@ -74,7 +75,6 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP throw new RuntimeException(e); } } - this.sniForSanMatching = upstreamTlsContext.getAutoSniSanValidation()? sniForSanMatching : null; } @Override diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 82e3feb3b90..03e6d4189b6 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -221,7 +221,7 @@ public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationConte UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -248,7 +248,7 @@ public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false, SAN_TO_MATCH, false, null); + CLIENT_PEM_FILE, false, SAN_TO_MATCH, false, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -271,7 +271,7 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, true, null); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, true, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -299,7 +299,7 @@ public void tlsClientServer_useSystemRootCerts_noAutoSniValidation_failureToMatc UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, "server1.test.google.in", false, null); + CLIENT_PEM_FILE, true, "server1.test.google.in", false, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -330,9 +330,12 @@ public void tlsClientServer_useSystemRootCerts_autoSniValidation() UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, CLIENT_PEM_FILE, true, - // won't be used + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation "server1.test.google.in", - false, SAN_TO_MATCH); + false, + // SNI in UpstreamTlsContext + SAN_TO_MATCH, + true); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -359,7 +362,7 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null); + CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -641,7 +644,11 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContext(String cli private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts( String clientKeyFile, String clientPemFile, - boolean useCombinedValidationContext, String sanToMatch, boolean isMtls, String sni) { + boolean useCombinedValidationContext, + String sanToMatch, + boolean isMtls, + String sniInUpstreamTlsContext, + boolean autoSniSanValidation) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); @@ -656,7 +663,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys .addMatchSubjectAltNames( StringMatcher.newBuilder() .setExact(sanToMatch)) - .build(), sni, false); + .build(), sniInUpstreamTlsContext, false, autoSniSanValidation); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index aef1e79061c..1adfed5cb38 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -79,7 +79,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, null, false); + /* staticCertValidationContext= */ null, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -138,7 +138,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() "gcp_id", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, null, false); + /* staticCertValidationContext= */ null, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -172,7 +172,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() "gcp_id", "root-default", /* alpnProtocols= */ null, - staticCertValidationContext, null, false); + staticCertValidationContext, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = @@ -202,7 +202,7 @@ public void createCertProviderClientSslContextProvider_2providers() "file_provider", "root-default", /* alpnProtocols= */ null, - /* staticCertValidationContext= */ null, null, false); + /* staticCertValidationContext= */ null, null, false, false); Bootstrapper.BootstrapInfo bootstrapInfo = CommonBootstrapperTestUtils.getTestBootstrapInfo(); clientSslContextProviderFactory = diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index b71e4b2a866..539ede1d60f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -173,7 +173,7 @@ public static EnvoyServerProtoData.UpstreamTlsContext buildUpstreamTlsContext( null, null, sni, - autoHostSni); + autoHostSni, false); } /** Gets a cert from contents of a resource. */ @@ -284,7 +284,10 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( @Nullable String rootInstanceName, @Nullable String rootCertName, Iterable alpnProtocols, - CertificateValidationContext staticCertValidationContext, String sni, boolean autoHostSni) { + CertificateValidationContext staticCertValidationContext, + String sni, + boolean autoHostSni, + boolean autoSniSanValidation) { return buildUpstreamTlsContext( buildCommonTlsContextForCertProviderInstance( certInstanceName, @@ -293,7 +296,7 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( rootCertName, alpnProtocols, staticCertValidationContext), - sni, autoHostSni, false); + sni, autoHostSni, autoSniSanValidation); } /** Helper method to build UpstreamTlsContext for CertProvider tests. */ diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 8bbfdfe2098..dc1ab9aae5f 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -94,7 +94,7 @@ private CertProviderClientSslContextProvider getSslContextProvider( "root-default", alpnProtocols, staticCertValidationContext, - null, false); + null, false, false); } return (CertProviderClientSslContextProvider) certProviderClientSslContextProviderFactory.getProvider( From 92f3182217f00a10fcc6d03595991c518e0cbea7 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 16 Sep 2025 04:56:15 +0000 Subject: [PATCH 37/51] Save changes. --- .../grpc/xds/XdsSecurityClientServerTest.java | 118 ++++++++++++++---- 1 file changed, 95 insertions(+), 23 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 03e6d4189b6..bc437cf3812 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -118,7 +118,7 @@ @RunWith(Parameterized.class) public class XdsSecurityClientServerTest { - private static final String SAN_TO_MATCH = "waterzooi.test.google.be"; + private static final String SNI_IN_UTC = "waterzooi.test.google.be"; @Parameter public Boolean enableSpiffe; @@ -221,7 +221,7 @@ public void tlsClientServer_useSystemRootCerts_noMtls_useCombinedValidationConte UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null, false); + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -248,7 +248,7 @@ public void tlsClientServer_useSystemRootCerts_noMtls_validationContext() throws UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, false, SAN_TO_MATCH, false, null, false); + CLIENT_PEM_FILE, false, SNI_IN_UTC, false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -271,7 +271,7 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, true, null, false); + CLIENT_PEM_FILE, true, SNI_IN_UTC, true, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -287,7 +287,7 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { * Subj Alt Names to match are specified in the validation context. */ @Test - public void tlsClientServer_useSystemRootCerts_noAutoSniValidation_failureToMatchSubjAltNames() + public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { @@ -299,7 +299,7 @@ public void tlsClientServer_useSystemRootCerts_noAutoSniValidation_failureToMatc UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, "server1.test.google.in", false, null, false); + CLIENT_PEM_FILE, true, "server1.test.google.in", false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -317,7 +317,7 @@ public void tlsClientServer_useSystemRootCerts_noAutoSniValidation_failureToMatc } @Test - public void tlsClientServer_useSystemRootCerts_autoSniValidation() + public void tlsClientServer_autoSniValidation_sniInUTC() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { @@ -333,9 +333,69 @@ public void tlsClientServer_useSystemRootCerts_autoSniValidation() // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation "server1.test.google.in", false, - // SNI in UpstreamTlsContext - SAN_TO_MATCH, - true); + SNI_IN_UTC, + false, true); + + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_sni_san_validation_from_hostname() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // SAN matcher in CommonValidationContext. Will be overridden by autoSniSanValidation + "server1.test.google.in", + false, + "", + true, true); + + // TODO: Change this to foo.test.gooogle.fr that needs wildcard matching after + // https://github.com/grpc/grpc-java/pull/12345 is done + SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = + getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY, + "waterzooi.test.google.be"); + unaryRpc(/* requestMessage= */ "buddy", blockingStub); + } finally { + Files.deleteIfExists(trustStoreFilePath); + clearTrustStoreSystemProperties(); + } + } + + @Test + public void tlsClientServer_autoSniValidation_noSNIApplicable_usesMatcherFromCmnVdnCtx() + throws Exception { + Path trustStoreFilePath = getCacertFilePathForTestCa(); + try { + setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); + DownstreamTlsContext downstreamTlsContext = + setBootstrapInfoAndBuildDownstreamTlsContext(SERVER_1_PEM_FILE, null, null, null, null, + null, false, false); + buildServerWithTlsContext(downstreamTlsContext); + + UpstreamTlsContext upstreamTlsContext = + setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, + CLIENT_PEM_FILE, true, + // This is what will get used for the SAN validation since no SNI was used + "waterzooi.test.google.be", + false, + "", + false, true); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -362,7 +422,7 @@ public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Except UpstreamTlsContext upstreamTlsContext = setBootstrapInfoAndBuildUpstreamTlsContextForUsingSystemRootCerts(CLIENT_KEY_FILE, - CLIENT_PEM_FILE, true, SAN_TO_MATCH, false, null, false); + CLIENT_PEM_FILE, true, SNI_IN_UTC, false, null, false, false); SimpleServiceGrpc.SimpleServiceBlockingStub blockingStub = getBlockingStub(upstreamTlsContext, /* overrideAuthority= */ OVERRIDE_AUTHORITY); @@ -648,7 +708,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys String sanToMatch, boolean isMtls, String sniInUpstreamTlsContext, - boolean autoSniSanValidation) { + boolean autoHostSni, boolean autoSniSanValidation) { bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", clientKeyFile, clientPemFile, CA_PEM_FILE, null, null, null, null, null); @@ -663,7 +723,7 @@ private UpstreamTlsContext setBootstrapInfoAndBuildUpstreamTlsContextForUsingSys .addMatchSubjectAltNames( StringMatcher.newBuilder() .setExact(sanToMatch)) - .build(), sniInUpstreamTlsContext, false, autoSniSanValidation); + .build(), sniInUpstreamTlsContext, autoHostSni, autoSniSanValidation); } return CommonTlsContextTestsUtil.buildNewUpstreamTlsContextForCertProviderInstance( "google_cloud_private_spiffe-client", "ROOT", null, @@ -748,8 +808,18 @@ static EnvoyServerProtoData.Listener buildListener( } private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( - final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) - throws URISyntaxException { + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority) { + return getBlockingStub(upstreamTlsContext, overrideAuthority, overrideAuthority); + } + + // Two separate parameters for overrideAuthority and addrAttribute is for the SAN SNI validation test + // tlsClientServer_useSystemRootCerts_sni_san_validation_from_hostname that uses hostname passed for SNI. + // foo.test.google.fr is used for virtual host matching via authority but it can't be used + // for SNI in this testcase because foo.test.google.fr needs wildcard matching to match against *.test.google.fr + // in the certificate SNI, which isn't implemented yet (https://github.com/grpc/grpc-java/pull/12345 implements it) + // so use an exact match SAN such as waterzooi.test.google.be for SNI for this testcase. + private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority, String addrAttribute) { ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder( "sectest://localhost:" + port, @@ -761,14 +831,16 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( InetSocketAddress socketAddress = new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); tlsContextManagerForClient = new TlsContextManagerImpl(bootstrapInfoForClient); - sslContextAttributes = - (upstreamTlsContext != null) - ? Attributes.newBuilder() - .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier( - upstreamTlsContext, tlsContextManagerForClient)) - .build() - : Attributes.EMPTY; + Attributes.Builder sslContextAttributesBuilder = (upstreamTlsContext != null) + ? Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier( + upstreamTlsContext, tlsContextManagerForClient)) + : Attributes.newBuilder(); + if (addrAttribute != null) { + sslContextAttributesBuilder.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, addrAttribute); + } + sslContextAttributes = sslContextAttributesBuilder.build(); fakeNameResolverFactory.setServers( ImmutableList.of(new EquivalentAddressGroup(socketAddress, sslContextAttributes))); return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); From 2f5ba5d3c0564d8b1dc4de217e121624700c3dc5 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 17 Sep 2025 05:13:12 +0000 Subject: [PATCH 38/51] Save changes. --- .../netty/InternalProtocolNegotiators.java | 14 +- .../io/grpc/netty/NettyChannelBuilder.java | 2 +- .../io/grpc/netty/ProtocolNegotiators.java | 16 +- .../grpc/netty/NettyClientTransportTest.java | 2 +- .../grpc/netty/ProtocolNegotiatorsTest.java | 2 +- .../S2AProtocolNegotiatorFactory.java | 3 +- .../security/SecurityProtocolNegotiators.java | 17 +- .../security/trust/XdsX509TrustManager.java | 5 +- .../grpc/xds/XdsSecurityClientServerTest.java | 6 +- .../SecurityProtocolNegotiatorsTest.java | 280 ++++++++++-------- .../trust/XdsX509TrustManagerTest.java | 2 +- 11 files changed, 188 insertions(+), 161 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index c0e22b7483a..ed9af8e237f 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -38,14 +38,16 @@ private InternalProtocolNegotiators() {} * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks + * @param isXdsTarget */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, - Optional handshakeCompleteRunnable, - String sni) { + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + String sni, boolean isXdsTarget) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, - executorPool, handshakeCompleteRunnable, null, sni); + executorPool, handshakeCompleteRunnable, null, sni, isXdsTarget); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -73,8 +75,8 @@ public void close() { * may happen immediately, even before the TLS Handshake is complete. */ public static InternalProtocolNegotiator.ProtocolNegotiator tls( - SslContext sslContext, String sni) { - return tls(sslContext, null, Optional.absent(), sni); + SslContext sslContext, String sni, boolean isXdsTarget) { + return tls(sslContext, null, Optional.absent(), sni, isXdsTarget); } /** diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index 2db5ab20a91..b659dc7d9a9 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -652,7 +652,7 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null); + return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null, false); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index b0ec651735b..06d810fcbda 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -21,7 +21,6 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; import io.grpc.CallCredentials; @@ -583,7 +582,7 @@ public ClientTlsProtocolNegotiator(SslContext sslContext, ObjectPool executorPool, Optional handshakeCompleteRunnable, X509TrustManager x509ExtendedTrustManager, - String sni) { + String sni, boolean isXdsTarget) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { @@ -592,6 +591,7 @@ public ClientTlsProtocolNegotiator(SslContext sslContext, this.handshakeCompleteRunnable = handshakeCompleteRunnable; this.x509ExtendedTrustManager = x509ExtendedTrustManager; this.sni = sni; + this.isXdsTarget = isXdsTarget; } private final SslContext sslContext; @@ -599,6 +599,7 @@ public ClientTlsProtocolNegotiator(SslContext sslContext, private final Optional handshakeCompleteRunnable; private final X509TrustManager x509ExtendedTrustManager; private final String sni; + private final boolean isXdsTarget; private Executor executor; @Override @@ -611,7 +612,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { ChannelHandler gnh = new GrpcNegotiationHandler(grpcHandler); ChannelLogger negotiationLogger = grpcHandler.getNegotiationLogger(); ChannelHandler cth = new ClientTlsHandler(gnh, sslContext, - !Strings.isNullOrEmpty(sni) ? sni : grpcHandler.getAuthority(), + isXdsTarget ? sni : grpcHandler.getAuthority(), this.executor, negotiationLogger, handshakeCompleteRunnable, null, x509ExtendedTrustManager); return new WaitUntilActiveHandler(cth, negotiationLogger); @@ -753,13 +754,14 @@ static HostPort parseAuthority(String authority) { * Returns a {@link ProtocolNegotiator} that ensures the pipeline is set up so that TLS will * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. + * * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, - x509ExtendedTrustManager, sni); + x509ExtendedTrustManager, sni, isXdsTarget); } /** @@ -769,7 +771,7 @@ public static ProtocolNegotiator tls(SslContext sslContext, */ public static ProtocolNegotiator tls(SslContext sslContext, X509TrustManager x509ExtendedTrustManager) { - return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null); + return tls(sslContext, null, Optional.absent(), x509ExtendedTrustManager, null, false); } public static ProtocolNegotiator.ClientFactory tlsClientFactory(SslContext sslContext, diff --git a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java index 0cabbc0428f..d983d81293b 100644 --- a/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java +++ b/netty/src/test/java/io/grpc/netty/NettyClientTransportTest.java @@ -836,7 +836,7 @@ public void tlsNegotiationServerExecutorShouldSucceed() throws Exception { .keyManager(clientCert, clientKey) .build(); ProtocolNegotiator negotiator = ProtocolNegotiators.tls(clientContext, clientExecutorPool, - Optional.absent(), null, null); + Optional.absent(), null, null, false); // after starting the client, the Executor in the client pool should be used assertEquals(true, clientExecutorPool.isInUse()); final NettyClientTransport transport = newTransport(negotiator); diff --git a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java index b62b3e57a7e..1e1fa07c228 100644 --- a/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java +++ b/netty/src/test/java/io/grpc/netty/ProtocolNegotiatorsTest.java @@ -1271,7 +1271,7 @@ public void clientTlsHandler_firesNegotiation() throws Exception { } FakeGrpcHttp2ConnectionHandler gh = FakeGrpcHttp2ConnectionHandler.newHandler(); ClientTlsProtocolNegotiator pn = new ClientTlsProtocolNegotiator(clientSslContext, - null, Optional.absent(), null, null); + null, Optional.absent(), null, null, false); WriteBufferingAndExceptionHandler clientWbaeh = new WriteBufferingAndExceptionHandler(pn.newHandler(gh)); diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index 81aca657e40..cbcad749109 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -38,7 +38,6 @@ import io.grpc.netty.InternalProtocolNegotiator.ProtocolNegotiator; import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.InternalProtocolNegotiators.ProtocolNegotiationHandler; -import io.grpc.s2a.internal.handshaker.S2AIdentity; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -260,7 +259,7 @@ public void run() { s2aStub.close(); } }), - null) + null, false) .newHandler(grpcHandler); // Delegate the rest of the handshake to the TLS handler. and remove the diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index f9cc329e541..75692716b65 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -54,9 +54,6 @@ @VisibleForTesting public final class SecurityProtocolNegotiators { - static boolean useChannelAuthorityIfNoSniApplicable = - GrpcUtil.getFlag("GRPC_USE_CHANNEL_AUTHORITY_IF_NO_SNI_APPLICABLE", false); - /** Name associated with individual address, if available (e.g., DNS name). */ @EquivalentAddressGroup.Attr public static final Attributes.Key ATTR_ADDRESS_NAME = @@ -196,6 +193,8 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @VisibleForTesting static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { + static boolean isXdsSniEnabled = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SNI", false); + private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; private final String sni; @@ -219,12 +218,12 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { this.sslContextProviderSupplier = sslContextProviderSupplier; EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); - String sniVal = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) - ? endpointHostname : upstreamTlsContext.getSni(); - if (Strings.isNullOrEmpty(sniVal) && useChannelAuthorityIfNoSniApplicable) { - sniVal = grpcHandler.getAuthority(); + if (isXdsSniEnabled) { + sni = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) + ? endpointHostname : upstreamTlsContext.getSni(); + } else { + sni = grpcHandler.getAuthority(); } - sni = sniVal; } @VisibleForTesting @@ -250,7 +249,7 @@ public void updateSslContext(SslContext sslContext) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext, sni).newHandler(grpcHandler); + InternalProtocolNegotiators.tls(sslContext, sni, true).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index b32f0821b60..970381f4bb5 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -27,6 +27,7 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.internal.GrpcUtil; import io.grpc.internal.SpiffeUtil; import java.net.Socket; import java.security.cert.CertificateException; @@ -52,6 +53,8 @@ */ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509TrustManager { + static boolean isXdsSniEnabled = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SNI", false); + // ref: io.grpc.okhttp.internal.OkHostnameVerifier and // sun.security.x509.GeneralNameInterface private static final int ALT_DNS_NAME = 2; @@ -217,7 +220,7 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi return; } @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names - List verifyList = !Strings.isNullOrEmpty(sniForSanMatching) + List verifyList = isXdsSniEnabled && !Strings.isNullOrEmpty(sniForSanMatching) ? ImmutableList.of(StringMatcher.newBuilder().setExact(sniForSanMatching).build()) : certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index bc437cf3812..ac5f4e92b2a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -282,10 +282,6 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { } } - /** - * Use system root ca cert for TLS channel - no mTLS. - * Subj Alt Names to match are specified in the validation context. - */ @Test public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() throws Exception { @@ -346,7 +342,7 @@ public void tlsClientServer_autoSniValidation_sniInUTC() } @Test - public void tlsClientServer_sni_san_validation_from_hostname() + public void tlsClientServer_autoSniValidation_sniFromHostname() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); try { diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index a27871b9915..8d92c5fbb31 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -145,26 +145,31 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() @Test public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPassedToClientSecurityHandler() { - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true, false); - ClientSecurityProtocolNegotiator pn = - new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); - GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); - ChannelLogger logger = mock(ChannelLogger.class); - doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); - when(mockHandler.getNegotiationLogger()).thenReturn(logger); - TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); - when(mockHandler.getEagAttributes()) - .thenReturn( - Attributes.newBuilder() - .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, - new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) - .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) - .build()); - ChannelHandler newHandler = pn.newHandler(mockHandler); - assertThat(newHandler).isNotNull(); - assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); - assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); + ClientSecurityHandler.isXdsSniEnabled = true; + try { + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true, false); + ClientSecurityProtocolNegotiator pn = + new ClientSecurityProtocolNegotiator(InternalProtocolNegotiators.plaintext()); + GrpcHttp2ConnectionHandler mockHandler = mock(GrpcHttp2ConnectionHandler.class); + ChannelLogger logger = mock(ChannelLogger.class); + doNothing().when(logger).log(any(ChannelLogLevel.class), anyString()); + when(mockHandler.getNegotiationLogger()).thenReturn(logger); + TlsContextManager mockTlsContextManager = mock(TlsContextManager.class); + when(mockHandler.getEagAttributes()) + .thenReturn( + Attributes.newBuilder() + .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, + new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) + .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) + .build()); + ChannelHandler newHandler = pn.newHandler(mockHandler); + assertThat(newHandler).isNotNull(); + assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); + assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); + } finally { + ClientSecurityHandler.isXdsSniEnabled = false; + } } @Test @@ -222,66 +227,105 @@ protected void onException(Throwable throwable) { @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, true); - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); + ClientSecurityHandler.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); - assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); + assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); + } finally { + ClientSecurityHandler.isXdsSniEnabled = false; + } } @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsEmpty_usesSniFromUpstreamTlsContext() { - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); + ClientSecurityHandler.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, ""); + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, ""); - assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } finally { + ClientSecurityHandler.isXdsSniEnabled = false; + } } @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsNull_usesSniFromUpstreamTlsContext() { - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); + ClientSecurityHandler.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, true); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, null); + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, null); - assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } finally { + ClientSecurityHandler.isXdsSniEnabled = false; + } } @Test public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTlsContext() { + ClientSecurityHandler.isXdsSniEnabled = true; + try { + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, false); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); + } finally { + ClientSecurityHandler.isXdsSniEnabled = false; + } + } + + @Test + public void sniFeatureNotEnabled_usesChannelAuthorityForSni() { + ClientSecurityHandler.isXdsSniEnabled = false; Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, SNI_IN_UTC, false); + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, "", false); SslContextProviderSupplier sslContextProviderSupplier = new SslContextProviderSupplier(upstreamTlsContext, new TlsContextManagerImpl(bootstrapInfoForClient)); @@ -289,30 +333,7 @@ public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTls ClientSecurityHandler clientSecurityHandler = new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); - assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); - } - - @Test - public void emptySni_useChannelAuthorityIfNoSniApplicableIsTrue_usesChannelAuthority() { - SecurityProtocolNegotiators.useChannelAuthorityIfNoSniApplicable = true; - try { - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, "", false); - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); - - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); - - assertThat(clientSecurityHandler.getSni()).isEqualTo(FAKE_AUTHORITY); - } finally { - SecurityProtocolNegotiators.useChannelAuthorityIfNoSniApplicable = false; - } + assertThat(clientSecurityHandler.getSni()).isEqualTo(FAKE_AUTHORITY); } @Test @@ -477,53 +498,58 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { @Test public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() throws InterruptedException, TimeoutException, ExecutionException { - FakeClock executor = new FakeClock(); - CommonCertProviderTestUtils.register(executor); - Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils - .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, - CA_PEM_FILE, null, null, null, null, null); - UpstreamTlsContext upstreamTlsContext = - CommonTlsContextTestsUtil - .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); - - SslContextProviderSupplier sslContextProviderSupplier = - new SslContextProviderSupplier(upstreamTlsContext, - new TlsContextManagerImpl(bootstrapInfoForClient)); - ClientSecurityHandler clientSecurityHandler = - new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); - - pipeline.addLast(clientSecurityHandler); - channelHandlerCtx = pipeline.context(clientSecurityHandler); - assertNotNull(channelHandlerCtx); // non-null since we just added it - - // kick off protocol negotiation. - pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); - final SettableFuture future = SettableFuture.create(); - sslContextProviderSupplier - .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { - @Override - public void updateSslContext(SslContext sslContext) { - future.set(sslContext); - } - - @Override - protected void onException(Throwable throwable) { - future.set(throwable); - } - }, null); - executor.runDueTasks(); - channel.runPendingTasks(); // need this for tasks to execute on eventLoop - Object fromFuture = future.get(5, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); - channel.runPendingTasks(); - channelHandlerCtx = pipeline.context(clientSecurityHandler); - assertThat(channelHandlerCtx).isNull(); - Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; + ClientSecurityHandler.isXdsSniEnabled = true; + try { + FakeClock executor = new FakeClock(); + CommonCertProviderTestUtils.register(executor); + Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils + .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, + CA_PEM_FILE, null, null, null, null, null); + UpstreamTlsContext upstreamTlsContext = + CommonTlsContextTestsUtil + .buildUpstreamTlsContext("google_cloud_private_spiffe-client", true, null, false); - pipeline.fireUserEventTriggered(sslEvent); - channel.runPendingTasks(); // need this for tasks to execute on eventLoop - assertTrue(channel.isOpen()); - CommonCertProviderTestUtils.register0(); + SslContextProviderSupplier sslContextProviderSupplier = + new SslContextProviderSupplier(upstreamTlsContext, + new TlsContextManagerImpl(bootstrapInfoForClient)); + ClientSecurityHandler clientSecurityHandler = + new ClientSecurityHandler(grpcHandler, sslContextProviderSupplier, HOSTNAME); + + pipeline.addLast(clientSecurityHandler); + channelHandlerCtx = pipeline.context(clientSecurityHandler); + assertNotNull(channelHandlerCtx); // non-null since we just added it + + // kick off protocol negotiation. + pipeline.fireUserEventTriggered(InternalProtocolNegotiationEvent.getDefault()); + final SettableFuture future = SettableFuture.create(); + sslContextProviderSupplier + .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { + @Override + public void updateSslContext(SslContext sslContext) { + future.set(sslContext); + } + + @Override + protected void onException(Throwable throwable) { + future.set(throwable); + } + }, null); + executor.runDueTasks(); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + Object fromFuture = future.get(5, TimeUnit.SECONDS); + assertThat(fromFuture).isInstanceOf(SslContext.class); + channel.runPendingTasks(); + channelHandlerCtx = pipeline.context(clientSecurityHandler); + assertThat(channelHandlerCtx).isNull(); + Object sslEvent = SslHandshakeCompletionEvent.SUCCESS; + + pipeline.fireUserEventTriggered(sslEvent); + channel.runPendingTasks(); // need this for tasks to execute on eventLoop + assertTrue(channel.isOpen()); + CommonCertProviderTestUtils.register0(); + } finally { + ClientSecurityHandler.isXdsSniEnabled = false; + } } @Test diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index 40232ed9425..e808635f61c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -189,7 +189,7 @@ public void oneSanInPeerCertsVerifies() throws CertificateException, IOException } @Test - public void autoSanSniValidation_precedes_subAltNamesToMatch() throws CertificateException, IOException { + public void autoSanSniValidation_overrides_subAltNamesToMatch() throws CertificateException, IOException { StringMatcher stringMatcher = StringMatcher.newBuilder() .setExact("notgonnabeused.test.google.be") From 2985cc39679e86a4d2827ff216d64c06d764be91 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 17 Sep 2025 10:07:41 +0000 Subject: [PATCH 39/51] Fixes. --- .../io/grpc/netty/ProtocolNegotiators.java | 21 ++++-- .../security/SecurityProtocolNegotiators.java | 4 +- .../security/trust/CertificateUtils.java | 4 ++ .../security/trust/XdsX509TrustManager.java | 4 +- .../grpc/xds/XdsSecurityClientServerTest.java | 13 +++- .../SecurityProtocolNegotiatorsTest.java | 31 +++++---- .../security/TlsContextManagerTest.java | 2 +- .../trust/XdsX509TrustManagerTest.java | 69 ++++++++++++++----- 8 files changed, 103 insertions(+), 45 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 06d810fcbda..4b935cd7b96 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -21,6 +21,7 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Optional; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.errorprone.annotations.ForOverride; import io.grpc.Attributes; import io.grpc.CallCredentials; @@ -89,6 +90,7 @@ import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; + import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** @@ -641,16 +643,21 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private final X509TrustManager x509ExtendedTrustManager; private SSLEngine sslEngine; - ClientTlsHandler(ChannelHandler next, SslContext sslContext, String authority, + ClientTlsHandler(ChannelHandler next, SslContext sslContext, String sniHostPort, Executor executor, ChannelLogger negotiationLogger, Optional handshakeCompleteRunnable, ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); - HostPort hostPort = parseAuthority(authority); - this.host = hostPort.host; - this.port = hostPort.port; + if (!Strings.isNullOrEmpty(sniHostPort)) { + HostPort hostPort = parseAuthority(sniHostPort); + this.host = hostPort.host; + this.port = hostPort.port; + } else { + this.host = null; + this.port = 0; + } this.executor = executor; this.handshakeCompleteRunnable = handshakeCompleteRunnable; this.x509ExtendedTrustManager = x509ExtendedTrustManager; @@ -659,7 +666,11 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { @Override @IgnoreJRERequirement protected void handlerAdded0(ChannelHandlerContext ctx) { - sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + if (host != null) { + sslEngine = sslContext.newEngine(ctx.alloc(), host, port); + } else { + sslEngine = sslContext.newEngine(ctx.alloc()); + } SSLParameters sslParams = sslEngine.getSSLParameters(); sslParams.setEndpointIdentificationAlgorithm("HTTPS"); sslEngine.setSSLParameters(sslParams); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 75692716b65..bd6eee70304 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -33,6 +33,7 @@ import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; import io.netty.channel.ChannelHandlerContext; @@ -193,7 +194,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @VisibleForTesting static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { - static boolean isXdsSniEnabled = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SNI", false); private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; @@ -218,7 +218,7 @@ public void handlerAdded(ChannelHandlerContext ctx) throws Exception { this.sslContextProviderSupplier = sslContextProviderSupplier; EnvoyServerProtoData.BaseTlsContext tlsContext = sslContextProviderSupplier.getTlsContext(); UpstreamTlsContext upstreamTlsContext = ((UpstreamTlsContext) tlsContext); - if (isXdsSniEnabled) { + if (CertificateUtils.isXdsSniEnabled) { sni = upstreamTlsContext.getAutoHostSni() && !Strings.isNullOrEmpty(endpointHostname) ? endpointHostname : upstreamTlsContext.getSni(); } else { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java index 86b6dd95c3e..9ad871f29ed 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java @@ -16,6 +16,8 @@ package io.grpc.xds.internal.security.trust; +import io.grpc.internal.GrpcUtil; + import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; @@ -29,6 +31,8 @@ * Contains certificate utility method(s). */ public final class CertificateUtils { + public static boolean isXdsSniEnabled = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SNI", false); + /** * Generates X509Certificate array from a file on disk. * diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 970381f4bb5..21cb8b731d6 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -53,8 +53,6 @@ */ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509TrustManager { - static boolean isXdsSniEnabled = GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_SNI", false); - // ref: io.grpc.okhttp.internal.OkHostnameVerifier and // sun.security.x509.GeneralNameInterface private static final int ALT_DNS_NAME = 2; @@ -220,7 +218,7 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi return; } @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names - List verifyList = isXdsSniEnabled && !Strings.isNullOrEmpty(sniForSanMatching) + List verifyList = CertificateUtils.isXdsSniEnabled && !Strings.isNullOrEmpty(sniForSanMatching) ? ImmutableList.of(StringMatcher.newBuilder().setExact(sniForSanMatching).build()) : certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index ac5f4e92b2a..454f66458e2 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -76,6 +76,7 @@ import io.grpc.xds.internal.security.SslContextProviderSupplier; import io.grpc.xds.internal.security.TlsContextManagerImpl; import io.grpc.xds.internal.security.certprovider.FileWatcherCertificateProviderProvider; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.handler.ssl.NotSslRecordException; import java.io.File; import java.io.FileOutputStream; @@ -315,6 +316,7 @@ public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() @Test public void tlsClientServer_autoSniValidation_sniInUTC() throws Exception { + CertificateUtils.isXdsSniEnabled = true; Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -338,12 +340,14 @@ public void tlsClientServer_autoSniValidation_sniInUTC() } finally { Files.deleteIfExists(trustStoreFilePath); clearTrustStoreSystemProperties(); + CertificateUtils.isXdsSniEnabled = false; } } @Test public void tlsClientServer_autoSniValidation_sniFromHostname() throws Exception { + CertificateUtils.isXdsSniEnabled = true; Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -370,12 +374,14 @@ public void tlsClientServer_autoSniValidation_sniFromHostname() } finally { Files.deleteIfExists(trustStoreFilePath); clearTrustStoreSystemProperties(); + CertificateUtils.isXdsSniEnabled = false; } } @Test public void tlsClientServer_autoSniValidation_noSNIApplicable_usesMatcherFromCmnVdnCtx() throws Exception { + CertificateUtils.isXdsSniEnabled = true; Path trustStoreFilePath = getCacertFilePathForTestCa(); try { setTrustStoreSystemProperties(trustStoreFilePath.toAbsolutePath().toString()); @@ -399,6 +405,7 @@ public void tlsClientServer_autoSniValidation_noSNIApplicable_usesMatcherFromCmn } finally { Files.deleteIfExists(trustStoreFilePath); clearTrustStoreSystemProperties(); + CertificateUtils.isXdsSniEnabled = false; } } @@ -815,7 +822,7 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( // in the certificate SNI, which isn't implemented yet (https://github.com/grpc/grpc-java/pull/12345 implements it) // so use an exact match SAN such as waterzooi.test.google.be for SNI for this testcase. private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( - final UpstreamTlsContext upstreamTlsContext, String overrideAuthority, String addrAttribute) { + final UpstreamTlsContext upstreamTlsContext, String overrideAuthority, String addrNameAttribute) { ManagedChannelBuilder channelBuilder = Grpc.newChannelBuilder( "sectest://localhost:" + port, @@ -833,8 +840,8 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( new SslContextProviderSupplier( upstreamTlsContext, tlsContextManagerForClient)) : Attributes.newBuilder(); - if (addrAttribute != null) { - sslContextAttributesBuilder.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, addrAttribute); + if (addrNameAttribute != null) { + sslContextAttributesBuilder.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, addrNameAttribute); } sslContextAttributes = sslContextAttributesBuilder.build(); fakeNameResolverFactory.setServers( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 8d92c5fbb31..103b5d8b385 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -51,6 +51,7 @@ import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityHandler; import io.grpc.xds.internal.security.SecurityProtocolNegotiators.ClientSecurityProtocolNegotiator; import io.grpc.xds.internal.security.certprovider.CommonCertProviderTestUtils; +import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPipeline; @@ -145,7 +146,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_withTlsContextAttribute() @Test public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPassedToClientSecurityHandler() { - ClientSecurityHandler.isXdsSniEnabled = true; + CertificateUtils.isXdsSniEnabled = true; try { UpstreamTlsContext upstreamTlsContext = CommonTlsContextTestsUtil.buildUpstreamTlsContext(CommonTlsContext.newBuilder().build(), null, true, false); @@ -168,7 +169,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPas assertThat(newHandler).isInstanceOf(ClientSecurityHandler.class); assertThat(((ClientSecurityHandler) newHandler).getSni()).isEqualTo(FAKE_AUTHORITY); } finally { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; } } @@ -207,7 +208,7 @@ public void updateSslContext(SslContext sslContext) { protected void onException(Throwable throwable) { future.set(throwable); } - }, null); + }, FAKE_AUTHORITY); assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); @@ -227,7 +228,7 @@ protected void onException(Throwable throwable) { @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() { - ClientSecurityHandler.isXdsSniEnabled = true; + CertificateUtils.isXdsSniEnabled = true; try { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, @@ -244,13 +245,13 @@ public void sniInClientSecurityHandler_autoHostSniIsTrue_usesEndpointHostname() assertThat(clientSecurityHandler.getSni()).isEqualTo(HOSTNAME); } finally { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; } } @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsEmpty_usesSniFromUpstreamTlsContext() { - ClientSecurityHandler.isXdsSniEnabled = true; + CertificateUtils.isXdsSniEnabled = true; try { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, @@ -267,13 +268,13 @@ public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsEmpty assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); } finally { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; } } @Test public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsNull_usesSniFromUpstreamTlsContext() { - ClientSecurityHandler.isXdsSniEnabled = true; + CertificateUtils.isXdsSniEnabled = true; try { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, @@ -290,13 +291,13 @@ public void sniInClientSecurityHandler_autoHostSniIsTrue_endpointHostnameIsNull_ assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); } finally { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; } } @Test public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTlsContext() { - ClientSecurityHandler.isXdsSniEnabled = true; + CertificateUtils.isXdsSniEnabled = true; try { Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, @@ -313,13 +314,13 @@ public void sniInClientSecurityHandler_autoHostSniIsFalse_usesSniFromUpstreamTls assertThat(clientSecurityHandler.getSni()).isEqualTo(SNI_IN_UTC); } finally { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; } } @Test public void sniFeatureNotEnabled_usesChannelAuthorityForSni() { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; Bootstrapper.BootstrapInfo bootstrapInfoForClient = CommonBootstrapperTestUtils .buildBootstrapInfo("google_cloud_private_spiffe-client", CLIENT_KEY_FILE, CLIENT_PEM_FILE, CA_PEM_FILE, null, null, null, null, null); @@ -498,7 +499,7 @@ public void nullTlsContext_nullFallbackProtocolNegotiator_expectException() { @Test public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEvent() throws InterruptedException, TimeoutException, ExecutionException { - ClientSecurityHandler.isXdsSniEnabled = true; + CertificateUtils.isXdsSniEnabled = true; try { FakeClock executor = new FakeClock(); CommonCertProviderTestUtils.register(executor); @@ -533,7 +534,7 @@ public void updateSslContext(SslContext sslContext) { protected void onException(Throwable throwable) { future.set(throwable); } - }, null); + }, ""); executor.runDueTasks(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop Object fromFuture = future.get(5, TimeUnit.SECONDS); @@ -548,7 +549,7 @@ protected void onException(Throwable throwable) { assertTrue(channel.isOpen()); CommonCertProviderTestUtils.register0(); } finally { - ClientSecurityHandler.isXdsSniEnabled = false; + CertificateUtils.isXdsSniEnabled = false; } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 67817aff71d..66352aa27bd 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -169,7 +169,7 @@ public void createClientSslContextProvider_releaseInstance() { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); SslContextProvider mockProvider = mock(SslContextProvider.class); - when(mockClientFactory.create(new AbstractMap.SimpleImmutableEntry("sni", upstreamTlsContext))) + when(mockClientFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI))) .thenReturn(mockProvider); SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java index e808635f61c..09f5b1028cc 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsX509TrustManagerTest.java @@ -108,6 +108,7 @@ public void emptySanListContextTest() throws CertificateException, IOException { @Test public void missingPeerCerts() { if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; trustManager = new XdsX509TrustManager( CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); } else { @@ -122,12 +123,17 @@ public void missingPeerCerts() { fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @Test public void emptyArrayPeerCerts() { if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; trustManager = new XdsX509TrustManager( CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); } else { @@ -142,12 +148,17 @@ public void emptyArrayPeerCerts() { fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate(s) missing"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @Test public void noSansInPeerCerts() throws CertificateException, IOException { if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; trustManager = new XdsX509TrustManager( CertificateValidationContext.getDefaultInstance(), mockDelegate, "foo.com"); } else { @@ -164,13 +175,18 @@ public void noSansInPeerCerts() throws CertificateException, IOException { fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @Test public void oneSanInPeerCertsVerifies() throws CertificateException, IOException { if (useSniForSanMatching) { - trustManager = new XdsX509TrustManager( + CertificateUtils.isXdsSniEnabled = true; + trustManager = new XdsX509TrustManager( CertificateValidationContext.getDefaultInstance(), mockDelegate, "waterzooi.test.google.be"); } else { StringMatcher stringMatcher = @@ -183,25 +199,36 @@ public void oneSanInPeerCertsVerifies() throws CertificateException, IOException CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); trustManager = new XdsX509TrustManager(certContext, mockDelegate); } - X509Certificate[] certs = - CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + try { + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } + } } @Test public void autoSanSniValidation_overrides_subAltNamesToMatch() throws CertificateException, IOException { - StringMatcher stringMatcher = - StringMatcher.newBuilder() - .setExact("notgonnabeused.test.google.be") - .setIgnoreCase(false) - .build(); - @SuppressWarnings("deprecation") - CertificateValidationContext certContext = - CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); - trustManager = new XdsX509TrustManager(certContext, mockDelegate, "waterzooi.test.google.be"); - X509Certificate[] certs = - CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); - trustManager.verifySubjectAltNameInChain(certs); + CertificateUtils.isXdsSniEnabled = true; + try { + StringMatcher stringMatcher = + StringMatcher.newBuilder() + .setExact("notgonnabeused.test.google.be") + .setIgnoreCase(false) + .build(); + @SuppressWarnings("deprecation") + CertificateValidationContext certContext = + CertificateValidationContext.newBuilder().addMatchSubjectAltNames(stringMatcher).build(); + trustManager = new XdsX509TrustManager(certContext, mockDelegate, "waterzooi.test.google.be"); + X509Certificate[] certs = + CertificateUtils.toX509Certificates(TlsTesting.loadCert(SERVER_1_PEM_FILE)); + trustManager.verifySubjectAltNameInChain(certs); + } finally { + CertificateUtils.isXdsSniEnabled = false; + } } @Test @@ -487,6 +514,7 @@ public void oneSanInPeerCertsVerifiesMultipleVerifySans() public void oneSanInPeerCertsNotFoundException() throws CertificateException, IOException { if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, "x.foo.com"); } else { @@ -503,6 +531,10 @@ public void oneSanInPeerCertsNotFoundException() fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } @@ -549,6 +581,7 @@ public void wildcardSanInPeerCertsSubdomainMismatch() // For example, *.example.com matches test.example.com but does not match // sub.test.example.com. if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = true; trustManager = new XdsX509TrustManager(CertificateValidationContext.getDefaultInstance(), mockDelegate, "sub.abc.test.youtube.com"); } else { @@ -566,6 +599,10 @@ public void wildcardSanInPeerCertsSubdomainMismatch() fail("no exception thrown"); } catch (CertificateException expected) { assertThat(expected).hasMessageThat().isEqualTo("Peer certificate SAN check failed"); + } finally { + if (useSniForSanMatching) { + CertificateUtils.isXdsSniEnabled = false; + } } } From 011a9ea92c8747c8014b25bb9ab8869323ead61b Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 17 Sep 2025 11:39:29 +0000 Subject: [PATCH 40/51] Allow trustedRootCerts to be present in static CertificateValidationContext when SystemRootCerts is also there. --- .../security/trust/XdsTrustManagerFactory.java | 3 ++- .../security/trust/XdsTrustManagerFactoryTest.java | 13 +++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index 3ba31d8ff2b..7ad24b75d98 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -81,7 +81,8 @@ private XdsTrustManagerFactory( throws CertStoreException { if (validationContextIsStatic) { checkArgument( - certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), + certificateValidationContext == null || !certificateValidationContext.hasTrustedCa() + || certificateValidationContext.hasSystemRootCerts(), "only static certificateValidationContext expected"); } xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext, sniForSanMatching); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java index 36e75327419..5d156e29c08 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactoryTest.java @@ -176,6 +176,19 @@ public void constructorRootCert_nonStaticContext_throwsException() } } + @Test + public void constructorRootCert_nonStaticContext_systemRootCerts_valid() + throws CertificateException, IOException, CertStoreException { + X509Certificate x509Cert = TestUtils.loadX509Cert(CA_PEM_FILE); + CertificateValidationContext certValidationContext = CertificateValidationContext.newBuilder() + .setTrustedCa( + DataSource.newBuilder().setFilename(TestUtils.loadCert(CA_PEM_FILE).getAbsolutePath())) + .setSystemRootCerts(CertificateValidationContext.SystemRootCerts.getDefaultInstance()) + .build(); + new XdsTrustManagerFactory( + new X509Certificate[] {x509Cert}, certValidationContext, null); + } + @Test public void constructorRootCert_checkServerTrusted_throwsException() throws CertificateException, IOException, CertStoreException { From b828098de1b60ada3a65cb1cfa72c8fd5ab325b2 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Fri, 19 Sep 2025 14:25:23 +0000 Subject: [PATCH 41/51] Pass extended trust manager to protocol negotiator. --- .../io/grpc/internal/CertificateUtils.java | 22 +++++++++ .../netty/InternalProtocolNegotiators.java | 14 ++++-- .../io/grpc/netty/ProtocolNegotiators.java | 20 +------- .../S2AProtocolNegotiatorFactory.java | 2 +- .../security/DynamicSslContextProvider.java | 47 +++++++++++-------- .../security/SecurityProtocolNegotiators.java | 14 ++++-- .../internal/security/SslContextProvider.java | 13 +++-- .../security/SslContextProviderSupplier.java | 6 ++- .../CertProviderClientSslContextProvider.java | 40 +++++++++------- .../CertProviderServerSslContextProvider.java | 10 +++- .../security/CommonTlsContextTestsUtil.java | 16 ++++--- .../SecurityProtocolNegotiatorsTest.java | 9 ++-- .../SslContextProviderSupplierTest.java | 27 +++++++---- ...tProviderClientSslContextProviderTest.java | 32 ++++++------- ...tProviderServerSslContextProviderTest.java | 20 ++++---- 15 files changed, 176 insertions(+), 116 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java index 91d17de93cb..e0171e8d357 100644 --- a/core/src/main/java/io/grpc/internal/CertificateUtils.java +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -26,6 +26,7 @@ import java.security.cert.CertificateFactory; import java.security.cert.X509Certificate; import java.util.Collection; +import java.util.List; import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.security.auth.x500.X500Principal; @@ -34,6 +35,16 @@ * Contains certificate/key PEM file utility method(s) for internal usage. */ public final class CertificateUtils { + private static Class x509ExtendedTrustManagerClass; + + static { + try { + x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + } catch (ClassNotFoundException e) { + // Will disallow per-rpc authority override via call option. + } + } + /** * Creates X509TrustManagers using the provided CA certs. */ @@ -71,6 +82,17 @@ public static TrustManager[] createTrustManager(InputStream rootCerts) return trustManagerFactory.getTrustManagers(); } + public static TrustManager getX509ExtendedTrustManager(List trustManagers) { + if (x509ExtendedTrustManagerClass != null) { + for (TrustManager trustManager : trustManagers) { + if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { + return trustManager; + } + } + } + return null; + } + private static X509Certificate[] getX509Certificates(InputStream inputStream) throws CertificateException { CertificateFactory factory = CertificateFactory.getInstance("X.509"); diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index ed9af8e237f..b62e37d51e6 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -25,6 +25,9 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; + +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; import java.util.concurrent.Executor; /** @@ -39,15 +42,16 @@ private InternalProtocolNegotiators() {} * be negotiated, the {@code handler} is added and writes to the {@link io.netty.channel.Channel} * may happen immediately, even before the TLS Handshake is complete. * - * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks - * @param isXdsTarget + * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, ObjectPool executorPool, Optional handshakeCompleteRunnable, + TrustManager extendedX509TrustManager, String sni, boolean isXdsTarget) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, - executorPool, handshakeCompleteRunnable, null, sni, isXdsTarget); + executorPool, handshakeCompleteRunnable, (X509TrustManager) extendedX509TrustManager, sni, + isXdsTarget); final class TlsNegotiator implements InternalProtocolNegotiator.ProtocolNegotiator { @Override @@ -75,8 +79,8 @@ public void close() { * may happen immediately, even before the TLS Handshake is complete. */ public static InternalProtocolNegotiator.ProtocolNegotiator tls( - SslContext sslContext, String sni, boolean isXdsTarget) { - return tls(sslContext, null, Optional.absent(), sni, isXdsTarget); + SslContext sslContext, String sni, boolean isXdsTarget, TrustManager extendedX509TrustManager) { + return tls(sslContext, null, Optional.absent(), extendedX509TrustManager, sni, isXdsTarget); } /** diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 4b935cd7b96..c1ef4dcfc42 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -104,15 +104,6 @@ final class ProtocolNegotiators { private static final EnumSet understoodServerTlsFeatures = EnumSet.of( TlsServerCredentials.Feature.MTLS, TlsServerCredentials.Feature.CUSTOM_MANAGERS); - private static Class x509ExtendedTrustManagerClass; - - static { - try { - x509ExtendedTrustManagerClass = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); - } catch (ClassNotFoundException e) { - // Will disallow per-rpc authority override via call option. - } - } private ProtocolNegotiators() { } @@ -149,15 +140,8 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { trustManagers = Arrays.asList(tmf.getTrustManagers()); } builder.trustManager(new FixedTrustManagerFactory(trustManagers)); - TrustManager x509ExtendedTrustManager = null; - if (x509ExtendedTrustManagerClass != null) { - for (TrustManager trustManager : trustManagers) { - if (x509ExtendedTrustManagerClass.isInstance(trustManager)) { - x509ExtendedTrustManager = trustManager; - break; - } - } - } + TrustManager x509ExtendedTrustManager = + CertificateUtils.getX509ExtendedTrustManager(trustManagers); return FromChannelCredentialsResult.negotiator(tlsClientFactory(builder.build(), (X509TrustManager) x509ExtendedTrustManager)); } catch (SSLException | GeneralSecurityException ex) { diff --git a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java index cbcad749109..e1fe04d26d3 100644 --- a/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java +++ b/s2a/src/main/java/io/grpc/s2a/internal/handshaker/S2AProtocolNegotiatorFactory.java @@ -259,7 +259,7 @@ public void run() { s2aStub.close(); } }), - null, false) + null, null, false) .newHandler(grpcHandler); // Delegate the rest of the handshake to the TLS handler. and remove the diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 6bf66d022ff..0530b18fb14 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -30,9 +30,11 @@ import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** Base class for dynamic {@link SslContextProvider}s. */ @Internal @@ -40,7 +42,8 @@ public abstract class DynamicSslContextProvider extends SslContextProvider { protected final List pendingCallbacks = new ArrayList<>(); @Nullable protected final CertificateValidationContext staticCertificateValidationContext; - @Nullable protected SslContext sslContext; + @Nullable protected AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX509TrustManager; protected DynamicSslContextProvider( BaseTlsContext tlsContext, CertificateValidationContext staticCertValidationContext) { @@ -49,14 +52,16 @@ protected DynamicSslContextProvider( } @Nullable - public SslContext getSslContext() { - return sslContext; + public AbstractMap.SimpleImmutableEntry getSslContextAndExtendedX509TrustManager() { + return sslContextAndExtendedX509TrustManager; } protected abstract CertificateValidationContext generateCertificateValidationContext(); - /** Gets a server or client side SslContextBuilder. */ - protected abstract SslContextBuilder getSslContextBuilder( + /** + * Gets a server or client side SslContextBuilder. + */ + protected abstract AbstractMap.SimpleImmutableEntry getSslContextBuilderAndExtendedX509TrustManager( CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException; @@ -65,7 +70,8 @@ protected final void updateSslContext() { try { CertificateValidationContext localCertValidationContext = generateCertificateValidationContext(); - SslContextBuilder sslContextBuilder = getSslContextBuilder(localCertValidationContext); + AbstractMap.SimpleImmutableEntry sslContextBuilderAndTm = + getSslContextBuilderAndExtendedX509TrustManager(localCertValidationContext); CommonTlsContext commonTlsContext = getCommonTlsContext(); if (commonTlsContext != null && commonTlsContext.getAlpnProtocolsCount() > 0) { List alpnList = commonTlsContext.getAlpnProtocolsList(); @@ -75,16 +81,17 @@ protected final void updateSslContext() { ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, alpnList); - sslContextBuilder.applicationProtocolConfig(apn); + sslContextBuilderAndTm.getKey().applicationProtocolConfig(apn); } List pendingCallbacksCopy; - SslContext sslContextCopy; + AbstractMap.SimpleImmutableEntry sslContextAndExtendedX09TrustManagerCopy; synchronized (pendingCallbacks) { - sslContext = sslContextBuilder.build(); - sslContextCopy = sslContext; + sslContextAndExtendedX509TrustManager = new AbstractMap.SimpleImmutableEntry<>( + sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); + sslContextAndExtendedX09TrustManagerCopy = sslContextAndExtendedX509TrustManager; pendingCallbacksCopy = clonePendingCallbacksAndClear(); } - makePendingCallbacks(sslContextCopy, pendingCallbacksCopy); + makePendingCallbacks(sslContextAndExtendedX09TrustManagerCopy, pendingCallbacksCopy); } catch (Exception e) { onError(Status.fromThrowable(e)); throw new RuntimeException(e); @@ -92,12 +99,12 @@ protected final void updateSslContext() { } protected final void callPerformCallback( - Callback callback, final SslContext sslContextCopy) { + Callback callback, final AbstractMap.SimpleImmutableEntry sslContextAndTmCopy) { performCallback( new SslContextGetter() { @Override - public SslContext get() { - return sslContextCopy; + public AbstractMap.SimpleImmutableEntry get() { + return sslContextAndTmCopy; } }, callback @@ -108,10 +115,10 @@ public SslContext get() { public final void addCallback(Callback callback) { checkNotNull(callback, "callback"); // if there is a computed sslContext just send it - SslContext sslContextCopy = null; + AbstractMap.SimpleImmutableEntry sslContextCopy = null; synchronized (pendingCallbacks) { - if (sslContext != null) { - sslContextCopy = sslContext; + if (sslContextAndExtendedX509TrustManager != null) { + sslContextCopy = sslContextAndExtendedX509TrustManager; } else { pendingCallbacks.add(callback); } @@ -122,9 +129,11 @@ public final void addCallback(Callback callback) { } private final void makePendingCallbacks( - SslContext sslContextCopy, List pendingCallbacksCopy) { + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX509TrustManagerCopy, + List pendingCallbacksCopy) { for (Callback callback : pendingCallbacksCopy) { - callPerformCallback(callback, sslContextCopy); + callPerformCallback(callback, sslContextAndExtendedX509TrustManagerCopy); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index bd6eee70304..7bedba4e0a0 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -41,12 +41,14 @@ import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.ArrayList; import java.util.List; import java.util.concurrent.Executor; import java.util.logging.Level; import java.util.logging.Logger; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** * Provides client and server side gRPC {@link ProtocolNegotiator}s to provide the SSL @@ -240,7 +242,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { if (ctx.isRemoved()) { return; } @@ -249,7 +252,9 @@ public void updateSslContext(SslContext sslContext) { "ClientSecurityHandler.updateSslContext authority={0}, ctx.name={1}", new Object[]{grpcHandler.getAuthority(), ctx.name()}); ChannelHandler handler = - InternalProtocolNegotiators.tls(sslContext, sni, true).newHandler(grpcHandler); + InternalProtocolNegotiators.tls( + sslContextAndTm.getKey(), sni, true, sslContextAndTm.getValue()) + .newHandler(grpcHandler); // Delegate rest of handshake to TLS handler ctx.pipeline().addAfter(ctx.name(), null, handler); @@ -383,9 +388,10 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { new SslContextProvider.Callback(ctx.executor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { ChannelHandler handler = - InternalProtocolNegotiators.serverTls(sslContext).newHandler(grpcHandler); + InternalProtocolNegotiators.serverTls(sslContextAndTm.getKey()).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler if (!ctx.isRemoved()) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index fcd3a899c8d..31df4b797e7 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -29,9 +29,12 @@ import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; + +import javax.net.ssl.TrustManager; import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; +import java.util.AbstractMap; import java.util.concurrent.Executor; /** @@ -59,7 +62,8 @@ protected Callback(Executor executor) { } /** Informs callee of new/updated SslContext. */ - @VisibleForTesting public abstract void updateSslContext(SslContext sslContext); + @VisibleForTesting public abstract void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContext); /** Informs callee of an exception that was generated. */ @VisibleForTesting protected abstract void onException(Throwable throwable); @@ -121,8 +125,9 @@ protected final void performCallback( @Override public void run() { try { - SslContext sslContext = sslContextGetter.get(); - callback.updateSslContext(sslContext); + AbstractMap.SimpleImmutableEntry sslContextAndTm = + sslContextGetter.get(); + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); } catch (Throwable e) { callback.onException(e); } @@ -132,6 +137,6 @@ public void run() { /** Allows implementations to compute or get SslContext. */ protected interface SslContextGetter { - SslContext get() throws Exception; + AbstractMap.SimpleImmutableEntry get() throws Exception; } } \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index 26661eddfe2..d0afd29c80b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -24,6 +24,8 @@ import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; +import javax.net.ssl.TrustManager; +import java.util.AbstractMap; import java.util.HashSet; import java.util.Objects; import java.util.Set; @@ -70,8 +72,8 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { - callback.updateSslContext(sslContext); + public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContextAndTm) { + callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); releaseSslContextProvider(toRelease, sni); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index 49e3048f3c9..a7729ca03af 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -32,6 +32,7 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Arrays; import java.util.Collection; import java.util.List; @@ -70,7 +71,10 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP try { // Instantiate sslContext so that addCallback will immediately update the callback with // the SslContext. - sslContext = getSslContextBuilder(staticCertificateValidationContext).build(); + AbstractMap.SimpleImmutableEntry sslContextBuilderAndTm = + getSslContextBuilderAndExtendedX509TrustManager(staticCertificateValidationContext); + sslContextAndExtendedX509TrustManager = new AbstractMap.SimpleImmutableEntry( + sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); } catch (CertStoreException | CertificateException | IOException e) { throw new RuntimeException(e); } @@ -78,28 +82,30 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP } @Override - protected final SslContextBuilder getSslContextBuilder( + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndExtendedX509TrustManager( CertificateValidationContext certificateValidationContext) - throws CertificateException, IOException, CertStoreException { + throws CertificateException, IOException, CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); + XdsTrustManagerFactory trustManagerFactory; if (rootCertInstance != null) { if (savedSpiffeTrustMap != null) { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedSpiffeTrustMap, - certificateValidationContext, sniForSanMatching)); + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); } else { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContext, sniForSanMatching)); + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); } } else { try { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - getX509CertificatesFromSystemTrustStore(), - certificateValidationContext, sniForSanMatching)); + trustManagerFactory = new XdsTrustManagerFactory( + getX509CertificatesFromSystemTrustStore(), + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); } catch (KeyStoreException | NoSuchAlgorithmException e) { throw new CertStoreException(e); } @@ -107,7 +113,9 @@ protected final SslContextBuilder getSslContextBuilder( if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } - return sslContextBuilder; + return new AbstractMap.SimpleImmutableEntry<>(sslContextBuilder, + io.grpc.internal.CertificateUtils.getX509ExtendedTrustManager( + Arrays.asList(trustManagerFactory.getTrustManagers()))); } private X509Certificate[] getX509CertificatesFromSystemTrustStore() diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index 2488fcb1199..a356712c31b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -21,6 +21,7 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; +import io.grpc.internal.CertificateUtils; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; @@ -30,8 +31,11 @@ import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; +import java.util.Arrays; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** A server SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderServerSslContextProvider extends CertProviderSslContextProvider { @@ -55,7 +59,7 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP } @Override - protected final SslContextBuilder getSslContextBuilder( + protected final AbstractMap.SimpleImmutableEntry getSslContextBuilderAndExtendedX509TrustManager( CertificateValidationContext certificateValidationContextdationContext) throws CertStoreException, CertificateException, IOException { SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(savedKey, savedCertChain); @@ -71,7 +75,9 @@ protected final SslContextBuilder getSslContextBuilder( } setClientAuthValues(sslContextBuilder, trustManagerFactory); sslContextBuilder = GrpcSslContexts.configure(sslContextBuilder); - return sslContextBuilder; + return new AbstractMap.SimpleImmutableEntry(sslContextBuilder, + CertificateUtils.getX509ExtendedTrustManager( + Arrays.asList(trustManagerFactory.getTrustManagers()))); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java index 539ede1d60f..335dc201483 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/CommonTlsContextTestsUtil.java @@ -36,10 +36,12 @@ import java.io.InputStream; import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import java.util.AbstractMap; import java.util.Arrays; import java.util.List; import java.util.concurrent.Executor; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** Utility class for client and server ssl provider tests. */ public class CommonTlsContextTestsUtil { @@ -359,14 +361,14 @@ private static CommonTlsContext.Builder addNewCertificateValidationContext( } /** Perform some simple checks on sslContext. */ - public static void doChecksOnSslContext(boolean server, SslContext sslContext, - List expectedApnProtos) { + public static void doChecksOnSslContext(boolean server, AbstractMap.SimpleImmutableEntry sslContext, + List expectedApnProtos) { if (server) { - assertThat(sslContext.isServer()).isTrue(); + assertThat(sslContext.getKey().isServer()).isTrue(); } else { - assertThat(sslContext.isClient()).isTrue(); + assertThat(sslContext.getKey().isClient()).isTrue(); } - List apnProtos = sslContext.applicationProtocolNegotiator().protocols(); + List apnProtos = sslContext.getKey().applicationProtocolNegotiator().protocols(); assertThat(apnProtos).isNotNull(); if (expectedApnProtos != null) { assertThat(apnProtos).isEqualTo(expectedApnProtos); @@ -392,7 +394,7 @@ public static TestCallback getValueThruCallback(SslContextProvider provider, Exe public static class TestCallback extends SslContextProvider.Callback { - public SslContext updatedSslContext; + public AbstractMap.SimpleImmutableEntry updatedSslContext; public Throwable updatedThrowable; public TestCallback(Executor executor) { @@ -400,7 +402,7 @@ public TestCallback(Executor executor) { } @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { updatedSslContext = sslContext; } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 103b5d8b385..aae1a1ccf22 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -72,6 +72,7 @@ import java.net.InetSocketAddress; import java.net.SocketAddress; import java.security.cert.CertStoreException; +import java.util.AbstractMap; import java.util.Iterator; import java.util.Map; import java.util.concurrent.ExecutionException; @@ -81,6 +82,8 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; +import javax.net.ssl.TrustManager; + /** Unit tests for {@link SecurityProtocolNegotiators}. */ @RunWith(JUnit4.class) public class SecurityProtocolNegotiatorsTest { @@ -200,7 +203,7 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { future.set(sslContext); } @@ -388,7 +391,7 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { future.set(sslContext); } @@ -526,7 +529,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContext(SslContext sslContext) { + public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { future.set(sslContext); } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index bac85686688..0e83fbb8118 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -31,6 +31,8 @@ import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; + +import java.util.AbstractMap; import java.util.concurrent.Executor; import org.junit.Rule; import org.junit.Test; @@ -41,6 +43,8 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; +import javax.net.ssl.TrustManager; + /** * Unit tests for {@link SslContextProviderSupplier}. */ @@ -88,9 +92,10 @@ public void get_updateSecret() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)).updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); @@ -112,9 +117,11 @@ public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); @@ -166,9 +173,11 @@ public void systemRootCertsWithMtls_callbackExecutedFromProvider() { verify(mockSslContextProvider, times(1)).addCallback(callbackCaptor.capture()); SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); - SslContext mockSslContext = mock(SslContext.class); - capturedCallback.updateSslContext(mockSslContext); - verify(mockCallback, times(1)).updateSslContext(eq(mockSslContext)); + AbstractMap.SimpleImmutableEntry mockSslContextAndTm = + mock(AbstractMap.SimpleImmutableEntry.class); + capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); + verify(mockCallback, times(1)) + .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) .releaseClientSslContextProvider(eq(mockSslContextProvider), eq(SNI)); SslContextProvider.Callback mockCallback = mock(SslContextProvider.Callback.class); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index dc1ab9aae5f..19622d7ac7b 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -142,7 +142,7 @@ public void testProviderForClient_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -150,11 +150,11 @@ public void testProviderForClient_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -183,7 +183,7 @@ public void testProviderForClient_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -209,10 +209,10 @@ public void testProviderForClient_systemRootCerts_regularTls() { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); TestCallback testCallback = CommonTlsContextTestsUtil.getValueThruCallback(provider); - assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContext()); + assertThat(testCallback.updatedSslContext).isEqualTo(provider.getSslContextAndExtendedX509TrustManager()); assertThat(watcherCaptor[0]).isNull(); } @@ -238,11 +238,11 @@ public void testProviderForClient_systemRootCerts_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update, will get ignored because of systemRootCerts config watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -253,7 +253,7 @@ public void testProviderForClient_systemRootCerts_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); TestCallback testCallback = CommonTlsContextTestsUtil.getValueThruCallback(provider); @@ -280,7 +280,7 @@ public void testProviderForClient_systemRootCerts_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -302,7 +302,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -310,11 +310,11 @@ public void testProviderForClient_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -343,7 +343,7 @@ public void testProviderForClient_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -398,11 +398,11 @@ public void testProviderForClient_tls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java index 423829ff5af..6515a97f526 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProviderTest.java @@ -127,7 +127,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -135,11 +135,11 @@ public void testProviderForServer_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -168,7 +168,7 @@ public void testProviderForServer_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -196,7 +196,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -204,11 +204,11 @@ public void testProviderForServer_mtls_newXds() throws Exception { ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); assertThat(provider.savedKey).isNotNull(); assertThat(provider.savedCertChain).isNotNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate root cert update watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); @@ -237,7 +237,7 @@ public void testProviderForServer_mtls_newXds() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); } @@ -294,14 +294,14 @@ public void testProviderForServer_tls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( CommonCertProviderTestUtils.getPrivateKey(SERVER_0_KEY_FILE), ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); - assertThat(provider.getSslContext()).isNotNull(); + assertThat(provider.getSslContextAndExtendedX509TrustManager()).isNotNull(); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); assertThat(provider.savedTrustedRoots).isNull(); From c19a24f33b9c2e93cad9a686ada2b899a564ca90 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Fri, 19 Sep 2025 15:12:17 +0000 Subject: [PATCH 42/51] Trust manager not needed on server side when invoking SslProvider.Callback. --- .../io/grpc/xds/ClusterImplLoadBalancer.java | 3 ++- .../security/DynamicSslContextProvider.java | 3 ++- .../security/SslContextProviderSupplier.java | 3 ++- .../CertProviderServerSslContextProvider.java | 5 ++--- .../security/CommonTlsContextTestsUtil.java | 11 +++++----- .../SecurityProtocolNegotiatorsTest.java | 21 +++++++++++-------- 6 files changed, 26 insertions(+), 20 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index c5491a92bed..45077cc2a57 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -377,7 +377,8 @@ private class RequestLimitingSubchannelPicker extends SubchannelPicker { private final Map filterMetadata; private RequestLimitingSubchannelPicker(SubchannelPicker delegate, - List dropPolicies, long maxConcurrentRequests, + List dropPolicies, + long maxConcurrentRequests, Map filterMetadata) { this.delegate = delegate; this.dropPolicies = dropPolicies; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 0530b18fb14..027b5aebdec 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -99,7 +99,8 @@ protected final void updateSslContext() { } protected final void callPerformCallback( - Callback callback, final AbstractMap.SimpleImmutableEntry sslContextAndTmCopy) { + Callback callback, + final AbstractMap.SimpleImmutableEntry sslContextAndTmCopy) { performCallback( new SslContextGetter() { @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index d0afd29c80b..e85d2c498ba 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -72,7 +72,8 @@ public synchronized void updateSslContext(final SslContextProvider.Callback call new SslContextProvider.Callback(callback.getExecutor()) { @Override - public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContextAndTm) { + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { callback.updateSslContextAndExtendedX509TrustManager(sslContextAndTm); releaseSslContextProvider(toRelease, sni); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index a356712c31b..25136fab9d6 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -75,9 +75,8 @@ protected final AbstractMap.SimpleImmutableEntry sslContext, - List expectedApnProtos) { + public static void doChecksOnSslContext(boolean server, + AbstractMap.SimpleImmutableEntry sslContextAndTm, + List expectedApnProtos) { if (server) { - assertThat(sslContext.getKey().isServer()).isTrue(); + assertThat(sslContextAndTm.getKey().isServer()).isTrue(); } else { - assertThat(sslContext.getKey().isClient()).isTrue(); + assertThat(sslContextAndTm.getKey().isClient()).isTrue(); } - List apnProtos = sslContext.getKey().applicationProtocolNegotiator().protocols(); + List apnProtos = sslContextAndTm.getKey().applicationProtocolNegotiator().protocols(); assertThat(apnProtos).isNotNull(); if (expectedApnProtos != null) { assertThat(apnProtos).isEqualTo(expectedApnProtos); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index aae1a1ccf22..7aa7fd56d2d 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -203,8 +203,9 @@ public void clientSecurityHandler_addLast() sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override @@ -215,7 +216,7 @@ protected void onException(Throwable throwable) { assertThat(executor.runDueTasks()).isEqualTo(1); channel.runPendingTasks(); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); @@ -391,8 +392,9 @@ public SocketAddress remoteAddress() { sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override @@ -403,7 +405,7 @@ protected void onException(Throwable throwable) { channel.runPendingTasks(); // need this for tasks to execute on eventLoop assertThat(executor.runDueTasks()).isEqualTo(1); Object fromFuture = future.get(2, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(SecurityProtocolNegotiators.ServerSecurityHandler.class); assertThat(channelHandlerCtx).isNull(); @@ -529,8 +531,9 @@ public void clientSecurityProtocolNegotiatorNewHandler_fireProtocolNegotiationEv sslContextProviderSupplier .updateSslContext(new SslContextProvider.Callback(MoreExecutors.directExecutor()) { @Override - public void updateSslContextAndExtendedX509TrustManager(AbstractMap.SimpleImmutableEntry sslContext) { - future.set(sslContext); + public void updateSslContextAndExtendedX509TrustManager( + AbstractMap.SimpleImmutableEntry sslContextAndTm) { + future.set(sslContextAndTm); } @Override @@ -541,7 +544,7 @@ protected void onException(Throwable throwable) { executor.runDueTasks(); channel.runPendingTasks(); // need this for tasks to execute on eventLoop Object fromFuture = future.get(5, TimeUnit.SECONDS); - assertThat(fromFuture).isInstanceOf(SslContext.class); + assertThat(fromFuture).isInstanceOf(AbstractMap.SimpleImmutableEntry.class); channel.runPendingTasks(); channelHandlerCtx = pipeline.context(clientSecurityHandler); assertThat(channelHandlerCtx).isNull(); From 5ba39b3ae883d1bda44df74186fe707a9932fe94 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Fri, 19 Sep 2025 15:39:26 +0000 Subject: [PATCH 43/51] Remove whitespace only formatting done by wrong indendation settings in Intellij. --- examples/example-xds/xds-client.Dockerfile | 47 ------------------- examples/example-xds/xds-server.Dockerfile | 47 ------------------- .../testing/integration/XdsTestClient.java | 2 - .../netty/InternalProtocolNegotiators.java | 16 +++---- .../io/grpc/netty/ProtocolNegotiators.java | 41 ++++++++-------- .../io/grpc/xds/EnvoyServerProtoData.java | 12 ++--- .../security/SecurityProtocolNegotiators.java | 7 ++- .../CertProviderClientSslContextProvider.java | 15 +++--- .../trust/XdsTrustManagerFactory.java | 8 ++-- .../security/trust/XdsX509TrustManager.java | 2 +- .../io/grpc/xds/CdsLoadBalancer2Test.java | 28 +++++------ 11 files changed, 63 insertions(+), 162 deletions(-) delete mode 100644 examples/example-xds/xds-client.Dockerfile delete mode 100644 examples/example-xds/xds-server.Dockerfile diff --git a/examples/example-xds/xds-client.Dockerfile b/examples/example-xds/xds-client.Dockerfile deleted file mode 100644 index 0f34d219177..00000000000 --- a/examples/example-xds/xds-client.Dockerfile +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2024 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -# Stage 1: Build XDS client -# - -FROM eclipse-temurin:11-jdk AS build - -WORKDIR /grpc-java/examples -COPY . . - -RUN cd example-xds && ../gradlew installDist -PskipCodegen=true -PskipAndroid=true - -# -# Stage 2: -# -# - Copy only the necessary files to reduce Docker image size. -# - Have an ENTRYPOINT script which will launch the XDS client -# with the given parameters. -# - -FROM eclipse-temurin:11-jre - -WORKDIR /grpc-java/ -COPY --from=build /grpc-java/examples/example-xds/build/install/example-xds/. . - -# Intentionally after the COPY to force the update on each build. -# Update Ubuntu system packages: -RUN apt-get update \ - && apt-get -y upgrade \ - && apt-get -y autoremove \ - && rm -rf /var/lib/apt/lists/* - -# Client -ENTRYPOINT ["bin/xds-hello-world-client"] diff --git a/examples/example-xds/xds-server.Dockerfile b/examples/example-xds/xds-server.Dockerfile deleted file mode 100644 index 542fb0263af..00000000000 --- a/examples/example-xds/xds-server.Dockerfile +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2024 gRPC authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# -# Stage 1: Build XDS server -# - -FROM eclipse-temurin:11-jdk AS build - -WORKDIR /grpc-java/examples -COPY . . - -RUN cd example-xds && ../gradlew installDist -PskipCodegen=true -PskipAndroid=true - -# -# Stage 2: -# -# - Copy only the necessary files to reduce Docker image size. -# - Have an ENTRYPOINT script which will launch the XDS server -# with the given parameters. -# - -FROM eclipse-temurin:11-jre - -WORKDIR /grpc-java/ -COPY --from=build /grpc-java/examples/example-xds/build/install/example-xds/. . - -# Intentionally after the COPY to force the update on each build. -# Update Ubuntu system packages: -RUN apt-get update \ - && apt-get -y upgrade \ - && apt-get -y autoremove \ - && rm -rf /var/lib/apt/lists/* - -# Server -ENTRYPOINT ["bin/xds-hello-world-server"] diff --git a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java index f341836e71d..89519041a79 100644 --- a/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java +++ b/interop-testing/src/main/java/io/grpc/testing/integration/XdsTestClient.java @@ -452,14 +452,12 @@ public void onNext(SimpleResponse response) { private void handleRpcCompleted(long requestId, RpcType rpcType, String hostname, Set watchers) { - logger.info("RPC completed"); statsAccumulator.recordRpcFinished(rpcType, Status.OK); notifyWatchers(watchers, rpcType, requestId, hostname); } private void handleRpcError(long requestId, RpcType rpcType, Status status, Set watchers) { - logger.info("RPC error with status " + status); statsAccumulator.recordRpcFinished(rpcType, status); notifyWatchers(watchers, rpcType, requestId, null); } diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index b62e37d51e6..5c27ab03a89 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -25,7 +25,6 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; - import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; import java.util.concurrent.Executor; @@ -45,10 +44,11 @@ private InternalProtocolNegotiators() {} * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static InternalProtocolNegotiator.ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, - Optional handshakeCompleteRunnable, - TrustManager extendedX509TrustManager, - String sni, boolean isXdsTarget) { + ObjectPool executorPool, + Optional handshakeCompleteRunnable, + TrustManager extendedX509TrustManager, + String sni, + boolean isXdsTarget) { final io.grpc.netty.ProtocolNegotiator negotiator = ProtocolNegotiators.tls(sslContext, executorPool, handshakeCompleteRunnable, (X509TrustManager) extendedX509TrustManager, sni, isXdsTarget); @@ -163,7 +163,7 @@ public void close() { * Internal version of {@link WaitUntilActiveHandler}. */ public static ChannelHandler waitUntilActiveHandler(ChannelHandler next, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { return new WaitUntilActiveHandler(next, negotiationLogger); } @@ -185,7 +185,7 @@ public static class ProtocolNegotiationHandler extends ProtocolNegotiators.ProtocolNegotiationHandler { protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { super(next, negotiatorName, negotiationLogger); } @@ -193,4 +193,4 @@ protected ProtocolNegotiationHandler(ChannelHandler next, ChannelLogger negotiat super(next, negotiationLogger); } } -} \ No newline at end of file +} diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index c1ef4dcfc42..824fb96916f 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -132,7 +132,7 @@ public static FromChannelCredentialsResult from(ChannelCredentials creds) { trustManagers = tlsCreds.getTrustManagers(); } else if (tlsCreds.getRootCertificates() != null) { trustManagers = Arrays.asList(CertificateUtils.createTrustManager( - new ByteArrayInputStream(tlsCreds.getRootCertificates()))); + new ByteArrayInputStream(tlsCreds.getRootCertificates()))); } else { // else use system default TrustManagerFactory tmf = TrustManagerFactory.getInstance( TrustManagerFactory.getDefaultAlgorithm()); @@ -265,7 +265,7 @@ public static final class FromChannelCredentialsResult { public final String error; private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator, - CallCredentials creds, String error) { + CallCredentials creds, String error) { this.negotiator = negotiator; this.callCredentials = creds; this.error = error; @@ -379,7 +379,7 @@ public ProtocolNegotiator newNegotiator(ObjectPool offloadEx * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator serverTls(final SslContext sslContext, - final ObjectPool executorPool) { + final ObjectPool executorPool) { Preconditions.checkNotNull(sslContext, "sslContext"); final Executor executor; if (executorPool != null) { @@ -428,8 +428,8 @@ static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; ServerTlsHandler(ChannelHandler next, - SslContext sslContext, - final ObjectPool executorPool) { + SslContext sslContext, + final ObjectPool executorPool) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.next = Preconditions.checkNotNull(next, "next"); if (executorPool != null) { @@ -486,9 +486,8 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, - final @Nullable String proxyPassword, - final ProtocolNegotiator negotiator) { + final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final ProtocolNegotiator negotiator) { Preconditions.checkNotNull(negotiator, "negotiator"); Preconditions.checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); @@ -565,10 +564,8 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool, - Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, - String sni, boolean isXdsTarget) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { @@ -628,10 +625,10 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private SSLEngine sslEngine; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String sniHostPort, - Executor executor, ChannelLogger negotiationLogger, - Optional handshakeCompleteRunnable, - ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, - X509TrustManager x509ExtendedTrustManager) { + Executor executor, ChannelLogger negotiationLogger, + Optional handshakeCompleteRunnable, + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); if (!Strings.isNullOrEmpty(sniHostPort)) { @@ -753,8 +750,8 @@ static HostPort parseAuthority(String authority) { * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, x509ExtendedTrustManager, sni, isXdsTarget); } @@ -938,7 +935,7 @@ private static RuntimeException unavailableException(String msg) { @VisibleForTesting static void logSslEngineDetails(Level level, ChannelHandlerContext ctx, String msg, - @Nullable Throwable t) { + @Nullable Throwable t) { if (!log.isLoggable(level)) { return; } @@ -1064,8 +1061,8 @@ static final class PlaintextHandler extends ProtocolNegotiationHandler { protected void protocolNegotiationEventTriggered(ChannelHandlerContext ctx) { ProtocolNegotiationEvent existingPne = getProtocolNegotiationEvent(); Attributes attrs = existingPne.getAttributes().toBuilder() - .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) - .build(); + .set(GrpcAttributes.ATTR_AUTHORITY_VERIFIER, (authority) -> Status.OK) + .build(); replaceProtocolNegotiationEvent(existingPne.withAttributes(attrs)); fireProtocolNegotiationEvent(ctx); } @@ -1127,7 +1124,7 @@ static class ProtocolNegotiationHandler extends ChannelDuplexHandler { private final ChannelLogger negotiationLogger; protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { this.next = Preconditions.checkNotNull(next, "next"); this.negotiatorName = negotiatorName; this.negotiationLogger = Preconditions.checkNotNull(negotiationLogger, "negotiationLogger"); diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index 3a30cf0aad2..52ebc698d6d 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -138,7 +138,7 @@ public static DownstreamTlsContext fromEnvoyProtoDownstreamTlsContext( io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.DownstreamTlsContext downstreamTlsContext) { return new DownstreamTlsContext(downstreamTlsContext.getCommonTlsContext(), - downstreamTlsContext.hasRequireClientCertificate()); + downstreamTlsContext.hasRequireClientCertificate()); } public boolean isRequireClientCertificate() { @@ -224,10 +224,10 @@ abstract static class FilterChainMatch { abstract String transportProtocol(); public static FilterChainMatch create(int destinationPort, - ImmutableList prefixRanges, - ImmutableList applicationProtocols, ImmutableList sourcePrefixRanges, - ConnectionSourceType connectionSourceType, ImmutableList sourcePorts, - ImmutableList serverNames, String transportProtocol) { + ImmutableList prefixRanges, + ImmutableList applicationProtocols, ImmutableList sourcePrefixRanges, + ConnectionSourceType connectionSourceType, ImmutableList sourcePorts, + ImmutableList serverNames, String transportProtocol) { return new AutoValue_EnvoyServerProtoData_FilterChainMatch( destinationPort, prefixRanges, applicationProtocols, sourcePrefixRanges, connectionSourceType, sourcePorts, serverNames, transportProtocol); @@ -439,4 +439,4 @@ static FailurePercentageEjection create( enforcementPercentage, minimumHosts, requestVolume); } } -} \ No newline at end of file +} diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 7bedba4e0a0..e698a45296a 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -196,7 +196,6 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { @VisibleForTesting static final class ClientSecurityHandler extends InternalProtocolNegotiators.ProtocolNegotiationHandler { - private final GrpcHttp2ConnectionHandler grpcHandler; private final SslContextProviderSupplier sslContextProviderSupplier; private final String sni; @@ -321,10 +320,10 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc if (evt instanceof ProtocolNegotiationEvent) { ProtocolNegotiationEvent pne = (ProtocolNegotiationEvent)evt; SslContextProviderSupplier sslContextProviderSupplier = InternalProtocolNegotiationEvent - .getAttributes(pne).get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER); + .getAttributes(pne).get(ATTR_SERVER_SSL_CONTEXT_PROVIDER_SUPPLIER); if (sslContextProviderSupplier == null) { logger.log(Level.FINE, "No sslContextProviderSupplier found in filterChainMatch " - + "for connection from {0} to {1}", + + "for connection from {0} to {1}", new Object[]{ctx.channel().remoteAddress(), ctx.channel().localAddress()}); if (fallbackProtocolNegotiator == null) { ctx.fireExceptionCaught(new CertStoreException("No certificate source found!")); @@ -406,7 +405,7 @@ public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } }, - null); + null); } } } \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index a7729ca03af..ecbaaec390f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -49,13 +49,14 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP private final String sniForSanMatching; CertProviderClientSslContextProvider( - Node node, - @Nullable Map certProviders, - CommonTlsContext.CertificateProviderInstance certInstance, - CommonTlsContext.CertificateProviderInstance rootCertInstance, - CertificateValidationContext staticCertValidationContext, - UpstreamTlsContext upstreamTlsContext, - String sniForSanMatching, CertificateProviderStore certificateProviderStore) { + Node node, + @Nullable Map certProviders, + CommonTlsContext.CertificateProviderInstance certInstance, + CommonTlsContext.CertificateProviderInstance rootCertInstance, + CertificateValidationContext staticCertValidationContext, + UpstreamTlsContext upstreamTlsContext, + String sniForSanMatching, + CertificateProviderStore certificateProviderStore) { super( node, certProviders, diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index 7ad24b75d98..79c590f1ff4 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -69,7 +69,7 @@ public XdsTrustManagerFactory( } public XdsTrustManagerFactory(Map> spiffeTrustMap, - CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) throws CertStoreException { + CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) throws CertStoreException { this(spiffeTrustMap, staticCertificateValidationContext, true, sniForSanMatching); } @@ -125,14 +125,14 @@ private static X509Certificate[] getTrustedCaFromCertContext( @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - X509Certificate[] certs, CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { + X509Certificate[] certs, CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { return new XdsX509TrustManager(certContext, createTrustManager(certs), sniForSanMatching); } @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - Map> spiffeTrustMapFile, - CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { + Map> spiffeTrustMapFile, + CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { checkNotNull(spiffeTrustMapFile, "spiffeTrustMapFile"); Map delegates = new HashMap<>(); for (Map.Entry> entry:spiffeTrustMapFile.entrySet()) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index 21cb8b731d6..e8bd4501798 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -79,7 +79,7 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 } XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - Map spiffeTrustMapDelegates, @Nullable String sniForSanMatching) { + Map spiffeTrustMapDelegates, @Nullable String sniForSanMatching) { checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); this.certContext = certContext; diff --git a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java index 771f8c2596d..0d2e198c494 100644 --- a/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java +++ b/xds/src/test/java/io/grpc/xds/CdsLoadBalancer2Test.java @@ -124,7 +124,7 @@ public class CdsLoadBalancer2Test { .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() .setServiceName(EDS_SERVICE_NAME) .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setAds(AggregatedConfigSource.newBuilder()))) .build(); private final FakeClock fakeClock = new FakeClock(); @@ -135,9 +135,9 @@ public class CdsLoadBalancer2Test { Arrays.asList("control-plane.example.com"), serverInfo -> new GrpcXdsTransportFactory.GrpcXdsTransport( InProcessChannelBuilder - .forName(serverInfo.target()) - .directExecutor() - .build()), + .forName(serverInfo.target()) + .directExecutor() + .build()), fakeClock); private final ServerInfo lrsServerInfo = xdsClient.getBootstrapInfo().servers().get(0); private XdsDependencyManager xdsDepManager; @@ -227,16 +227,16 @@ public void discoverTopLevelEdsCluster() { .setName(CLUSTER) .setType(Cluster.DiscoveryType.EDS) .setEdsClusterConfig(Cluster.EdsClusterConfig.newBuilder() - .setServiceName(EDS_SERVICE_NAME) - .setEdsConfig(ConfigSource.newBuilder() - .setAds(AggregatedConfigSource.newBuilder()))) + .setServiceName(EDS_SERVICE_NAME) + .setEdsConfig(ConfigSource.newBuilder() + .setAds(AggregatedConfigSource.newBuilder()))) .setLbPolicy(Cluster.LbPolicy.ROUND_ROBIN) .setLrsServer(ConfigSource.newBuilder() - .setSelf(SelfConfigSource.getDefaultInstance())) + .setSelf(SelfConfigSource.getDefaultInstance())) .setCircuitBreakers(CircuitBreakers.newBuilder() .addThresholds(CircuitBreakers.Thresholds.newBuilder() - .setPriority(RoutingPriority.DEFAULT) - .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) + .setPriority(RoutingPriority.DEFAULT) + .setMaxRequests(UInt32Value.newBuilder().setValue(100)))) .setTransportSocket(TransportSocket.newBuilder() .setName("envoy.transport_sockets.tls") .setTypedConfig(Any.pack(UpstreamTlsContext.newBuilder() @@ -254,10 +254,10 @@ public void discoverTopLevelEdsCluster() { ClusterResolverConfig childLbConfig = (ClusterResolverConfig) childBalancer.config; assertThat(childLbConfig.discoveryMechanism).isEqualTo( DiscoveryMechanism.forEds( - CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, upstreamTlsContext, - Collections.emptyMap(), io.grpc.xds.EnvoyServerProtoData.OutlierDetection.create( - null, null, null, null, SuccessRateEjection.create(null, null, null, null), - FailurePercentageEjection.create(null, null, null, null)))); + CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, upstreamTlsContext, + Collections.emptyMap(), io.grpc.xds.EnvoyServerProtoData.OutlierDetection.create( + null, null, null, null, SuccessRateEjection.create(null, null, null, null), + FailurePercentageEjection.create(null, null, null, null)))); assertThat( GracefulSwitchLoadBalancerAccessor.getChildProvider(childLbConfig.lbConfig).getPolicyName()) .isEqualTo("wrr_locality_experimental"); From f13594373ec4e3107fcc44ac2a256d79472a40fb Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 22 Sep 2025 15:26:48 -0700 Subject: [PATCH 44/51] xds: Plumb system root certs similarly to CertProviders --- .../CertProviderClientSslContextProvider.java | 23 ++-- .../CertProviderSslContextProvider.java | 102 +++++++++++++----- .../SystemRootCertificateProvider.java | 71 ++++++++++++ 3 files changed, 159 insertions(+), 37 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index d7c2267c48f..131ae6b6125 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -55,20 +55,19 @@ protected final SslContextBuilder getSslContextBuilder( CertificateValidationContext certificateValidationContextdationContext) throws CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); - // Null rootCertInstance implies hasSystemRootCerts because of the check in - // CertProviderClientSslContextProviderFactory. - if (rootCertInstance != null) { - if (savedSpiffeTrustMap != null) { - sslContextBuilder = sslContextBuilder.trustManager( + if (savedSpiffeTrustMap != null) { + sslContextBuilder = sslContextBuilder.trustManager( + new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContextdationContext)); + } else if (savedTrustedRoots != null) { + sslContextBuilder = sslContextBuilder.trustManager( new XdsTrustManagerFactory( - savedSpiffeTrustMap, + savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContextdationContext)); - } else { - sslContextBuilder = sslContextBuilder.trustManager( - new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContextdationContext)); - } + } else { + // Should be impossible because of the check in CertProviderClientSslContextProviderFactory + throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); } if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 801dabeecb7..ef9a3cf062c 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -16,14 +16,18 @@ package io.grpc.xds.internal.security.certprovider; +import static java.util.Objects.requireNonNull; + import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; +import io.grpc.Status; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.DynamicSslContextProvider; +import java.io.Closeable; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; @@ -34,15 +38,12 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider implements CertificateProvider.Watcher { - @Nullable private final CertificateProviderStore.Handle certHandle; - @Nullable private final CertificateProviderStore.Handle rootCertHandle; - @Nullable private final CertificateProviderInstance certInstance; - @Nullable protected final CertificateProviderInstance rootCertInstance; + @Nullable private final NoExceptionCloseable certHandle; + @Nullable private final NoExceptionCloseable rootCertHandle; @Nullable protected PrivateKey savedKey; @Nullable protected List savedCertChain; @Nullable protected List savedTrustedRoots; @Nullable protected Map> savedSpiffeTrustMap; - private final boolean isUsingSystemRootCerts; protected CertProviderSslContextProvider( Node node, @@ -53,26 +54,33 @@ protected CertProviderSslContextProvider( BaseTlsContext tlsContext, CertificateProviderStore certificateProviderStore) { super(tlsContext, staticCertValidationContext); - this.certInstance = certInstance; - this.rootCertInstance = rootCertInstance; - String certInstanceName = null; - if (certInstance != null && certInstance.isInitialized()) { - certInstanceName = certInstance.getInstanceName(); + boolean createCertInstance = certInstance != null && certInstance.isInitialized(); + boolean createRootCertInstance = rootCertInstance != null && rootCertInstance.isInitialized(); + boolean sharedCertInstance = createCertInstance && createRootCertInstance + && rootCertInstance.getInstanceName().equals(certInstance.getInstanceName()); + if (createCertInstance) { CertificateProviderInfo certProviderInstanceConfig = - getCertProviderConfig(certProviders, certInstanceName); + getCertProviderConfig(certProviders, certInstance.getInstanceName()); + CertificateProvider.Watcher watcher = this; + if (!sharedCertInstance) { + watcher = new IgnoreUpdatesWatcher(watcher, /* ignoreRootCertUpdates= */ true); + } + // TODO: Previously we'd hang if certProviderInstanceConfig were null or + // certInstance.isInitialized() == false. Now we'll proceed. Those should be errors, or are + // they impossible and should be assertions? certHandle = certProviderInstanceConfig == null ? null : certificateProviderStore.createOrGetProvider( certInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + watcher, + true)::close; } else { certHandle = null; } - if (rootCertInstance != null - && rootCertInstance.isInitialized() - && !rootCertInstance.getInstanceName().equals(certInstanceName)) { + if (createRootCertInstance && sharedCertInstance) { + rootCertHandle = () -> { }; + } else if (createRootCertInstance && !sharedCertInstance) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); rootCertHandle = certProviderInstanceConfig == null ? null @@ -80,13 +88,16 @@ protected CertProviderSslContextProvider( rootCertInstance.getCertificateName(), certProviderInstanceConfig.pluginName(), certProviderInstanceConfig.config(), - this, - true); + new IgnoreUpdatesWatcher(this, /* ignoreRootCertUpdates= */ false), + false)::close; + } else if (rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext())) { + SystemRootCertificateProvider systemRootProvider = new SystemRootCertificateProvider(this); + systemRootProvider.start(); + rootCertHandle = systemRootProvider::close; } else { rootCertHandle = null; } - this.isUsingSystemRootCerts = rootCertInstance == null - && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); } private static CertificateProviderInfo getCertProviderConfig( @@ -150,8 +161,7 @@ public final void updateSpiffeTrustMap(Map> spiffe private void updateSslContextWhenReady() { if (isMtls()) { - if (savedKey != null - && (savedTrustedRoots != null || isUsingSystemRootCerts || savedSpiffeTrustMap != null)) { + if (savedKey != null && (savedTrustedRoots != null || savedSpiffeTrustMap != null)) { updateSslContext(); clearKeysAndCerts(); } @@ -176,15 +186,15 @@ private void clearKeysAndCerts() { } protected final boolean isMtls() { - return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); + return certHandle != null && rootCertHandle != null; } protected final boolean isClientSideTls() { - return rootCertInstance != null && certInstance == null; + return rootCertHandle != null && certHandle == null; } protected final boolean isServerSideTls() { - return certInstance != null && rootCertInstance == null; + return certHandle != null && rootCertHandle == null; } @Override @@ -201,4 +211,46 @@ public final void close() { rootCertHandle.close(); } } + + interface NoExceptionCloseable extends Closeable { + @Override + void close(); + } + + static final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { + private final CertificateProvider.Watcher delegate; + private final boolean ignoreRootCertUpdates; + + public IgnoreUpdatesWatcher( + CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { + this.delegate = requireNonNull(delegate, "delegate"); + this.ignoreRootCertUpdates = ignoreRootCertUpdates; + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + if (ignoreRootCertUpdates) { + delegate.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + if (!ignoreRootCertUpdates) { + delegate.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + if (!ignoreRootCertUpdates) { + delegate.updateSpiffeTrustMap(spiffeTrustMap); + } + } + + @Override + public void onError(Status errorStatus) { + delegate.onError(errorStatus); + } + } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java new file mode 100644 index 00000000000..7c60f714e71 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/SystemRootCertificateProvider.java @@ -0,0 +1,71 @@ +/* + * Copyright 2020 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import io.grpc.Status; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.stream.Collectors; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; + +/** + * An non-registered provider for CertProviderSslContextProvider to use the same code path for + * system root certs as provider-obtained certs. + */ +final class SystemRootCertificateProvider extends CertificateProvider { + public SystemRootCertificateProvider(CertificateProvider.Watcher watcher) { + super(new DistributorWatcher(), false); + getWatcher().addWatcher(watcher); + } + + @Override + public void start() { + try { + TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + + List trustManagers = Arrays.asList(trustManagerFactory.getTrustManagers()); + List rootCerts = trustManagers.stream() + .filter(X509TrustManager.class::isInstance) + .map(X509TrustManager.class::cast) + .map(trustManager -> Arrays.asList(trustManager.getAcceptedIssuers())) + .flatMap(Collection::stream) + .collect(Collectors.toList()); + getWatcher().updateTrustedRoots(rootCerts); + } catch (KeyStoreException | NoSuchAlgorithmException ex) { + getWatcher().onError(Status.UNAVAILABLE + .withDescription("Could not load system root certs") + .withCause(ex)); + } + } + + @Override + public void close() { + // Unnecessary because there's no more callbacks, but do it for good measure + for (Watcher watcher : getWatcher().getDownstreamWatchers()) { + getWatcher().removeWatcher(watcher); + } + } +} From 220e428b62d1337836ea3e725e98e4da41db69f2 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 23 Sep 2025 06:39:53 +0000 Subject: [PATCH 45/51] Save changes --- netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java | 3 +++ 1 file changed, 3 insertions(+) diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index 824fb96916f..aff61ef292d 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -582,6 +582,9 @@ public ClientTlsProtocolNegotiator(SslContext sslContext, private final Optional handshakeCompleteRunnable; private final X509TrustManager x509ExtendedTrustManager; private final String sni; + // For xds targets there may be no SNI determined, and no SNI may be sent in that case. + // Non xds-targets will always use channel authority for SNI. This field is used to handle + // the two cases differently. private final boolean isXdsTarget; private Executor executor; From 0d5eb0a7e2ed8fcdac2e5781b053299aaea15929 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 23 Sep 2025 11:28:46 +0000 Subject: [PATCH 46/51] Fix certs not updated for handshake. --- .../CertProviderSslContextProvider.java | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index ef9a3cf062c..05902115a94 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -40,10 +40,13 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider @Nullable private final NoExceptionCloseable certHandle; @Nullable private final NoExceptionCloseable rootCertHandle; + @Nullable private final CertificateProviderInstance certInstance; + @Nullable protected final CertificateProviderInstance rootCertInstance; @Nullable protected PrivateKey savedKey; @Nullable protected List savedCertChain; @Nullable protected List savedTrustedRoots; @Nullable protected Map> savedSpiffeTrustMap; + private final boolean isUsingSystemRootCerts; protected CertProviderSslContextProvider( Node node, @@ -54,6 +57,10 @@ protected CertProviderSslContextProvider( BaseTlsContext tlsContext, CertificateProviderStore certificateProviderStore) { super(tlsContext, staticCertValidationContext); + this.certInstance = certInstance; + this.rootCertInstance = rootCertInstance; + this.isUsingSystemRootCerts = rootCertInstance == null + && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); boolean createCertInstance = certInstance != null && certInstance.isInitialized(); boolean createRootCertInstance = rootCertInstance != null && rootCertInstance.isInitialized(); boolean sharedCertInstance = createCertInstance && createRootCertInstance @@ -186,15 +193,15 @@ private void clearKeysAndCerts() { } protected final boolean isMtls() { - return certHandle != null && rootCertHandle != null; + return certInstance != null && (rootCertInstance != null || isUsingSystemRootCerts); } protected final boolean isClientSideTls() { - return rootCertHandle != null && certHandle == null; + return rootCertInstance != null && certInstance == null; } protected final boolean isServerSideTls() { - return certHandle != null && rootCertHandle == null; + return certInstance != null && rootCertInstance == null; } @Override From c14a4884ca2873b2ff1647208c4d51df7a792a26 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 23 Sep 2025 13:54:58 +0000 Subject: [PATCH 47/51] More fixes for system root certs. --- .../CertProviderSslContextProvider.java | 54 ++------------- .../certprovider/IgnoreUpdatesWatcher.java | 68 +++++++++++++++++++ .../grpc/xds/XdsSecurityClientServerTest.java | 5 -- .../ClientSslContextProviderFactoryTest.java | 32 +++++---- .../ServerSslContextProviderFactoryTest.java | 16 ++--- ...tProviderClientSslContextProviderTest.java | 22 +----- 6 files changed, 104 insertions(+), 93 deletions(-) create mode 100644 xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 1020932e834..967ed836007 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -16,13 +16,10 @@ package io.grpc.xds.internal.security.certprovider; -import static java.util.Objects.requireNonNull; - import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance; -import io.grpc.Status; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; import io.grpc.xds.internal.security.CommonTlsContextUtil; @@ -69,7 +66,7 @@ protected CertProviderSslContextProvider( CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, certInstance.getInstanceName()); CertificateProvider.Watcher watcher = this; - if (!sharedCertInstance) { + if (!sharedCertInstance && !isUsingSystemRootCerts) { watcher = new IgnoreUpdatesWatcher(watcher, /* ignoreRootCertUpdates= */ true); } // TODO: Previously we'd hang if certProviderInstanceConfig were null or @@ -156,9 +153,6 @@ public final void updateCertificate(PrivateKey key, List certCh @Override public final void updateTrustedRoots(List trustedRoots) { - if (isUsingSystemRootCerts) { - return; - } savedTrustedRoots = trustedRoots; updateSslContextWhenReady(); } @@ -190,7 +184,9 @@ private void updateSslContextWhenReady() { private void clearKeysAndCerts() { savedKey = null; - savedTrustedRoots = null; + if (!isUsingSystemRootCerts) { + savedTrustedRoots = null; + } savedSpiffeTrustMap = null; savedCertChain = null; } @@ -200,10 +196,7 @@ protected final boolean isMtls() { } protected final boolean isRegularTlsAndClientSide() { - // We don't do (rootCertInstance != null || isUsingSystemRootCerts) here because of how this - // method is used. With the rootCertInstance being null when using system root certs, there - // is nothing to update in the SslContext - return rootCertInstance != null && certInstance == null; + return (rootCertInstance != null || isUsingSystemRootCerts) && certInstance == null; } protected final boolean isRegularTlsAndServerSide() { @@ -229,41 +222,4 @@ interface NoExceptionCloseable extends Closeable { @Override void close(); } - - static final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { - private final CertificateProvider.Watcher delegate; - private final boolean ignoreRootCertUpdates; - - public IgnoreUpdatesWatcher( - CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { - this.delegate = requireNonNull(delegate, "delegate"); - this.ignoreRootCertUpdates = ignoreRootCertUpdates; - } - - @Override - public void updateCertificate(PrivateKey key, List certChain) { - if (ignoreRootCertUpdates) { - delegate.updateCertificate(key, certChain); - } - } - - @Override - public void updateTrustedRoots(List trustedRoots) { - if (!ignoreRootCertUpdates) { - delegate.updateTrustedRoots(trustedRoots); - } - } - - @Override - public void updateSpiffeTrustMap(Map> spiffeTrustMap) { - if (!ignoreRootCertUpdates) { - delegate.updateSpiffeTrustMap(spiffeTrustMap); - } - } - - @Override - public void onError(Status errorStatus) { - delegate.onError(errorStatus); - } - } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java new file mode 100644 index 00000000000..cd9d88be41b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/IgnoreUpdatesWatcher.java @@ -0,0 +1,68 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.security.certprovider; + +import static java.util.Objects.requireNonNull; + +import com.google.common.annotations.VisibleForTesting; +import io.grpc.Status; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.List; +import java.util.Map; + +public final class IgnoreUpdatesWatcher implements CertificateProvider.Watcher { + private final CertificateProvider.Watcher delegate; + private final boolean ignoreRootCertUpdates; + + public IgnoreUpdatesWatcher( + CertificateProvider.Watcher delegate, boolean ignoreRootCertUpdates) { + this.delegate = requireNonNull(delegate, "delegate"); + this.ignoreRootCertUpdates = ignoreRootCertUpdates; + } + + @Override + public void updateCertificate(PrivateKey key, List certChain) { + if (ignoreRootCertUpdates) { + delegate.updateCertificate(key, certChain); + } + } + + @Override + public void updateTrustedRoots(List trustedRoots) { + if (!ignoreRootCertUpdates) { + delegate.updateTrustedRoots(trustedRoots); + } + } + + @Override + public void updateSpiffeTrustMap(Map> spiffeTrustMap) { + if (!ignoreRootCertUpdates) { + delegate.updateSpiffeTrustMap(spiffeTrustMap); + } + } + + @Override + public void onError(Status errorStatus) { + delegate.onError(errorStatus); + } + + @VisibleForTesting + public CertificateProvider.Watcher getDelegate() { + return delegate; + } +} diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 03f1651f2c7..e46e440475a 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -282,10 +282,6 @@ public void tlsClientServer_useSystemRootCerts_mtls() throws Exception { } } - /** - * Use system root ca cert for TLS channel - no mTLS. - * Subj Alt Names to match are specified in the validaton context. - */ @Test public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() throws Exception { Path trustStoreFilePath = getCacertFilePathForTestCa(); @@ -317,7 +313,6 @@ public void tlsClientServer_useSystemRootCerts_failureToMatchSubjAltNames() thro /** * Use system root ca cert for TLS channel - mTLS. - * Uses common_tls_context.combined_validation_context in upstream_tls_context. */ @Test public void tlsClientServer_useSystemRootCerts_requireClientAuth() throws Exception { diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 397fe01e0f5..a0eac581d5c 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -37,6 +37,7 @@ import io.grpc.xds.internal.security.certprovider.CertificateProviderProvider; import io.grpc.xds.internal.security.certprovider.CertificateProviderRegistry; import io.grpc.xds.internal.security.certprovider.CertificateProviderStore; +import io.grpc.xds.internal.security.certprovider.IgnoreUpdatesWatcher; import io.grpc.xds.internal.security.certprovider.TestCertificateProvider; import org.junit.Before; import org.junit.Test; @@ -84,7 +85,7 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = clientSslContextProviderFactory.create(upstreamTlsContext); @@ -119,7 +120,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -145,7 +146,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +180,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -209,8 +210,8 @@ public void createCertProviderClientSslContextProvider_2providers() clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -246,8 +247,8 @@ public void createNewCertProviderClientSslContextProvider_withSans() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -280,7 +281,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { clientSslContextProviderFactory.create(upstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } static void createAndRegisterProviderProvider( @@ -310,11 +311,18 @@ public CertificateProvider answer(InvocationOnMock invocation) throws Throwable } static void verifyWatcher( - SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor) { + SslContextProvider sslContextProvider, CertificateProvider.DistributorWatcher watcherCaptor, + boolean usesDelegateWatcher) { assertThat(watcherCaptor).isNotNull(); assertThat(watcherCaptor.getDownstreamWatchers()).hasSize(1); - assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) - .isSameInstanceAs(sslContextProvider); + if (usesDelegateWatcher) { + assertThat(((IgnoreUpdatesWatcher) watcherCaptor.getDownstreamWatchers().iterator().next()) + .getDelegate()) + .isSameInstanceAs(sslContextProvider); + } else { + assertThat(watcherCaptor.getDownstreamWatchers().iterator().next()) + .isSameInstanceAs(sslContextProvider); + } } static CommonTlsContext.Builder addFilenames( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java index cf86b511f1f..7a5a6c00639 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ServerSslContextProviderFactoryTest.java @@ -78,7 +78,7 @@ public void createCertProviderServerSslContextProvider() throws XdsInitializatio serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = serverSslContextProviderFactory.create(downstreamTlsContext); @@ -117,7 +117,7 @@ public void bothPresent_expectCertProviderServerSslContextProvider() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -144,7 +144,7 @@ public void createCertProviderServerSslContextProvider_onlyCertInstance() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); } @Test @@ -179,7 +179,7 @@ public void createCertProviderServerSslContextProvider_withStaticContext() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); + verifyWatcher(sslContextProvider, watcherCaptor[0], false); } @Test @@ -210,8 +210,8 @@ public void createCertProviderServerSslContextProvider_2providers() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } @Test @@ -249,7 +249,7 @@ public void createNewCertProviderServerSslContextProvider_withSans() serverSslContextProviderFactory.create(downstreamTlsContext); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderServerSslContextProvider"); - verifyWatcher(sslContextProvider, watcherCaptor[0]); - verifyWatcher(sslContextProvider, watcherCaptor[1]); + verifyWatcher(sslContextProvider, watcherCaptor[0], true); + verifyWatcher(sslContextProvider, watcherCaptor[1], true); } } diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 3b2ca05231e..84e858a8ab9 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -235,15 +235,8 @@ public void testProviderForClient_systemRootCerts_mtls() throws Exception { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); - assertThat(provider.savedTrustedRoots).isNull(); - assertThat(provider.getSslContext()).isNull(); - - // now generate root cert update, will get ignored because of systemRootCerts config - watcherCaptor[0].updateTrustedRoots(ImmutableList.of(getCertFromResourceName(CA_PEM_FILE))); + assertThat(provider.savedTrustedRoots).isNotNull(); assertThat(provider.getSslContext()).isNull(); - assertThat(provider.savedKey).isNull(); - assertThat(provider.savedCertChain).isNull(); - assertThat(provider.savedTrustedRoots).isNull(); // now generate cert update watcherCaptor[0].updateCertificate( @@ -251,6 +244,7 @@ public void testProviderForClient_systemRootCerts_mtls() throws Exception { ImmutableList.of(getCertFromResourceName(CLIENT_PEM_FILE))); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); assertThat(provider.getSslContext()).isNotNull(); TestCallback testCallback = @@ -261,23 +255,13 @@ public void testProviderForClient_systemRootCerts_mtls() throws Exception { CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); - // just do root cert update: sslContext should still be the same, will get ignored because of - // systemRootCerts config - watcherCaptor[0].updateTrustedRoots( - ImmutableList.of(getCertFromResourceName(SERVER_0_PEM_FILE))); - assertThat(provider.savedKey).isNull(); - assertThat(provider.savedCertChain).isNull(); - assertThat(provider.savedTrustedRoots).isNull(); - testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); - assertThat(testCallback1.updatedSslContext).isSameInstanceAs(testCallback.updatedSslContext); - // now update id cert: sslContext should be updated i.e. different from the previous one watcherCaptor[0].updateCertificate( CommonCertProviderTestUtils.getPrivateKey(SERVER_1_KEY_FILE), ImmutableList.of(getCertFromResourceName(SERVER_1_PEM_FILE))); assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); - assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); assertThat(provider.getSslContext()).isNotNull(); testCallback1 = CommonTlsContextTestsUtil.getValueThruCallback(provider); assertThat(testCallback1.updatedSslContext).isNotSameInstanceAs(testCallback.updatedSslContext); From 733f57c9943d6eaf2815c9b1cb4f378322637a5c Mon Sep 17 00:00:00 2001 From: Kannan J Date: Tue, 23 Sep 2025 14:09:01 +0000 Subject: [PATCH 48/51] More fixes for system root certs. --- .../certprovider/CertProviderClientSslContextProviderTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java index 84e858a8ab9..3c734df3f5a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProviderTest.java @@ -206,7 +206,7 @@ public void testProviderForClient_systemRootCerts_regularTls() { assertThat(provider.savedKey).isNull(); assertThat(provider.savedCertChain).isNull(); - assertThat(provider.savedTrustedRoots).isNull(); + assertThat(provider.savedTrustedRoots).isNotNull(); assertThat(provider.getSslContext()).isNotNull(); TestCallback testCallback = CommonTlsContextTestsUtil.getValueThruCallback(provider); From 967fe8c268c4ac611f90a7a3c9a5a22942534f05 Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 24 Sep 2025 06:33:54 +0000 Subject: [PATCH 49/51] Address review comment to remove reundant if block --- .../security/certprovider/CertProviderSslContextProvider.java | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 967ed836007..b4df876dfe3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -82,9 +82,7 @@ protected CertProviderSslContextProvider( } else { certHandle = null; } - if (createRootCertInstance && sharedCertInstance) { - rootCertHandle = () -> { }; - } else if (createRootCertInstance && !sharedCertInstance) { + if (createRootCertInstance && !sharedCertInstance) { CertificateProviderInfo certProviderInstanceConfig = getCertProviderConfig(certProviders, rootCertInstance.getInstanceName()); rootCertHandle = certProviderInstanceConfig == null ? null From 9a817f8b4d00242cda788d9fea860a8df850368b Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 24 Sep 2025 09:56:57 +0000 Subject: [PATCH 50/51] Merge from system root certs PR. --- .../java/io/grpc/EquivalentAddressGroup.java | 4 ++++ .../io/grpc/xds/ClusterImplLoadBalancer.java | 6 ++--- .../grpc/xds/ClusterResolverLoadBalancer.java | 2 +- .../security/SecurityProtocolNegotiators.java | 7 +----- .../CertProviderClientSslContextProvider.java | 24 ++++--------------- .../grpc/xds/ClusterImplLoadBalancerTest.java | 10 ++++---- .../xds/ClusterResolverLoadBalancerTest.java | 6 ++--- .../grpc/xds/XdsSecurityClientServerTest.java | 3 +-- .../ClientSslContextProviderFactoryTest.java | 16 ++++++------- .../SecurityProtocolNegotiatorsTest.java | 3 ++- .../SslContextProviderSupplierTest.java | 9 ++++--- .../security/TlsContextManagerTest.java | 2 +- 12 files changed, 39 insertions(+), 53 deletions(-) diff --git a/api/src/main/java/io/grpc/EquivalentAddressGroup.java b/api/src/main/java/io/grpc/EquivalentAddressGroup.java index bf8a864902c..8ed6c4ad0e4 100644 --- a/api/src/main/java/io/grpc/EquivalentAddressGroup.java +++ b/api/src/main/java/io/grpc/EquivalentAddressGroup.java @@ -55,6 +55,10 @@ public final class EquivalentAddressGroup { */ public static final Attributes.Key ATTR_LOCALITY_NAME = Attributes.Key.create("io.grpc.EquivalentAddressGroup.LOCALITY"); + /** Name associated with individual address, if available (e.g., DNS name). */ + @Attr + public static final Attributes.Key ATTR_ADDRESS_NAME = + Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); private final List addrs; private final Attributes attrs; diff --git a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java index 489b714cb1a..a33e97b3317 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterImplLoadBalancer.java @@ -241,9 +241,9 @@ public Subchannel createSubchannel(CreateSubchannelArgs args) { .set(ATTR_CLUSTER_LOCALITY, localityAtomicReference); if (GrpcUtil.getFlag("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false)) { String hostname = args.getAddresses().get(0).getAttributes() - .get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME); + .get(EquivalentAddressGroup.ATTR_ADDRESS_NAME); if (hostname != null) { - attrsBuilder.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, hostname); + attrsBuilder.set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, hostname); } } args = args.toBuilder().setAddresses(addresses).setAttributes(attrsBuilder.build()).build(); @@ -439,7 +439,7 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { result = PickResult.withSubchannel(result.getSubchannel(), result.getStreamTracerFactory(), result.getSubchannel().getAttributes().get( - SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)); + EquivalentAddressGroup.ATTR_ADDRESS_NAME)); } } return result; diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index f4030c96382..6f419e59d63 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -195,7 +195,7 @@ StatusOr edsUpdateToResult( .set(XdsAttributes.ATTR_LOCALITY_WEIGHT, localityLbInfo.localityWeight()) .set(XdsAttributes.ATTR_SERVER_WEIGHT, weight) - .set(XdsAttributes.ATTR_ADDRESS_NAME, endpoint.hostname()) + .set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, endpoint.hostname()) .build(); EquivalentAddressGroup eag; if (config.isHttp11ProxyAvailable()) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index e698a45296a..409935569b9 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -57,11 +57,6 @@ @VisibleForTesting public final class SecurityProtocolNegotiators { - /** Name associated with individual address, if available (e.g., DNS name). */ - @EquivalentAddressGroup.Attr - public static final Attributes.Key ATTR_ADDRESS_NAME = - Attributes.Key.create("io.grpc.xds.XdsAttributes.addressName"); - // Prevent instantiation. private SecurityProtocolNegotiators() { } @@ -155,7 +150,7 @@ public ChannelHandler newHandler(GrpcHttp2ConnectionHandler grpcHandler) { return fallbackProtocolNegotiator.newHandler(grpcHandler); } return new ClientSecurityHandler(grpcHandler, localSslContextProviderSupplier, - grpcHandler.getEagAttributes().get(ATTR_ADDRESS_NAME)); + grpcHandler.getEagAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)); } @Override diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index 1e11be3ea44..1605e2e5328 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -22,16 +22,16 @@ import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; +import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; import java.security.cert.CertStoreException; import java.security.cert.X509Certificate; import java.util.AbstractMap; import java.util.Arrays; -import java.util.Collection; -import java.util.List; import java.util.Map; import javax.annotation.Nullable; +import javax.net.ssl.TrustManager; /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { @@ -56,30 +56,13 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP upstreamTlsContext, certificateProviderStore); this.sniForSanMatching = upstreamTlsContext.getAutoSniSanValidation()? sniForSanMatching : null; - if (rootCertInstance == null - && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()) - && !isMtls()) { - try { - // Instantiate sslContext so that addCallback will immediately update the callback with - // the SslContext. - AbstractMap.SimpleImmutableEntry sslContextBuilderAndTm = - getSslContextBuilderAndExtendedX509TrustManager(staticCertificateValidationContext); - sslContextAndExtendedX509TrustManager = new AbstractMap.SimpleImmutableEntry( - sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); - } catch (CertStoreException | CertificateException | IOException e) { - throw new RuntimeException(e); - } - } } @Override - protected final SslContextBuilder getSslContextBuilder( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException { protected final AbstractMap.SimpleImmutableEntry getSslContextBuilderAndExtendedX509TrustManager( CertificateValidationContext certificateValidationContext) - throws CertificateException, IOException, CertStoreException { + throws CertStoreException { SslContextBuilder sslContextBuilder = GrpcSslContexts.forClient(); if (savedSpiffeTrustMap != null) { sslContextBuilder = sslContextBuilder.trustManager( @@ -91,6 +74,7 @@ protected final SslContextBuilder getSslContextBuilder( new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContext, sniForSanMatching)); + } XdsTrustManagerFactory trustManagerFactory; if (rootCertInstance != null) { if (savedSpiffeTrustMap != null) { diff --git a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java index 059c1696a50..7b3d355b79a 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterImplLoadBalancerTest.java @@ -811,10 +811,10 @@ public void endpointAddressesAttachedWithClusterName() { new FixedResultPicker(PickResult.withSubchannel(subchannel))); } }); - assertThat(subchannel.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isEqualTo( + assertThat(subchannel.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isEqualTo( "authority-host-name"); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)) + assertThat(eag.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)) .isEqualTo("authority-host-name"); } @@ -863,9 +863,9 @@ public void endpointAddressesAttachedWithClusterName() { } }); // Sub Channel wrapper args won't have the address name although addresses will. - assertThat(subchannel.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)).isNull(); + assertThat(subchannel.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isNull(); for (EquivalentAddressGroup eag : subchannel.getAllAddresses()) { - assertThat(eag.getAttributes().get(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME)) + assertThat(eag.getAttributes().get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)) .isEqualTo("authority-host-name"); } @@ -1019,7 +1019,7 @@ public String toString() { // Unique but arbitrary string .set(EquivalentAddressGroup.ATTR_LOCALITY_NAME, locality.toString()); if (authorityHostname != null) { - attributes.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, authorityHostname); + attributes.set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, authorityHostname); } EquivalentAddressGroup eag = new EquivalentAddressGroup(new FakeSocketAddress(name), attributes.build()); diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index be68018792b..a5041e4c3cd 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -408,7 +408,7 @@ public void edsClustersEndpointHostname_addedToAddressAttribute() { assertThat( childBalancer.addresses.get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); + .get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isEqualTo("hostname1"); } @Test @@ -897,7 +897,7 @@ public void onlyLogicalDnsCluster_endpointsResolved() { newInetSocketAddress("127.0.2.1", 9000), newInetSocketAddress("127.0.2.2", 9000)))), childBalancer.addresses); assertThat(childBalancer.addresses.get(0).getAttributes() - .get(XdsAttributes.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); + .get(EquivalentAddressGroup.ATTR_ADDRESS_NAME)).isEqualTo(DNS_HOST_NAME + ":9000"); } @Test @@ -995,7 +995,7 @@ public void config_equalsTester() { ServerInfo lrsServerInfo = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); UpstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true); + CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); DiscoveryMechanism edsDiscoveryMechanism1 = DiscoveryMechanism.forEds(CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, tlsContext, Collections.emptyMap(), null); diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index cf9dd16d8be..211ab3cb23c 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -85,7 +85,6 @@ import java.net.Inet4Address; import java.net.InetSocketAddress; import java.net.URI; -import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; import java.security.KeyStore; @@ -840,7 +839,7 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( upstreamTlsContext, tlsContextManagerForClient)) : Attributes.newBuilder(); if (addrNameAttribute != null) { - sslContextAttributesBuilder.set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, addrNameAttribute); + sslContextAttributesBuilder.set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, addrNameAttribute); } sslContextAttributes = sslContextAttributesBuilder.build(); fakeNameResolverFactory.setServers( diff --git a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java index 041d0c05d4f..53b7d1df83a 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/ClientSslContextProviderFactoryTest.java @@ -87,13 +87,13 @@ public void createCertProviderClientSslContextProvider() throws XdsInitializatio new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], false); // verify that bootstrapInfo is cached... sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); } @@ -120,7 +120,7 @@ public void bothPresent_expectCertProviderClientSslContextProvider() new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], true); @@ -146,7 +146,7 @@ public void createCertProviderClientSslContextProvider_onlyRootCert() new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], true); @@ -180,7 +180,7 @@ public void createCertProviderClientSslContextProvider_withStaticContext() new ClientSslContextProviderFactory(bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], true); @@ -210,7 +210,7 @@ public void createCertProviderClientSslContextProvider_2providers() new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], true); @@ -247,7 +247,7 @@ public void createNewCertProviderClientSslContextProvider_withSans() { new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], true); @@ -281,7 +281,7 @@ public void createNewCertProviderClientSslContextProvider_onlyRootCert() { new ClientSslContextProviderFactory( bootstrapInfo, certProviderClientSslContextProviderFactory); SslContextProvider sslContextProvider = - clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI)); + clientSslContextProviderFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI)); assertThat(sslContextProvider.getClass().getSimpleName()).isEqualTo( "CertProviderClientSslContextProvider"); verifyWatcher(sslContextProvider, watcherCaptor[0], true); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java index 7aa7fd56d2d..a8bf66f0fc4 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SecurityProtocolNegotiatorsTest.java @@ -36,6 +36,7 @@ import io.grpc.Attributes; import io.grpc.ChannelLogger; import io.grpc.ChannelLogger.ChannelLogLevel; +import io.grpc.EquivalentAddressGroup; import io.grpc.internal.FakeClock; import io.grpc.internal.TestUtils.NoopChannelLogger; import io.grpc.netty.GrpcHttp2ConnectionHandler; @@ -165,7 +166,7 @@ public void clientSecurityProtocolNegotiatorNewHandler_autoHostSni_hostnameIsPas Attributes.newBuilder() .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, new SslContextProviderSupplier(upstreamTlsContext, mockTlsContextManager)) - .set(SecurityProtocolNegotiators.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) + .set(EquivalentAddressGroup.ATTR_ADDRESS_NAME, FAKE_AUTHORITY) .build()); ChannelHandler newHandler = pn.newHandler(mockHandler); assertThat(newHandler).isNotNull(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index 0e83fbb8118..e2011fdc5cc 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -93,7 +93,8 @@ public void get_updateSecret() { SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); AbstractMap.SimpleImmutableEntry mockSslContextAndTm = - mock(AbstractMap.SimpleImmutableEntry.class); + (AbstractMap.SimpleImmutableEntry) + mock(AbstractMap.SimpleImmutableEntry.class); capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); verify(mockCallback, times(1)).updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); verify(mockTlsContextManager, times(1)) @@ -118,7 +119,8 @@ public void autoHostSniFalse_usesSniFromUpstreamTlsContext() { SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); AbstractMap.SimpleImmutableEntry mockSslContextAndTm = - mock(AbstractMap.SimpleImmutableEntry.class); + (AbstractMap.SimpleImmutableEntry) + mock(AbstractMap.SimpleImmutableEntry.class); capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); verify(mockCallback, times(1)) .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); @@ -174,7 +176,8 @@ public void systemRootCertsWithMtls_callbackExecutedFromProvider() { SslContextProvider.Callback capturedCallback = callbackCaptor.getValue(); assertThat(capturedCallback).isNotNull(); AbstractMap.SimpleImmutableEntry mockSslContextAndTm = - mock(AbstractMap.SimpleImmutableEntry.class); + (AbstractMap.SimpleImmutableEntry) + mock(AbstractMap.SimpleImmutableEntry.class); capturedCallback.updateSslContextAndExtendedX509TrustManager(mockSslContextAndTm); verify(mockCallback, times(1)) .updateSslContextAndExtendedX509TrustManager(eq(mockSslContextAndTm)); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java index 66352aa27bd..3db482690d2 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/TlsContextManagerTest.java @@ -169,7 +169,7 @@ public void createClientSslContextProvider_releaseInstance() { TlsContextManagerImpl tlsContextManagerImpl = new TlsContextManagerImpl(mockClientFactory, mockServerFactory); SslContextProvider mockProvider = mock(SslContextProvider.class); - when(mockClientFactory.create(new AbstractMap.SimpleImmutableEntry(upstreamTlsContext, SNI))) + when(mockClientFactory.create(new AbstractMap.SimpleImmutableEntry<>(upstreamTlsContext, SNI))) .thenReturn(mockProvider); SslContextProvider clientSecretProvider = tlsContextManagerImpl.findOrCreateClientSslContextProvider(upstreamTlsContext, SNI); From 107cbd80ef1fbee951b63bad14ff625bf3c11b9d Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 24 Sep 2025 12:48:29 +0000 Subject: [PATCH 51/51] Some more changes needed after the changes in the base branch. --- .../netty/InternalProtocolNegotiators.java | 7 ++-- .../io/grpc/netty/NettyChannelBuilder.java | 3 +- .../io/grpc/netty/ProtocolNegotiators.java | 27 ++++++++------- .../grpc/xds/ClusterResolverLoadBalancer.java | 1 - .../io/grpc/xds/EnvoyServerProtoData.java | 28 ++++++++-------- .../io/grpc/xds/GcpAuthenticationFilter.java | 5 --- .../ClientSslContextProviderFactory.java | 6 ++-- .../security/DynamicSslContextProvider.java | 11 ++++--- .../security/SecurityProtocolNegotiators.java | 9 +++-- .../internal/security/SslContextProvider.java | 5 +-- .../security/SslContextProviderSupplier.java | 16 +++++---- .../security/TlsContextManagerImpl.java | 9 +++-- .../CertProviderClientSslContextProvider.java | 33 ++++++++----------- .../CertProviderServerSslContextProvider.java | 14 ++++---- .../CertProviderSslContextProvider.java | 14 +++++--- .../security/trust/CertificateUtils.java | 1 - .../trust/XdsTrustManagerFactory.java | 19 +++++++---- .../security/trust/XdsX509TrustManager.java | 7 ++-- .../xds/ClusterResolverLoadBalancerTest.java | 3 +- .../grpc/xds/XdsSecurityClientServerTest.java | 2 +- .../SslContextProviderSupplierTest.java | 4 +-- 21 files changed, 115 insertions(+), 109 deletions(-) diff --git a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java index 5c27ab03a89..ba4fde8d15c 100644 --- a/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/InternalProtocolNegotiators.java @@ -25,9 +25,9 @@ import io.netty.channel.ChannelHandler; import io.netty.handler.ssl.SslContext; import io.netty.util.AsciiString; +import java.util.concurrent.Executor; import javax.net.ssl.TrustManager; import javax.net.ssl.X509TrustManager; -import java.util.concurrent.Executor; /** * Internal accessor for {@link ProtocolNegotiators}. @@ -79,7 +79,8 @@ public void close() { * may happen immediately, even before the TLS Handshake is complete. */ public static InternalProtocolNegotiator.ProtocolNegotiator tls( - SslContext sslContext, String sni, boolean isXdsTarget, TrustManager extendedX509TrustManager) { + SslContext sslContext, String sni, boolean isXdsTarget, + TrustManager extendedX509TrustManager) { return tls(sslContext, null, Optional.absent(), extendedX509TrustManager, sni, isXdsTarget); } @@ -185,7 +186,7 @@ public static class ProtocolNegotiationHandler extends ProtocolNegotiators.ProtocolNegotiationHandler { protected ProtocolNegotiationHandler(ChannelHandler next, String negotiatorName, - ChannelLogger negotiationLogger) { + ChannelLogger negotiationLogger) { super(next, negotiatorName, negotiationLogger); } diff --git a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java index b659dc7d9a9..29bb7be51bc 100644 --- a/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java +++ b/netty/src/main/java/io/grpc/netty/NettyChannelBuilder.java @@ -652,7 +652,8 @@ static ProtocolNegotiator createProtocolNegotiatorByType( case PLAINTEXT_UPGRADE: return ProtocolNegotiators.plaintextUpgrade(); case TLS: - return ProtocolNegotiators.tls(sslContext, executorPool, Optional.absent(), null, null, false); + return ProtocolNegotiators.tls( + sslContext, executorPool, Optional.absent(), null, null, false); default: throw new IllegalArgumentException("Unsupported negotiationType: " + negotiationType); } diff --git a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java index aff61ef292d..216e45f8bf1 100644 --- a/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java +++ b/netty/src/main/java/io/grpc/netty/ProtocolNegotiators.java @@ -90,7 +90,6 @@ import javax.net.ssl.TrustManager; import javax.net.ssl.TrustManagerFactory; import javax.net.ssl.X509TrustManager; - import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** @@ -265,7 +264,7 @@ public static final class FromChannelCredentialsResult { public final String error; private FromChannelCredentialsResult(ProtocolNegotiator.ClientFactory negotiator, - CallCredentials creds, String error) { + CallCredentials creds, String error) { this.negotiator = negotiator; this.callCredentials = creds; this.error = error; @@ -428,8 +427,8 @@ static final class ServerTlsHandler extends ChannelInboundHandlerAdapter { private ProtocolNegotiationEvent pne = ProtocolNegotiationEvent.DEFAULT; ServerTlsHandler(ChannelHandler next, - SslContext sslContext, - final ObjectPool executorPool) { + SslContext sslContext, + final ObjectPool executorPool) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.next = Preconditions.checkNotNull(next, "next"); if (executorPool != null) { @@ -486,8 +485,8 @@ private void fireProtocolNegotiationEvent(ChannelHandlerContext ctx, SSLSession * Returns a {@link ProtocolNegotiator} that does HTTP CONNECT proxy negotiation. */ public static ProtocolNegotiator httpProxy(final SocketAddress proxyAddress, - final @Nullable String proxyUsername, final @Nullable String proxyPassword, - final ProtocolNegotiator negotiator) { + final @Nullable String proxyUsername, final @Nullable String proxyPassword, + final ProtocolNegotiator negotiator) { Preconditions.checkNotNull(negotiator, "negotiator"); Preconditions.checkNotNull(proxyAddress, "proxyAddress"); final AsciiString scheme = negotiator.scheme(); @@ -564,8 +563,8 @@ protected void userEventTriggered0(ChannelHandlerContext ctx, Object evt) throws static final class ClientTlsProtocolNegotiator implements ProtocolNegotiator { public ClientTlsProtocolNegotiator(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); this.executorPool = executorPool; if (this.executorPool != null) { @@ -628,10 +627,10 @@ static final class ClientTlsHandler extends ProtocolNegotiationHandler { private SSLEngine sslEngine; ClientTlsHandler(ChannelHandler next, SslContext sslContext, String sniHostPort, - Executor executor, ChannelLogger negotiationLogger, - Optional handshakeCompleteRunnable, - ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, - X509TrustManager x509ExtendedTrustManager) { + Executor executor, ChannelLogger negotiationLogger, + Optional handshakeCompleteRunnable, + ClientTlsProtocolNegotiator clientTlsProtocolNegotiator, + X509TrustManager x509ExtendedTrustManager) { super(next, negotiationLogger); this.sslContext = Preconditions.checkNotNull(sslContext, "sslContext"); if (!Strings.isNullOrEmpty(sniHostPort)) { @@ -753,8 +752,8 @@ static HostPort parseAuthority(String authority) { * @param executorPool a dedicated {@link Executor} pool for time-consuming TLS tasks */ public static ProtocolNegotiator tls(SslContext sslContext, - ObjectPool executorPool, Optional handshakeCompleteRunnable, - X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { + ObjectPool executorPool, Optional handshakeCompleteRunnable, + X509TrustManager x509ExtendedTrustManager, String sni, boolean isXdsTarget) { return new ClientTlsProtocolNegotiator(sslContext, executorPool, handshakeCompleteRunnable, x509ExtendedTrustManager, sni, isXdsTarget); } diff --git a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java index 6f419e59d63..50b8097fca9 100644 --- a/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java +++ b/xds/src/main/java/io/grpc/xds/ClusterResolverLoadBalancer.java @@ -47,7 +47,6 @@ import io.grpc.xds.client.Locality; import io.grpc.xds.client.XdsLogger; import io.grpc.xds.client.XdsLogger.XdsLogLevel; -import io.grpc.xds.internal.security.SecurityProtocolNegotiators; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.util.ArrayList; diff --git a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java index a89a43f3e90..e6ef72c15b8 100644 --- a/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java +++ b/xds/src/main/java/io/grpc/xds/EnvoyServerProtoData.java @@ -74,23 +74,25 @@ public int hashCode() { public static final class UpstreamTlsContext extends BaseTlsContext { private final String sni; - private final boolean auto_host_sni; - private final boolean auto_sni_san_validation; + private final boolean autoHostSni; + private final boolean autoSniSanValidation; @VisibleForTesting public UpstreamTlsContext(CommonTlsContext commonTlsContext) { super(commonTlsContext); this.sni = null; - this.auto_host_sni = false; - this.auto_sni_san_validation = false; + this.autoHostSni = false; + this.autoSniSanValidation = false; } @VisibleForTesting - public UpstreamTlsContext(io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext upstreamTlsContext) { + public UpstreamTlsContext( + io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext + upstreamTlsContext) { super(upstreamTlsContext.getCommonTlsContext()); this.sni = upstreamTlsContext.getSni(); - this.auto_host_sni = upstreamTlsContext.getAutoHostSni(); - this.auto_sni_san_validation = upstreamTlsContext.getAutoSniSanValidation(); + this.autoHostSni = upstreamTlsContext.getAutoHostSni(); + this.autoSniSanValidation = upstreamTlsContext.getAutoSniSanValidation(); } public static UpstreamTlsContext fromEnvoyProtoUpstreamTlsContext( @@ -105,20 +107,20 @@ public String getSni() { } public boolean getAutoHostSni() { - return auto_host_sni; + return autoHostSni; } public boolean getAutoSniSanValidation() { - return auto_sni_san_validation; + return autoSniSanValidation; } @Override public String toString() { - return "UpstreamTlsContext{" + - "commonTlsContext=" + commonTlsContext + return "UpstreamTlsContext{" + + "commonTlsContext=" + commonTlsContext + "sni=" + sni - + "\nauto_host_sni=" + auto_host_sni - + "\nauto_sni_san_validation=" + auto_sni_san_validation + + "\nauto_host_sni=" + autoHostSni + + "\nauto_sni_san_validation=" + autoSniSanValidation + "}"; } } diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index edd6cd6a190..dc133eaaf1a 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -37,7 +37,6 @@ import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.CompositeCallCredentials; -import io.grpc.InternalLogId; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -46,13 +45,10 @@ import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.XdsConfig.XdsClusterConfig; -import io.grpc.xds.client.XdsLogger; -import io.grpc.xds.client.XdsLogger.XdsLogLevel; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; import java.util.LinkedHashMap; import java.util.Map; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.TimeUnit; import java.util.function.Function; import javax.annotation.Nullable; @@ -65,7 +61,6 @@ final class GcpAuthenticationFilter implements Filter { static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; private final LruCache callCredentialsCache; - private final XdsLogger logger = XdsLogger.withLogId(InternalLogId.allocate("bootstrapper", null)); final String filterInstanceName; GcpAuthenticationFilter(String name, int cacheSize) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java index 4df2a896ef8..a66644786db 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/ClientSslContextProviderFactory.java @@ -24,7 +24,8 @@ /** Factory to create client-side SslContextProvider from UpstreamTlsContext. */ final class ClientSslContextProviderFactory - implements ValueFactory, SslContextProvider> { + implements ValueFactory, + SslContextProvider> { private BootstrapInfo bootstrapInfo; private final CertProviderClientSslContextProviderFactory @@ -42,7 +43,8 @@ final class ClientSslContextProviderFactory /** Creates an SslContextProvider from the given UpstreamTlsContext. */ @Override - public SslContextProvider create(AbstractMap.SimpleImmutableEntry key) { + public SslContextProvider create( + AbstractMap.SimpleImmutableEntry key) { return certProviderClientSslContextProviderFactory.getProvider( key.getKey(), key.getValue(), bootstrapInfo.node().toEnvoyProtoNode(), diff --git a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java index 027b5aebdec..f2d40696c9f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/DynamicSslContextProvider.java @@ -52,7 +52,8 @@ protected DynamicSslContextProvider( } @Nullable - public AbstractMap.SimpleImmutableEntry getSslContextAndExtendedX509TrustManager() { + public AbstractMap.SimpleImmutableEntry + getSslContextAndExtendedX509TrustManager() { return sslContextAndExtendedX509TrustManager; } @@ -61,8 +62,9 @@ public AbstractMap.SimpleImmutableEntry getSslContextA /** * Gets a server or client side SslContextBuilder. */ - protected abstract AbstractMap.SimpleImmutableEntry getSslContextBuilderAndExtendedX509TrustManager( - CertificateValidationContext certificateValidationContext) + protected abstract AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndExtendedX509TrustManager( + CertificateValidationContext certificateValidationContext) throws CertificateException, IOException, CertStoreException; // this gets called only when requested secrets are ready... @@ -84,7 +86,8 @@ protected final void updateSslContext() { sslContextBuilderAndTm.getKey().applicationProtocolConfig(apn); } List pendingCallbacksCopy; - AbstractMap.SimpleImmutableEntry sslContextAndExtendedX09TrustManagerCopy; + AbstractMap.SimpleImmutableEntry + sslContextAndExtendedX09TrustManagerCopy; synchronized (pendingCallbacks) { sslContextAndExtendedX509TrustManager = new AbstractMap.SimpleImmutableEntry<>( sslContextBuilderAndTm.getKey().build(), sslContextBuilderAndTm.getValue()); diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java index 409935569b9..86d28a0554b 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SecurityProtocolNegotiators.java @@ -22,7 +22,6 @@ import com.google.common.base.Strings; import io.grpc.Attributes; import io.grpc.EquivalentAddressGroup; -import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.Grpc; import io.grpc.internal.GrpcUtil; import io.grpc.internal.ObjectPool; @@ -33,6 +32,7 @@ import io.grpc.netty.InternalProtocolNegotiators; import io.grpc.netty.ProtocolNegotiationEvent; import io.grpc.xds.EnvoyServerProtoData; +import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.internal.security.trust.CertificateUtils; import io.netty.channel.ChannelHandler; import io.netty.channel.ChannelHandlerAdapter; @@ -384,8 +384,8 @@ protected void handlerAdded0(final ChannelHandlerContext ctx) { @Override public void updateSslContextAndExtendedX509TrustManager( AbstractMap.SimpleImmutableEntry sslContextAndTm) { - ChannelHandler handler = - InternalProtocolNegotiators.serverTls(sslContextAndTm.getKey()).newHandler(grpcHandler); + ChannelHandler handler = InternalProtocolNegotiators.serverTls( + sslContextAndTm.getKey()).newHandler(grpcHandler); // Delegate rest of handshake to TLS handler if (!ctx.isRemoved()) { @@ -399,8 +399,7 @@ public void updateSslContextAndExtendedX509TrustManager( public void onException(Throwable throwable) { ctx.fireExceptionCaught(throwable); } - }, - null); + }, null); } } } \ No newline at end of file diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java index 31df4b797e7..49cbab2ca03 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProvider.java @@ -29,13 +29,12 @@ import io.netty.handler.ssl.ClientAuth; import io.netty.handler.ssl.SslContext; import io.netty.handler.ssl.SslContextBuilder; - -import javax.net.ssl.TrustManager; import java.io.IOException; import java.security.cert.CertStoreException; import java.security.cert.CertificateException; import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.TrustManager; /** * A SslContextProvider is a "container" or provider of SslContext. This is used by gRPC-xds to @@ -50,11 +49,9 @@ public abstract class SslContextProvider implements Closeable { @VisibleForTesting public abstract static class Callback { private final Executor executor; - private final String hostname; protected Callback(Executor executor) { this.executor = executor; - this.hostname = null; } @VisibleForTesting public Executor getExecutor() { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java index e85d2c498ba..8819874ad00 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/SslContextProviderSupplier.java @@ -16,6 +16,8 @@ package io.grpc.xds.internal.security; +import static com.google.common.base.Preconditions.checkNotNull; + import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext; @@ -23,14 +25,11 @@ import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; - -import javax.net.ssl.TrustManager; import java.util.AbstractMap; import java.util.HashSet; import java.util.Objects; import java.util.Set; - -import static com.google.common.base.Preconditions.checkNotNull; +import javax.net.ssl.TrustManager; /** * Enables Client or server side to initialize this object with the received {@link BaseTlsContext} @@ -58,7 +57,8 @@ public BaseTlsContext getTlsContext() { } /** Updates SslContext via the passed callback. */ - public synchronized void updateSslContext(final SslContextProvider.Callback callback, String sni) { + public synchronized void updateSslContext( + final SslContextProvider.Callback callback, String sni) { checkNotNull(callback, "callback"); try { if (!shutdown) { @@ -106,9 +106,11 @@ private void releaseSslContextProvider(SslContextProvider toRelease, String sni) private SslContextProvider getSslContextProvider(String sni) { if (tlsContext instanceof UpstreamTlsContext) { snisSentByClients.add(sni); - return tlsContextManager.findOrCreateClientSslContextProvider((UpstreamTlsContext) tlsContext, sni); + return tlsContextManager.findOrCreateClientSslContextProvider( + (UpstreamTlsContext) tlsContext, sni); } - return tlsContextManager.findOrCreateServerSslContextProvider((DownstreamTlsContext) tlsContext); + return tlsContextManager.findOrCreateServerSslContextProvider( + (DownstreamTlsContext) tlsContext); } @VisibleForTesting public boolean isShutdown() { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java index 560c28077da..a36eeb7a727 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/TlsContextManagerImpl.java @@ -35,7 +35,8 @@ */ public final class TlsContextManagerImpl implements TlsContextManager { - private final ReferenceCountingMap, SslContextProvider> mapForClients; + private final ReferenceCountingMap, + SslContextProvider> mapForClients; private final ReferenceCountingMap mapForServers; /** @@ -49,7 +50,8 @@ public final class TlsContextManagerImpl implements TlsContextManager { @VisibleForTesting TlsContextManagerImpl( - ValueFactory, SslContextProvider> clientFactory, + ValueFactory, + SslContextProvider> clientFactory, ValueFactory serverFactory) { checkNotNull(clientFactory, "clientFactory"); checkNotNull(serverFactory, "serverFactory"); @@ -80,7 +82,8 @@ public SslContextProvider releaseClientSslContextProvider( SslContextProvider clientSslContextProvider, String sni) { checkNotNull(clientSslContextProvider, "clientSslContextProvider"); return mapForClients.release( - new AbstractMap.SimpleImmutableEntry<>(clientSslContextProvider.getUpstreamTlsContext(), sni), + new AbstractMap.SimpleImmutableEntry<>( + clientSslContextProvider.getUpstreamTlsContext(), sni), clientSslContextProvider); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java index 1605e2e5328..cff29389cf3 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderClientSslContextProvider.java @@ -22,7 +22,6 @@ import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.UpstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; -import io.grpc.xds.internal.security.CommonTlsContextUtil; import io.grpc.xds.internal.security.trust.XdsTrustManagerFactory; import io.netty.handler.ssl.SslContextBuilder; import java.security.cert.CertStoreException; @@ -36,8 +35,6 @@ /** A client SslContext provider using CertificateProviderInstance to fetch secrets. */ final class CertProviderClientSslContextProvider extends CertProviderSslContextProvider { - private final String sniForSanMatching; - CertProviderClientSslContextProvider( Node node, @Nullable Map certProviders, @@ -54,8 +51,8 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP rootCertInstance, staticCertValidationContext, upstreamTlsContext, - certificateProviderStore); - this.sniForSanMatching = upstreamTlsContext.getAutoSniSanValidation()? sniForSanMatching : null; + certificateProviderStore, + upstreamTlsContext.getAutoSniSanValidation() ? sniForSanMatching : null); } @Override @@ -74,24 +71,22 @@ final class CertProviderClientSslContextProvider extends CertProviderSslContextP new XdsTrustManagerFactory( savedTrustedRoots.toArray(new X509Certificate[0]), certificateValidationContext, sniForSanMatching)); - } - XdsTrustManagerFactory trustManagerFactory; - if (rootCertInstance != null) { - if (savedSpiffeTrustMap != null) { - trustManagerFactory = new XdsTrustManagerFactory( - savedSpiffeTrustMap, - certificateValidationContext, sniForSanMatching); - sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); - } else { - trustManagerFactory = new XdsTrustManagerFactory( - savedTrustedRoots.toArray(new X509Certificate[0]), - certificateValidationContext, sniForSanMatching); - sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); - } } else { // Should be impossible because of the check in CertProviderClientSslContextProviderFactory throw new IllegalStateException("There must be trusted roots or a SPIFFE trust map"); } + XdsTrustManagerFactory trustManagerFactory; + if (savedSpiffeTrustMap != null) { + trustManagerFactory = new XdsTrustManagerFactory( + savedSpiffeTrustMap, + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); + } else { + trustManagerFactory = new XdsTrustManagerFactory( + savedTrustedRoots.toArray(new X509Certificate[0]), + certificateValidationContext, sniForSanMatching); + sslContextBuilder = sslContextBuilder.trustManager(trustManagerFactory); + } if (isMtls()) { sslContextBuilder.keyManager(savedKey, savedCertChain); } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java index 25136fab9d6..a6140133f38 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderServerSslContextProvider.java @@ -21,7 +21,6 @@ import io.envoyproxy.envoy.config.core.v3.Node; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext; -import io.grpc.internal.CertificateUtils; import io.grpc.netty.GrpcSslContexts; import io.grpc.xds.EnvoyServerProtoData.DownstreamTlsContext; import io.grpc.xds.client.Bootstrapper.CertificateProviderInfo; @@ -32,7 +31,6 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; import java.util.AbstractMap; -import java.util.Arrays; import java.util.Map; import javax.annotation.Nullable; import javax.net.ssl.TrustManager; @@ -55,13 +53,15 @@ final class CertProviderServerSslContextProvider extends CertProviderSslContextP rootCertInstance, staticCertValidationContext, downstreamTlsContext, - certificateProviderStore); + certificateProviderStore, + null); } @Override - protected final AbstractMap.SimpleImmutableEntry getSslContextBuilderAndExtendedX509TrustManager( - CertificateValidationContext certificateValidationContextdationContext) - throws CertStoreException, CertificateException, IOException { + protected final AbstractMap.SimpleImmutableEntry + getSslContextBuilderAndExtendedX509TrustManager( + CertificateValidationContext certificateValidationContextdationContext) + throws CertStoreException, CertificateException, IOException { SslContextBuilder sslContextBuilder = SslContextBuilder.forServer(savedKey, savedCertChain); XdsTrustManagerFactory trustManagerFactory = null; if (isMtls() && savedSpiffeTrustMap != null) { @@ -76,7 +76,7 @@ protected final AbstractMap.SimpleImmutableEntry(sslContextBuilder, null); } } diff --git a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java index 3cfff478146..883fda4cd4e 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/certprovider/CertProviderSslContextProvider.java @@ -44,6 +44,13 @@ abstract class CertProviderSslContextProvider extends DynamicSslContextProvider @Nullable protected List savedTrustedRoots; @Nullable protected Map> savedSpiffeTrustMap; private final boolean isUsingSystemRootCerts; + // Logically this field belongs in the client cert provider subclass but it had to be kept here + // because of the whole lot of things that happen in this class's constructor when the client + // cert provider is created, and + // CertProviderClientSslContextProvider::getSslContextBuilderAndExtendedX509TrustManager gets + // called before the constructor is complete, and this field needs to be set in order to be + // passed on to the TrustManager constructor. + protected final String sniForSanMatching; protected CertProviderSslContextProvider( Node node, @@ -52,10 +59,11 @@ protected CertProviderSslContextProvider( CertificateProviderInstance rootCertInstance, CertificateValidationContext staticCertValidationContext, BaseTlsContext tlsContext, - CertificateProviderStore certificateProviderStore) { + CertificateProviderStore certificateProviderStore, String sniForSanMatching) { super(tlsContext, staticCertValidationContext); this.certInstance = certInstance; this.rootCertInstance = rootCertInstance; + this.sniForSanMatching = sniForSanMatching; this.isUsingSystemRootCerts = rootCertInstance == null && CommonTlsContextUtil.isUsingSystemRootCerts(tlsContext.getCommonTlsContext()); boolean createCertInstance = certInstance != null && certInstance.isInitialized(); @@ -102,10 +110,6 @@ protected CertProviderSslContextProvider( } } - public boolean isUsingSystemRootCerts() { - return this.isUsingSystemRootCerts; - } - private static CertificateProviderInfo getCertProviderConfig( @Nullable Map certProviders, String pluginInstanceName) { return certProviders != null ? certProviders.get(pluginInstanceName) : null; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java index 9ad871f29ed..89b4abd3029 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/CertificateUtils.java @@ -17,7 +17,6 @@ package io.grpc.xds.internal.security.trust; import io.grpc.internal.GrpcUtil; - import java.io.BufferedInputStream; import java.io.File; import java.io.FileInputStream; diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java index 79c590f1ff4..eea47e1ef9f 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsTrustManagerFactory.java @@ -63,13 +63,14 @@ public XdsTrustManagerFactory(CertificateValidationContext certificateValidation } public XdsTrustManagerFactory( - X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) - throws CertStoreException { + X509Certificate[] certs, CertificateValidationContext staticCertificateValidationContext, + String sniForSanMatching) throws CertStoreException { this(certs, staticCertificateValidationContext, true, sniForSanMatching); } public XdsTrustManagerFactory(Map> spiffeTrustMap, - CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) throws CertStoreException { + CertificateValidationContext staticCertificateValidationContext, String sniForSanMatching) + throws CertStoreException { this(spiffeTrustMap, staticCertificateValidationContext, true, sniForSanMatching); } @@ -85,7 +86,8 @@ private XdsTrustManagerFactory( || certificateValidationContext.hasSystemRootCerts(), "only static certificateValidationContext expected"); } - xdsX509TrustManager = createX509TrustManager(certs, certificateValidationContext, sniForSanMatching); + xdsX509TrustManager = createX509TrustManager( + certs, certificateValidationContext, sniForSanMatching); } private XdsTrustManagerFactory( @@ -98,7 +100,8 @@ private XdsTrustManagerFactory( checkArgument( certificateValidationContext == null || !certificateValidationContext.hasTrustedCa(), "only static certificateValidationContext expected"); - xdsX509TrustManager = createX509TrustManager(spiffeTrustMap, certificateValidationContext, sniForSanMatching); + xdsX509TrustManager = createX509TrustManager( + spiffeTrustMap, certificateValidationContext, sniForSanMatching); } } @@ -125,14 +128,16 @@ private static X509Certificate[] getTrustedCaFromCertContext( @VisibleForTesting static XdsX509TrustManager createX509TrustManager( - X509Certificate[] certs, CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { + X509Certificate[] certs, CertificateValidationContext certContext, String sniForSanMatching) + throws CertStoreException { return new XdsX509TrustManager(certContext, createTrustManager(certs), sniForSanMatching); } @VisibleForTesting static XdsX509TrustManager createX509TrustManager( Map> spiffeTrustMapFile, - CertificateValidationContext certContext, String sniForSanMatching) throws CertStoreException { + CertificateValidationContext certContext, String sniForSanMatching) + throws CertStoreException { checkNotNull(spiffeTrustMapFile, "spiffeTrustMapFile"); Map delegates = new HashMap<>(); for (Map.Entry> entry:spiffeTrustMapFile.entrySet()) { diff --git a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java index e8bd4501798..371aca62414 100644 --- a/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java +++ b/xds/src/main/java/io/grpc/xds/internal/security/trust/XdsX509TrustManager.java @@ -27,7 +27,6 @@ import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext; import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; -import io.grpc.internal.GrpcUtil; import io.grpc.internal.SpiffeUtil; import java.net.Socket; import java.security.cert.CertificateException; @@ -79,7 +78,8 @@ final class XdsX509TrustManager extends X509ExtendedTrustManager implements X509 } XdsX509TrustManager(@Nullable CertificateValidationContext certContext, - Map spiffeTrustMapDelegates, @Nullable String sniForSanMatching) { + Map spiffeTrustMapDelegates, + @Nullable String sniForSanMatching) { checkNotNull(spiffeTrustMapDelegates, "spiffeTrustMapDelegates"); this.spiffeTrustMapDelegates = ImmutableMap.copyOf(spiffeTrustMapDelegates); this.certContext = certContext; @@ -218,7 +218,8 @@ void verifySubjectAltNameInChain(X509Certificate[] peerCertChain) throws Certifi return; } @SuppressWarnings("deprecation") // gRFC A29 predates match_typed_subject_alt_names - List verifyList = CertificateUtils.isXdsSniEnabled && !Strings.isNullOrEmpty(sniForSanMatching) + List verifyList = + CertificateUtils.isXdsSniEnabled && !Strings.isNullOrEmpty(sniForSanMatching) ? ImmutableList.of(StringMatcher.newBuilder().setExact(sniForSanMatching).build()) : certContext.getMatchSubjectAltNamesList(); if (verifyList.isEmpty()) { diff --git a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java index a5041e4c3cd..8b2b43956f8 100644 --- a/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/ClusterResolverLoadBalancerTest.java @@ -995,7 +995,8 @@ public void config_equalsTester() { ServerInfo lrsServerInfo = ServerInfo.create("lrs.googleapis.com", InsecureChannelCredentials.create()); UpstreamTlsContext tlsContext = - CommonTlsContextTestsUtil.buildUpstreamTlsContext("google_cloud_private_spiffe", true, null, false); + CommonTlsContextTestsUtil.buildUpstreamTlsContext( + "google_cloud_private_spiffe", true, null, false); DiscoveryMechanism edsDiscoveryMechanism1 = DiscoveryMechanism.forEds(CLUSTER, EDS_SERVICE_NAME, lrsServerInfo, 100L, tlsContext, Collections.emptyMap(), null); diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index 211ab3cb23c..25fdfb3665f 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -313,7 +313,7 @@ public void tlsClientServer_noAutoSniValidation_failureToMatchSubjAltNames() } @Test - public void tlsClientServer_autoSniValidation_sniInUTC() + public void tlsClientServer_autoSniValidation_sniInUtc() throws Exception { CertificateUtils.isXdsSniEnabled = true; Path trustStoreFilePath = getCacertFilePathForTestCa(); diff --git a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java index e2011fdc5cc..c8d5c2ef519 100644 --- a/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java +++ b/xds/src/test/java/io/grpc/xds/internal/security/SslContextProviderSupplierTest.java @@ -31,9 +31,9 @@ import io.grpc.xds.EnvoyServerProtoData; import io.grpc.xds.TlsContextManager; import io.netty.handler.ssl.SslContext; - import java.util.AbstractMap; import java.util.concurrent.Executor; +import javax.net.ssl.TrustManager; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -43,8 +43,6 @@ import org.mockito.junit.MockitoJUnit; import org.mockito.junit.MockitoRule; -import javax.net.ssl.TrustManager; - /** * Unit tests for {@link SslContextProviderSupplier}. */