diff --git a/services/sts/src/main/java/software/amazon/awssdk/services/sts/auth/StsCredentialsProvider.java b/services/sts/src/main/java/software/amazon/awssdk/services/sts/auth/StsCredentialsProvider.java index 4ca191cd74aa..d72d83e91175 100644 --- a/services/sts/src/main/java/software/amazon/awssdk/services/sts/auth/StsCredentialsProvider.java +++ b/services/sts/src/main/java/software/amazon/awssdk/services/sts/auth/StsCredentialsProvider.java @@ -18,6 +18,7 @@ import java.time.Duration; import java.time.Instant; import java.util.Optional; +import java.util.concurrent.ScheduledThreadPoolExecutor; import java.util.function.Function; import software.amazon.awssdk.annotations.NotThreadSafe; import software.amazon.awssdk.annotations.SdkInternalApi; @@ -66,8 +67,11 @@ protected StsCredentialsProvider(BaseBuilder builder, String asyncThreadNa this.prefetchTime = Optional.ofNullable(builder.prefetchTime).orElse(DEFAULT_PREFETCH_TIME); CachedSupplier.Builder cacheBuilder = CachedSupplier.builder(this::updateSessionCredentials); - if (builder.asyncCredentialUpdateEnabled) { - cacheBuilder.prefetchStrategy(new NonBlocking(asyncThreadName)); + if (builder.asyncCredentialUpdateEnabled || builder.scheduledThreadPoolExecutor != null) { + CachedSupplier.PrefetchStrategy prefetchStrategy = builder.scheduledThreadPoolExecutor == null ? + new NonBlocking(asyncThreadName) : + new NonBlocking(builder.scheduledThreadPoolExecutor); + cacheBuilder.prefetchStrategy(prefetchStrategy); } this.sessionCache = cacheBuilder.build(); } @@ -125,6 +129,9 @@ protected abstract static class BaseBuilder, T> { private final Function providerConstructor; private Boolean asyncCredentialUpdateEnabled = false; + + private ScheduledThreadPoolExecutor scheduledThreadPoolExecutor; + private StsClient stsClient; private Duration staleTime; private Duration prefetchTime; @@ -159,6 +166,21 @@ public B asyncCredentialUpdateEnabled(Boolean asyncCredentialUpdateEnabled) { return (B) this; } + /** + * Configure whether the provider should fetch credentials asynchronously in the background using the provided + * ScheduledThreadPoolExecutor. + * + *

This is recommended for advanced uses cases where there are large numbers of credentials provider instances. The + * provider of the ScheduledThreadPoolExecutor instance is responsible for the configuration and lifecycle thereof.

+ * + *

By default this is null, resulting an internally created ScheduledThreadPoolExecutor for each credentials provider + * when async is enabled via asyncCredentialUpdateEnabled

+ */ + public B scheduledThreadPoolExecutor(ScheduledThreadPoolExecutor scheduledThreadPoolExecutor) { + this.scheduledThreadPoolExecutor = scheduledThreadPoolExecutor; + return (B) this; + } + /** * Configure the amount of time, relative to STS token expiration, that the cached credentials are considered * stale and should no longer be used. All threads will block until the value is updated. diff --git a/utils/src/main/java/software/amazon/awssdk/utils/cache/NonBlocking.java b/utils/src/main/java/software/amazon/awssdk/utils/cache/NonBlocking.java index ecf85be00041..82c0247c8db6 100644 --- a/utils/src/main/java/software/amazon/awssdk/utils/cache/NonBlocking.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/cache/NonBlocking.java @@ -52,6 +52,11 @@ public class NonBlocking implements CachedSupplier.PrefetchStrategy { */ private final ScheduledExecutorService executor; + /** + * Whether 'executor' is owned (created) by this object + */ + private final boolean ownsExecutor; + /** * Create a non-blocking prefetch strategy that uses the provided value for the name of the background thread that will be * performing the update. @@ -60,10 +65,27 @@ public NonBlocking(String asyncThreadName) { this(asyncThreadName, Duration.ofMinutes(1)); } + /** + * Create a non-blocking prefetch strategy that uses the provided ScheduledExecutorService to perform the update. + */ + public NonBlocking(ScheduledExecutorService executor) { + this(Duration.ofMinutes(1), executor, false); + } + @SdkTestInternalApi NonBlocking(String asyncThreadName, Duration asyncRefreshFrequency) { - this.executor = newExecutor(asyncThreadName); + this(asyncRefreshFrequency, newExecutor(asyncThreadName), true); + } + + @SdkTestInternalApi + NonBlocking(Duration asyncRefreshFrequency, ScheduledExecutorService executor) { + this(asyncRefreshFrequency, executor, false); + } + + private NonBlocking(Duration asyncRefreshFrequency, ScheduledExecutorService executor, boolean ownsExecutor) { + this.executor = executor; this.asyncRefreshFrequency = asyncRefreshFrequency; + this.ownsExecutor = ownsExecutor; } private static ScheduledExecutorService newExecutor(String asyncThreadName) { @@ -111,6 +133,8 @@ public void prefetch(Runnable valueUpdater) { @Override public void close() { - executor.shutdown(); + if( ownsExecutor ) { + executor.shutdown(); + } } } diff --git a/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java b/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java index ba5153cb6696..e5ce79a30dc9 100644 --- a/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java +++ b/utils/src/test/java/software/amazon/awssdk/utils/cache/CachedSupplierTest.java @@ -30,12 +30,14 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import software.amazon.awssdk.utils.ThreadFactoryBuilder; /** * Validate the functionality of {@link CachedSupplier}. @@ -48,6 +50,11 @@ public class CachedSupplierTest { */ private ExecutorService executorService; + /** + * A scheduled executor for NonBlocking use, in the case where this is externally provided. + */ + private ScheduledExecutorService scheduledExecutorService; + /** * All executions added to the {@link #executorService} since the beginning of an individual test method. */ @@ -60,6 +67,11 @@ public class CachedSupplierTest { public void setup() { executorService = Executors.newFixedThreadPool(50); allExecutions = new ArrayList<>(); + scheduledExecutorService = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder().daemonThreads(true) + .threadNamePrefix("non-blocking" + + "-creds" + + "-refresh") + .build()); } /** @@ -258,6 +270,23 @@ public void nonBlockingPrefetchStrategyRefreshesInBackground() { } } + @Test + public void nonBlockingPrefetchStrategyRefreshesInBackgroundWithProvidedExecutor() { + try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), past()); + CachedSupplier cachedSupplier = CachedSupplier.builder(waitingSupplier) + .prefetchStrategy(new NonBlocking(Duration.ofSeconds(1), + scheduledExecutorService)) + .build()) { + waitingSupplier.permits.release(2); + cachedSupplier.get(); + + // Ensure two "get"s happens even though we only made one call to the cached supplier. + waitingSupplier.waitForGetsToHaveStarted(2); + + assertThat(cachedSupplier.get()).isNotNull(); + } + } + @Test public void nonBlockingPrefetchStrategyBackgroundRefreshesHitCache() throws InterruptedException { try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), future()); @@ -274,6 +303,22 @@ public void nonBlockingPrefetchStrategyBackgroundRefreshesHitCache() throws Inte } } + @Test + public void nonBlockingPrefetchStrategyBackgroundRefreshesHitCacheWithProvidedExecutor() throws InterruptedException { + try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), future()); + CachedSupplier cachedSupplier = CachedSupplier.builder(waitingSupplier) + .prefetchStrategy(new NonBlocking(Duration.ofMillis(1), + scheduledExecutorService)) + .build()) { + waitingSupplier.permits.release(5); + cachedSupplier.get(); + + Thread.sleep(1_000); + + assertThat(waitingSupplier.permits.availablePermits()).isEqualTo(4); // Only 1 call to supplier + } + } + @Test public void nonBlockingPrefetchStrategyDoesNotRefreshUntilItIsCalled() throws InterruptedException { try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), past()); @@ -289,6 +334,21 @@ public void nonBlockingPrefetchStrategyDoesNotRefreshUntilItIsCalled() throws In } } + @Test + public void nonBlockingPrefetchStrategyDoesNotRefreshUntilItIsCalledWithProvidedExecutor() throws InterruptedException { + try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), past()); + CachedSupplier cachedSupplier = CachedSupplier.builder(waitingSupplier) + .prefetchStrategy(new NonBlocking(Duration.ofMillis(1), + scheduledExecutorService)) + .build()) { + waitingSupplier.startedGetPermits.release(); + + Thread.sleep(1_000); + + assertThat(waitingSupplier.startedGetPermits.availablePermits()).isEqualTo(1); + } + } + /** * Asynchronously perform a "get" on the provided supplier, returning the future that will be completed when the "get" * finishes.