Skip to content

Allow StsCredentialsProvider builder to take a ScheduledExecutorService #3260

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.time.Duration;
import java.time.Instant;
import java.util.Optional;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.function.Function;
import software.amazon.awssdk.annotations.NotThreadSafe;
import software.amazon.awssdk.annotations.SdkInternalApi;
Expand Down Expand Up @@ -66,8 +67,11 @@ protected StsCredentialsProvider(BaseBuilder<?, ?> builder, String asyncThreadNa
this.prefetchTime = Optional.ofNullable(builder.prefetchTime).orElse(DEFAULT_PREFETCH_TIME);

CachedSupplier.Builder<SessionCredentialsHolder> cacheBuilder = CachedSupplier.builder(this::updateSessionCredentials);
if (builder.asyncCredentialUpdateEnabled) {
cacheBuilder.prefetchStrategy(new NonBlocking(asyncThreadName));
if (builder.asyncCredentialUpdateEnabled || builder.scheduledThreadPoolExecutor != null) {
CachedSupplier.PrefetchStrategy prefetchStrategy = builder.scheduledThreadPoolExecutor == null ?
new NonBlocking(asyncThreadName) :
new NonBlocking(builder.scheduledThreadPoolExecutor);
cacheBuilder.prefetchStrategy(prefetchStrategy);
}
this.sessionCache = cacheBuilder.build();
}
Expand Down Expand Up @@ -125,6 +129,9 @@ protected abstract static class BaseBuilder<B extends BaseBuilder<B, T>, T> {
private final Function<B, T> providerConstructor;

private Boolean asyncCredentialUpdateEnabled = false;

private ScheduledThreadPoolExecutor scheduledThreadPoolExecutor;

private StsClient stsClient;
private Duration staleTime;
private Duration prefetchTime;
Expand Down Expand Up @@ -159,6 +166,21 @@ public B asyncCredentialUpdateEnabled(Boolean asyncCredentialUpdateEnabled) {
return (B) this;
}

/**
* Configure whether the provider should fetch credentials asynchronously in the background using the provided
* ScheduledThreadPoolExecutor.
*
* <p>This is recommended for advanced uses cases where there are large numbers of credentials provider instances. The
* provider of the ScheduledThreadPoolExecutor instance is responsible for the configuration and lifecycle thereof.</p>
*
* <p>By default this is null, resulting an internally created ScheduledThreadPoolExecutor for each credentials provider
* when async is enabled via asyncCredentialUpdateEnabled</p>
*/
public B scheduledThreadPoolExecutor(ScheduledThreadPoolExecutor scheduledThreadPoolExecutor) {
this.scheduledThreadPoolExecutor = scheduledThreadPoolExecutor;
return (B) this;
}

/**
* Configure the amount of time, relative to STS token expiration, that the cached credentials are considered
* stale and should no longer be used. All threads will block until the value is updated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ public class NonBlocking implements CachedSupplier.PrefetchStrategy {
*/
private final ScheduledExecutorService executor;

/**
* Whether 'executor' is owned (created) by this object
*/
private final boolean ownsExecutor;

/**
* Create a non-blocking prefetch strategy that uses the provided value for the name of the background thread that will be
* performing the update.
Expand All @@ -60,10 +65,27 @@ public NonBlocking(String asyncThreadName) {
this(asyncThreadName, Duration.ofMinutes(1));
}

/**
* Create a non-blocking prefetch strategy that uses the provided ScheduledExecutorService to perform the update.
*/
public NonBlocking(ScheduledExecutorService executor) {
this(Duration.ofMinutes(1), executor, false);
}

@SdkTestInternalApi
NonBlocking(String asyncThreadName, Duration asyncRefreshFrequency) {
this.executor = newExecutor(asyncThreadName);
this(asyncRefreshFrequency, newExecutor(asyncThreadName), true);
}

@SdkTestInternalApi
NonBlocking(Duration asyncRefreshFrequency, ScheduledExecutorService executor) {
this(asyncRefreshFrequency, executor, false);
}

private NonBlocking(Duration asyncRefreshFrequency, ScheduledExecutorService executor, boolean ownsExecutor) {
this.executor = executor;
this.asyncRefreshFrequency = asyncRefreshFrequency;
this.ownsExecutor = ownsExecutor;
}

private static ScheduledExecutorService newExecutor(String asyncThreadName) {
Expand Down Expand Up @@ -111,6 +133,8 @@ public void prefetch(Runnable valueUpdater) {

@Override
public void close() {
executor.shutdown();
if( ownsExecutor ) {
executor.shutdown();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import software.amazon.awssdk.utils.ThreadFactoryBuilder;

/**
* Validate the functionality of {@link CachedSupplier}.
Expand All @@ -48,6 +50,11 @@ public class CachedSupplierTest {
*/
private ExecutorService executorService;

/**
* A scheduled executor for NonBlocking use, in the case where this is externally provided.
*/
private ScheduledExecutorService scheduledExecutorService;

/**
* All executions added to the {@link #executorService} since the beginning of an individual test method.
*/
Expand All @@ -60,6 +67,11 @@ public class CachedSupplierTest {
public void setup() {
executorService = Executors.newFixedThreadPool(50);
allExecutions = new ArrayList<>();
scheduledExecutorService = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder().daemonThreads(true)
.threadNamePrefix("non-blocking"
+ "-creds"
+ "-refresh")
.build());
}

/**
Expand Down Expand Up @@ -258,6 +270,23 @@ public void nonBlockingPrefetchStrategyRefreshesInBackground() {
}
}

@Test
public void nonBlockingPrefetchStrategyRefreshesInBackgroundWithProvidedExecutor() {
try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), past());
CachedSupplier<String> cachedSupplier = CachedSupplier.builder(waitingSupplier)
.prefetchStrategy(new NonBlocking(Duration.ofSeconds(1),
scheduledExecutorService))
.build()) {
waitingSupplier.permits.release(2);
cachedSupplier.get();

// Ensure two "get"s happens even though we only made one call to the cached supplier.
waitingSupplier.waitForGetsToHaveStarted(2);

assertThat(cachedSupplier.get()).isNotNull();
}
}

@Test
public void nonBlockingPrefetchStrategyBackgroundRefreshesHitCache() throws InterruptedException {
try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), future());
Expand All @@ -274,6 +303,22 @@ public void nonBlockingPrefetchStrategyBackgroundRefreshesHitCache() throws Inte
}
}

@Test
public void nonBlockingPrefetchStrategyBackgroundRefreshesHitCacheWithProvidedExecutor() throws InterruptedException {
try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), future());
CachedSupplier<String> cachedSupplier = CachedSupplier.builder(waitingSupplier)
.prefetchStrategy(new NonBlocking(Duration.ofMillis(1),
scheduledExecutorService))
.build()) {
waitingSupplier.permits.release(5);
cachedSupplier.get();

Thread.sleep(1_000);

assertThat(waitingSupplier.permits.availablePermits()).isEqualTo(4); // Only 1 call to supplier
}
}

@Test
public void nonBlockingPrefetchStrategyDoesNotRefreshUntilItIsCalled() throws InterruptedException {
try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), past());
Expand All @@ -289,6 +334,21 @@ public void nonBlockingPrefetchStrategyDoesNotRefreshUntilItIsCalled() throws In
}
}

@Test
public void nonBlockingPrefetchStrategyDoesNotRefreshUntilItIsCalledWithProvidedExecutor() throws InterruptedException {
try (WaitingSupplier waitingSupplier = new WaitingSupplier(future(), past());
CachedSupplier<String> cachedSupplier = CachedSupplier.builder(waitingSupplier)
.prefetchStrategy(new NonBlocking(Duration.ofMillis(1),
scheduledExecutorService))
.build()) {
waitingSupplier.startedGetPermits.release();

Thread.sleep(1_000);

assertThat(waitingSupplier.startedGetPermits.availablePermits()).isEqualTo(1);
}
}

/**
* Asynchronously perform a "get" on the provided supplier, returning the future that will be completed when the "get"
* finishes.
Expand Down