diff --git a/.changes/next-release/feature-S3-12a0b55.json b/.changes/next-release/feature-S3-12a0b55.json new file mode 100644 index 000000000000..0be1c7b5691e --- /dev/null +++ b/.changes/next-release/feature-S3-12a0b55.json @@ -0,0 +1,6 @@ +{ + "type": "feature", + "category": "S3", + "contributor": "", + "description": "Add support for parallel download for individual part-get for multipart GetObject in s3 async client and Transfer Manager" +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java index a7abf157a628..23e369ea7e03 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/AsyncResponseTransformer.java @@ -381,6 +381,15 @@ interface SplitResult */ CompletableFuture resultFuture(); + /** + * Indicates if the split async response transformer supports sending individual transformer non-serially, as well as + * receiving back data from the many {@link AsyncResponseTransformer#onStream(SdkPublisher) publishers} non-serially. + * @return true if non-serial data is supported, false otherwise + */ + default Boolean parallelSplitSupported() { + return false; + } + static Builder builder() { return DefaultAsyncResponseTransformerSplitResult.builder(); } @@ -413,6 +422,20 @@ interface Builder * @return an instance of this Builder */ Builder resultFuture(CompletableFuture future); + + /** + * If the AsyncResponseTransformers returned by the {@link SplitResult#publisher()} support concurrent + * parallel streaming of multiple content body concurrently. + * @return + */ + Boolean parallelSplitSupported(); + + /** + * Sets whether the AsyncResponseTransformers returned by the {@link SplitResult#publisher()} support concurrent + * parallel streaming of multiple content body concurrently + * @return + */ + Builder parallelSplitSupported(Boolean parallelSplitSupported); } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncResponseTransformerListener.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncResponseTransformerListener.java index c7ee37690ca0..f1612e1b1f3e 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncResponseTransformerListener.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/async/listener/AsyncResponseTransformerListener.java @@ -20,6 +20,7 @@ import org.reactivestreams.Subscriber; import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.annotations.SdkProtectedApi; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.utils.Logger; @@ -108,6 +109,11 @@ public String name() { return delegate.name(); } + @Override + public SplitResult split(SplittingTransformerConfiguration splitConfig) { + return delegate.split(splitConfig); + } + static void invoke(Runnable runnable, String callbackName) { try { runnable.run(); diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java index ed64b1d8eae4..0b095a34e973 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/DefaultAsyncResponseTransformerSplitResult.java @@ -27,12 +27,14 @@ public final class DefaultAsyncResponseTransformerSplitResult> publisher; private final CompletableFuture future; + private final Boolean parallelSplitSupported; private DefaultAsyncResponseTransformerSplitResult(Builder builder) { this.publisher = Validate.paramNotNull( builder.publisher(), "asyncResponseTransformerPublisher"); this.future = Validate.paramNotNull( builder.resultFuture(), "future"); + this.parallelSplitSupported = Validate.getOrDefault(builder.parallelSplitSupported(), () -> false); } /** @@ -52,6 +54,11 @@ public CompletableFuture resultFuture() { return this.future; } + @Override + public Boolean parallelSplitSupported() { + return this.parallelSplitSupported; + } + @Override public AsyncResponseTransformer.SplitResult.Builder toBuilder() { return new DefaultBuilder<>(this); @@ -65,6 +72,7 @@ public static class DefaultBuilder implements AsyncResponseTransformer.SplitResult.Builder { private SdkPublisher> publisher; private CompletableFuture future; + private Boolean parallelSplitSupported; DefaultBuilder() { } @@ -72,6 +80,7 @@ public static class DefaultBuilder DefaultBuilder(DefaultAsyncResponseTransformerSplitResult split) { this.publisher = split.publisher; this.future = split.future; + this.parallelSplitSupported = split.parallelSplitSupported; } @Override @@ -92,14 +101,28 @@ public CompletableFuture resultFuture() { } @Override - public AsyncResponseTransformer.SplitResult.Builder resultFuture(CompletableFuture future) { + public AsyncResponseTransformer.SplitResult.Builder resultFuture( + CompletableFuture future) { this.future = future; return this; } + @Override + public Boolean parallelSplitSupported() { + return parallelSplitSupported; + } + + @Override + public AsyncResponseTransformer.SplitResult.Builder parallelSplitSupported( + Boolean parallelSplitSupported) { + this.parallelSplitSupported = parallelSplitSupported; + return this; + } + @Override public AsyncResponseTransformer.SplitResult build() { return new DefaultAsyncResponseTransformerSplitResult<>(this); } + } } diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/EmittingSubscription.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/EmittingSubscription.java new file mode 100644 index 000000000000..25f1a78705ca --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/EmittingSubscription.java @@ -0,0 +1,144 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.function.Supplier; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.annotations.ThreadSafe; +import software.amazon.awssdk.utils.Logger; + +/** + * Subscription which can emit {@link Subscriber#onNext(T)} signals to a subscriber, based on the demand received with the + * {@link Subscription#request(long)}. It tracks the outstandingDemand that has not yet been fulfilled and used a Supplier + * passed to it to create the object it needs to emit. + * @param the type of object to emit to the subscriber. + */ +@SdkInternalApi +@ThreadSafe +public final class EmittingSubscription implements Subscription { + private static final Logger log = Logger.loggerFor(EmittingSubscription.class); + + private Subscriber downstreamSubscriber; + private final AtomicBoolean emitting; + private final AtomicLong outstandingDemand; + private final Runnable onCancel; + private final AtomicBoolean isCancelled; + private final Supplier supplier; + + private EmittingSubscription(Builder builder) { + this.downstreamSubscriber = builder.downstreamSubscriber; + this.onCancel = builder.onCancel; + this.supplier = builder.supplier; + this.isCancelled = new AtomicBoolean(); + this.outstandingDemand = new AtomicLong(0); + this.emitting = new AtomicBoolean(); + } + + public static Builder builder() { + return new Builder<>(); + } + + @Override + public void request(long n) { + if (n <= 0) { + downstreamSubscriber.onError(new IllegalArgumentException("Amount requested must be positive")); + return; + } + long newDemand = outstandingDemand.updateAndGet(current -> { + if (Long.MAX_VALUE - current < n) { + return Long.MAX_VALUE; + } + return current + n; + }); + log.trace(() -> String.format("new outstanding demand: %s", newDemand)); + emit(); + } + + @Override + public void cancel() { + isCancelled.set(true); + downstreamSubscriber = null; + onCancel.run(); + } + + private void emit() { + do { + if (!emitting.compareAndSet(false, true)) { + return; + } + try { + if (doEmit()) { + return; + } + } finally { + emitting.compareAndSet(true, false); + } + } while (outstandingDemand.get() > 0); + } + + private boolean doEmit() { + long demand = outstandingDemand.get(); + + while (demand > 0) { + if (isCancelled.get()) { + return true; + } + if (outstandingDemand.get() > 0) { + demand = outstandingDemand.decrementAndGet(); + T value; + try { + value = supplier.get(); + } catch (Exception e) { + downstreamSubscriber.onError(e); + return true; + } + downstreamSubscriber.onNext(value); + } + } + return false; + } + + public static class Builder { + private Subscriber downstreamSubscriber; + private Runnable onCancel; + private Supplier supplier; + + public Builder downstreamSubscriber(Subscriber subscriber) { + this.downstreamSubscriber = subscriber; + return this; + } + + public Builder onCancel(Runnable onCancel) { + this.onCancel = onCancel; + return this; + } + + public Builder supplier(Supplier supplier) { + this.supplier = supplier; + return this; + } + + public EmittingSubscription build() { + return new EmittingSubscription<>(this); + } + } + + +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java index 4348355fa5d8..dd915f90f293 100644 --- a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformer.java @@ -41,6 +41,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.core.FileTransformerConfiguration; import software.amazon.awssdk.core.FileTransformerConfiguration.FailureBehavior; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncResponseTransformer; import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.exception.SdkClientException; @@ -76,6 +77,18 @@ private FileAsyncResponseTransformer(Path path, FileTransformerConfiguration fil this.position = position; } + FileTransformerConfiguration config() { + return configuration.toBuilder().build(); + } + + Path path() { + return path; + } + + long position() { + return position; + } + private static long determineFilePositionToWrite(Path path, FileTransformerConfiguration fileConfiguration) { if (fileConfiguration.fileWriteOption() == CREATE_OR_APPEND_TO_EXISTING) { try { @@ -89,7 +102,7 @@ private static long determineFilePositionToWrite(Path path, FileTransformerConfi if (fileConfiguration.fileWriteOption() == WRITE_TO_POSITION) { return Validate.getOrDefault(fileConfiguration.position(), () -> 0L); } - return 0L; + return 0L; } private AsynchronousFileChannel createChannel(Path path) throws IOException { @@ -183,6 +196,7 @@ static class FileSubscriber implements Subscriber { private final Path path; private final CompletableFuture future; private final Consumer onErrorMethod; + private final Object closeLock = new Object(); private volatile boolean writeInProgress = false; private volatile boolean closeOnLastWrite = false; @@ -228,7 +242,7 @@ public void completed(Integer result, ByteBuffer attachment) { if (byteBuffer.hasRemaining()) { performWrite(byteBuffer); } else { - synchronized (FileSubscriber.this) { + synchronized (closeLock) { writeInProgress = false; if (closeOnLastWrite) { close(); @@ -256,7 +270,7 @@ public void onError(Throwable t) { public void onComplete() { log.trace(() -> "onComplete"); // if write in progress, tell write to close on finish. - synchronized (this) { + synchronized (closeLock) { if (writeInProgress) { log.trace(() -> "writeInProgress = true, not closing"); closeOnLastWrite = true; @@ -284,4 +298,18 @@ public String toString() { return getClass() + ":" + path.toString(); } } -} \ No newline at end of file + + + @Override + public SplitResult split(SplittingTransformerConfiguration splitConfig) { + if (configuration.fileWriteOption() == CREATE_OR_APPEND_TO_EXISTING) { + return AsyncResponseTransformer.super.split(splitConfig); + } + CompletableFuture future = new CompletableFuture<>(); + return (SplitResult) SplitResult.builder() + .publisher(new FileAsyncResponseTransformerPublisher(this)) + .resultFuture(future) + .parallelSplitSupported(true) + .build(); + } +} diff --git a/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisher.java b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisher.java new file mode 100644 index 000000000000..29208894cf31 --- /dev/null +++ b/core/sdk-core/src/main/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisher.java @@ -0,0 +1,196 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import java.nio.ByteBuffer; +import java.nio.file.Path; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicLong; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.FileTransformerConfiguration; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.ContentRangeParser; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; +import software.amazon.awssdk.utils.Validate; + +/** + * A publisher of {@link FileAsyncResponseTransformer} that uses the Content-Range header of a {@link SdkResponse} to write to the + * offset defined in the range of the Content-Range. Correspond to the {@link SplittingTransformer} for non-linear write cases. + */ +@SdkInternalApi +public class FileAsyncResponseTransformerPublisher + implements SdkPublisher> { + private static final Logger log = Logger.loggerFor(FileAsyncResponseTransformerPublisher.class); + + private final Path path; + private final FileTransformerConfiguration initialConfig; + private final long initialPosition; + private Subscriber subscriber; + private final AtomicLong transformerCount; + + + public FileAsyncResponseTransformerPublisher(FileAsyncResponseTransformer responseTransformer) { + this.path = Validate.paramNotNull(responseTransformer.path(), "path"); + Validate.isTrue(responseTransformer.config().fileWriteOption() + != FileTransformerConfiguration.FileWriteOption.CREATE_OR_APPEND_TO_EXISTING, + "CREATE_OR_APPEND_TO_EXISTING is not supported for non-serial operations"); + this.initialConfig = Validate.paramNotNull(responseTransformer.config(), "fileTransformerConfiguration"); + this.initialPosition = responseTransformer.position(); + this.transformerCount = new AtomicLong(0); + } + + @Override + public void subscribe(Subscriber> s) { + Validate.notNull(s, "Subscriber must not be null"); + this.subscriber = s; + s.onSubscribe(EmittingSubscription.>builder() + .downstreamSubscriber(s) + .onCancel(this::onCancel) + .supplier(this::createTransformer) + .build()); + } + + private AsyncResponseTransformer createTransformer() { + return new IndividualFileTransformer(); + } + + private void onCancel() { + subscriber = null; + } + + /** + * This is the AsyncResponseTransformer that will be used for each individual requests. + *

+ * We delegate to new instances of the already existing class {@link FileAsyncResponseTransformer} to perform the individual + * requests. This FileAsyncResponseTransformer will write the content of the request to the file at the offset taken from the + * Content-Range header ('x-amz-content-range'). As such, we don't need to manually manage the state of the + * AsyncResponseTransformer passed by the user, like we do for {@link SplittingTransformer}. Here, we know it is a + * FileAsyncResponseTransformer, so we can just ignore it, and instead rely on the individual FileAsyncResponseTransformer of + * every part. + *

+ * Note on retries: since we are delegating requests to {@link FileAsyncResponseTransformer}, each request made with this + * transformer will retry independently based on the retry configuration of the client it is used with. We only need to verify + * the completion state of the future of each individually + */ + private class IndividualFileTransformer implements AsyncResponseTransformer { + private AsyncResponseTransformer delegate; + private CompletableFuture future; + + @Override + public CompletableFuture prepare() { + this.future = new CompletableFuture<>(); + return this.future; + } + + @Override + public void onResponse(T response) { + Optional contentRangeList = response.sdkHttpResponse().firstMatchingHeader("x-amz-content-range"); + if (!contentRangeList.isPresent()) { + if (subscriber != null) { + IllegalStateException e = new IllegalStateException("Content range header is missing"); + handleError(e); + } + return; + } + + String contentRange = contentRangeList.get(); + Optional> contentRangePair = ContentRangeParser.range(contentRange); + if (!contentRangePair.isPresent()) { + if (subscriber != null) { + IllegalStateException e = new IllegalStateException("Could not parse content range header " + contentRange); + handleError(e); + } + return; + } + + this.delegate = getDelegateTransformer(contentRangePair.get().left()); + CompletableFuture delegateFuture = delegate.prepare(); + CompletableFutureUtils.forwardResultTo(delegateFuture, future); + CompletableFutureUtils.forwardExceptionTo(future, delegateFuture); + transformerCount.incrementAndGet(); + delegate.onResponse(response); + } + + private void handleError(Throwable e) { + subscriber.onError(e); + future.completeExceptionally(e); + } + + private AsyncResponseTransformer getDelegateTransformer(Long startAt) { + if (transformerCount.get() == 0) { + // On the first request we need to maintain the same config so + // that the file is actually created on disk if it doesn't exist (for example, if CREATE_NEW or + // CREATE_OR_REPLACE_EXISTING is used) + return AsyncResponseTransformer.toFile(path, initialConfig); + } + switch (initialConfig.fileWriteOption()) { + case CREATE_NEW: + case CREATE_OR_REPLACE_EXISTING: { + FileTransformerConfiguration newConfig = initialConfig.copy(c -> c + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION) + .position(startAt)); + return AsyncResponseTransformer.toFile(path, newConfig); + } + case WRITE_TO_POSITION: { + long initialOffset = initialConfig.position(); + FileTransformerConfiguration newConfig = initialConfig.copy(c -> c + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION) + .position(initialOffset + startAt)); + return AsyncResponseTransformer.toFile(path, newConfig); + } + // As per design specification, APPEND mode is not supported for non-serial operations + case CREATE_OR_APPEND_TO_EXISTING: + default: + throw new UnsupportedOperationException("Unsupported fileWriteOption: " + initialConfig.fileWriteOption()); + } + } + + @Override + public void onStream(SdkPublisher publisher) { + // should never be null as per AsyncResponseTransformer runtime contract, but we never know + if (delegate == null) { + if (future != null) { + future.completeExceptionally(new IllegalStateException("onStream called before onResponse")); + } + return; + } + delegate.onStream(publisher); + } + + @Override + public void exceptionOccurred(Throwable error) { + if (delegate != null) { + // do not call onError, because exceptionOccurred may be called multiple times due to retries, simply forward the + // error to the delegate async response transformer which will let the service call pipeline handle the error. + delegate.exceptionOccurred(error); + } else { + // If we received an error without even having a delegate, this means we have thrown an error before even + // getting a onResponse signal. We complete the prepared future, to let the + // service call pipeline handle the error + if (future != null) { + future.completeExceptionally(error); + } + } + } + } + +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisherTckTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisherTckTest.java new file mode 100644 index 000000000000..fbc0ae4662f5 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisherTckTest.java @@ -0,0 +1,60 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import com.google.common.jimfs.Configuration; +import com.google.common.jimfs.Jimfs; +import java.nio.file.FileSystem; +import java.nio.file.Path; +import org.reactivestreams.Publisher; +import org.reactivestreams.tck.PublisherVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; + +public class FileAsyncResponseTransformerPublisherTckTest extends PublisherVerification> { + + private final FileSystem fileSystem = Jimfs.newFileSystem(Configuration.unix()); + private final Path testFile = fileSystem.getPath("/test-file.txt"); + + + public FileAsyncResponseTransformerPublisherTckTest() { + super(new TestEnvironment()); + } + + @Override + public Publisher> createPublisher(long elements) { + FileAsyncResponseTransformer art = + (FileAsyncResponseTransformer) AsyncResponseTransformer.toFile(testFile); + FileAsyncResponseTransformerPublisher publisher = + new FileAsyncResponseTransformerPublisher<>(art); + + return SdkPublisher.adapt(publisher).limit((int) elements); + } + + @Override + public Publisher> createFailedPublisher() { + return null; + } + + @Override + public long maxElementsFromPublisher() { + return Long.MAX_VALUE; + } + +} diff --git a/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisherTest.java b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisherTest.java new file mode 100644 index 000000000000..adfd4ecc3200 --- /dev/null +++ b/core/sdk-core/src/test/java/software/amazon/awssdk/core/internal/async/FileAsyncResponseTransformerPublisherTest.java @@ -0,0 +1,306 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.core.internal.async; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.assertj.core.api.Assertions.fail; +import static org.assertj.core.api.Assertions.in; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.common.jimfs.Jimfs; +import java.nio.ByteBuffer; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.stream.Stream; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.core.FileTransformerConfiguration; +import software.amazon.awssdk.core.SdkResponse; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.http.SdkHttpResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; + +class FileAsyncResponseTransformerPublisherTest { + + private FileSystem fileSystem; + private Path testFile; + + @BeforeEach + void setUp() throws Exception { + fileSystem = Jimfs.newFileSystem(); + testFile = fileSystem.getPath(String.format("/test-file-%s.txt", UUID.randomUUID())); + } + + @AfterEach + void tearDown() throws Exception { + fileSystem.close(); + } + + @ParameterizedTest + @MethodSource("transformers") + void singleDemand_shouldEmitOneTransformer( + Function> transformerFunction) throws Exception { + // Given + // FileAsyncResponseTransformer initialTransformer = + // (FileAsyncResponseTransformer) AsyncResponseTransformer.toFile(testFile); + + AsyncResponseTransformer initialTransformer = transformerFunction.apply(testFile); + createFileIfNeeded(initialTransformer); + + FileAsyncResponseTransformerPublisher publisher = + new FileAsyncResponseTransformerPublisher<>((FileAsyncResponseTransformer) initialTransformer); + + CountDownLatch latch = new CountDownLatch(1); + AtomicReference> receivedTransformer = new AtomicReference<>(); + CompletableFuture future = new CompletableFuture<>(); + + // When + publisher.subscribe(new Subscriber>() { + private Subscription subscription; + + @Override + public void onSubscribe(Subscription s) { + this.subscription = s; + s.request(1); + } + + @Override + public void onNext(AsyncResponseTransformer transformer) { + receivedTransformer.set(transformer); + + // Simulate response with content-range header + SdkResponse mockResponse = createMockResponseWithRange("bytes 0-9/10"); + CompletableFuture prepareFuture = transformer.prepare(); + CompletableFutureUtils.forwardResultTo(prepareFuture, future); + transformer.onResponse(mockResponse); + + // Simulate stream data + SdkPublisher mockPublisher = createMockPublisher(); + transformer.onStream(mockPublisher); + + latch.countDown(); + } + + @Override + public void onError(Throwable t) { + fail("Unexpected error with exception: " + t.getMessage()); + } + + @Override + public void onComplete() { + latch.countDown(); + } + }); + + // Then + assertThat(latch.await(5, TimeUnit.SECONDS)).isTrue(); + assertThat(receivedTransformer.get()).isNotNull(); + assertThat(Files.exists(testFile)).isTrue(); + assertThat(future).succeedsWithin(10, TimeUnit.SECONDS); + } + + private void createFileIfNeeded(AsyncResponseTransformer initialTransformer) throws Exception { + FileTransformerConfiguration.FileWriteOption fileWriteOption = + ((FileAsyncResponseTransformer) initialTransformer).config().fileWriteOption(); + if (fileWriteOption == FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION) { + Files.createFile(testFile); + } + } + + private SdkPublisher createMockPublisher() { + return s -> s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + s.onNext(ByteBuffer.wrap("test data".getBytes())); + s.onComplete(); + } + + @Override + public void cancel() { + } + }); + } + + @ParameterizedTest + @MethodSource("transformers") + void requestManyTransformers_withResponseContainingDifferentContentRanges_shouldWriteToFileAtThoseRanges( + Function> transformerFunction + ) throws Exception { + // Given + FileAsyncResponseTransformer initialTransformer = + (FileAsyncResponseTransformer) transformerFunction.apply(testFile); + createFileIfNeeded(initialTransformer); + + FileAsyncResponseTransformerPublisher publisher = + new FileAsyncResponseTransformerPublisher<>(initialTransformer); + + int numTransformers = 8; + CountDownLatch latch = new CountDownLatch(numTransformers); + AtomicInteger transformerCount = new AtomicInteger(0); + List> futures = new ArrayList<>(); + + // When + publisher.subscribe(new Subscriber>() { + @Override + public void onSubscribe(Subscription s) { + s.request(numTransformers); + } + + @Override + public void onNext(AsyncResponseTransformer transformer) { + int index = transformerCount.getAndIncrement(); + + // Each transformer gets a different 10-byte range + long startByte = index * 10L; + long endByte = startByte + 9; + String contentRange = String.format("bytes %d-%d/80", startByte, endByte); + byte[] data = new byte[10]; + for (int i = 0; i < 10; i++) { + data[i] = (byte) ((byte) startByte + i); + } + + SdkResponse mockResponse = createMockResponseWithRange(contentRange); + CompletableFuture future = transformer.prepare(); + futures.add(future); + + transformer.onResponse(mockResponse); + transformer.onStream(createMockPublisherWithData(data)); + + latch.countDown(); + } + + @Override + public void onError(Throwable t) { + for (int i = 0; i < numTransformers; i++) { + latch.countDown(); + } + } + + @Override + public void onComplete() { + } + }); + + // Then + assertThat(latch.await(10, TimeUnit.SECONDS)).isTrue(); + assertThat(transformerCount.get()).isEqualTo(numTransformers); + + CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])).join(); + + assertThat(Files.exists(testFile)).isTrue(); + byte[] fileContent = Files.readAllBytes(testFile); + + int offset = + initialTransformer.config().fileWriteOption() == FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION + ? (int) initialTransformer.position() + : 0; + assertThat(fileContent.length).isEqualTo(80 + offset); + for (int i = 0; i < numTransformers; i++) { + int startPos = i * 10; + byte[] expectedData = new byte[10]; + for (int j = 0; j < 10; j++) { + expectedData[j] = (byte) ((byte) startPos + j); + } + byte[] actualData = Arrays.copyOfRange(fileContent, startPos + offset, startPos + offset + 10); + assertThat(actualData).isEqualTo(expectedData); + } + } + + private SdkResponse createMockResponseWithRange(String contentRange) { + SdkResponse mockResponse = mock(SdkResponse.class); + SdkHttpResponse mockHttpResponse = mock(SdkHttpResponse.class); + + when(mockResponse.sdkHttpResponse()).thenReturn(mockHttpResponse); + when(mockHttpResponse.firstMatchingHeader("x-amz-content-range")).thenReturn(Optional.of(contentRange)); + + return mockResponse; + } + + private SdkPublisher createMockPublisherWithData(byte[] data) { + return s -> s.onSubscribe(new Subscription() { + @Override + public void request(long n) { + s.onNext(ByteBuffer.wrap(data)); + s.onComplete(); + } + + @Override + public void cancel() { + } + }); + } + + private static Stream>> transformers() { + return Stream.of( + AsyncResponseTransformer::toFile, + path -> AsyncResponseTransformer.toFile( + path, + FileTransformerConfiguration.builder() + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.CREATE_NEW) + .failureBehavior(FileTransformerConfiguration.FailureBehavior.LEAVE) + .build()), + path -> AsyncResponseTransformer.toFile( + path, + FileTransformerConfiguration.builder() + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.CREATE_OR_REPLACE_EXISTING) + .failureBehavior(FileTransformerConfiguration.FailureBehavior.LEAVE) + .build()), + path -> AsyncResponseTransformer.toFile( + path, + FileTransformerConfiguration.builder() + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.WRITE_TO_POSITION) + .failureBehavior(FileTransformerConfiguration.FailureBehavior.LEAVE) + .position(10L) + .build()) + ); + } + + @Test + void createOrAppendToExisting_shouldThrowException() throws Exception { + AsyncResponseTransformer initialTransformer = AsyncResponseTransformer.toFile( + testFile, + FileTransformerConfiguration.builder() + .failureBehavior(FileTransformerConfiguration.FailureBehavior.DELETE) + .fileWriteOption(FileTransformerConfiguration.FileWriteOption.CREATE_OR_APPEND_TO_EXISTING) + .build()); + assertThatThrownBy(() -> new FileAsyncResponseTransformerPublisher<>((FileAsyncResponseTransformer) initialTransformer)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("CREATE_OR_APPEND_TO_EXISTING"); + + } + +} \ No newline at end of file diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java index de7ffc6ca949..f65cabec1ee3 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/GenericS3TransferManager.java @@ -348,10 +348,18 @@ public final Download download(DownloadRequest downl TransferProgressUpdater progressUpdater = new TransferProgressUpdater(downloadRequest, null); progressUpdater.transferInitiated(); - responseTransformer = isS3ClientMultipartEnabled() - ? progressUpdater.wrapResponseTransformerForMultipartDownload( - responseTransformer, downloadRequest.getObjectRequest()) - : progressUpdater.wrapResponseTransformer(responseTransformer); + if (isS3ClientMultipartEnabled()) { + if (responseTransformer.split(b -> b.bufferSizeInBytes(1L)).parallelSplitSupported()) { + responseTransformer = + progressUpdater.wrapForNonSerialFileDownload(responseTransformer, downloadRequest.getObjectRequest()); + } else { + responseTransformer = + progressUpdater.wrapResponseTransformerForMultipartDownload( + responseTransformer, downloadRequest.getObjectRequest()); + } + } else { + responseTransformer = progressUpdater.wrapResponseTransformer(responseTransformer); + } progressUpdater.registerCompletion(returnFuture); try { @@ -402,7 +410,7 @@ private TransferProgressUpdater doDownloadFile( try { progressUpdater.transferInitiated(); responseTransformer = isS3ClientMultipartEnabled() - ? progressUpdater.wrapResponseTransformerForMultipartDownload( + ? progressUpdater.wrapForNonSerialFileDownload( responseTransformer, downloadRequest.getObjectRequest()) : progressUpdater.wrapResponseTransformer(responseTransformer); progressUpdater.registerCompletion(returnFuture); diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java index 2cab45039d97..4d78530a45ce 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java +++ b/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/TransferProgressUpdater.java @@ -22,8 +22,10 @@ import java.util.concurrent.atomic.AtomicBoolean; import org.reactivestreams.Subscriber; import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; import software.amazon.awssdk.core.async.listener.AsyncRequestBodyListener; import software.amazon.awssdk.core.async.listener.AsyncResponseTransformerListener; import software.amazon.awssdk.core.async.listener.PublisherListener; @@ -35,6 +37,7 @@ import software.amazon.awssdk.transfer.s3.progress.TransferListener; import software.amazon.awssdk.transfer.s3.progress.TransferProgress; import software.amazon.awssdk.transfer.s3.progress.TransferProgressSnapshot; +import software.amazon.awssdk.utils.ContentRangeParser; /** * An SDK-internal helper class that facilitates updating a {@link TransferProgress} and invoking {@link TransferListener}s. @@ -172,25 +175,85 @@ public AsyncResponseTransformer wrapRespon new BaseAsyncResponseTransformerListener() { @Override public void transformerOnResponse(GetObjectResponse response) { - // if the GetObjectRequest is a range-get, the Content-Length headers of the response needs to be used - // to update progress since the Content-Range would incorrectly upgrade progress with the whole object - // size. - if (request.range() != null) { - if (response.contentLength() != null) { - progress.updateAndGet(b -> b.totalBytes(response.contentLength()).sdkResponse(response)); - } - } else { - // if the GetObjectRequest is not a range-get, it might be a part-get. In that case, we need to parse - // the Content-Range header to get the correct totalByte amount. - ContentRangeParser - .totalBytes(response.contentRange()) - .ifPresent(totalBytes -> progress.updateAndGet(b -> b.totalBytes(totalBytes).sdkResponse(response))); - } + multipartDownloadOnResponse(request, response); } } ); } + private void multipartDownloadOnResponse(GetObjectRequest request, GetObjectResponse response) { + // if the GetObjectRequest is a range-get, the Content-Length headers of the response needs to be used + // to update progress since the Content-Range would incorrectly upgrade progress with the whole object + // size. + if (request.range() != null) { + if (response.contentLength() != null) { + progress.updateAndGet(b -> b.totalBytes(response.contentLength()).sdkResponse(response)); + } + } else { + // if the GetObjectRequest is not a range-get, it might be a part-get. In that case, we need to parse + // the Content-Range header to get the correct totalByte amount. + ContentRangeParser + .totalBytes(response.contentRange()) + .ifPresent(totalBytes -> progress.updateAndGet(b -> b.totalBytes(totalBytes).sdkResponse(response))); + } + } + + // upstream transformer + public AsyncResponseTransformer wrapForNonSerialFileDownload( + AsyncResponseTransformer responseTransformer, GetObjectRequest request) { + return new AsyncResponseTransformer() { + @Override + public CompletableFuture prepare() { + return responseTransformer.prepare(); + } + + @Override + public void onResponse(GetObjectResponse response) { + responseTransformer.onResponse(response); + } + + @Override + public void onStream(SdkPublisher publisher) { + responseTransformer.onStream(publisher); + } + + @Override + public void exceptionOccurred(Throwable error) { + responseTransformer.exceptionOccurred(error); + } + + @Override + public SplitResult split(SplittingTransformerConfiguration splitConfig) { + return responseTransformer + .split(splitConfig) + .copy(b -> b.publisher(wrapIndividualTransformer(b.publisher(), request))); + } + + @Override + public String name() { + return responseTransformer.name(); + } + }; + } + + private SdkPublisher> wrapIndividualTransformer( + SdkPublisher> publisher, GetObjectRequest request) { + // each of the individual transformer for multipart file download + return publisher.map(art -> AsyncResponseTransformerListener.wrap( + art, + new AsyncResponseTransformerListener() { + @Override + public void transformerOnResponse(GetObjectResponse response) { + multipartDownloadOnResponse(request, response); + } + + @Override + public void subscriberOnNext(ByteBuffer byteBuffer) { + incrementBytesTransferred(byteBuffer.limit()); + } + })); + } + public AsyncResponseTransformer wrapResponseTransformer( AsyncResponseTransformer responseTransformer) { return AsyncResponseTransformerListener.wrap( diff --git a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java index 6dc7a7fc3ce8..b89b7ab3ce7f 100644 --- a/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java +++ b/services-custom/s3-transfer-manager/src/test/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParserTest.java @@ -17,12 +17,14 @@ import static org.assertj.core.api.Assertions.assertThat; +import java.util.Optional; import java.util.OptionalLong; import java.util.stream.Stream; -import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import software.amazon.awssdk.utils.ContentRangeParser; +import software.amazon.awssdk.utils.Pair; class ContentRangeParserTest { @@ -51,4 +53,18 @@ static Stream argumentProvider() { Arguments.of("bla bla bla", OptionalLong.empty())); } -} \ No newline at end of file + @ParameterizedTest + @MethodSource("testRange") + void testRange(String contentRange, Optional> expected) { + assertThat(ContentRangeParser.range(contentRange)).isEqualTo(expected); + } + + static Stream testRange() { + return Stream.of( + Arguments.of("bytes 0-9/10", Optional.of(Pair.of(0L, 9L))), + Arguments.of("bytes */10", Optional.empty()), + Arguments.of("bytes 12000000-17999999/30000000", Optional.of(Pair.of(12000000L, 17999999L))) + ); + } + +} diff --git a/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientFileDownloadIntegrationTest.java b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientFileDownloadIntegrationTest.java new file mode 100644 index 000000000000..10823cb5937a --- /dev/null +++ b/services/s3/src/it/java/software/amazon/awssdk/services/s3/multipart/S3MultipartClientFileDownloadIntegrationTest.java @@ -0,0 +1,136 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.testutils.service.S3BucketUtils.temporaryBucketName; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.security.MessageDigest; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import software.amazon.awssdk.core.FileTransformerConfiguration; +import software.amazon.awssdk.core.SdkRequest; +import software.amazon.awssdk.core.async.AsyncRequestBody; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.interceptor.Context; +import software.amazon.awssdk.core.interceptor.ExecutionAttributes; +import software.amazon.awssdk.core.interceptor.ExecutionInterceptor; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3IntegrationTestBase; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.testutils.RandomTempFile; +import software.amazon.awssdk.utils.JavaSystemSetting; +import software.amazon.awssdk.utils.Logger; + +@Timeout(value = 5, unit = TimeUnit.MINUTES) +public class S3MultipartClientFileDownloadIntegrationTest extends S3IntegrationTestBase { + private static final Logger log = Logger.loggerFor(S3MultipartClientFileDownloadIntegrationTest.class); + private static final int MIB = 1024 * 1024; + private static final String TEST_BUCKET = temporaryBucketName(S3MultipartClientFileDownloadIntegrationTest.class); + private static final String TEST_KEY = "testfile.dat"; + private static final int OBJ_SIZE = 100 * MIB; + private static final long PART_SIZE = 5 * MIB; + + private static RandomTempFile localFile; + private S3AsyncClient s3Client; + private TestInterceptor interceptor; + + @BeforeAll + public static void setup() throws Exception { + log.info(() -> "setup"); + setUp(); + log.info(() -> "create bucket"); + createBucket(TEST_BUCKET); + localFile = new RandomTempFile(TEST_KEY, OBJ_SIZE); + localFile.deleteOnExit(); + S3AsyncClient s3Client = S3AsyncClient.builder() + .multipartEnabled(true) + .multipartConfiguration(c -> c.minimumPartSizeInBytes(PART_SIZE)) + .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) + .build(); + log.info(() -> "put multipart object"); + s3Client.putObject(r -> r.bucket(TEST_BUCKET).key(TEST_KEY), AsyncRequestBody.fromFile(localFile)) + .join(); + s3.close(); + } + + @BeforeEach + public void init() { + log.info(() -> "Initializing S3MultipartClientFileDownloadIntegrationTest"); + this.interceptor = new TestInterceptor(); + this.s3Client = S3AsyncClient.builder() + .multipartEnabled(true) + .overrideConfiguration(o -> o.addExecutionInterceptor(this.interceptor)) + .credentialsProvider(CREDENTIALS_PROVIDER_CHAIN) + .build(); + } + + @Test + void download_defaultCreateNewFile_shouldSucceed() throws Exception { + Path path = tmpPath().resolve(UUID.randomUUID().toString()); + CompletableFuture future = s3Client.getObject( + req -> req.bucket(TEST_BUCKET).key(TEST_KEY), + AsyncResponseTransformer.toFile(path, FileTransformerConfiguration.defaultCreateNew())); + future.join(); + assertSameContentWithChecksum(path); + int totalParts = OBJ_SIZE / (int) PART_SIZE; + assertThat(interceptor.parts.size()).isEqualTo(totalParts); + assertThat(interceptor.parts).hasSameElementsAs(IntStream.range(1, totalParts +1).boxed().collect(Collectors.toList())); + path.toFile().delete(); + } + + private Path tmpPath() { + return Paths.get(JavaSystemSetting.TEMP_DIRECTORY.getStringValueOrThrow()); + } + + private static final class TestInterceptor implements ExecutionInterceptor { + List parts = new ArrayList<>(); + + @Override + public void beforeTransmission(Context.BeforeTransmission context, ExecutionAttributes executionAttributes) { + SdkRequest request = context.request(); + if (request instanceof GetObjectRequest) { + log.info(() -> "Received GetObjectRequest for request " + request); + GetObjectRequest getObjectRequest = (GetObjectRequest) request; + parts.add(getObjectRequest.partNumber()); + } else { + log.warn(() -> "Unexpected request type: " + request.getClass()); + } + } + } + + private void assertSameContentWithChecksum(Path path) throws Exception { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + byte[] downloadedHash = md.digest(Files.readAllBytes(path)); + md.reset(); + byte[] originalHash = md.digest(Files.readAllBytes(localFile.toPath())); + assertThat(downloadedHash).isEqualTo(originalHash); + } +} \ No newline at end of file diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java index e09a065b29be..099ec7f7fb78 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/DownloadObjectHelper.java @@ -32,10 +32,12 @@ public class DownloadObjectHelper { private final S3AsyncClient s3AsyncClient; private final long bufferSizeInBytes; + private final int maxInFlightParts; - public DownloadObjectHelper(S3AsyncClient s3AsyncClient, long bufferSizeInBytes) { + public DownloadObjectHelper(S3AsyncClient s3AsyncClient, long bufferSizeInBytes, int maxInFlightParts) { this.s3AsyncClient = s3AsyncClient; this.bufferSizeInBytes = bufferSizeInBytes; + this.maxInFlightParts = maxInFlightParts; } public CompletableFuture downloadObject( @@ -48,6 +50,26 @@ public CompletableFuture downloadObject( asyncResponseTransformer.split(SplittingTransformerConfiguration.builder() .bufferSizeInBytes(bufferSizeInBytes) .build()); + if (!split.parallelSplitSupported()) { + return downloadPartsSerially(getObjectRequest, split); + } + + return downloadPartsNonSerially(getObjectRequest, split, maxInFlightParts); + + } + + private CompletableFuture downloadPartsNonSerially( + GetObjectRequest getObjectRequest, + AsyncResponseTransformer.SplitResult split, + int maxInFlight) { + ParallelMultipartDownloaderSubscriber subscriber = new ParallelMultipartDownloaderSubscriber( + s3AsyncClient, getObjectRequest, (CompletableFuture) split.resultFuture(), maxInFlight); + split.publisher().subscribe(subscriber); + return split.resultFuture(); + } + + private CompletableFuture downloadPartsSerially(GetObjectRequest getObjectRequest, + AsyncResponseTransformer.SplitResult split) { MultipartDownloaderSubscriber subscriber = subscriber(getObjectRequest); split.publisher().subscribe(subscriber); CompletableFuture splitFuture = split.resultFuture(); diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolver.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolver.java index 9fc199175bda..d5a302362b26 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolver.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolver.java @@ -17,6 +17,7 @@ import software.amazon.awssdk.annotations.SdkInternalApi; import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; +import software.amazon.awssdk.services.s3.multipart.ParallelConfiguration; import software.amazon.awssdk.utils.Validate; /** @@ -26,9 +27,14 @@ public final class MultipartConfigurationResolver { private static final long DEFAULT_MIN_PART_SIZE = 8L * 1024 * 1024; + + // Using 50 as the default since maxConcurrency for http client is also 50 + private static final int DEFAULT_MAX_IN_FLIGHT_PARTS = 50; + private final long minimalPartSizeInBytes; private final long apiCallBufferSize; private final long thresholdInBytes; + private final int maxInFlightParts; public MultipartConfigurationResolver(MultipartConfiguration multipartConfiguration) { Validate.notNull(multipartConfiguration, "multipartConfiguration"); @@ -37,6 +43,13 @@ public MultipartConfigurationResolver(MultipartConfiguration multipartConfigurat this.apiCallBufferSize = Validate.getOrDefault(multipartConfiguration.apiCallBufferSizeInBytes(), () -> minimalPartSizeInBytes * 4); this.thresholdInBytes = Validate.getOrDefault(multipartConfiguration.thresholdInBytes(), () -> minimalPartSizeInBytes); + ParallelConfiguration parallelConfiguration = multipartConfiguration.parallelConfiguration(); + if (parallelConfiguration == null) { + this.maxInFlightParts = DEFAULT_MAX_IN_FLIGHT_PARTS; + } else { + this.maxInFlightParts = Validate.getOrDefault(multipartConfiguration.parallelConfiguration().maxInFlightParts(), + () -> DEFAULT_MAX_IN_FLIGHT_PARTS); + } } public long minimalPartSizeInBytes() { @@ -50,4 +63,8 @@ public long thresholdInBytes() { public long apiCallBufferSize() { return apiCallBufferSize; } + + public int maxInFlightParts() { + return maxInFlightParts; + } } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java index d638ca2f4b6a..499f36a6dd1b 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartS3AsyncClient.java @@ -62,9 +62,10 @@ private MultipartS3AsyncClient(S3AsyncClient delegate, MultipartConfiguration mu long minPartSizeInBytes = resolver.minimalPartSizeInBytes(); long threshold = resolver.thresholdInBytes(); long apiCallBufferSize = resolver.apiCallBufferSize(); + int maxInFlightParts = resolver.maxInFlightParts(); mpuHelper = new UploadObjectHelper(delegate, resolver); copyObjectHelper = new CopyObjectHelper(delegate, minPartSizeInBytes, threshold); - downloadObjectHelper = new DownloadObjectHelper(delegate, apiCallBufferSize); + downloadObjectHelper = new DownloadObjectHelper(delegate, apiCallBufferSize, maxInFlightParts); this.checksumEnabled = checksumEnabled; } diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriber.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriber.java new file mode 100644 index 000000000000..f2243a8b103d --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriber.java @@ -0,0 +1,378 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import java.util.Map; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import software.amazon.awssdk.annotations.SdkInternalApi; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.utils.CompletableFutureUtils; +import software.amazon.awssdk.utils.Logger; +import software.amazon.awssdk.utils.Pair; + +/** + * A subscriber implementation that will download all individual parts for a multipart get-object request in parallel, + * concurrently. The amount of concurrent get-object is limited by the {@code maxInFlightParts} configuration. It receives the + * individual {@link AsyncResponseTransformer} which will be used to perform the individual part requests. These + * AsyncResponseTransformer should be able to handle receiving data in parts potentially out of order, For example, the + * AsyncResponseTransformer for part 4 might may have any of its callback called before part 1, 2 or 3 if it finishes before. This + * is a 'one-shot' class, it should NOT be reused for more than one multipart download. + */ +@SdkInternalApi +public class ParallelMultipartDownloaderSubscriber + implements Subscriber> { + private static final Logger log = Logger.loggerFor(ParallelMultipartDownloaderSubscriber.class); + + /** + * Maximum number of concurrent GetObject requests + */ + private final int maxInFlightParts; + + /** + * The s3 client used to make the individual part requests + */ + private final S3AsyncClient s3; + + /** + * The GetObjectRequest that was provided when calling s3.getObject(...). It is copied for each individual request, and the + * copy has the partNumber field updated as more parts are downloaded. + */ + private final GetObjectRequest getObjectRequest; + + /** + * The total number of completed parts. A part is considered complete once the completable future associated with its request + * completes successfully. + */ + private final AtomicInteger completedParts = new AtomicInteger(); + + /** + * The future returned to the user when calling + * {@link S3AsyncClient#getObject(GetObjectRequest, AsyncResponseTransformer) getObject}. This will be completed once the last + * part finishes. Contrary to the linear code path, the future returned to the user is handled here so that we can complete it + * once the last part writting to the file is completed. + */ + private final CompletableFuture resultFuture; + + /** + * The {@link GetObjectResponse} to be returned in the completed future to the user. It corresponds to the response of first + * part GetObject + */ + private GetObjectResponse getObjectResponse; + + /** + * The subscription received from the publisher this subscriber subscribes to. + */ + private Subscription subscription; + + /** + * This value indicates the total number of parts of the object to get. If null, it means we don't know the total amount of + * parts, either because we haven't received a response from s3 yet to set it, or the object to get is not multipart. + */ + private CompletableFuture totalPartsFuture = new CompletableFuture<>(); + + /** + * The etag of the object being downloaded. + */ + private volatile String eTag; + + /** + * Lock around calls to the subscription + */ + private final Object subscriptionLock = new Object(); + + /** + * Tracks request that are currently in flights, waiting to be completed. Once completed, future are removed from the map + */ + private final Map> inFlightRequests = new ConcurrentHashMap<>(); + + /** + * Trasck the amount of in flight requests + */ + private final AtomicInteger inFlightRequestsNum = new AtomicInteger(0); + + /** + * Pending transformers received through onNext that are waiting to be executed. + */ + private final Queue>> pendingTransformers = + new ConcurrentLinkedQueue<>(); + + /** + * Amount of demand requested but not yet fulfilled by the subscription + */ + private final AtomicInteger outstandingDemand = new AtomicInteger(0); + + /** + * Indicates whether this is the first response transformer or not. + */ + private final AtomicBoolean isFirstResponseTransformer = new AtomicBoolean(true); + + /** + * Indicates if we are currently processing pending transformer, which are waiting to be used to send requests + */ + private final AtomicBoolean processingPendingTransformers = new AtomicBoolean(false); + + /** + * The current part of the object to get + */ + private final AtomicInteger partNumber = new AtomicInteger(0); + + /** + * Tracks if one of the parts requests future completed exceptionally. If this occurs, it means all retries were + * attempted for that part, but it still failed. This is a failure state, the error should be reported back to the user + * and any more request should be ignored. + */ + private final AtomicBoolean isCompletedExceptionally = new AtomicBoolean(false); + + public ParallelMultipartDownloaderSubscriber(S3AsyncClient s3, + GetObjectRequest getObjectRequest, + CompletableFuture resultFuture, + int maxInFlightParts) { + this.s3 = s3; + this.getObjectRequest = getObjectRequest; + this.resultFuture = resultFuture; + this.maxInFlightParts = maxInFlightParts; + } + + @Override + public void onSubscribe(Subscription s) { + if (this.subscription != null) { + s.cancel(); + return; + } + this.subscription = s; + subscription.request(maxInFlightParts); + } + + @Override + public void onNext(AsyncResponseTransformer asyncResponseTransformer) { + if (asyncResponseTransformer == null) { + subscription.cancel(); + throw new NullPointerException("onNext must not be called with null asyncResponseTransformer"); + } + + log.trace(() -> "On Next - Total in flight parts: " + inFlightRequests.size() + + " - Demand : " + outstandingDemand.get() + + " - Total completed parts: " + completedParts + + " - Total pending transformers: " + pendingTransformers.size() + + " - Current in flight requests: " + inFlightRequests.keySet()); + + int currentPartNum = partNumber.incrementAndGet(); + + if (isFirstResponseTransformer.compareAndSet(true, false)) { + sendFirstRequest(asyncResponseTransformer); + } else { + pendingTransformers.offer(Pair.of(currentPartNum, asyncResponseTransformer)); + totalPartsFuture.thenAccept( + totalParts -> processingRequests(asyncResponseTransformer, currentPartNum, totalParts)); + } + } + + private void processingRequests(AsyncResponseTransformer asyncResponseTransformer, + int currentPartNum, Integer totalParts) { + + if (currentPartNum > totalParts) { + // Do not process requests above total parts. + // Since we request for maxInFlight during onSubscribe, and the object might actually have less part than maxInFlight, + // there may be situations where we received more onNext signals than the amount of GetObjectRequest required to be + // made. + return; + } + + if (inFlightRequests.size() >= maxInFlightParts) { + pendingTransformers.offer(Pair.of(currentPartNum, asyncResponseTransformer)); + return; + } + + processPendingTransformers(totalParts); + } + + private void sendNextRequest(AsyncResponseTransformer asyncResponseTransformer, + int currentPartNumber, int totalParts) { + if (inFlightRequestsNum.get() + completedParts.get() >= totalParts) { + return; + } + + GetObjectRequest request = nextRequest(currentPartNumber); + log.debug(() -> "Sending next request for part: " + currentPartNumber); + + CompletableFuture response = s3.getObject(request, asyncResponseTransformer); + + inFlightRequestsNum.incrementAndGet(); + inFlightRequests.put(currentPartNumber, response); + CompletableFutureUtils.forwardExceptionTo(resultFuture, response); + + response.whenComplete((res, e) -> { + if (e != null || isCompletedExceptionally.get()) { + // Note on retries: When this future completes exceptionally, it means we did all retries and still failed for + // that part. We need to report back the failure to the user. + handlePartError(e, currentPartNumber); + return; + } + log.debug(() -> "Completed part: " + currentPartNumber); + + inFlightRequests.remove(currentPartNumber); + inFlightRequestsNum.decrementAndGet(); + completedParts.incrementAndGet(); + + if (completedParts.get() >= totalParts) { + if (completedParts.get() > totalParts) { + resultFuture.completeExceptionally(new IllegalStateException("Total parts exceeded")); + } else { + resultFuture.complete(getObjectResponse); + } + + synchronized (subscriptionLock) { + subscription.cancel(); + } + + } else { + processPendingTransformers(res.partsCount()); + synchronized (subscriptionLock) { + subscription.request(1); + } + } + }); + } + + private void sendFirstRequest(AsyncResponseTransformer asyncResponseTransformer) { + log.debug(() -> "Sending first request"); + GetObjectRequest request = nextRequest(1); + CompletableFuture responseFuture = s3.getObject(request, asyncResponseTransformer); + + // Propagate cancellation from user + CompletableFutureUtils.forwardExceptionTo(resultFuture, responseFuture); + + responseFuture.whenComplete((res, e) -> { + if (e != null || isCompletedExceptionally.get()) { + // Note on retries: When this future completes exceptionally, it means we did all retries and still failed for + // that part. We need to report back the failure to the user. + handlePartError(e, 1); + return; + } + + log.debug(() -> "Completed part: 1"); + completedParts.incrementAndGet(); + setInitialPartCountAndEtag(res); + + if (!isMultipartObject(res)) { + return; + } + + log.debug(() -> "Multipart object detected, performing multipart download"); + getObjectResponse = res; + + processPendingTransformers(res.partsCount()); + synchronized (subscriptionLock) { + subscription.request(1); + } + }); + } + + private boolean isMultipartObject(GetObjectResponse response) { + if (response.partsCount() == null || response.partsCount() == 1) { + // Single part object detected, skip multipart and complete everything now + log.debug(() -> "Single Part object detected, skipping multipart download"); + subscription.cancel(); + resultFuture.complete(response); + return false; + } + return true; + } + + private void setInitialPartCountAndEtag(GetObjectResponse response) { + Integer partCount = response.partsCount(); + eTag = response.eTag(); + if (partCount != null) { + log.debug(() -> String.format("Total amount of parts of the object to download: %d", partCount)); + totalPartsFuture.complete(partCount); + } else { + totalPartsFuture.complete(1); + } + } + + private void handlePartError(Throwable e, int part) { + isCompletedExceptionally.set(true); + log.debug(() -> "Error on part " + part, e); + resultFuture.completeExceptionally(e); + inFlightRequests.values().forEach(future -> future.cancel(true)); + } + + private void processPendingTransformers(int totalParts) { + do { + if (!processingPendingTransformers.compareAndSet(false, true)) { + return; + } + try { + doProcessPendingTransformers(totalParts); + } finally { + processingPendingTransformers.set(false); + } + + } while (shouldProcessPendingTransformers()); + + } + + private void doProcessPendingTransformers(int totalParts) { + while (shouldProcessPendingTransformers()) { + Pair> transformer = + pendingTransformers.poll(); + sendNextRequest(transformer.right(), transformer.left(), totalParts); + } + } + + private boolean shouldProcessPendingTransformers() { + if (pendingTransformers.isEmpty()) { + return false; + } + return maxInFlightParts - inFlightRequestsNum.get() > 0; + } + + @Override + public void onError(Throwable t) { + // Signal received from the publisher this is subscribed to + // (in the case of file download, that's FileAsyncResponseTransformerPublisher) + // Failed state, something really wrong has happened, cancel everything + inFlightRequests.values().forEach(future -> future.cancel(true)); + inFlightRequests.clear(); + resultFuture.completeExceptionally(t); + } + + @Override + public void onComplete() { + // We check for completion state when we receive the GetObjectResponse for last part. + // This Subscriber is responsible for its completed state, so we do nothing here. + } + + private GetObjectRequest nextRequest(int nextPartToGet) { + return getObjectRequest.copy(req -> { + req.partNumber(nextPartToGet); + if (eTag != null) { + req.ifMatch(eTag); + } + }); + } + +} diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java index 7370a383b1cf..f1d5ff35c60c 100644 --- a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/MultipartConfiguration.java @@ -15,6 +15,7 @@ package software.amazon.awssdk.services.s3.multipart; +import java.util.function.Consumer; import software.amazon.awssdk.annotations.SdkPublicApi; import software.amazon.awssdk.core.async.AsyncRequestBody; import software.amazon.awssdk.core.async.AsyncResponseTransformer; @@ -50,11 +51,13 @@ public final class MultipartConfiguration implements ToCopyableBuilder consumer); + + /** + * Configuration specifically related to parallel multipart operations. + * @return the configuration class + */ + ParallelConfiguration parallelConfiguration(); } private static class DefaultMultipartConfigBuilder implements Builder { private Long thresholdInBytes; private Long minimumPartSizeInBytes; private Long apiCallBufferSizeInBytes; + private ParallelConfiguration parallelConfiguration; @Override public Builder thresholdInBytes(Long thresholdInBytes) { @@ -205,6 +237,24 @@ public Long minimumPartSizeInBytes() { return this.minimumPartSizeInBytes; } + @Override + public Builder parallelConfiguration(ParallelConfiguration parallelConfiguration) { + this.parallelConfiguration = parallelConfiguration; + return this; + } + + @Override + public Builder parallelConfiguration(Consumer configuration) { + ParallelConfiguration.Builder builder = ParallelConfiguration.builder(); + configuration.accept(builder); + return parallelConfiguration(builder.build()); + } + + @Override + public ParallelConfiguration parallelConfiguration() { + return this.parallelConfiguration; + } + @Override public Builder apiCallBufferSizeInBytes(Long maximumMemoryUsageInBytes) { this.apiCallBufferSizeInBytes = maximumMemoryUsageInBytes; diff --git a/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java new file mode 100644 index 000000000000..a3816cd97fa2 --- /dev/null +++ b/services/s3/src/main/java/software/amazon/awssdk/services/s3/multipart/ParallelConfiguration.java @@ -0,0 +1,71 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.multipart; + +import software.amazon.awssdk.annotations.SdkPublicApi; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.utils.builder.CopyableBuilder; +import software.amazon.awssdk.utils.builder.ToCopyableBuilder; + +/** + * Class that holds configuration properties related to multipart operations for a {@link S3AsyncClient}, related specifically + * to non-linear, parallel operations, that is, when the {@link AsyncResponseTransformer} supports non-serial split. + */ +@SdkPublicApi +public class ParallelConfiguration implements ToCopyableBuilder { + + private final Integer maxInFlightParts; + + public ParallelConfiguration(Builder builder) { + this.maxInFlightParts = builder.maxInFlightParts; + } + + public static Builder builder() { + return new Builder(); + } + + /** + * The maximum number of concurrent GetObject the that are allowed for multipart download. + * @return The value for the maximum number of concurrent GetObject the that are allowed for multipart download. + */ + public Integer maxInFlightParts() { + return maxInFlightParts; + } + + @Override + public Builder toBuilder() { + return builder().maxInFlightParts(maxInFlightParts); + } + + public static class Builder implements CopyableBuilder { + private int maxInFlightParts; + + public Builder maxInFlightParts(int maxInFlightParts) { + this.maxInFlightParts = maxInFlightParts; + return this; + } + + public int maxInFlightParts() { + return maxInFlightParts; + } + + @Override + public ParallelConfiguration build() { + return new ParallelConfiguration(this); + } + } +} diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolverTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolverTest.java index 99e929c09f4e..18e4348a247a 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolverTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/MultipartConfigurationResolverTest.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Test; import software.amazon.awssdk.services.s3.multipart.MultipartConfiguration; +import software.amazon.awssdk.services.s3.multipart.ParallelConfiguration; public class MultipartConfigurationResolverTest { @@ -59,17 +60,32 @@ void resolveApiCallBufferSize_valueNotProvided_shouldComputeBasedOnPartSize() { assertThat(resolver.apiCallBufferSize()).isEqualTo(40L); } + @Test + void resolveMaxInFlightParts_valueProvidedWithBuilder_shouldHonor() { + MultipartConfiguration configuration = + MultipartConfiguration.builder() + .parallelConfiguration(p -> p.maxInFlightParts(1)) + .build(); + MultipartConfigurationResolver resolver = new MultipartConfigurationResolver(configuration); + assertThat(resolver.maxInFlightParts()).isEqualTo(1); + } + @Test void valueProvidedForAllFields_shouldHonor() { - MultipartConfiguration configuration = MultipartConfiguration.builder() - .minimumPartSizeInBytes(10L) - .thresholdInBytes(8L) - .apiCallBufferSizeInBytes(3L) - .build(); + MultipartConfiguration configuration = + MultipartConfiguration.builder() + .minimumPartSizeInBytes(10L) + .thresholdInBytes(8L) + .apiCallBufferSizeInBytes(3L) + .parallelConfiguration(ParallelConfiguration.builder() + .maxInFlightParts(1) + .build()) + .build(); MultipartConfigurationResolver resolver = new MultipartConfigurationResolver(configuration); assertThat(resolver.minimalPartSizeInBytes()).isEqualTo(10L); assertThat(resolver.thresholdInBytes()).isEqualTo(8L); assertThat(resolver.apiCallBufferSize()).isEqualTo(3L); + assertThat(resolver.maxInFlightParts()).isEqualTo(1); } @Test @@ -79,5 +95,7 @@ void noValueProvided_shouldUseDefault() { assertThat(resolver.minimalPartSizeInBytes()).isEqualTo(8L * 1024 * 1024); assertThat(resolver.thresholdInBytes()).isEqualTo(8L * 1024 * 1024); assertThat(resolver.apiCallBufferSize()).isEqualTo(8L * 1024 * 1024 * 4); + assertThat(resolver.maxInFlightParts()).isEqualTo(50); } + } diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriberTckTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriberTckTest.java new file mode 100644 index 000000000000..79c6b60bb077 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriberTckTest.java @@ -0,0 +1,116 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.nio.ByteBuffer; +import java.util.concurrent.CompletableFuture; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; +import org.reactivestreams.tck.SubscriberWhiteboxVerification; +import org.reactivestreams.tck.TestEnvironment; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.core.async.SdkPublisher; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +public class ParallelMultipartDownloaderSubscriberTckTest + extends SubscriberWhiteboxVerification> { + private S3AsyncClient s3Client; + private CompletableFuture future; + + public ParallelMultipartDownloaderSubscriberTckTest() { + super(new TestEnvironment()); + s3Client = mock(S3AsyncClient.class); + when(s3Client.getObject(any(GetObjectRequest.class), any(AsyncResponseTransformer.class))) + .thenReturn(CompletableFuture.completedFuture(GetObjectResponse.builder() + .partsCount(10) + .eTag("eTag") + .build())); + future = new CompletableFuture<>(); + } + + @Override + public Subscriber> createSubscriber( + WhiteboxSubscriberProbe> probe) { + return new ParallelMultipartDownloaderSubscriber(s3Client, GetObjectRequest.builder().build(), future, 50) { + @Override + public void onSubscribe(Subscription s) { + super.onSubscribe(s); + probe.registerOnSubscribe(new SubscriberWhiteboxVerification.SubscriberPuppet() { + + @Override + public void triggerRequest(long l) { + s.request(l); + } + + @Override + public void signalCancel() { + s.cancel(); + } + }); + } + + @Override + public void onNext(AsyncResponseTransformer item) { + super.onNext(item); + probe.registerOnNext(item); + } + + @Override + public void onError(Throwable t) { + super.onError(t); + probe.registerOnError(t); + } + + @Override + public void onComplete() { + super.onComplete(); + probe.registerOnComplete(); + } + + }; + } + + @Override + public AsyncResponseTransformer createElement(int element) { + return new AsyncResponseTransformer() { + @Override + public CompletableFuture prepare() { + return new CompletableFuture<>(); + } + + @Override + public void onResponse(GetObjectResponse response) { + // do nothing, test + } + + @Override + public void onStream(SdkPublisher publisher) { + // do nothing, test + } + + @Override + public void exceptionOccurred(Throwable error) { + // do nothing, test + } + }; + } +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriberWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriberWiremockTest.java new file mode 100644 index 000000000000..54abbee78c64 --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/ParallelMultipartDownloaderSubscriberWiremockTest.java @@ -0,0 +1,146 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.google.common.jimfs.Configuration; +import com.google.common.jimfs.Jimfs; +import java.net.URI; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.reactivestreams.Subscriber; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.FileTransformerConfiguration; +import software.amazon.awssdk.core.SplittingTransformerConfiguration; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +@WireMockTest +class ParallelMultipartDownloaderSubscriberWiremockTest { + + private final String testBucket = "test-bucket"; + private final String testKey = "test-key"; + + private S3AsyncClient s3AsyncClient; + private MultipartDownloadTestUtils utils; + private FileSystem fileSystem; + private Path testFile; + + @BeforeEach + public void init(WireMockRuntimeInfo wiremock) throws Exception { + s3AsyncClient = S3AsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("key", "secret"))) + .region(Region.US_WEST_2) + .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) + .serviceConfiguration(S3Configuration.builder() + .pathStyleAccessEnabled(true) + .build()) + .build(); + utils = new MultipartDownloadTestUtils(testBucket, testKey, "test-etag"); + fileSystem = Jimfs.newFileSystem(Configuration.unix()); + testFile = fileSystem.getPath("/test-file.txt"); + Files.createDirectories(testFile.getParent()); + Files.createFile(testFile); + } + + @AfterEach + void tearDown() throws Exception { + fileSystem.close(); + } + + @ParameterizedTest + @ValueSource(ints = {2, 3, 4, 5, 6, 7, 8, 9, 10, 49}) + void happyPath_multipartDownload_partsLessThanMaxInFlight(int numParts) throws Exception { + int partSize = 1024; + byte[] expectedBody = utils.stubAllParts(testBucket, testKey, numParts, partSize); + + AsyncResponseTransformer transformer = + AsyncResponseTransformer.toFile(testFile, FileTransformerConfiguration.defaultCreateOrReplaceExisting()); + AsyncResponseTransformer.SplitResult split = transformer.split( + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(1024 * 32L) + .build()); + + CompletableFuture resultFuture = new CompletableFuture<>(); + Subscriber> subscriber = + new ParallelMultipartDownloaderSubscriber(s3AsyncClient, + GetObjectRequest.builder() + .bucket(testBucket) + .key(testKey) + .build(), + resultFuture, + 50); + + split.publisher().subscribe(subscriber); + GetObjectResponse getObjectResponse = resultFuture.join(); + + assertThat(Files.exists(testFile)).isTrue(); + byte[] actualBody = Files.readAllBytes(testFile); + assertThat(actualBody).isEqualTo(expectedBody); + assertThat(getObjectResponse).isNotNull(); + utils.verifyCorrectAmountOfRequestsMade(numParts); + } + + @Test + void singlePartObject_shouldCompleteWithoutMultipart() throws Exception { + int partSize = 1024; + byte[] expectedBody = utils.stubSinglePart(testBucket, testKey, partSize); + + AsyncResponseTransformer transformer = + AsyncResponseTransformer.toFile(testFile, FileTransformerConfiguration.defaultCreateOrReplaceExisting()); + AsyncResponseTransformer.SplitResult split = transformer.split( + SplittingTransformerConfiguration.builder() + .bufferSizeInBytes(1024 * 32L) + .build()); + + CompletableFuture resultFuture = new CompletableFuture<>(); + Subscriber> subscriber = + new ParallelMultipartDownloaderSubscriber(s3AsyncClient, + GetObjectRequest.builder() + .bucket(testBucket) + .key(testKey) + .build(), + resultFuture, + 50); + + split.publisher().subscribe(subscriber); + GetObjectResponse getObjectResponse = resultFuture.join(); + + assertThat(Files.exists(testFile)).isTrue(); + byte[] actualBody = Files.readAllBytes(testFile); + assertThat(actualBody).isEqualTo(expectedBody); + assertThat(getObjectResponse).isNotNull(); + utils.verifyCorrectAmountOfRequestsMade(1); + } + +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java index 449f245203c9..7d4234b0ce5a 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartClientGetObjectWiremockTest.java @@ -147,8 +147,9 @@ public void errorOnThirdPart_shouldCompleteExceptionallyOnlyPartsGreaterThan } } - @ParameterizedTest - @MethodSource("partSizeAndTransformerParams") + // todo temporary, remove when support for resume is added to multipart file download + // @ParameterizedTest + // @MethodSource("partSizeAndTransformerParams") public void partCountValidationFailure_shouldThrowException( AsyncResponseTransformerTestSupplier supplier, int partSize) { @@ -302,8 +303,7 @@ private static Stream partSizeAndTransformerParams() { */ private static Stream> nonRetryableResponseTransformers() { return Stream.of(new AsyncResponseTransformerTestSupplier.InputStreamArtSupplier(), - new AsyncResponseTransformerTestSupplier.PublisherArtSupplier(), - new AsyncResponseTransformerTestSupplier.FileArtSupplier()); + new AsyncResponseTransformerTestSupplier.PublisherArtSupplier()); } private CompletableFuture mock200Response(S3AsyncClient s3Client, int runNumber, diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartFileDownloadWiremockTest.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartFileDownloadWiremockTest.java new file mode 100644 index 000000000000..5d1edb4370fb --- /dev/null +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/S3MultipartFileDownloadWiremockTest.java @@ -0,0 +1,418 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://aws.amazon.com/apache2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package software.amazon.awssdk.services.s3.internal.multipart; + +import static com.github.tomakehurst.wiremock.client.WireMock.aResponse; +import static com.github.tomakehurst.wiremock.client.WireMock.exactly; +import static com.github.tomakehurst.wiremock.client.WireMock.get; +import static com.github.tomakehurst.wiremock.client.WireMock.getRequestedFor; +import static com.github.tomakehurst.wiremock.client.WireMock.stubFor; +import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo; +import static com.github.tomakehurst.wiremock.client.WireMock.verify; +import static org.assertj.core.api.Assertions.assertThat; +import static software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils.contentRangeHeader; + +import com.github.tomakehurst.wiremock.junit5.WireMockRuntimeInfo; +import com.github.tomakehurst.wiremock.junit5.WireMockTest; +import com.google.common.jimfs.Configuration; +import com.google.common.jimfs.Jimfs; +import java.net.URI; +import java.nio.file.FileSystem; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Duration; +import java.time.temporal.ChronoUnit; +import java.util.Random; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.async.AsyncResponseTransformer; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.s3.S3AsyncClient; +import software.amazon.awssdk.services.s3.S3Configuration; +import software.amazon.awssdk.services.s3.internal.multipart.utils.MultipartDownloadTestUtils; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import software.amazon.awssdk.services.s3.model.S3Exception; + +@WireMockTest +class S3MultipartFileDownloadWiremockTest { + + private final String testBucket = "test-bucket"; + private final String testKey = "test-key"; + + private S3AsyncClient s3AsyncClient; + private MultipartDownloadTestUtils util; + private FileSystem fileSystem; + private Path testFile; + + @BeforeEach + public void init(WireMockRuntimeInfo wiremock) throws Exception { + s3AsyncClient = S3AsyncClient.builder() + .credentialsProvider(StaticCredentialsProvider.create( + AwsBasicCredentials.create("key", "secret"))) + .region(Region.US_WEST_2) + .endpointOverride(URI.create("http://localhost:" + wiremock.getHttpPort())) + .multipartEnabled(true) + .serviceConfiguration(S3Configuration.builder() + .pathStyleAccessEnabled(true) + .build()) + .build(); + util = new MultipartDownloadTestUtils(testBucket, testKey, "test-etag"); + fileSystem = Jimfs.newFileSystem(Configuration.unix()); + testFile = fileSystem.getPath("test-file.txt"); + } + + @AfterEach + void tearDown() throws Exception { + fileSystem.close(); + } + + @Test + void happyPath_singlePart() throws Exception { + int partSize = 1024; + byte[] expectedBody = util.stubSinglePart(testBucket, testKey, partSize); + + CompletableFuture response = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + + assertThat(response).succeedsWithin(Duration.of(10, ChronoUnit.SECONDS)); + assertThat(Files.exists(testFile)).isTrue(); + byte[] actualBody = Files.readAllBytes(testFile); + assertThat(actualBody).isEqualTo(expectedBody); + assertThat(response).isNotNull(); + util.verifyCorrectAmountOfRequestsMade(1); + } + + @Test + void happyPath_multipart() throws Exception { + int numParts = 4; + int partSize = 1024; + byte[] expectedBody = util.stubAllParts(testBucket, testKey, numParts, partSize); + + CompletableFuture response = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + + assertThat(response).succeedsWithin(Duration.of(10, ChronoUnit.SECONDS)); + byte[] actualBody = Files.readAllBytes(testFile); + assertThat(actualBody).isEqualTo(expectedBody); + assertThat(response).isNotNull(); + util.verifyCorrectAmountOfRequestsMade(numParts); + } + + @Test + void errorOnFirstPart_nonRetryable() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))).willReturn( + aResponse() + .withStatus(403) + .withBody("AccessDeniedTest: Access denied!"))); + + CompletableFuture resp = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + assertThat(resp).failsWithin(Duration.of(10, ChronoUnit.SECONDS)) + .withThrowableOfType(ExecutionException.class) + .withCauseInstanceOf(S3Exception.class) + .withMessageContaining("Test: Access denied!"); + verify(exactly(1), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey)))); + } + + @Test + void errorOnFirstPart_retryable() { + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("Started") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 1")) + .willSetStateTo("retry1")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry1") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 2")) + .willSetStateTo("retry2")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry2") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 3")) + .willSetStateTo("retry3")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry3") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 4"))); + + CompletableFuture resp = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + assertThat(resp).failsWithin(Duration.of(10, ChronoUnit.SECONDS)) + .withThrowableOfType(ExecutionException.class) + .withCauseInstanceOf(S3Exception.class) + .withMessageContaining("Internal error 4"); + verify(exactly(4), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey)))); + } + + @Test + void errorOnMiddlePart_nonRetryable() { + util.stubForPart(testBucket, testKey, 1, 3, 1024); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))).willReturn( + aResponse() + .withStatus(403) + .withBody("AccessDeniedTest: Access denied!"))); + util.stubForPart(testBucket, testKey, 3, 3, 1024); + + CompletableFuture resp = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + + assertThat(resp).failsWithin(Duration.of(10, ChronoUnit.SECONDS)) + .withThrowableOfType(ExecutionException.class) + .withCauseInstanceOf(S3Exception.class) + .withMessageContaining("Test: Access denied!"); + verify(exactly(1), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey)))); + verify(exactly(1), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey)))); + } + + @Test + void errorOnMiddlePart_retryable() { + util.stubForPart(testBucket, testKey, 1, 3, 1024); + util.stubForPart(testBucket, testKey, 3, 3, 1024); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("Started") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 1")) + .willSetStateTo("retry1")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry1") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 2")) + .willSetStateTo("retry2")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry2") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 3")) + .willSetStateTo("retry3") + ); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry3") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 4"))); + + CompletableFuture resp = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + + assertThat(resp).failsWithin(Duration.of(10, ChronoUnit.SECONDS)) + .withThrowableOfType(ExecutionException.class) + .withCauseInstanceOf(S3Exception.class) + .withMessageContaining("Internal error 4"); + + verify(exactly(1), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey)))); + verify(exactly(4), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey)))); + } + + @Test + void errorOnFirstPart_retryable_thenSucceeds() throws Exception { + int partSize = 1024; + int totalPart = 3; + Random random = new Random(); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("Started") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 1")) + .willSetStateTo("retry1")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry1") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 2")) + .willSetStateTo("retry2")); + + byte[] part1Data = new byte[partSize]; + random.nextBytes(part1Data); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, 1))) + .inScenario("retry") + .whenScenarioStateIs("retry2") + .willReturn(aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("x-amz-content-range", contentRangeHeader(1, totalPart, partSize)) + .withHeader("ETag", "test-etag") + .withBody(part1Data))); + + byte[] part2Data = new byte[partSize]; + random.nextBytes(part2Data); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, 2))) + .inScenario("retry") + .whenScenarioStateIs("retry2") + .willReturn(aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("x-amz-content-range", contentRangeHeader(2, totalPart, partSize)) + .withHeader("ETag", "test-etag") + .withBody(part2Data))); + + byte[] part3Data = new byte[partSize]; + random.nextBytes(part3Data); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, 3))) + .inScenario("retry") + .whenScenarioStateIs("retry2") + .willReturn(aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("x-amz-content-range", contentRangeHeader(3, totalPart, partSize)) + .withHeader("ETag", "test-etag") + .withBody(part3Data))); + + + } + + @Test + void errorOnMiddlePart_retryable_thenSucceeds() throws Exception { + int partSize = 1024; + int totalPart = 3; + Random random = new Random(); + byte[] part1Data = util.stubForPart(testBucket, testKey, 1, totalPart, partSize); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("Started") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 1")) + .willSetStateTo("retry1")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry1") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 2")) + .willSetStateTo("retry2")); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey))) + .inScenario("retry") + .whenScenarioStateIs("retry2") + .willReturn(aResponse() + .withStatus(500) + .withBody("InternalErrorInternal error 3")) + .willSetStateTo("retry3") + ); + + byte[] part2Data = new byte[partSize]; + random.nextBytes(part2Data); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, 2))) + .inScenario("retry") + .whenScenarioStateIs("retry3") + .willReturn(aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("x-amz-content-range", contentRangeHeader(2, totalPart, partSize)) + .withHeader("ETag", "test-etag") + .withBody(part2Data))); + + byte[] part3Data = new byte[partSize]; + random.nextBytes(part3Data); + + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, 3))) + .willReturn(aResponse() + .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("x-amz-content-range", contentRangeHeader(3, totalPart, partSize)) + .withHeader("ETag", "test-etag") + .withBody(part3Data))); + + CompletableFuture resp = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + + assertThat(resp).succeedsWithin(Duration.of(10, ChronoUnit.SECONDS)); + + verify(exactly(1), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey)))); + verify(exactly(4), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=2", testBucket, testKey)))); + verify(exactly(1), getRequestedFor(urlEqualTo(String.format("/%s/%s?partNumber=3", testBucket, testKey)))); + + assertThat(Files.exists(testFile)).isTrue(); + byte[] actualBody = Files.readAllBytes(testFile); + byte[] expectedBody = new byte[partSize * totalPart]; + System.arraycopy(part1Data, 0, expectedBody, 0, partSize); + System.arraycopy(part2Data, 0, expectedBody, partSize, partSize); + System.arraycopy(part3Data, 0, expectedBody, partSize * 2, partSize); + assertThat(actualBody).isEqualTo(expectedBody); + + } + + @Test + void veryHighPartCount_shouldSucceed() throws Exception { + int numParts = 5000; + int partSize = 100; + + byte[] expectedBody = util.stubAllParts(testBucket, testKey, numParts, partSize); + + CompletableFuture response = s3AsyncClient.getObject(b -> b + .bucket(testBucket) + .key(testKey) + .build(), + AsyncResponseTransformer.toFile(testFile)); + + assertThat(response).succeedsWithin(Duration.of(5, ChronoUnit.MINUTES)); + response.join(); + byte[] actualBody = Files.readAllBytes(testFile); + assertThat(actualBody).isEqualTo(expectedBody); + assertThat(response).isNotNull(); + util.verifyCorrectAmountOfRequestsMade(numParts); + + } +} \ No newline at end of file diff --git a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java index e12a7cee35a0..de7e442cea65 100644 --- a/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java +++ b/services/s3/src/test/java/software/amazon/awssdk/services/s3/internal/multipart/utils/MultipartDownloadTestUtils.java @@ -59,18 +59,26 @@ public byte[] stubAllParts(String testBucket, String testKey, int amountOfPartTo return expectedBody; } - public byte[] stubForPart(String testBucket, String testKey,int part, int totalPart, int partSize) { + public byte[] stubForPart(String testBucket, String testKey, int part, int totalPart, int partSize) { byte[] body = new byte[partSize]; random.nextBytes(body); stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=%d", testBucket, testKey, part))).willReturn( aResponse() .withHeader("x-amz-mp-parts-count", totalPart + "") + .withHeader("x-amz-content-range", contentRangeHeader(part, totalPart, partSize)) .withHeader("ETag", eTag) .withHeader("Content-Length", String.valueOf(body.length)) .withBody(body))); return body; } + public static String contentRangeHeader(int part, int totalPart, int partSize) { + long start = (part - 1) * (long) partSize; + long end = start + partSize - 1; + long total = totalPart * (long) partSize; + return String.format("bytes %d-%d/%d", start, end, total); + } + public void verifyCorrectAmountOfRequestsMade(int amountOfPartToTest) { String urlTemplate = ".*partNumber=%d.*"; for (int i = 1; i <= amountOfPartToTest; i++) { @@ -94,4 +102,16 @@ public static String internalErrorBody() { public static String slowdownErrorBody() { return errorBody("SlowDown", "Please reduce your request rate."); } + + public byte[] stubSinglePart(String testBucket, String testKey, int partSize) { + byte[] body = new byte[partSize]; + random.nextBytes(body); + stubFor(get(urlEqualTo(String.format("/%s/%s?partNumber=1", testBucket, testKey))).willReturn( + aResponse() + .withHeader("x-amz-mp-parts-count", "1") + .withHeader("x-amz-content-range", String.format("bytes %d-%d/%d", 0, partSize - 1, partSize)) + .withHeader("ETag", eTag) + .withBody(body))); + return body; + } } diff --git a/test/architecture-tests/archunit_store/4195d6e3-8849-4e5a-848d-04f810577cd3 b/test/architecture-tests/archunit_store/4195d6e3-8849-4e5a-848d-04f810577cd3 index dd81179f8903..abf18b7d902b 100644 --- a/test/architecture-tests/archunit_store/4195d6e3-8849-4e5a-848d-04f810577cd3 +++ b/test/architecture-tests/archunit_store/4195d6e3-8849-4e5a-848d-04f810577cd3 @@ -44,7 +44,7 @@ Method calls method in (ReceiveSqsMessageHelper.java:132) Method calls method in (AsyncBufferingSubscriber.java:67) Method calls method in (PauseResumeHelper.java:59) -Method calls method in (ContentRangeParser.java:71) +Method calls method in (ContentRangeParser.java:71) Method calls method in (LoggingTransferListener.java:76) Method calls method in (Logger.java:205) Method calls method in (AddingTrailingDataSubscriber.java:73) diff --git a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParser.java b/utils/src/main/java/software/amazon/awssdk/utils/ContentRangeParser.java similarity index 58% rename from services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParser.java rename to utils/src/main/java/software/amazon/awssdk/utils/ContentRangeParser.java index 03e67c402a56..3b4d2433c91b 100644 --- a/services-custom/s3-transfer-manager/src/main/java/software/amazon/awssdk/transfer/s3/internal/progress/ContentRangeParser.java +++ b/utils/src/main/java/software/amazon/awssdk/utils/ContentRangeParser.java @@ -13,12 +13,11 @@ * permissions and limitations under the License. */ -package software.amazon.awssdk.transfer.s3.internal.progress; +package software.amazon.awssdk.utils; +import java.util.Optional; import java.util.OptionalLong; -import software.amazon.awssdk.annotations.SdkInternalApi; -import software.amazon.awssdk.utils.Logger; -import software.amazon.awssdk.utils.StringUtils; +import software.amazon.awssdk.annotations.SdkProtectedApi; /** * Parse a Content-Range header value into a total byte count. The expected format is the following:

@@ -27,7 +26,7 @@ *

* The only supported {@code } is the {@code bytes} value. */ -@SdkInternalApi +@SdkProtectedApi public final class ContentRangeParser { private static final Logger log = Logger.loggerFor(ContentRangeParser.class); @@ -72,4 +71,47 @@ public static OptionalLong totalBytes(String contentRange) { return OptionalLong.empty(); } } + + /** + * Parse the Content-Range to extract the byte range from the content. Only supports the {@code bytes} unit, any + * other unit will result in an empty OptionalLong. If byte range in unknown, which is represented by a {@code *} symbol + * in the header value, an empty OptionalLong will be returned. + * + * @param contentRange the value of the Content-Range header to be parsed. + * @return The total number of bytes in the content range or an empty optional if the contentRange is null, empty or if the + * total length is not a valid long. + */ + public static Optional> range(String contentRange) { + if (StringUtils.isEmpty(contentRange)) { + return Optional.empty(); + } + + String trimmed = contentRange.trim(); + if (!trimmed.startsWith("bytes ")) { + return Optional.empty(); + } + String withoutBytes = trimmed.substring("bytes ".length()); + if (withoutBytes.startsWith("*")) { + return Optional.empty(); + } + int hyphen = withoutBytes.indexOf('-'); + if (hyphen == -1) { + return Optional.empty(); + } + String begin = withoutBytes.substring(0, hyphen); + int slash = withoutBytes.indexOf('/'); + if (slash == -1) { + return Optional.empty(); + } + String end = withoutBytes.substring(hyphen + 1, slash); + try { + long startInt = Long.parseLong(begin); + long endInt = Long.parseLong(end); + return Optional.of(Pair.of(startInt, endInt)); + } catch (Exception e) { + log.debug(() -> "failed to parse content range", e); + return Optional.empty(); + } + } + }