diff --git a/src/main/java/org/dataloader/DataLoaderRegistry.java b/src/main/java/org/dataloader/DataLoaderRegistry.java index 9b19c29..4e9b78f 100644 --- a/src/main/java/org/dataloader/DataLoaderRegistry.java +++ b/src/main/java/org/dataloader/DataLoaderRegistry.java @@ -10,7 +10,9 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; /** @@ -126,6 +128,29 @@ public Set getKeys() { return new HashSet<>(dataLoaders.keySet()); } + /** + * This method will call {@link org.dataloader.DataLoader#dispatch()} on registered {@link org.dataloader.DataLoader}s + * repeatedly until there are no more calls to dispatch. + * @return the promise of total count of dispatched keys. + */ + public CompletableFuture dispatch() { + AtomicInteger count = new AtomicInteger(); + CompletableFuture[] futuresToDispatch = getDataLoaders().stream() + .filter(dataLoader -> dataLoader.dispatchDepth() > 0) + .map(DataLoader::dispatchWithCounts) + .map(dispatchResult -> { + count.addAndGet(dispatchResult.getKeysCount()); + return dispatchResult.getPromisedResults(); + }) + .toArray(CompletableFuture[]::new); + if (futuresToDispatch.length > 0) { + return CompletableFuture.allOf(futuresToDispatch) + .thenCompose(__ -> dispatch()) + .thenApply(count::addAndGet); + } + return CompletableFuture.completedFuture(count.get()); + } + /** * This will called {@link org.dataloader.DataLoader#dispatch()} on each of the registered * {@link org.dataloader.DataLoader}s diff --git a/src/test/java/org/dataloader/DataLoaderRegistryTest.java b/src/test/java/org/dataloader/DataLoaderRegistryTest.java index aeaf668..69f947f 100644 --- a/src/test/java/org/dataloader/DataLoaderRegistryTest.java +++ b/src/test/java/org/dataloader/DataLoaderRegistryTest.java @@ -5,6 +5,7 @@ import org.junit.Test; import java.util.concurrent.CompletableFuture; +import java.util.stream.Collectors; import static java.util.Arrays.asList; import static org.dataloader.DataLoaderFactory.newDataLoader; @@ -15,6 +16,8 @@ public class DataLoaderRegistryTest { final BatchLoader identityBatchLoader = CompletableFuture::completedFuture; + final BatchLoader incrementalBatchLoader = + v -> CompletableFuture.supplyAsync(() -> v.stream().map(i -> ++i).collect(Collectors.toList())); @Test public void registration_works() { @@ -166,6 +169,51 @@ public void dispatch_counts_are_maintained() { assertThat(dispatchDepth, equalTo(0)); } + @Test + public void composed_dispatch_counts_are_maintained() { + + DataLoaderRegistry registry = new DataLoaderRegistry(); + + DataLoader dlA = newDataLoader(incrementalBatchLoader); + DataLoader dlB = newDataLoader(incrementalBatchLoader); + DataLoader dlC = newDataLoader(incrementalBatchLoader); + + registry.register("a", dlA).register("b", dlB).register("c", dlC); + + CompletableFuture test1 = dlA.load(100) + .thenCompose(dlB::load) + .thenCompose(dlC::load); + CompletableFuture test2 = dlC.load(200) + .thenCompose(dlB::load) + .thenCompose(dlA::load); + + assertThat("Initially dispatching only top level load calls", registry.dispatchDepth(), equalTo(2)); + + CompletableFuture dispatchedKeys1 = registry.dispatch(); + + assertThat("Total count of dispatched keys in first iteration", dispatchedKeys1.join(), equalTo(6)); + assertThat("Zero dispatch depth after first iteration done", registry.dispatchDepth(), equalTo(0)); + + CompletableFuture test3 = dlA.load(100) + .thenCompose(dlB::load) + .thenCompose(dlC::load); + CompletableFuture test4 = dlC.load(200) + .thenCompose(dlB::load) + .thenCompose(dlA::load); + + assertThat("Not dispatching the same keys twice", registry.dispatchDepth(), equalTo(0)); + + CompletableFuture dispatchedKeys2 = registry.dispatch(); + + assertThat("Zero dispatched keys in second iteration", dispatchedKeys2.join(), equalTo(0)); + assertThat("Zero dispatch depth after second iteration done", registry.dispatchDepth(), equalTo(0)); + + assertThat(test1.join(), equalTo(103)); + assertThat(test2.join(), equalTo(203)); + assertThat(test3.join(), equalTo(103)); + assertThat(test4.join(), equalTo(203)); + } + @Test public void builder_works() { DataLoader dlA = newDataLoader(identityBatchLoader);