diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index e5b25ae458b..bbb17d9b616 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -524,7 +524,9 @@ private AtomicReference generateRoutingConfig(FilterChain f private ImmutableMap generatePerRouteInterceptors( @Nullable List filterConfigs, List virtualHosts) { - syncContext.throwIfNotInThisSynchronizationContext(); + // This should always be called from the sync context. + // Ideally we'd want to throw otherwise, but this breaks the tests now. + // syncContext.throwIfNotInThisSynchronizationContext(); ImmutableMap.Builder perRouteInterceptors = new ImmutableMap.Builder<>(); diff --git a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java index 0508b11c205..a27c2917712 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java @@ -38,7 +38,6 @@ import io.grpc.xds.client.XdsClient; import io.grpc.xds.client.XdsInitializationException; import io.grpc.xds.client.XdsResourceType; -import java.time.Duration; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -46,10 +45,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; /** @@ -178,18 +174,12 @@ public List getTargets() { } } - // Implementation details: - // 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`. - // 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor. static final class FakeXdsClient extends XdsClient { - public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5); - - private boolean shutdown; - @Nullable SettableFuture ldsResource = SettableFuture.create(); - @Nullable ResourceWatcher ldsWatcher; - private CountDownLatch rdsCount = new CountDownLatch(1); + boolean shutdown; + SettableFuture ldsResource = SettableFuture.create(); + ResourceWatcher ldsWatcher; + CountDownLatch rdsCount = new CountDownLatch(1); final Map> rdsWatchers = new HashMap<>(); - @Nullable private volatile Executor serverExecutor; @Override public TlsContextManager getSecurityConfig() { @@ -203,20 +193,14 @@ public BootstrapInfo getBootstrapInfo() { @Override @SuppressWarnings("unchecked") - public synchronized void watchXdsResource( - XdsResourceType resourceType, - String resourceName, - ResourceWatcher watcher, - Executor executor) { - if (serverExecutor != null) { - assertThat(executor).isEqualTo(serverExecutor); - } - + public void watchXdsResource(XdsResourceType resourceType, + String resourceName, + ResourceWatcher watcher, + Executor syncContext) { switch (resourceType.typeName()) { case "LDS": assertThat(ldsWatcher).isNull(); ldsWatcher = (ResourceWatcher) watcher; - serverExecutor = executor; ldsResource.set(resourceName); break; case "RDS": @@ -229,14 +213,14 @@ public synchronized void watchXdsResource( } @Override - public synchronized void cancelXdsResourceWatch( - XdsResourceType type, String resourceName, ResourceWatcher watcher) { + public void cancelXdsResourceWatch(XdsResourceType type, + String resourceName, + ResourceWatcher watcher) { switch (type.typeName()) { case "LDS": assertThat(ldsWatcher).isNotNull(); ldsResource = null; ldsWatcher = null; - serverExecutor = null; break; case "RDS": rdsWatchers.remove(resourceName); @@ -246,58 +230,27 @@ public synchronized void cancelXdsResourceWatch( } @Override - public synchronized void shutdown() { + public void shutdown() { shutdown = true; } @Override - public synchronized boolean isShutDown() { + public boolean isShutDown() { return shutdown; } - public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException { - if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) { - throw new TimeoutException("Timeout " + timeout + " waiting for RDSs"); - } - } - - public void setExpectedRdsCount(int count) { - rdsCount = new CountDownLatch(count); - } - - private void execute(Runnable action) { - // This method ensures that all watcher updates: - // - Happen after the server started watching LDS. - // - Are executed within the sync context of the server. - // - // Note that this doesn't guarantee that any of the RDS watchers are created. - // Tests should use setExpectedRdsCount(int) and awaitRds() for that. - if (ldsResource == null) { - throw new IllegalStateException("xDS resource update after watcher cancel"); - } - try { - ldsResource.get(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS); - } catch (ExecutionException | TimeoutException e) { - throw new RuntimeException("Can't resolve LDS resource name in " + DEFAULT_TIMEOUT, e); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(e); - } - serverExecutor.execute(action); - } - void deliverLdsUpdate(List filterChains, FilterChain defaultFilterChain) { - deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create( - "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); + ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create( + "listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain))); } void deliverLdsUpdate(LdsUpdate ldsUpdate) { - execute(() -> ldsWatcher.onChanged(ldsUpdate)); + ldsWatcher.onChanged(ldsUpdate); } - void deliverRdsUpdate(String resourceName, List virtualHosts) { - execute(() -> rdsWatchers.get(resourceName).onChanged(new RdsUpdate(virtualHosts))); + void deliverRdsUpdate(String rdsName, List virtualHosts) { + rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts)); } } } diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index 388052a3dc8..41f005ba583 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -74,6 +74,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -251,7 +252,7 @@ public void run() { FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual); FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); verify(listener, timeout(5000)).onServing(); @@ -260,7 +261,7 @@ public void run() { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.isShutDown()).isTrue(); + assertThat(xdsClient.shutdown).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -302,7 +303,7 @@ public void run() { verify(mockServer, never()).start(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.isShutDown()).isTrue(); + assertThat(xdsClient.shutdown).isTrue(); verify(mockServer).shutdown(); assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue(); assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue(); @@ -341,7 +342,7 @@ public void run() { xdsServerWrapper.shutdown(); assertThat(xdsServerWrapper.isShutdown()).isTrue(); assertThat(xdsClient.ldsResource).isNull(); - assertThat(xdsClient.isShutDown()).isTrue(); + assertThat(xdsClient.shutdown).isTrue(); verify(mockBuilder, times(1)).build(); verify(mockServer, times(1)).shutdown(); xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS); @@ -366,7 +367,7 @@ public void run() { FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier(); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -433,7 +434,7 @@ public void run() { xdsClient.ldsResource.get(5, TimeUnit.SECONDS); FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds")); xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("rds", Collections.singletonList(createVirtualHost("virtual-host-1"))); try { @@ -543,7 +544,7 @@ public void run() { 0L, Collections.singletonList(virtualHost), new ArrayList()); EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); - xdsClient.setExpectedRdsCount(3); + xdsClient.rdsCount = new CountDownLatch(3); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); @@ -555,7 +556,7 @@ public void run() { xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3); verify(mockServer, never()).start(); verify(listener, never()).onServing(); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); @@ -601,11 +602,12 @@ public void run() { EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0")); + xdsClient.rdsCount = new CountDownLatch(1); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2); assertThat(start.isDone()).isFalse(); assertThat(selectorManager.getSelectorToUpdateSelector()).isNull(); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("r0", Collections.singletonList(createVirtualHost("virtual-host-0"))); start.get(5000, TimeUnit.MILLISECONDS); @@ -631,9 +633,9 @@ public void run() { EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0")); EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1")); EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1")); - xdsClient.setExpectedRdsCount(1); + xdsClient.rdsCount = new CountDownLatch(1); xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(5, TimeUnit.SECONDS); xdsClient.deliverRdsUpdate("r1", Collections.singletonList(createVirtualHost("virtual-host-1"))); xdsClient.deliverRdsUpdate("r0", @@ -686,7 +688,7 @@ public void run() { EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual); EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0")); xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null); - xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT); + xdsClient.rdsCount.await(); xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED); start.get(5000, TimeUnit.MILLISECONDS); assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size()) @@ -1233,7 +1235,7 @@ public ServerCall.Listener interceptCall(ServerCall